# DINO ViT-S/16 Feature Extraction on CIFAR-100

In [None]:
import torch
import torch.nn as nn
from typing import cast
from torchvision import transforms
from torchvision.datasets import CIFAR100

# Load DINO ViT-S/16 pre-trained from torch.hub

dino_model = cast(
    nn.Module,
    torch.hub.load("facebookresearch/dino:main", "dino_vits16", pretrained=True),
)
dino_model.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dino_model.to(device=device)

# Use the preprocess defined in the previous cell
# Make sure the dataset uses the correct preprocess
preprocess = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        # in federated training, we should consider to use mean and std of the cifar100
        # these are the parameters on which dino was trained
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

train_dataset = CIFAR100(root="./data", train=True, download=True, transform=preprocess)
test_dataset = CIFAR100(root="./data", train=False, download=True, transform=preprocess)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)


# Function to extract features from a dataloader
def extract_features_and_labels(dataloader, model, device):
    all_features = []
    all_labels = []
    with torch.no_grad():
        total_batches = len(dataloader)
        for batch_idx, (images, labels) in enumerate(dataloader):
            images = images.to(device)
            # Get features from the backbone (without the classification head)
            features = model(images)
            all_features.append(features.cpu())
            all_labels.append(labels.cpu())
            if (batch_idx + 1) % 10 == 0 or (batch_idx + 1) == total_batches:
                print(
                    f"Batch {batch_idx + 1}/{total_batches} ({(batch_idx + 1) / total_batches:.1%}) completed"
                )
    all_features = torch.cat(all_features, dim=0)
    all_labels = torch.cat(all_labels, dim=0)
    return all_features, all_labels


# Extract features and labels for train
train_features, train_labels = extract_features_and_labels(
    train_loader, dino_model, device
)
torch.save(
    {"features": train_features, "labels": train_labels},
    "features/train_features.pt",
)

# Extract features and labels for test
test_features, test_labels = extract_features_and_labels(
    test_loader, dino_model, device
)
torch.save(
    {"features": test_features, "labels": test_labels}, "features/test_features.pt"
)

Using cache found in /home/einrich99/.cache/torch/hub/facebookresearch_dino_main


Batch 10/1563 (0.6%) completed
Batch 20/1563 (1.3%) completed
Batch 30/1563 (1.9%) completed
Batch 40/1563 (2.6%) completed
Batch 50/1563 (3.2%) completed
Batch 60/1563 (3.8%) completed
Batch 70/1563 (4.5%) completed
Batch 80/1563 (5.1%) completed
Batch 90/1563 (5.8%) completed
Batch 100/1563 (6.4%) completed
Batch 110/1563 (7.0%) completed
Batch 120/1563 (7.7%) completed
Batch 130/1563 (8.3%) completed
Batch 140/1563 (9.0%) completed
Batch 150/1563 (9.6%) completed
Batch 160/1563 (10.2%) completed
Batch 170/1563 (10.9%) completed
Batch 180/1563 (11.5%) completed
Batch 190/1563 (12.2%) completed
Batch 200/1563 (12.8%) completed
Batch 210/1563 (13.4%) completed
Batch 220/1563 (14.1%) completed
Batch 230/1563 (14.7%) completed
Batch 240/1563 (15.4%) completed
Batch 250/1563 (16.0%) completed
Batch 260/1563 (16.6%) completed
Batch 270/1563 (17.3%) completed
Batch 280/1563 (17.9%) completed
Batch 290/1563 (18.6%) completed
Batch 300/1563 (19.2%) completed
Batch 310/1563 (19.8%) completed
B

# Linear Classifier Training on DINO Features

This section describes the process of training a linear classifier on top of precomputed DINO ViT-S/16 features extracted from the CIFAR-100 dataset. The classifier is trained using early stopping and evaluated on the test split to monitor performance.

In [1]:
from typing import Optional, cast
import torch
import torch.nn as nn
from torchvision import transforms
from tqdm.notebook import tqdm

# Training parameters
num_epochs = 10000
batch_size = 10000
test_batch_size = 1000
learning_rate = 1e-3

# Early stopping parameters
best_acc = 0
patience_counter = 0
best_model_state = {}
patience = 10

dino_pretrained = cast(
    nn.Module,
    torch.hub.load("facebookresearch/dino:main", "dino_vits16", pretrained=True),
)


class CustomDino(nn.Module):
    def __init__(self, num_classes: int = 100, backbone: Optional[nn.Module] = None):
        super().__init__()
        if backbone is None:
            # Carica DINO senza pretrained e rimuove la head
            backbone = cast(
                nn.Module,
                torch.hub.load(
                    "facebookresearch/dino:main", "dino_vits16", pretrained=False
                ),
            )
        self.backbone: nn.Module = backbone
        self.classifier = nn.Linear(
            384, num_classes
        )  # 384 = output CLS token DINO ViT-S/16

    def forward(self, x: torch.Tensor):
        features = self.backbone(x)  # [batch, 384]
        logits = self.classifier(features)  # [batch, num_classes]
        return logits  # , features


