In [None]:
import torch
import torch.optim as optim
import torch.nn as nn
import torchvision
from torchvision import transforms
from torchvision.datasets import CIFAR100
from torch.utils.data import Subset, DataLoader, random_split

from FederatedLearningProject.data.cifar100_loader import get_cifar100
import FederatedLearningProject.checkpoints.checkpointing as checkpointing
from FederatedLearningProject.training.centralized_training import train_and_validate

In [None]:
import numpy as np
import wandb
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
import shutil
import os                              # Import the 'os' module for changing directories
os.chdir('/content/drive/MyDrive/FL')  # Change the directory

In [None]:
wandb.login() # Ask for your API key for logging in to the wandb library.

In [None]:
# Import CIFAR100 dataset: train_set, val_set, test_set
# The transforms are applied before returning the dataset (in the module)

valid_split_perc = 0.2    # of the 50000 training data
train_set, val_set, test_set = get_cifar100(valid_split_perc)

In [None]:
# Create DataLoaders for training, validation, and test sets

# batch_size è in hyperparameter (64, 128, ..), anche num_workers (consigliato per colab 2 o 4)

train_loader = DataLoader(train_set, batch_size=64, shuffle=True, num_workers=2)
val_loader = DataLoader(val_set, batch_size=64, shuffle=False, num_workers=2)
test_loader = DataLoader(test_set, batch_size=64, shuffle=False, num_workers=2)


### Possible Models
|                        | **Simple linear head**                               | **MLP head w/ Dropout**                                                      |
| :--------------------- | :--------------------------------------------------- | :--------------------------------------------------------------------------- |
| **Definition**         | `nn.Linear(384 → 100)`                               | `Dropout → Linear(384 → 256) → ReLU → Dropout → Linear(256 → 100)`           |
| **# trainable params** | 384×100 + 100 ≈ **38 500**                           | 384×256+256 + 256×100+100 ≈ **123 000**                                      |
| **Regularization**     | none                                                 | dropout on both layers                                                       |
| **Expressive power**   | low  – just a single hyperplane on the CLS embedding | higher – small nonlinear bottleneck can learn more complex features in heads |
| **Compute / memory**   | minimal                                              | \~3× more weights, a bit more forward/backward cost                          |

---

**Appunto sui layer di testa:**

1. **`self.classifier`**

   * **Cosa contiene?** Un singolo `nn.Linear(embed_dim → num_classes)`.
   * **Quando usarlo?** Se vuoi un *linear probe* puro: un solo layer che prende il CLS token e mappa direttamente alle classi.
   * **Pro:** estremamente leggero (∼38 K parametri), veloce da addestrare e da inferire.
   * **Contro:** capacità espressiva minima (è solo un’iper‐superficie lineare sullo spazio degli embedding).

2. **`self.head`**

   * **Cosa contiene?** Una piccola sequenza (`nn.Sequential`) di layer:

     * Dropout
     * Linear (embed\_dim → hidden\_dim)
     * ReLU
     * Dropout
     * Linear (hidden\_dim → num\_classes)
   * **Quando usarlo?** Se vuoi dare al tuo “probe” un po’ più di potenza di calcolo, trasformando non-linearmente il CLS prima della classificazione.
   * **Pro:** maggiore capacità di apprendere rappresentazioni complesse nella testa, un minimo di regolarizzazione via dropout.
   * **Contro:** più pesante (∼3× parametri in più rispetto al solo `classifier`), leggermente più lento da addestrare e inferire.

---

### Perché una piuttosto che l’altra?

* **Vincoli di risorse** (GPU/RAM, tempo d’addestramento):

  * Se sei sotto forte pressione computazionale o vuoi risultati rapidi, opti per `self.classifier`.
* **Prestazioni** (accuratezza su dataset piccolo/mediamente grande come CIFAR-100):

  * Se noti che il linear probe raggiunge un plateau basso, un piccolo MLP (`self.head`) può guadagnare qualche punto percentuale in più.
* **Semplicità vs flessibilità**:

  * Con una sola `classifier` hai un codice più pulito e diretto.
  * Con `head` puoi sperimentare — cambiare `hidden_dim`, aggiungere altro dropout, batchnorm o ulteriori layer.

In definitiva, **il nome** (`classifier` vs `head`) è arbitrario: serve a rendere più chiaro nel codice di che “peso” stiamo parlando. Se hai un solo layer, chiamalo `classifier`; se invece è un blocco più articolato, chiamalo `head` o `projection_head`, per tener separata la parte “feature extractor” (backbone) dalla parte “feature consumer” (testa di classificazione).


In [None]:
# # --- MLP head w/ Dropout ---
# # Load DINO ViT-S/16 backbone and freeze it
# backbone = torch.hub.load('facebookresearch/dino:main', 'dino_vits16')
# backbone.eval()

# for p in backbone.parameters():
#     p.requires_grad = False

# # Define the classifier head with optional dropout/MLP
# class DinoClassifier(nn.Module):
#     def __init__(self, backbone, num_classes=100, hidden_dim=256, drop=0.5):
#         super().__init__()
#         self.backbone = backbone
#         embed_dim = backbone.embed_dim  # 384 for ViT-S/16
#         self.head = nn.Sequential(
#             nn.Dropout(drop),
#             nn.Linear(embed_dim, hidden_dim),
#             nn.ReLU(inplace=True),
#             nn.Dropout(drop),
#             nn.Linear(hidden_dim, num_classes)
#         )

