In [None]:
from data.icarl_dataset import iCaRLDataset, get_data_for_classes, extract_images_from_subset
from time import sleep
from torchvision.transforms import Compose, Normalize, ToTensor, Resize, CenterCrop
from torchvision import datasets
import torch
from models.icarl_head import IcarlModel, Icarl
import wandb
from utilities.wandb_utils import load_checkpoint_from_wandb, save_checkpoint_to_wandb
from torch.utils.data import DataLoader, Dataset, Subset
from tqdm import tqdm
from copy import deepcopy
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim


RUN_ID = "run-1-icarl_cifar100"
ENTITY = "aml-fl-project"
PROJECT = "fl-task-arithmetic"
GROUP = "icarl-cifar100"

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
TOTAL_EXEMPLARS_VECTORS = 1000
TASKS = 20
CLASSES_PER_TASK = 100 // TASKS

EPOCHS = 20
LR = 0.01
WEIGHT_DECAY = 1e-5
MOMENTUM = 0.9


stats = ((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
transform = Compose([
    Resize(256), CenterCrop(224),
    ToTensor(),
    Normalize(*stats),
])

# Load FULL Datasets
train_ds = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)
test_ds = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)

run = wandb.init(
    entity=ENTITY,
    project=PROJECT,
    group=GROUP,
    name="iCaRL_CIFAR100",
    id=RUN_ID,
    resume="allow",
    mode="online",
)

# for artifact in wandb.Api().run(f"{ENTITY}/{PROJECT}/{RUN_ID}").logged_artifacts():
#     artifact.delete()


# Initialize iCaRL
icarl = Icarl(
    num_classes=100,
    memory_size=TOTAL_EXEMPLARS_VECTORS,
    device=DEVICE
)

checkpoint = load_checkpoint_from_wandb(
    run,
    icarl,
    "model.pth"
)
start_task = 0
if checkpoint is not None:
    checkpoint_dict, artifact = checkpoint
    icarl.load_state_dict(checkpoint_dict['model'])
    icarl.exemplar_sets = [[img.to(icarl.device) for img in class_set] for class_set in checkpoint_dict['exemplar_sets']]

    start_task = artifact.metadata["task"] + 1
    icarl.is_old_model_usable = True
    print(f"Resuming from task {start_task}")
else:
    print("Starting from scratch")

