In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision import datasets, transforms
from torchvision.transforms import Compose, Normalize, ToTensor, Resize, CenterCrop
import numpy as np
from copy import deepcopy
from typing import Optional, cast
from tqdm import tqdm
import random
import os
import wandb

from utilities.wandb_utils import save_checkpoint_to_wandb

# ==========================================
# 0. Configuration & Paths
# ==========================================
BATCH_SIZE = 64
LR = 0.01
MOMENTUM = 0.9
WEIGHT_DECAY = 5e-4
EPOCHS = 50
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
TOTAL_EXEMPLARS_VECTORS = 1000  # Total exemplar vectors to store across all classes
FILENAME = "centroids_linear.pth"

WANDB_ENTITY = "aml-fl-project"
WANDB_PROJECT = "fl-task-arithmetic"
WANDB_GROUP = "centroids"
WANDB_RUN_NAME = "icarl-centroids"
WANDB_NOTES = "Centroids computed over 100 CIFAR100 classes"
WANDB_MODE = "online"


# ==========================================
# 1. Model Definition (CustomDino)
# ==========================================
class CustomDino(nn.Module):
    def __init__(self, num_classes: int = 100, backbone: Optional[nn.Module] = None):
        super().__init__()
        if backbone is None:
            backbone = cast(
                nn.Module,
                torch.hub.load(
                    "facebookresearch/dino:main", "dino_vits16", pretrained=True
                ),
            )
        self.backbone: nn.Module = backbone
        self.num_classes = num_classes
        self.classifier = nn.Linear(384, num_classes)

    def forward(self, x: torch.Tensor):
        features = self.backbone(x)  # [batch, 384]
        logits = self.classifier(features)
        return logits, features