model = CustomDino(num_classes=100, backbone=dino_pretrained)

# Example preprocessing for an input image
preprocess = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        # in federated training, we should consider to use mean and std of the cifar100
        # these are the parameters on which dino was trained
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

# Load precomputed features and labels
train_data = torch.load("features/train_features.pt")
test_data = torch.load("features/test_features.pt")

train_features, train_labels = train_data["features"], train_data["labels"]
test_features, test_labels = test_data["features"], test_data["labels"]

# Create TensorDatasets and DataLoaders from features
train_dataset = torch.utils.data.TensorDataset(train_features, train_labels)
test_dataset = torch.utils.data.TensorDataset(test_features, test_labels)

# DataLoader
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True
)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=test_batch_size)

# Optimizer for the head
optimizer = torch.optim.Adam(model.classifier.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# -----------------------------------
# Train only the linear head for fun startin from the features
# -----------------------------------

complete_model = model
model = model.classifier
model.to(device=device)
best_model = model

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    # Create a epoch-level progress bar and update it per-batch
    epoch_desc = f"Epoch {epoch+1} Training"
    with tqdm(total=len(train_loader), desc=epoch_desc, leave=True) as progress:
        for batch_idx, (features, labels) in enumerate(train_loader):
            # Move tensors to device
            features, labels = features.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(features)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * features.size(0)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            # Update the progress bar with running metrics
            batch_loss = running_loss / total if total > 0 else 0
            batch_acc = correct / total if total > 0 else 0
            progress.set_postfix(
                {"loss": f"{batch_loss:.4f}", "acc": f"{batch_acc:.4f}"}
            )
            progress.update(1)

    epoch_loss = running_loss / total if total > 0 else 0
    epoch_acc = correct / total if total > 0 else 0

    # Evaluation on test set
    model.eval()
    test_loss = 0.0
    test_acc = 0.0
    test_correct = 0
    test_total = 0

    # Create a per-epoch evaluation progress bar
    eval_desc = f"Epoch {epoch+1} Evaluation"
    with torch.no_grad():
        with tqdm(total=len(test_loader), desc=eval_desc, leave=True) as test_progress:
            for features, labels in test_loader:
                features, labels = features.to(device), labels.to(device)
                outputs = model(features)
                loss = criterion(outputs, labels)
                test_loss += loss.item() * features.size(0)
                _, predicted = outputs.max(1)
                test_total += labels.size(0)
                test_correct += predicted.eq(labels).sum().item()
                test_progress.update(1)

            test_loss = test_loss / test_total if test_total > 0 else 0
            test_acc = test_correct / test_total if test_total > 0 else 0

            test_progress.set_postfix(
                {
                    "loss": f"{test_loss / test_total if test_total > 0 else 0:.4f}",
                    "acc": f"{test_acc:.4f}",
                }
            )
            test_progress.update(1)

    # Early stopping
    if epoch == 0:
        best_acc = test_acc
        if patience == 0:
            print("Default patience = 3")
            patience = 3
        patience_counter = 0
        best_model_state = model.state_dict()
    else:
        if test_acc > best_acc:
            best_acc = test_acc
            patience_counter = 0
            best_model_state = model.state_dict()
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping triggered.")
                model.load_state_dict(best_model_state)
                break

torch.save(best_model_state, "./linear_classifier.pth")

print(
    f"Best model statistics:\nAccuracy: {best_acc:.4f}\nPatience reached: {patience_counter}\nModel state dict keys: {list(best_model_state.keys())}"
)

Using cache found in /home/einrich99/.cache/torch/hub/facebookresearch_dino_main


Epoch 1 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 1 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 2 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 2 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 3 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 3 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 4 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 4 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 5 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 5 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 6 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 6 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 7 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 7 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 8 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 8 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 9 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 9 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 10 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 10 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 11 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 11 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 12 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 12 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 13 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 13 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 14 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 14 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 15 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 15 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 16 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 16 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 17 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 17 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 18 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 18 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 19 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 19 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 20 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 20 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 21 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 21 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 22 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 22 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 23 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 23 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 24 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 24 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 25 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 25 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 26 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 26 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 27 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 27 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 28 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 28 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 29 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 29 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 30 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 30 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 31 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 31 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 32 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 32 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 33 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 33 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 34 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 34 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 35 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 35 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 36 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 36 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 37 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 37 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 38 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 38 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 39 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 39 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 40 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 40 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 41 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 41 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 42 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 42 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 43 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 43 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 44 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 44 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 45 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 45 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 46 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 46 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 47 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 47 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 48 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 48 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 49 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 49 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 50 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 50 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 51 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 51 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 52 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 52 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 53 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 53 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 54 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 54 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 55 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 55 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 56 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 56 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 57 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 57 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 58 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 58 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 59 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 59 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 60 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 60 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 61 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 61 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 62 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 62 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 63 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 63 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 64 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 64 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 65 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 65 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 66 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 66 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 67 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 67 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 68 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 68 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 69 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 69 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 70 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 70 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 71 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 71 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 72 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 72 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 73 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 73 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 74 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 74 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 75 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 75 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 76 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 76 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 77 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 77 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 78 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 78 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 79 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 79 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 80 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 80 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 81 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 81 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 82 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 82 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 83 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 83 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 84 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 84 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 85 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 85 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 86 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 86 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 87 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 87 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 88 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 88 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 89 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 89 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 90 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 90 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 91 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 91 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 92 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 92 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 93 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 93 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 94 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 94 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 95 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 95 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 96 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 96 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 97 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 97 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 98 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 98 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 99 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 99 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 100 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 100 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 101 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 101 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 102 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 102 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 103 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 103 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 104 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 104 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 105 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 105 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 106 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 106 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 107 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 107 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 108 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 108 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 109 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 109 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 110 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 110 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 111 Training:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 111 Evaluation:   0%|          | 0/10 [00:00<?, ?it/s]

Early stopping triggered.
Best model statistics:
Accuracy: 0.7719
Patience reached: 10
Model state dict keys: ['weight', 'bias']


# Class Centroid Computation and Exemplar Selection

This cell computes normalized class centroids from DINO features by randomly selecting a fixed number of exemplars per class. The centroids are saved for use in downstream tasks such as nearest centroid classification or incremental learning.

In [2]:
import torch
import torch.nn.functional as F

# Configuration
num_exemplars = 100

# Load precomputed features
train_data = torch.load("features/train_features.pt")
train_features = train_data["features"].cpu()
train_labels = train_data["labels"].cpu()

# Compute normalized mean vector (centroid) for each class
unique_classes = torch.unique(train_labels).tolist()
centroids = {}
rng = torch.Generator().manual_seed(42)

for cls in unique_classes:
    # Extract all features for this class
    class_mask = train_labels == cls
    class_features = train_features[class_mask]

    # Select num_exemplars randomly
    n = class_features.size(0)
    k = min(num_exemplars, n)
    indices = torch.randperm(n, generator=rng)[:k]
    exemplars = class_features[indices]

    # Compute normalized mean (centroid)
    centroid = F.normalize(exemplars.mean(dim=0, keepdim=True), p=2, dim=1).squeeze(0)
    centroids[cls] = centroid

# Save to disk
torch.save(
    {"class": unique_classes, "centroid": centroids},
    "./class_centroids.pth",
)

## Nearest Centroid Classification Evaluation

In [None]:
import torch
import torch.nn.functional as F


def nearest_neighbor(class_centroids, feature_vector):
    """
    Find the nearest class by computing distances to all centroids.

    Args:
        class_centroids: dict {class_label: centroid_tensor[384]}
        feature_vector: torch.Tensor of shape [384]

    Returns:
        predicted_class: int
    """
    min_distance = float("inf")
    predicted_class = None

    for cls, centroid in class_centroids.items():
        # Compute Euclidean distance
        distance = torch.linalg.vector_norm(feature_vector - centroid).item()

        if distance < min_distance:
            min_distance = distance
            predicted_class = cls

    return predicted_class

In [4]:
import torch
from tqdm.notebook import tqdm

# Load test features
test_data = torch.load("features/test_features.pt")
test_features = test_data["features"].cpu()
test_labels = test_data["labels"].cpu()

# Load class centroids
centroids_data = torch.load("./class_centroids.pth")
centroids = centroids_data["centroid"]
centroid_labels = centroids_data["class"]

# Evaluate accuracy using nearest neighbor
correct = 0
total = test_features.size(0)

for i in tqdm(range(total), desc="Nearest Centroid Evaluation"):
    pred_idx = nearest_neighbor(centroids, test_features[i])
    pred = centroid_labels[pred_idx]
    if pred == test_labels[i].item():
        correct += 1

accuracy = 100.0 * correct / total
print(f"Nearest Neighbor Accuracy: {accuracy:.2f}% ({correct}/{total})")

Nearest Centroid Evaluation:   0%|          | 0/10000 [00:00<?, ?it/s]

Nearest Neighbor Accuracy: 60.46% (6046/10000)


# Convert the centroids in linear layers 

In [None]:
# TODO