# Precompute features and associated labels for the CIFAR 100 train and test split

In [10]:
import torch
import torch.nn as nn
from typing import cast
from torchvision import transforms
from torchvision.datasets import CIFAR100

# Load DINO ViT-S/16 pre-trained from torch.hub

dino_model = cast(
    nn.Module,
    torch.hub.load("facebookresearch/dino:main", "dino_vits16", pretrained=True),
)
dino_model.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dino_model.to(device=device)

# Use the preprocess defined in the previous cell
# Make sure the dataset uses the correct preprocess
preprocess = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        # in federated training, we should consider to use mean and std of the cifar100
        # these are the parameters on which dino was trained
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

train_dataset = CIFAR100(root="./data", train=True, download=True, transform=preprocess)
test_dataset = CIFAR100(root="./data", train=False, download=True, transform=preprocess)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)


# Function to extract features from a dataloader
def extract_features_and_labels(dataloader, model, device):
    all_features = []
    all_labels = []
    with torch.no_grad():
        total_batches = len(dataloader)
        for batch_idx, (images, labels) in enumerate(dataloader):
            images = images.to(device)
            # Get features from the backbone (without the classification head)
            features = model(images)
            all_features.append(features.cpu())
            all_labels.append(labels.cpu())
            if (batch_idx + 1) % 10 == 0 or (batch_idx + 1) == total_batches:
                print(
                    f"Batch {batch_idx + 1}/{total_batches} ({(batch_idx + 1) / total_batches:.1%}) completed"
                )
    all_features = torch.cat(all_features, dim=0)
    all_labels = torch.cat(all_labels, dim=0)
    return all_features, all_labels


# Extract features and labels for train
train_features, train_labels = extract_features_and_labels(
    train_loader, dino_model, device
)
torch.save(
    {"features": train_features, "labels": train_labels},
    "features/train_features.pt",
)

# Extract features and labels for test
test_features, test_labels = extract_features_and_labels(
    test_loader, dino_model, device
)
torch.save(
    {"features": test_features, "labels": test_labels}, "features/test_features.pt"
)

Using cache found in /home/einrich99/.cache/torch/hub/facebookresearch_dino_main


Batch 10/1563 (0.6%) completed
Batch 20/1563 (1.3%) completed
Batch 30/1563 (1.9%) completed
Batch 40/1563 (2.6%) completed
Batch 50/1563 (3.2%) completed
Batch 60/1563 (3.8%) completed
Batch 70/1563 (4.5%) completed
Batch 80/1563 (5.1%) completed
Batch 90/1563 (5.8%) completed
Batch 100/1563 (6.4%) completed
Batch 110/1563 (7.0%) completed
Batch 120/1563 (7.7%) completed
Batch 130/1563 (8.3%) completed
Batch 140/1563 (9.0%) completed
Batch 150/1563 (9.6%) completed
Batch 160/1563 (10.2%) completed
Batch 170/1563 (10.9%) completed
Batch 180/1563 (11.5%) completed
Batch 190/1563 (12.2%) completed
Batch 200/1563 (12.8%) completed
Batch 210/1563 (13.4%) completed
Batch 220/1563 (14.1%) completed
Batch 230/1563 (14.7%) completed
Batch 240/1563 (15.4%) completed
Batch 250/1563 (16.0%) completed
Batch 260/1563 (16.6%) completed
Batch 270/1563 (17.3%) completed
Batch 280/1563 (17.9%) completed
Batch 290/1563 (18.6%) completed
Batch 300/1563 (19.2%) completed
Batch 310/1563 (19.8%) completed
B

In [10]:
from typing import Optional, cast
import torch
import torch.nn as nn
from torchvision import transforms
from tqdm.notebook import tqdm

# Training parameters
num_epochs = 100
batch_size = 64
learning_rate = 1e-3

# Early stopping parameters
best_acc = 0
patience_counter = 0
best_model_state = {}
patience = 0

dino_pretrained = cast(
    nn.Module,
    torch.hub.load("facebookresearch/dino:main", "dino_vits16", pretrained=True),
)


class CustomDino(nn.Module):
    def __init__(self, num_classes: int = 100, backbone: Optional[nn.Module] = None):
        super().__init__()
        if backbone is None:
            # Carica DINO senza pretrained e rimuove la head
            backbone = cast(
                nn.Module,
                torch.hub.load(
                    "facebookresearch/dino:main", "dino_vits16", pretrained=False
                ),
            )
        self.backbone: nn.Module = backbone
        self.classifier = nn.Linear(
            384, num_classes
        )  # 384 = output CLS token DINO ViT-S/16

    def forward(self, x: torch.Tensor):
        features = self.backbone(x)  # [batch, 384]
        logits = self.classifier(features)  # [batch, num_classes]
        return logits  # , features


model = CustomDino(num_classes=100, backbone=dino_pretrained)

# Example preprocessing for an input image
preprocess = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        # in federated training, we should consider to use mean and std of the cifar100
        # these are the parameters on which dino was trained
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