# ==========================================
# 2. iCaRL Logic Class
# ==========================================
class iCaRL:
    def __init__(
        self,
        num_classes=100,
        memory_size=TOTAL_EXEMPLARS_VECTORS,
        feature_dim=384,
        device="cuda",
    ):
        self.device = device
        self.num_classes = num_classes
        self.memory_size = memory_size
        self.feature_dim = feature_dim

        # Initialize Model
        self.model = CustomDino(num_classes=num_classes).to(self.device)
        self.old_model = None  # Snapshot of model before current task

        # Memory (Exemplars)
        self.exemplar_sets = []  # List of lists (images per class)
        self.exemplar_means = []  # Class prototypes for NME

        # Training Parameters
        self.lr = 0.01
        self.weight_decay = 1e-5
        self.momentum = 0.9
        self.epochs = 20  # Reduced for demo speed (standard is often higher)

    def update_representation(self, train_loader, new_classes):
        """
        Step 1: Train the model using Classification + Distillation Loss
        """
        print(f"--- Updating Representation for classes {new_classes} ---")

        # 1. Combine new data with exemplars
        # (In this simplified script, we assume train_loader already mixes them if available
        # or we just iterate. For strict iCaRL, we augment the batch with exemplars).
        # To keep it simple for Colab, we will rely on the DataLoader having the mix.

        optimizer = optim.SGD(
            self.model.parameters(),
            lr=self.lr,
            momentum=self.momentum,
            weight_decay=self.weight_decay,
        )
        # Scheduler helps convergence
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.epochs)

        self.model.train()
        if self.old_model:
            self.old_model.eval()

        for epoch in range(self.epochs):
            total_loss = 0
            for images, labels in tqdm(
                train_loader, desc=f"Epoch {epoch+1}/{self.epochs}", leave=False
            ):
                images = images.to(self.device)
                labels = labels.to(self.device)

                optimizer.zero_grad()

                # Forward Pass
                logits, _ = self.model(images)

                # --- Loss Calculation ---
                # A. Classification Loss (Cross Entropy on all visible classes)
                loss_cls = F.cross_entropy(logits, labels)

                # B. Distillation Loss (on OLD classes only)
                loss_dist = torch.tensor(0.0).to(self.device)
                if self.old_model is not None:
                    # Get old logits
                    with torch.no_grad():
                        old_logits, _ = self.old_model(images)

                    # Sigmoid Distillation (Rebuffi et al. 2017)
                    # We compute BCE between the sigmoid outputs of the new model and the old model
                    # solely for the classes the old model knew.
                    known_classes = self.old_model.classifier.out_features
                    # Usually iCaRL assumes specific output nodes. Here we map indices.
                    # We assume indices 0 to (start of new task) are old classes.

                    # Create a mask for old classes (e.g., 0 to 10, then 0 to 20...)
                    # The 'old_logits' typically has size [B, num_classes] same as current if architecture is fixed
                    # Or [B, old_num_classes] if it grew. DINO linear layer is usually fixed size or grows.
                    # Here we assume fixed size 100 for simplicity.

                    # Calculate Distillation:
                    # T=1 is standard for iCaRL's sigmoid distillation
                    # [:, :start_new_task] Are all the old classes the new model should not forget
                    start_new_task = new_classes[0]
                    if start_new_task > 0:
                        dist_target = torch.sigmoid(old_logits[:, :start_new_task])
                        dist_pred = torch.sigmoid(logits[:, :start_new_task])
                        loss_dist = F.binary_cross_entropy(dist_pred, dist_target)

                loss = loss_cls + loss_dist
                loss.backward()
                optimizer.step()
                total_loss += loss.item()

            scheduler.step()
            # print(f"Epoch {epoch}: Loss {total_loss:.4f}")

        # Update the frozen old model
        self.old_model = deepcopy(self.model)
        for p in self.old_model.parameters():
            p.requires_grad = False

    def reduce_exemplar_sets(self, m):
        """
        Step 2: Shrink stored exemplars to fit memory budget.
        m = memory_size / num_classes_seen_so_far
        """
        print(f"Reducing exemplars to {m} per class...")
        for y in range(len(self.exemplar_sets)):
            self.exemplar_sets[y] = self.exemplar_sets[y][:m]

    def construct_exemplar_sets(self, images, m, transform, class_id):
        """
        Step 3: Select new exemplars using Herding (nearest to mean).
        """
        print(f"Constructing {m} exemplars vectors per class number {class_id}")
        self.model.eval()

        # Compute mean of the class
        with torch.no_grad():
            # Extract features
            # Note: We need a loader to process 'images' (which is a list/tensor of raw images)
            # For efficiency in this script, we assume 'images' fits in VRAM or we batch it.
            # Simplified:
            img_tensor = torch.stack(images).to(self.device)
            _, features = self.model(img_tensor)
            features = F.normalize(features, p=2, dim=1)
            class_mean = torch.mean(features, dim=0)

            # Herding Selection
            exemplar_set = []
            exemplar_features = []

            # We assume features are [N, D]
            # We iterate m times to pick m samples
            for k in range(m):
                S = (
                    torch.sum(torch.stack(exemplar_features), dim=0)
                    if len(exemplar_features) > 0
                    else torch.zeros(self.feature_dim).to(self.device)
                )

                # Objective: minimize || class_mean - (S + phi(x)) / k   ||
                phi = features  # [N, D]
                mu = class_mean  # [D]

                # Distance for all candidates
                dists = torch.norm(mu - ((S + phi) / k), dim=1)

                # Pick best that isn't already chosen (simple way: set dist to inf)
                # In strict implementation, we remove the index.
                best_idx = torch.argmin(dists).item()

                exemplar_set.append(images[best_idx])
                exemplar_features.append(features[best_idx])

                # Mask this index so it's not picked again
                features[best_idx] = features[best_idx] + 10000  # Hacky mask

            self.exemplar_sets.append(exemplar_set)

    def classify_nme(self, x):
        """
        Step 4: Classification using Nearest Mean of Exemplars.
        Strict Implementation of Algorithm 1 & Eq. 2
        """
        self.model.eval()
        with torch.no_grad():
            # 1. Get features of the image to classify
            _, query_features = self.model(x.to(self.device))
            # Normalize query features (Section 2.1)
            query_features = F.normalize(query_features, p=2, dim=1)

            # 2. Compute Prototypes (Means of Exemplars)
            means = []
            for y in range(len(self.exemplar_sets)):
                # Get all exemplars for class y
                ex_imgs = torch.stack(self.exemplar_sets[y]).to(self.device)

                # Extract features for exemplars
                _, ex_feats = self.model(ex_imgs)

                # Normalize exemplar features BEFORE averaging (Section 2.1)
                ex_feats = F.normalize(ex_feats, p=2, dim=1)

                # Compute the mean
                class_mean = torch.mean(ex_feats, dim=0)

                # Re-normalize the mean vector itself (Section 2.1: "averages are also re-normalized")
                class_mean = F.normalize(class_mean.unsqueeze(0), p=2, dim=1).squeeze(0)

                means.append(class_mean)

            if len(means) == 0:
                return torch.zeros(x.size(0))

            means = torch.stack(means)  # [Num_Classes_Seen, Feature_Dim]

            # 3. Find Nearest Prototype (Algorithm 1)
            # "y* = argmin || phi(x) - mu_y ||"
            dists = torch.cdist(query_features, means)  # [Batch, Num_Classes]
            preds = torch.argmin(dists, dim=1)

        return preds


