# Precompute features and associated labels for the CIFAR 100 train and test split

In [10]:
import torch
import torch.nn as nn
from typing import cast
from torchvision import transforms
from torchvision.datasets import CIFAR100

# Load DINO ViT-S/16 pre-trained from torch.hub

dino_model = cast(
    nn.Module,
    torch.hub.load("facebookresearch/dino:main", "dino_vits16", pretrained=True),
)
dino_model.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dino_model.to(device=device)

# Use the preprocess defined in the previous cell
# Make sure the dataset uses the correct preprocess
preprocess = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        # in federated training, we should consider to use mean and std of the cifar100
        # these are the parameters on which dino was trained
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

train_dataset = CIFAR100(root="./data", train=True, download=True, transform=preprocess)
test_dataset = CIFAR100(root="./data", train=False, download=True, transform=preprocess)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)


# Function to extract features from a dataloader
def extract_features_and_labels(dataloader, model, device):
    all_features = []
    all_labels = []
    with torch.no_grad():
        total_batches = len(dataloader)
        for batch_idx, (images, labels) in enumerate(dataloader):
            images = images.to(device)
            # Get features from the backbone (without the classification head)
            features = model(images)
            all_features.append(features.cpu())
            all_labels.append(labels.cpu())
            if (batch_idx + 1) % 10 == 0 or (batch_idx + 1) == total_batches:
                print(
                    f"Batch {batch_idx + 1}/{total_batches} ({(batch_idx + 1) / total_batches:.1%}) completed"
                )
    all_features = torch.cat(all_features, dim=0)
    all_labels = torch.cat(all_labels, dim=0)
    return all_features, all_labels


# Extract features and labels for train
train_features, train_labels = extract_features_and_labels(
    train_loader, dino_model, device
)
torch.save(
    {"features": train_features, "labels": train_labels},
    "features/train_features.pt",
)

# Extract features and labels for test
test_features, test_labels = extract_features_and_labels(
    test_loader, dino_model, device
)
torch.save(
    {"features": test_features, "labels": test_labels}, "features/test_features.pt"
)

Using cache found in /home/einrich99/.cache/torch/hub/facebookresearch_dino_main


Batch 10/1563 (0.6%) completed
Batch 20/1563 (1.3%) completed
Batch 30/1563 (1.9%) completed
Batch 40/1563 (2.6%) completed
Batch 50/1563 (3.2%) completed
Batch 60/1563 (3.8%) completed
Batch 70/1563 (4.5%) completed
Batch 80/1563 (5.1%) completed
Batch 90/1563 (5.8%) completed
Batch 100/1563 (6.4%) completed
Batch 110/1563 (7.0%) completed
Batch 120/1563 (7.7%) completed
Batch 130/1563 (8.3%) completed
Batch 140/1563 (9.0%) completed
Batch 150/1563 (9.6%) completed
Batch 160/1563 (10.2%) completed
Batch 170/1563 (10.9%) completed
Batch 180/1563 (11.5%) completed
Batch 190/1563 (12.2%) completed
Batch 200/1563 (12.8%) completed
Batch 210/1563 (13.4%) completed
Batch 220/1563 (14.1%) completed
Batch 230/1563 (14.7%) completed
Batch 240/1563 (15.4%) completed
Batch 250/1563 (16.0%) completed
Batch 260/1563 (16.6%) completed
Batch 270/1563 (17.3%) completed
Batch 280/1563 (17.9%) completed
Batch 290/1563 (18.6%) completed
Batch 300/1563 (19.2%) completed
Batch 310/1563 (19.8%) completed
B

In [None]:
from typing import cast
import torch
import torch.nn as nn
from torchvision import transforms
from fl_task_arithmetic.task import CustomDino
from torchvision.datasets import CIFAR100

# Training parameters
num_epochs = 10
batch_size = 64
learning_rate = 1e-3

# Early stopping parameters
best_acc = 0
patience_counter = 0
best_model_state = {}
patience = 0


dino_pretrained = cast(
    nn.Module,
    torch.hub.load("facebookresearch/dino:main", "dino_vits16", pretrained=True),
)
model = CustomDino(num_classes=100, backbone=dino_pretrained)
print(model)

# Example preprocessing for an input image
preprocess = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        # in federated training, we should consider to use mean and std of the cifar100
        # these are the parameters on which dino was trained
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

# Download and load the CIFAR-100 training and test datasets
train_dataset = CIFAR100(root="./data", train=True, download=True, transform=preprocess)
test_dataset = CIFAR100(root="./data", train=False, download=True, transform=preprocess)

# Freeze backbone
for param in model.backbone.parameters():
    param.requires_grad = False

# Enable only the head
for param in model.classifier.parameters():
    param.requires_grad = True  # Imposta numero di epoche e batch size

# DataLoader
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True
)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size)

# Optimizer for the head
optimizer = torch.optim.Adam(model.classifier.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)


for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for batch_idx, (images, labels) in enumerate(train_loader):
        print(f"Epoch {epoch+1}, Iteration {batch_idx+1}")
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    epoch_loss = running_loss / total
    epoch_acc = correct / total
    print(
        f"Epoch {epoch+1}/{num_epochs} - Loss: {epoch_loss:.4f} - Acc: {epoch_acc:.4f}"
    )

    # Evaluation on test set
    model.eval()
    test_loss = 0.0
    test_correct = 0
    test_total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            test_loss += loss.item() * images.size(0)
            _, predicted = outputs.max(1)
            test_total += labels.size(0)
            test_correct += predicted.eq(labels).sum().item()
    test_loss /= test_total
    test_acc = test_correct / test_total
    print(f"Test Loss: {test_loss:.4f} - Test Acc: {test_acc:.4f}")

    # Early stopping
    if epoch == 0:
        best_acc = test_acc
        patience = 3
        patience_counter = 0
        best_model_state = model.state_dict()
    else:
        if test_acc > best_acc:
            best_acc = test_acc
            patience_counter = 0
            best_model_state = model.state_dict()
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping triggered.")
                model.load_state_dict(best_model_state)
                break

Using cache found in /home/einrich99/.cache/torch/hub/facebookresearch_dino_main


CustomDino(
  (backbone): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 384, kernel_size=(16, 16), stride=(16, 16))
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (blocks): ModuleList(
      (0-11): 12 x Block(
        (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=384, out_features=1152, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=384, out_features=384, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (drop_path): Identity()
        (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=384, out_features=1536, bias=True)
          (act): GELU(approximate='none')
          (fc2): Linear(in_features=1536, out_features=384, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (norm): Layer

100%|██████████| 169M/169M [00:09<00:00, 18.0MB/s] 
  tar.extractall(to_path)


KeyboardInterrupt: 