# Load precomputed features and labels
train_data = torch.load("features/train_features.pt")
test_data = torch.load("features/test_features.pt")

train_features, train_labels = train_data["features"], train_data["labels"]
test_features, test_labels = test_data["features"], test_data["labels"]

# Create TensorDatasets and DataLoaders from features
train_dataset = torch.utils.data.TensorDataset(train_features, train_labels)
test_dataset = torch.utils.data.TensorDataset(test_features, test_labels)

# DataLoader
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True
)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size)

# Optimizer for the head
optimizer = torch.optim.Adam(model.classifier.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# -----------------------------------
# Train only the linear head for fun startin from the features
# -----------------------------------

complete_model = model
model = model.classifier
model.to(device=device)

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    # Create a epoch-level progress bar and update it per-batch
    epoch_desc = f"Epoch {epoch+1} Training"
    with tqdm(total=len(train_loader), desc=epoch_desc, leave=True) as progress:
        for batch_idx, (features, labels) in enumerate(train_loader):
            # Move tensors to device
            features, labels = features.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(features)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * features.size(0)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            # Update the progress bar with running metrics
            batch_loss = running_loss / total if total > 0 else 0
            batch_acc = correct / total if total > 0 else 0
            progress.set_postfix(
                {"loss": f"{batch_loss:.4f}", "acc": f"{batch_acc:.4f}"}
            )
            progress.update(1)

    epoch_loss = running_loss / total if total > 0 else 0
    epoch_acc = correct / total if total > 0 else 0

    # Evaluation on test set
    model.eval()
    test_loss = 0.0
    test_acc = 0.0
    test_correct = 0
    test_total = 0

    # Create a per-epoch evaluation progress bar
    eval_desc = f"Epoch {epoch+1} Evaluation"
    with torch.no_grad():
        with tqdm(total=len(test_loader), desc=eval_desc, leave=True) as test_progress:
            for features, labels in test_loader:
                features, labels = features.to(device), labels.to(device)
                outputs = model(features)
                loss = criterion(outputs, labels)
                test_loss += loss.item() * features.size(0)
                _, predicted = outputs.max(1)
                test_total += labels.size(0)
                test_correct += predicted.eq(labels).sum().item()
                test_progress.update(1)

            test_loss = test_loss / test_total if test_total > 0 else 0
            test_acc = test_correct / test_total if test_total > 0 else 0

            test_progress.set_postfix(
                {
                    "loss": f"{test_loss / test_total if test_total > 0 else 0:.4f}",
                    "acc": f"{test_acc:.4f}",
                }
            )
            test_progress.update(1)

    # Early stopping
    if epoch == 0:
        best_acc = test_acc
        patience = 3
        patience_counter = 0
        best_model_state = model.state_dict()
    else:
        if test_acc > best_acc:
            best_acc = test_acc
            patience_counter = 0
            best_model_state = model.state_dict()
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping triggered.")
                model.load_state_dict(best_model_state)
                break

Using cache found in /home/einrich99/.cache/torch/hub/facebookresearch_dino_main


Epoch 1 Training:   0%|          | 0/782 [00:00<?, ?it/s]

Epoch 1 Evaluation:   0%|          | 0/157 [00:00<?, ?it/s]

Epoch 2 Training:   0%|          | 0/782 [00:00<?, ?it/s]

Epoch 2 Evaluation:   0%|          | 0/157 [00:00<?, ?it/s]

Epoch 3 Training:   0%|          | 0/782 [00:00<?, ?it/s]

Epoch 3 Evaluation:   0%|          | 0/157 [00:00<?, ?it/s]

Epoch 4 Training:   0%|          | 0/782 [00:00<?, ?it/s]

Epoch 4 Evaluation:   0%|          | 0/157 [00:00<?, ?it/s]

Epoch 5 Training:   0%|          | 0/782 [00:00<?, ?it/s]

Epoch 5 Evaluation:   0%|          | 0/157 [00:00<?, ?it/s]

Epoch 6 Training:   0%|          | 0/782 [00:00<?, ?it/s]

Epoch 6 Evaluation:   0%|          | 0/157 [00:00<?, ?it/s]

Epoch 7 Training:   0%|          | 0/782 [00:00<?, ?it/s]

Epoch 7 Evaluation:   0%|          | 0/157 [00:00<?, ?it/s]

Epoch 8 Training:   0%|          | 0/782 [00:00<?, ?it/s]

Epoch 8 Evaluation:   0%|          | 0/157 [00:00<?, ?it/s]

Epoch 9 Training:   0%|          | 0/782 [00:00<?, ?it/s]

Epoch 9 Evaluation:   0%|          | 0/157 [00:00<?, ?it/s]

Epoch 10 Training:   0%|          | 0/782 [00:00<?, ?it/s]

Epoch 10 Evaluation:   0%|          | 0/157 [00:00<?, ?it/s]

Epoch 11 Training:   0%|          | 0/782 [00:00<?, ?it/s]

Epoch 11 Evaluation:   0%|          | 0/157 [00:00<?, ?it/s]

Early stopping triggered.