# ==========================================
# 3. Checkpointing Functions
# ==========================================
def save_icarl_checkpoint(icarl_instance, task_id, acc, checkpoint_dir):
    """Saves the entire state required for iCaRL checkpointing."""

    # Construct the path
    filename = f"icarl_task_{task_id:02d}_acc_{acc:.2f}.pth"
    save_path = os.path.join(checkpoint_dir, filename)
    # Data to save: model states, memory, and task metadata
    checkpoint_data = {
        "task_id": task_id,
        "accuracy": acc,
        # 'model_state_dict': icarl_instance.model.state_dict(),
        # 'old_model_state_dict': icarl_instance.old_model.state_dict() if icarl_instance.old_model else None,
        # Save the exemplar sets. We move tensors to CPU for better portability/storage.
        "exemplar_sets": [
            [img.cpu() for img in class_set]
            for class_set in icarl_instance.exemplar_sets
        ],
    }

    print(f"\n Saving checkpoint for Task {task_id} (Acc: {acc:.2f}%) to: {save_path}")
    torch.save(checkpoint_data, save_path)

    # Optional: Keep only the latest checkpoint to save space
    # You might comment this out if you want to keep all task checkpoints
    for file in os.listdir(checkpoint_dir):
        if file.startswith("icarl_task") and file != filename:
            os.remove(os.path.join(checkpoint_dir, file))
            # print(f"Cleaned up old checkpoint: {file}")


def load_icarl_checkpoint(icarl_instance, checkpoint_dir):
    """
    Loads the latest checkpoint from the directory and restores the iCaRL state.
    Returns: The task_id to resume from (e.g., if task 2 was the last saved, returns 3).
    """

    # 1. Find the latest/best checkpoint file
    checkpoint_files = [
        f
        for f in os.listdir(checkpoint_dir)
        if f.startswith("icarl_task") and f.endswith(".pth")
    ]
    if not checkpoint_files:
        print(" No checkpoint found. Starting from Task 1.")
        return 0  # Start from task 0 (which becomes task 1 in the loop)

    # Simple heuristic: pick the one with the highest task ID in the filename
    latest_file = max(checkpoint_files, key=lambda f: int(f.split("_")[2]))
    load_path = os.path.join(checkpoint_dir, latest_file)
    print(f"Loading checkpoint from: {load_path}")

    # 2. Load the state
    checkpoint = torch.load(load_path, map_location=icarl_instance.device)

    # 3. Restore iCaRL state

    # Restore Model
    icarl_instance.model.load_state_dict(checkpoint["model_state_dict"])

    # Restore Old Model
    old_state_dict = checkpoint["old_model_state_dict"]
    if old_state_dict:
        # We need a new model instance to load the old state into
        icarl_instance.old_model = CustomDino(
            num_classes=icarl_instance.num_classes
        ).to(icarl_instance.device)
        icarl_instance.old_model.load_state_dict(old_state_dict)
        for p in icarl_instance.old_model.parameters():
            p.requires_grad = False

    # Restore Exemplar Sets (move back to GPU if necessary)
    # The images in exemplar sets are Tensors
    icarl_instance.exemplar_sets = [
        [img.to(icarl_instance.device) for img in class_set]
        for class_set in checkpoint["exemplar_sets"]
    ]

    # Determine next task
    last_completed_task = checkpoint["task_id"]
    print(
        f"Resuming from after Task {last_completed_task} (Acc: {checkpoint['accuracy']:.2f}%)."
    )
    return last_completed_task + 1