for task_id in range(start_task, TASKS):
    # 1. Define Classes for this Task
    start_class = task_id * CLASSES_PER_TASK
    end_class = (task_id + 1) * CLASSES_PER_TASK
    new_classes = list(range(start_class, end_class))

    print(f"\n================ TASK {task_id+1}/{TASKS} : Classes {new_classes} ================")

    # 2. Prepare Training Data (New Data + Exemplars)
    # Get subset of ONLY new classes
    task_data_subset = get_data_for_classes(train_ds, new_classes)

    # Create a list of (img, label) for the custom dataset
    # We iterate once to cache them (RAM intensive but simpler code)
    new_data_list = []
    for i in range(len(task_data_subset)):
        img, target = task_data_subset[i]
        new_data_list.append((img, target))

    # Create Hybrid Dataset
    train_dataset = iCaRLDataset(new_data_list, icarl.exemplar_sets)
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)

    # 3. Train (Update Representation) 
    #Only train the classifier parameters
    optimizer = optim.SGD(icarl.model.classifier.parameters(), lr=LR, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
    # Scheduler helps convergence
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

    icarl.model.train()
    if icarl.is_old_model_usable:
        icarl.old_model.eval()

    for epoch in range(EPOCHS):
        total_loss = 0
        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}", leave=False):
            images = images.to(icarl.device)
            labels = labels.to(icarl.device)

            optimizer.zero_grad()

            # Forward Pass
            logits, _ = icarl.model(images)

            # --- Loss Calculation ---
            # A. Classification Loss (Cross Entropy on all visible classes)
            loss_cls = F.cross_entropy(logits, labels)

            # B. Distillation Loss (on OLD classes only)
            loss_dist = torch.tensor(0.).to(icarl.device)
            if icarl.is_old_model_usable:
                # Get old logits
                with torch.no_grad():
                    old_logits, _ = icarl.old_model(images)

                # Sigmoid Distillation (Rebuffi et al. 2017)
                # We compute BCE between the sigmoid outputs of the new model and the old model
                # solely for the classes the old model knew.
                # Usually iCaRL assumes specific output nodes. Here we map indices.
                # We assume indices 0 to (start of new task) are old classes.

                # Create a mask for old classes (e.g., 0 to 10, then 0 to 20...)
                # The 'old_logits' typically has size [B, num_classes] same as current if architecture is fixed
                # Or [B, old_num_classes] if it grew. DINO linear layer is usually fixed size or grows.
                # Here we assume fixed size 100 for simplicity.

                # Calculate Distillation:
                # T=1 is standard for iCaRL's sigmoid distillation
                #[:, :start_new_task] Are all the old classes the new model should not forget
                start_new_task = new_classes[0]
                if start_new_task > 0:
                    dist_target = torch.sigmoid(old_logits[:, :start_new_task])
                    dist_pred = torch.sigmoid(logits[:, :start_new_task])
                    loss_dist = F.binary_cross_entropy(dist_pred, dist_target)

            loss = loss_cls + loss_dist
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        scheduler.step()
        print(f"Epoch {epoch}: Loss {total_loss:.4f}")
    icarl.old_model = deepcopy(icarl.model)
    icarl.is_old_model_usable = True
        

    # 4. Exemplar Management
    # A. Reduce old sets
    m = TOTAL_EXEMPLARS_VECTORS // 100
    icarl.reduce_exemplar_sets(m)

    # B. Construct new sets
    for c in new_classes:
        # Extract images for specific class c
        # (Re-extract from subset for clean separation)
        class_subset = get_data_for_classes(train_ds, [c])
        images_c = extract_images_from_subset(class_subset)
        icarl.construct_exemplar_sets(images_c, m, transform,c)

    # 5. Evaluate on ALL classes seen so far
    print("Evaluating...")
    seen_classes = list(range(0, end_class))
    test_subset = get_data_for_classes(test_ds, seen_classes)
    test_loader = DataLoader(test_subset, batch_size=64, shuffle=False)

    correct = 0
    total = 0
    for imgs, lbls in tqdm(test_loader):
        imgs = imgs.to(icarl.device)
        lbls = lbls.to(icarl.device)
        logits, _ = icarl(imgs)
        preds = torch.argmax(logits, dim=1)
        correct += preds.eq(lbls).sum().item()
        total += lbls.size(0)

    acc = 100. * correct / total
    print(f"Task {task_id+1} Accuracy (NME): {acc:.2f}%")


    save_checkpoint_to_wandb(run, {
        'model': icarl.state_dict(),
        'exemplar_sets': [[img.cpu() for img in class_set] for class_set in icarl.exemplar_sets],
    }, f"model.pth", {
        "task": task_id,
        "accuracy": acc,
    })
    print(task_id, "Saved checkpoint model to WandB.")




[34m[1mwandb[0m: Currently logged in as: [33madrientrahan[0m ([33maml-fl-project[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Using cache found in /Users/adrientrahan/.cache/torch/hub/facebookresearch_dino_main
Using cache found in /Users/adrientrahan/.cache/torch/hub/facebookresearch_dino_main
[34m[1mwandb[0m: Downloading large artifact 'icarl-cifar100-checkpoints:latest', 223.13MB. 1 files...
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 00:00:00.5 (481.8MB/s)


Loading model from: /Users/adrientrahan/Documents/ecole/AML/project/FL-task-arithmetic/notebooks/artifacts/icarl-cifar100-checkpoints:v1/model.pth
Successfully loaded model from: /Users/adrientrahan/Documents/ecole/AML/project/FL-task-arithmetic/notebooks/artifacts/icarl-cifar100-checkpoints:v1/model.pth
Resuming from task 2



                                                  

RuntimeError: _share_filename_: only available on CPU