# DINO ViT-S/16 Feature Extraction on CIFAR-100

## Compute feature vectors

In [None]:
import torch
import torch.nn as nn
from typing import cast
from torchvision import transforms
from torchvision.datasets import CIFAR100
import os

In [None]:
# Load DINO ViT-S/16 pre-trained from torch.hub

dino_model = cast(
    nn.Module,
    torch.hub.load("facebookresearch/dino:main", "dino_vits16", pretrained=True),
)
dino_model.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dino_model.to(device=device)

# Make sure the dataset uses the correct preprocess
preprocess = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        ### Normalization ImageNet
        # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ### Normalization Cifar100
        transforms.Normalize(
            mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761]
        ),
    ]
)

train_dataset = CIFAR100(root="./data", train=True, download=True, transform=preprocess)
test_dataset = CIFAR100(root="./data", train=False, download=True, transform=preprocess)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)


# Function to extract features from a dataloader
def extract_features_and_labels(dataloader, model, device):
    all_features = []
    all_labels = []
    with torch.no_grad():
        total_batches = len(dataloader)
        for batch_idx, (images, labels) in enumerate(dataloader):
            images = images.to(device)
            # Get features from the backbone (without the classification head)
            features = model(images)
            all_features.append(features.cpu())
            all_labels.append(labels.cpu())
            if (batch_idx + 1) % 10 == 0 or (batch_idx + 1) == total_batches:
                print(
                    f"Batch {batch_idx + 1}/{total_batches} ({(batch_idx + 1) / total_batches:.1%}) completed"
                )
    all_features = torch.cat(all_features, dim=0)
    all_labels = torch.cat(all_labels, dim=0)
    return all_features, all_labels


# Create the directory if it doesn't exist
os.makedirs("features", exist_ok=True)

# Extract features and labels for train (only if file doesn't exist)
if os.path.exists("features/train_features.pt"):
    print("Train features already exist. Skipping computation.")
else:
    print("Computing train features...")
    train_features, train_labels = extract_features_and_labels(
        train_loader, dino_model, device
    )
    torch.save(
        {"features": train_features, "labels": train_labels},
        "features/train_features.pt",
    )
    print("Train features saved.")

# Extract features and labels for test (only if file doesn't exist)
if os.path.exists("features/test_features.pt"):
    print("Test features already exist. Skipping computation.")
else:
    print("Computing test features...")
    test_features, test_labels = extract_features_and_labels(
        test_loader, dino_model, device
    )
    torch.save(
        {"features": test_features, "labels": test_labels}, "features/test_features.pt"
    )
    print("Test features saved.")

## Custom Dino

In [None]:
from typing import Optional, cast
from torch import nn
import torch
from pathlib import Path


class CustomDino(nn.Module):
    def __init__(
        self,
        num_classes: int = 100,
        backbone: Optional[nn.Module] = None,
        frozen_head: nn.Module = None,
    ):
        super().__init__()
        if backbone is None:
            backbone = cast(
                nn.Module,
                torch.hub.load(
                    "facebookresearch/dino:main", "dino_vits16", pretrained=True
                ),
            )
        self.backbone: nn.Module = backbone
        self.num_classes = num_classes

        # Attach the head
        if frozen_head is not None:
            self.head = frozen_head
        else:
            self.head = nn.Linear(384, num_classes)

    def forward(self, x: torch.Tensor):
        features = self.backbone(x)

        ### Normalize features so the Linear layer acts as Cosine Similarity
        ###features = torch.nn.functional.normalize(features, p=2, dim=1)

        logits = self.head(features)

        return logits

## Compute Centroids function


In [None]:
import torch.nn.functional as F