# ==========================================
# 4. Data Utilities
# ==========================================
class iCaRLDataset(Dataset):
    """
    Dataset that combines new task data with stored exemplars.
    """

    def __init__(self, new_data, exemplars, transform=None):
        self.new_data = new_data  # List of (image, label) tuples
        self.exemplars = exemplars  # List of lists of images
        self.transform = transform

        # Flatten exemplars into a list of (img, label)
        self.exemplar_data = []
        for label, img_list in enumerate(exemplars):
            for img in img_list:
                self.exemplar_data.append((img, label))

        self.all_data = self.new_data + self.exemplar_data

    def __getitem__(self, index):
        img, label = self.all_data[index]
        # img is a Tensor here if coming from CIFAR100(ToTensor),
        # but iCaRL usually stores raw images.
        # For simplicity in this script, we assume img is already Tensor from prev loader
        # If transform is needed, apply here.
        return img, label

    def __len__(self):
        return len(self.all_data)


def get_data_for_classes(dataset, classes):
    """
    Extracts all samples belonging to specific classes.
    """
    indices = [i for i, label in enumerate(dataset.targets) if label in classes]
    return Subset(dataset, indices)


def extract_images_from_subset(subset):
    """
    Helper to pull images out of a Subset for exemplar storage.
    """
    images = []
    # This is slow for large sets, efficient implementation would use indices directly
    # But for a tutorial script, iterating is safe.
    for i in range(len(subset)):
        img, _ = subset[i]
        images.append(img)
    return images