#     def forward(self, x):
#         # get CLS token from the frozen backbone
#         with torch.no_grad():
#             cls = self.backbone.forward_features(x)   # -> (B, embed_dim)
#         return self.head(cls)

# model = DinoClassifier(backbone, num_classes=100).to(device)

# # ensure backbone stays in eval, head in train
# model.backbone.eval()
# model.head.train()

In [None]:
# --- Simple linear Head ---
# Load DINO ViT-S/16 backbone e freeze
backbone = torch.hub.load('facebookresearch/dino:main', 'dino_vits16')
backbone.eval()

for p in backbone.parameters():
    p.requires_grad = False

# Define the classifier head
class DinoClassifier(nn.Module):
    def __init__(self, backbone, num_classes=100):
        super().__init__()
        self.backbone = backbone
        embed_dim = backbone.embed_dim  # 384 per ViT-S/16
        self.classifier = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        # estraggo CLS token con backbone congelato
        with torch.no_grad():
            cls = self.backbone.forward_features(x)  # -> (B, embed_dim)
        return self.classifier(cls)

model = DinoClassifier(backbone, num_classes=100)

# Device selection: Use GPU if available, otherwise CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

In [None]:
print(device)

In [None]:
print(model)

In [None]:
# --- OPTIMIZER AND LOSS FUNCTION ---
# optimizer = optim.Adam(model.parameters(), lr=1e-4)
optimizer = optim.SGD(model.parameters(), lr=1e-4)  # momentum=0.9, weight_decay=5e-4 -> optimizer consigliato

criterion = nn.CrossEntropyLoss()

In [None]:
# wandb.init() prepares the tracking of hyperparameters/metrics for later recording performance using wandb.log()

model_name = "dino_vits16"
project_name = "FederatedProject"
run_name = f"{model_name}_run"

# INITIALIZE W&B
wandb.init(
    project=project_name,
    name=run_name,
    config={
        "model": model_name,
        "epochs": 50,
        "batch_size": train_loader.batch_size,
        "learning_rate": optimizer.param_groups[0]['lr'],
        "architecture": model.__class__.__name__,
})

# Copy your config
config = wandb.config


In [None]:
#  PERCORSO CHECKPOINT
checkpoint_dir = "/content/drive/MyDrive/FL/FederatedLearningProject/checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint_path = os.path.join(checkpoint_dir, f"{model_name}_checkpoint.pth")    # we predefine the name of the file inside the specified folder (dir)

In [None]:
# RECOVER CHECKPOINT
start_epoch, model_data = checkpointing.load_checkpoint(model, optimizer, checkpoint_dir)

try:
  print()
  print(f"The 'model_data' dictionary contains the following keys: {list(model_data.keys())}")
  model.load_state_dict(model_data["model_state_dict"])
  optimizer.load_state_dict(model_data["optimizer_state_dict"])
except: None



In [None]:
# --- TRAINING LOOP ---
# Call to the training loop function
train_and_validate(start_epoch, model, train_loader, val_loader, optimizer, criterion, device, checkpoint_path, num_epochs=20, checkpoint_interval=10)

In [None]:
## Display some informations ##

print("Model:", model_name)
print("Train set size:", len(train_set))
print("Validation set size:", len(val_set))
print("Batch size:", train_loader.batch_size)
print("Number of epochs:", config.epochs)
print("DataLoader: ")
print("Learning rate:", optimizer.param_groups[0]['lr'])
print("Architecture:", model.__class__.__name__)
print("Device:", device)
print("Optimizer:", optimizer)
print("Loss function:", criterion)
print("Checkpoint directory:", checkpoint_dir)
print("Checkpoint path:", checkpoint_path)
print("Current epoch:", epoch)
print()

print("Train Loader Information:")
print(f"  Number of batches: {len(train_loader)}")
print(f"  Batch size: {train_loader.batch_size}")
# Get the dimension of a single batch
for images, labels in train_loader:
  print(f"  Dimension of 1 batch (images): {images.shape}")
  print(f"  Dimension of 1 batch (labels): {labels.shape}")
  break  # Exit the loop after processing one batch
print()

print("\nValidation Loader Information:")
print(f"  Number of batches: {len(val_loader)}")
print(f"  Batch size: {val_loader.batch_size}")
# Get the dimension of a single batch
for images, labels in val_loader:
  print(f"  Dimension of 1 batch (images): {images.shape}")
  print(f"  Dimension of 1 batch (labels): {labels.shape}")
  break  # Exit the loop after processing one batch
print()

# Check for CUDA availability
print("CUDA AVAIABILITY:")
if torch.cuda.is_available():
    print("CUDA is available. Using GPU.")
    print("Number of GPUs:", torch.cuda.device_count())
    print("Current GPU:", torch.cuda.current_device())
    print("GPU Name:", torch.cuda.get_device_name(torch.cuda.current_device()))
else:
    print("CUDA is not available. Using CPU.")

# Print model architecture summary
print("\nMODEL ARCHITECTURE:")
print(model)