def compute_centroid(class_features, num_exemplars=None, rng_seed=42):
    """
    Compute centroid using random exemplar selection.
    If num_exemplars is None, use all samples.

    Steps (following iCaRL paper):
    1. Normalize each feature vector (L2 norm)
    2. Compute mean of normalized features
    3. Normalize the resulting centroid
    """
    n = class_features.size(0)

    if num_exemplars is None or num_exemplars >= n:
        # Use all samples
        exemplars = class_features
    else:
        # Select num_exemplars randomly
        rng = torch.Generator().manual_seed(rng_seed)
        indices = torch.randperm(n, generator=rng)[:num_exemplars]
        exemplars = class_features[indices]

    # Step 1: Normalize each feature vector
    exemplars_normalized = F.normalize(exemplars, p=2, dim=1)

    # Step 2: Compute mean of normalized features
    mean_features = exemplars_normalized.mean(dim=0)

    # Step 3: Normalize the centroid
    centroid = F.normalize(mean_features.unsqueeze(0), p=2, dim=1).squeeze(0)

    return centroid

In [None]:
%pip install wandb --quiet
import wandb

# Opzionale: fai il login subito se non l'hai già fatto
wandb.login()

  | |_| | '_ \/ _` / _` |  _/ -_)
[34m[1mwandb[0m: (1) Create a W&B account
[34m[1mwandb[0m: (2) Use an existing W&B account
[34m[1mwandb[0m: (3) Don't visualize my results
[34m[1mwandb[0m: Enter your choice:

 2


[34m[1mwandb[0m: You chose 'Use an existing W&B account'
[34m[1mwandb[0m: Logging into https://api.wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: Find your API key here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mdaviderandino[0m ([33mdaviderandino-politecnico-di-torino[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [None]:
from torch.nn import Module


def save_checkpoint_to_wandb(
    run: wandb.Run,
    checkpoint: dict,
    filename: str = "model.pth",
    metadata: dict = None,
) -> None:
    """Save a PyTorch model to WandB as an artifact."""

    # 1. Save model locally
    torch.save(checkpoint, filename)

    # 2. Create artifact
    artifact_name = (
        f"{run.group}-checkpoints"
        if hasattr(run, "group") and run.group
        else "checkpoints"
    )
    artifact = wandb.Artifact(name=artifact_name, type="model", metadata=metadata or {})

    artifact.add_file(filename)

    # 3. Log artifact to the existing run
    run.log_artifact(artifact)

    print(f"Model saved to WandB as artifact '{artifact_name}'.")


def load_checkpoint_from_wandb(
    run: wandb.Run, model: Module, filename: str = "model.pth", version: str = "latest"
) -> tuple[dict, wandb.Artifact] | None:
    """Download the latest model artifact and load it into `model`."""
    try:
        artifact_name = (
            f"{run.group}-checkpoints"
            if hasattr(run, "group") and run.group
            else "checkpoints"
        )
        artifact = run.use_artifact(f"{artifact_name}:{version}", type="model")
        artifact_dir = artifact.download()
        model_path = Path(artifact_dir) / filename
        print(f"Loading model from: {model_path}")
        checkpoint = torch.load(
            model_path,
            model.device if hasattr(model, "device") else None,
            weights_only=False,
        )
        print(f"Successfully loaded model from: {model_path}")
        return checkpoint, artifact
    except Exception as e:
        print(e)
        print(f"Model checkpoint not found on WandB. {e}")
        return None


def save_centroids_to_wandb(
    centroids: torch.Tensor,
    entity: str = "aml-fl-project",
    project: str = "fl-task-arithmetic",
    artifact_name: str = "nearest_centroid_classifier",
    filename: str = "nearest_centroid_classifier.pth",
    metadata: dict = None,
) -> None:
    """Save centroids to WandB as nearest_centroid_classifier artifact (same as model.py)."""

    # Create state dict with weight tensor (matching Linear layer format)
    state_dict = {
        "weight": centroids,  # Shape: (100, 384)
        "bias": torch.zeros(centroids.size(0)),  # Shape: (100,)
    }

    torch.save(state_dict, filename)

    # Use wandb.Api() to log artifact (same approach as model.py for consistency)
    api = wandb.Api()
    artifact = wandb.Artifact(
        name=artifact_name,
        type="model",  # Using "model" type like the loading function expects
        metadata=metadata or {},
    )
    artifact.add_file(filename)

    # Log using current run if available, otherwise need explicit run context
    if wandb.run is not None:
        wandb.run.log_artifact(artifact)
        print(f"Centroids saved to WandB as artifact '{artifact_name}'.")
    else:
        print("Warning: No active wandb run. Artifact not logged.")
        print(f"Please log artifact manually or within a wandb.init() context.")


def load_centroids_from_wandb(
    entity: str = "aml-fl-project",
    project: str = "fl-task-arithmetic",
    artifact_name: str = "nearest_centroid_classifier",
    version: str = "latest",
) -> torch.Tensor | None:
    """Download nearest_centroid_classifier artifact from WandB (same as model.py)."""
    try:
        api = wandb.Api()
        artifact_path = f"{entity}/{project}/{artifact_name}:{version}"
        print(f"Loading centroids from W&B: {artifact_path}")
        artifact = api.artifact(artifact_path)
        artifact_dir = artifact.download()

        pth_files = list(Path(artifact_dir).glob("*.pth"))
        if not pth_files:
            raise FileNotFoundError(f"No .pth file found in artifact {artifact_path}")

        centroids_path = pth_files[0]
        print(f"Loading centroids from: {centroids_path}")

        # Load the state dict and extract weights as centroids
        state_dict = torch.load(centroids_path, weights_only=False)
        centroids = state_dict["weight"]  # Extract weight tensor (100, 384)

        print(f"Successfully loaded centroids with shape: {centroids.shape}")
        return centroids
    except Exception as e:
        print(f"Centroids not found on WandB: {e}")
        return None

## Train

In [None]:
BATCH_SIZE = 64
LR = 1e-4  # FIX: Changed from 10e-4 (0.001) to 1e-4 (0.0001)
MOMENTUM = 0.9
WEIGHT_DECAY = 5e-4
EPOCHS = 30
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
PATIENCE = 5

ENTITY = "aml-fl-project"
PROJECT = "fl-task-arithmetic"
GROUP = "baseline_Test_TA_changes"
NAME = f"centralized-dino-icarl-cifar100-lr{LR}-mom{MOMENTUM}-wd{WEIGHT_DECAY}"
RUN_ID = f"Test-TA-help"

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import Compose, Normalize, ToTensor, Resize, CenterCrop
from tqdm import tqdm
import wandb


# Standard CIFAR-100 Normalization
stats = ((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))

# Transforms
transform_train = Compose(
    [
        Resize(256),
        CenterCrop(224),  # Required for DINO
        # transforms.RandomHorizontalFlip(), # Optional augmentation
        ToTensor(),
        Normalize(*stats),
    ]
)

transform_test = Compose(
    [
        Resize(256),
        CenterCrop(224),
        ToTensor(),
        Normalize(*stats),
    ]
)

# 1. LOAD RAW IMAGES (Required for fine-tuning backbone)
# We use the raw dataset, not the precomputed features, so we can backprop through the last layer.
train_dataset = datasets.CIFAR100(
    root="./data", train=True, download=True, transform=transform_train
)
test_dataset = datasets.CIFAR100(
    root="./data", train=False, download=True, transform=transform_test
)

trainloader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2
)
testloader = DataLoader(
    test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2
)


def get_or_compute_centroids(run: wandb.Run, device):
    """Try to load centroids from WandB, otherwise compute them."""

    # Try loading from WandB first (using same artifact as model.py)
    centroids = load_centroids_from_wandb()
    if centroids is not None:
        print("Loaded centroids from WandB.")
        return centroids.to(device)

    # Compute centroids if not found
    print("Centroids not found on WandB. Computing...")
    train_data = torch.load("features/train_features.pt")
    saved_features, saved_labels = train_data["features"], train_data["labels"]

    num_classes = 100
    feature_dim = saved_features.shape[1]
    centroids = torch.zeros((num_classes, feature_dim))

    for c in range(num_classes):
        class_features = saved_features[saved_labels == c]
        if len(class_features) > 0:
            # Use all samples (num_exemplars=None) like DINO_linear
            centroids[c] = compute_centroid(class_features, num_exemplars=None)

    # Save to WandB for future use (using same artifact name as model.py)
    save_centroids_to_wandb(
        centroids,
        metadata={"num_classes": num_classes, "feature_dim": feature_dim},
    )

    print("Centroids calculated and saved to WandB.")
    return centroids


def train(lr, momentum, weight_decay, epochs):
    best_accuracy = 0.0
    patience_counter = 0
    run = wandb.init(
        entity=ENTITY,
        project=PROJECT,
        group=GROUP,
        name=NAME,
        id=RUN_ID,
        resume="allow",
        mode="online",
    )

    # Get or compute centroids
    centroids = get_or_compute_centroids(run, DEVICE)

    # Construct the fixed linear layer from centroids
    LinearLayer = nn.Linear(384, 100)
    with torch.no_grad():
        LinearLayer.weight.copy_(centroids)
        LinearLayer.bias.zero_()

    # Freeze the classification layer
    for param in LinearLayer.parameters():
        param.requires_grad = False

    # Initialize model with the frozen centroid head
    model = CustomDino(num_classes=100, frozen_head=LinearLayer).to(DEVICE)

    # Load checkpoint (if any)
    checkpoint = load_checkpoint_from_wandb(run, model, "model.pth")
    start_epoch = 0
    if checkpoint is not None:
        checkpoint_dict, artifact = checkpoint
        model.load_state_dict(checkpoint_dict["model"])
        start_epoch = artifact.metadata["epoch"] + 1
        print(f"Resuming from epoch {start_epoch}")
    else:
        print("Starting from scratch")

    criterion = nn.CrossEntropyLoss()
    # NO CLASSIFIER IN OPTIM SINCE WE ARE NOT UPDATING IT
    optimizer = optim.SGD(
        list(model.backbone.parameters()),
        lr=lr,
        momentum=momentum,
        weight_decay=weight_decay,
    )
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)

    for epoch in range(start_epoch, epochs):
        # --- TRAINING ---
        model.train()
        running_loss = 0.0

        pbar = tqdm(trainloader, desc=f"Epoch {epoch+1}/{EPOCHS}")
        for images, labels in pbar:
            images, labels = images.to(DEVICE), labels.to(DEVICE)

            optimizer.zero_grad()

            # Model returns Logits (scores), not features
            outputs = model(images)

            # CrossEntropy expects (Logits, Class_Indices)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            pbar.set_postfix({"loss": loss.item()})

        scheduler.step()
        avg_train_loss = running_loss / len(trainloader)

        # --- EVALUATION ---
        model.eval()
        test_loss = 0.0
        correct = 0
        total = 0

        with torch.no_grad():
            for images, labels in testloader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)

                # 1. Forward Pass
                # The Frozen Linear Layer calculates the similarity to centroids for us.
                outputs = model(images)  # Shape: [Batch_Size, 100]

                # 2. Loss
                loss = criterion(outputs, labels)
                test_loss += loss.item()

                # 3. Accuracy
                # The class with the highest score (dot product) is the nearest centroid.
                # No need for manual torch.cdist calculation.
                _, predicted = outputs.max(1)

                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()

        avg_test_loss = test_loss / len(testloader)
        acc = 100.0 * correct / total
        print(
            f"Epoch {epoch+1} Results: Train Loss: {avg_train_loss:.4f} | Test Loss: {avg_test_loss:.4f} | Test Acc: {acc:.2f}%"
        )

        wandb.log(
            {
                "epoch": epoch + 1,
                "train_loss": avg_train_loss,
                "test_loss": avg_test_loss,
                "test_accuracy": acc,
                "best_accuracy": best_accuracy,
                "learning_rate": optimizer.param_groups[0]["lr"],
            }
        )

        # Checkpointing
        if acc > best_accuracy:
            best_accuracy = acc
            save_checkpoint_to_wandb(
                run,
                {
                    "model": model.state_dict(),
                },
                "model.pth",
                {"task": model, "accuracy": acc, "epoch": epoch},
            )
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter > PATIENCE:
                print("Early stopping triggered.")
                break

        print(epoch, "Saved checkpoint model to WandB.")


train(
    lr=LR,
    momentum=MOMENTUM,
    weight_decay=WEIGHT_DECAY,
    epochs=EPOCHS,
)

Calculating Centroids for initialization...
Centroids calculated.


Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


artifact membership 'baseline-Davide-collab-test-checkpoints:latest' not found in 'aml-fl-project/fl-task-arithmetic'
Model checkpoint not found on WandB. artifact membership 'baseline-Davide-collab-test-checkpoints:latest' not found in 'aml-fl-project/fl-task-arithmetic'
Starting from scratch


Epoch 1/10: 100%|██████████| 782/782 [08:11<00:00,  1.59it/s, loss=1.19]


Epoch 1 Results: Train Loss: 1.7899 | Test Loss: 1.0647 | Test Acc: 68.96%
Model saved to WandB as artifact 'baseline-Davide-collab-test-checkpoints'.
0 Saved checkpoint model to WandB.


Epoch 2/10: 100%|██████████| 782/782 [08:10<00:00,  1.59it/s, loss=0.794]


Epoch 2 Results: Train Loss: 0.7885 | Test Loss: 0.8146 | Test Acc: 75.53%
Model saved to WandB as artifact 'baseline-Davide-collab-test-checkpoints'.
1 Saved checkpoint model to WandB.


Epoch 3/10: 100%|██████████| 782/782 [08:10<00:00,  1.59it/s, loss=0.764]


Epoch 3 Results: Train Loss: 0.5058 | Test Loss: 0.7781 | Test Acc: 77.26%
Model saved to WandB as artifact 'baseline-Davide-collab-test-checkpoints'.
2 Saved checkpoint model to WandB.


Epoch 4/10: 100%|██████████| 782/782 [08:10<00:00,  1.59it/s, loss=0.385]


Epoch 4 Results: Train Loss: 0.3082 | Test Loss: 0.6239 | Test Acc: 81.12%
Model saved to WandB as artifact 'baseline-Davide-collab-test-checkpoints'.
3 Saved checkpoint model to WandB.


Epoch 5/10: 100%|██████████| 782/782 [08:10<00:00,  1.59it/s, loss=0.234]


Epoch 5 Results: Train Loss: 0.1650 | Test Loss: 0.5716 | Test Acc: 83.31%
Model saved to WandB as artifact 'baseline-Davide-collab-test-checkpoints'.
4 Saved checkpoint model to WandB.


Epoch 6/10: 100%|██████████| 782/782 [08:11<00:00,  1.59it/s, loss=0.0493]


Epoch 6 Results: Train Loss: 0.0756 | Test Loss: 0.5314 | Test Acc: 85.26%
Model saved to WandB as artifact 'baseline-Davide-collab-test-checkpoints'.
5 Saved checkpoint model to WandB.


Epoch 7/10: 100%|██████████| 782/782 [08:10<00:00,  1.59it/s, loss=0.0167]


Epoch 7 Results: Train Loss: 0.0296 | Test Loss: 0.5356 | Test Acc: 85.42%
Model saved to WandB as artifact 'baseline-Davide-collab-test-checkpoints'.
6 Saved checkpoint model to WandB.


Epoch 8/10: 100%|██████████| 782/782 [08:11<00:00,  1.59it/s, loss=0.00712]


Epoch 8 Results: Train Loss: 0.0142 | Test Loss: 0.5436 | Test Acc: 85.78%
Model saved to WandB as artifact 'baseline-Davide-collab-test-checkpoints'.
7 Saved checkpoint model to WandB.


Epoch 9/10:  94%|█████████▍| 736/782 [07:42<00:28,  1.60it/s, loss=0.00573]