# ==========================================
# 5. Main Experiment Loop (single round over all 100 classes)
# ==========================================
def main():
    print("Preparing Data...")
    stats = ((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
    transform = Compose(
        [
            Resize(256),
            CenterCrop(224),
            ToTensor(),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
    )

    run = wandb.init(
        entity=WANDB_ENTITY,
        project=WANDB_PROJECT,
        group=WANDB_GROUP,
        name=WANDB_RUN_NAME,
        notes=WANDB_NOTES,
        mode=WANDB_MODE,
    )

    # Load full CIFAR-100
    train_ds = datasets.CIFAR100(
        root="./data", train=True, download=True, transform=transform
    )

    # Model (no incremental steps)
    model = CustomDino(num_classes=100).to(DEVICE)

    # Compute centroids using TOTAL_EXEMPLARS_VECTORS / 100 images per class
    images_per_class = TOTAL_EXEMPLARS_VECTORS // 100
    print(f"Using {images_per_class} images per class for centroid computation")

    model.eval()
    centroids_list = []

    with torch.no_grad():
        for class_id in tqdm(range(100), desc="Computing centroids per class"):
            # Get indices for this class
            class_indices = [
                i for i, label in enumerate(train_ds.targets) if label == class_id
            ]

            # Select only images_per_class images
            selected_indices = class_indices[:images_per_class]

            # Extract features for selected images
            class_features = []
            for idx in selected_indices:
                img, _ = train_ds[idx]
                img = img.unsqueeze(0).to(DEVICE)
                _, feat = model(img)
                class_features.append(feat)

            # Compute mean of features, then normalize (DINO style)
            if class_features:
                class_mean = torch.mean(
                    torch.cat(class_features, dim=0), dim=0, keepdim=True
                )
                class_centroid = F.normalize(class_mean, p=2, dim=1).squeeze(0)
                centroids_list.append(class_centroid)
            else:
                centroids_list.append(torch.zeros(384, device=DEVICE))

    centroids = torch.stack(centroids_list).cpu()

    # Save centroids as a Linear state_dict-compatible payload
    state_dict = {
        "weight": centroids,
        "bias": torch.zeros(100),
    }
    torch.save(state_dict, FILENAME)
    print(f"Saved centroids as linear weights to {FILENAME}")

    save_checkpoint_to_wandb(
        run=run,
        checkpoint=state_dict,
        filename=FILENAME,
        metadata={
            "type": "centroids",
            "num_classes": 100,
            "feature_dim": 384,
            "images_per_class": images_per_class,
        },
    )
    run.finish()


if __name__ == "__main__":
    main()

Preparing Data...


[34m[1mwandb[0m: Currently logged in as: [33merikscolaro31[0m ([33maml-fl-project[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Using cache found in /home/einrich99/.cache/torch/hub/facebookresearch_dino_main


Using 10 images per class for centroid computation


Computing centroids per class: 100%|██████████| 100/100 [00:28<00:00,  3.51it/s]


Saved centroids as linear weights to centroids_linear.pth


[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


Model saved to WandB as artifact 'centroids-checkpoints'.


In [2]:
# Load centroids from WandB and evaluate on CIFAR-100 test
from utilities.wandb_utils import load_checkpoint_from_wandb
import wandb

# Init run to reuse artifact
run = wandb.init(
    entity=WANDB_ENTITY,
    project=WANDB_PROJECT,
    group=WANDB_GROUP,
    name=f"{WANDB_RUN_NAME}-eval",
    notes="Eval centroids on CIFAR100 test",
    mode=WANDB_MODE,
)

# Prepare model with pretrained backbone
model = CustomDino(num_classes=100).to(DEVICE)

# Download checkpoint artifact - it returns a dict with 'weight' and 'bias' keys
result = load_checkpoint_from_wandb(
    run=run,
    model=model,  # Needed for device detection
    filename=FILENAME,
    version="latest",
)

if result is None:
    print("No checkpoint found on WandB; aborting eval.")
    run.finish()
else:
    checkpoint, _artifact = result
    # Load the centroids into the classifier layer
    model.classifier.load_state_dict(checkpoint)
    model.eval()

    # CIFAR-100 test loader
    stats = ((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
    transform = Compose(
        [
            Resize(256),
            CenterCrop(224),
            ToTensor(),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
    )
    test_ds = datasets.CIFAR100(
        root="./data", train=False, download=True, transform=transform
    )
    test_loader = DataLoader(test_ds, batch_size=64, shuffle=False, num_workers=2)

    # Evaluate accuracy
    correct = 0
    total = 0
    with torch.no_grad():
        for imgs, labels in tqdm(test_loader, desc="Evaluating"):
            imgs = imgs.to(DEVICE)
            labels = labels.to(DEVICE)
            logits, _ = model(imgs)
            preds = logits.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    acc = correct / max(total, 1)
    print(f"Test accuracy with WandB centroids: {acc:.4f}")
    run.log({"test_accuracy": acc})
    run.finish()

Using cache found in /home/einrich99/.cache/torch/hub/facebookresearch_dino_main
[34m[1mwandb[0m:   1 of 1 files downloaded.  


Loading model from: /home/einrich99/Progetti/FL-task-arithmetic/notebooks/artifacts/centroids-checkpoints:v3/centroids_linear.pth
Successfully loaded model from: /home/einrich99/Progetti/FL-task-arithmetic/notebooks/artifacts/centroids-checkpoints:v3/centroids_linear.pth


Evaluating: 100%|██████████| 157/157 [06:22<00:00,  2.43s/it]
[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


Test accuracy with WandB centroids: 0.5149


0,1
test_accuracy,▁

0,1
test_accuracy,0.5149
