In [None]:
from data.icarl_dataset import iCaRLDataset, get_data_for_classes, extract_images_from_subset
from time import sleep
from torchvision.transforms import Compose, Normalize, ToTensor, Resize, CenterCrop
from torchvision import datasets
import torch
from models.icarl_head import IcarlModel, Icarl
import wandb
from utilities.wandb_utils import load_checkpoint_from_wandb, save_checkpoint_to_wandb
from torch.utils.data import DataLoader, Dataset, Subset
from tqdm import tqdm
from copy import deepcopy
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim


RUN_ID = "run-1-icarl_cifar100"
ENTITY = "aml-fl-project"
PROJECT = "fl-task-arithmetic"
GROUP = "icarl-cifar100"

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
TOTAL_EXEMPLARS_VECTORS = 1000
TASKS = 20
CLASSES_PER_TASK = 100 // TASKS

EPOCHS = 20
LR = 0.01
WEIGHT_DECAY = 1e-5
MOMENTUM = 0.9


stats = ((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
transform = Compose([
    Resize(256), CenterCrop(224),
    ToTensor(),
    Normalize(*stats),
])

# Load FULL Datasets
train_ds = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)
test_ds = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)

run = wandb.init(
    entity=ENTITY,
    project=PROJECT,
    group=GROUP,
    name="iCaRL_CIFAR100",
    id=RUN_ID,
    resume="allow",
    mode="online",
)

# for artifact in wandb.Api().run(f"{ENTITY}/{PROJECT}/{RUN_ID}").logged_artifacts():
#     artifact.delete()


# Initialize iCaRL
icarl = Icarl(
    num_classes=100,
    memory_size=TOTAL_EXEMPLARS_VECTORS,
    device=DEVICE
)

checkpoint = load_checkpoint_from_wandb(
    run,
    icarl,
    "model.pth"
)
start_task = 0
if checkpoint is not None:
    checkpoint_dict, artifact = checkpoint
    icarl.load_state_dict(checkpoint_dict['model'])
    icarl.exemplar_sets = [[img.to(icarl.device) for img in class_set] for class_set in checkpoint_dict['exemplar_sets']]

    start_task = artifact.metadata["task"] + 1
    icarl.is_old_model_usable = True
    print(f"Resuming from task {start_task}")
else:
    print("Starting from scratch")

for task_id in range(start_task, TASKS):
    # 1. Define Classes for this Task
    start_class = task_id * CLASSES_PER_TASK
    end_class = (task_id + 1) * CLASSES_PER_TASK
    new_classes = list(range(start_class, end_class))

    print(f"\n================ TASK {task_id+1}/{TASKS} : Classes {new_classes} ================")

    # 2. Prepare Training Data (New Data + Exemplars)
    # Get subset of ONLY new classes
    task_data_subset = get_data_for_classes(train_ds, new_classes)

    # Create a list of (img, label) for the custom dataset
    # We iterate once to cache them (RAM intensive but simpler code)
    new_data_list = []
    for i in range(len(task_data_subset)):
        img, target = task_data_subset[i]
        new_data_list.append((img.to(icarl.device), target))

    # Create Hybrid Dataset
    train_dataset = iCaRLDataset(new_data_list, icarl.exemplar_sets)
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=0)

    # 3. Train (Update Representation) 
    #Only train the classifier parameters
    optimizer = optim.SGD(icarl.model.classifier.parameters(), lr=LR, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
    # Scheduler helps convergence
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

    icarl.model.train()
    if icarl.is_old_model_usable:
        icarl.old_model.eval()

    for epoch in range(EPOCHS):
        total_loss = 0
        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}", leave=False):
            images = images.to(icarl.device)
            labels = labels.to(icarl.device)

            optimizer.zero_grad()

            # Forward Pass
            logits, _ = icarl.model(images)

            # --- Loss Calculation ---
            # A. Classification Loss (Cross Entropy on all visible classes)
            loss_cls = F.cross_entropy(logits, labels)

            # B. Distillation Loss (on OLD classes only)
            loss_dist = torch.tensor(0.).to(icarl.device)
            if icarl.is_old_model_usable:
                # Get old logits
                with torch.no_grad():
                    old_logits, _ = icarl.old_model(images)

                # Sigmoid Distillation (Rebuffi et al. 2017)
                # We compute BCE between the sigmoid outputs of the new model and the old model
                # solely for the classes the old model knew.
                # Usually iCaRL assumes specific output nodes. Here we map indices.
                # We assume indices 0 to (start of new task) are old classes.

                # Create a mask for old classes (e.g., 0 to 10, then 0 to 20...)
                # The 'old_logits' typically has size [B, num_classes] same as current if architecture is fixed
                # Or [B, old_num_classes] if it grew. DINO linear layer is usually fixed size or grows.
                # Here we assume fixed size 100 for simplicity.

                # Calculate Distillation:
                # T=1 is standard for iCaRL's sigmoid distillation
                #[:, :start_new_task] Are all the old classes the new model should not forget
                start_new_task = new_classes[0]
                if start_new_task > 0:
                    dist_target = torch.sigmoid(old_logits[:, :start_new_task])
                    dist_pred = torch.sigmoid(logits[:, :start_new_task])
                    loss_dist = F.binary_cross_entropy(dist_pred, dist_target)

            loss = loss_cls + loss_dist
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        scheduler.step()
        print(f"Epoch {epoch}: Loss {total_loss:.4f}")
    icarl.old_model = deepcopy(icarl.model)
    icarl.is_old_model_usable = True
        

    # 4. Exemplar Management
    # A. Reduce old sets
    m = TOTAL_EXEMPLARS_VECTORS // 100
    icarl.reduce_exemplar_sets(m)

    # B. Construct new sets
    for c in new_classes:
        # Extract images for specific class c
        # (Re-extract from subset for clean separation)
        class_subset = get_data_for_classes(train_ds, [c])
        images_c = extract_images_from_subset(class_subset)
        icarl.construct_exemplar_sets(images_c, m, transform,c)

    # 5. Evaluate on ALL classes seen so far
    print("Evaluating...")
    seen_classes = list(range(0, end_class))
    test_subset = get_data_for_classes(test_ds, seen_classes)
    test_loader = DataLoader(test_subset, batch_size=64, shuffle=False, num_workers=0)

    correct = 0
    total = 0
    for imgs, lbls in tqdm(test_loader):
        imgs = imgs.to(icarl.device)
        lbls = lbls.to(icarl.device)
        logits, _ = icarl(imgs)
        preds = torch.argmax(logits, dim=1)
        correct += preds.eq(lbls).sum().item()
        total += lbls.size(0)

    acc = 100. * correct / total
    print(f"Task {task_id+1} Accuracy (NME): {acc:.2f}%")


    save_checkpoint_to_wandb(run, {
        'model': icarl.state_dict(),
        'exemplar_sets': [[img.cpu() for img in class_set] for class_set in icarl.exemplar_sets],
    }, f"model.pth", {
        "task": task_id,
        "accuracy": acc,
    })
    print(task_id, "Saved checkpoint model to WandB.")




[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


Using cache found in /Users/adrientrahan/.cache/torch/hub/facebookresearch_dino_main
Using cache found in /Users/adrientrahan/.cache/torch/hub/facebookresearch_dino_main
[34m[1mwandb[0m: Downloading large artifact 'icarl-cifar100-checkpoints:latest', 223.13MB. 1 files...
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 00:00:00.2 (1035.7MB/s)


Loading model from: /Users/adrientrahan/Documents/ecole/AML/project/FL-task-arithmetic/notebooks/artifacts/icarl-cifar100-checkpoints:v1/model.pth
Successfully loaded model from: /Users/adrientrahan/Documents/ecole/AML/project/FL-task-arithmetic/notebooks/artifacts/icarl-cifar100-checkpoints:v1/model.pth
Resuming from task 2



                                                           

Epoch 0: Loss 1423.5848


                                                           

Epoch 1: Loss 314.4528


                                                           

Epoch 2: Loss 56.0591


                                                           

Epoch 3: Loss 12.5427


                                                           

Epoch 4: Loss 6.9069


                                                           

Epoch 5: Loss 4.8334


                                                           

Epoch 6: Loss 1.6332


                                                           

Epoch 7: Loss 1.4345


                                                           

Epoch 8: Loss 0.6945


                                                            

Epoch 9: Loss 0.2994


                                                            

Epoch 10: Loss 0.1001


                                                            

Epoch 11: Loss 0.0805


                                                            

Epoch 12: Loss 0.0221


                                                            

Epoch 13: Loss 0.0266


                                                            

Epoch 14: Loss 0.0089


                                                            

Epoch 15: Loss 0.0088


                                                            

Epoch 16: Loss 0.0081


                                                            

Epoch 17: Loss 0.0079


                                                            

Epoch 18: Loss 0.0077


                                                            

Epoch 19: Loss 0.0076
Reducing exemplars to 10 per class...
Constructing 10 exemplars vectors per class number 10
Constructing 10 exemplars vectors per class number 11
Constructing 10 exemplars vectors per class number 12
Constructing 10 exemplars vectors per class number 13
Constructing 10 exemplars vectors per class number 14
Evaluating...


100%|██████████| 24/24 [00:48<00:00,  2.02s/it]


Task 3 Accuracy (NME): 77.73%
Model saved to WandB as artifact 'icarl-cifar100-checkpoints'.
2 Saved checkpoint model to WandB.



                                                           

Epoch 0: Loss 1539.5116


                                                           

Epoch 1: Loss 340.4018


                                                           

Epoch 2: Loss 33.1563


                                                           

Epoch 3: Loss 12.8993


                                                           

Epoch 4: Loss 7.6644


                                                           

Epoch 5: Loss 4.5369


                                                           

Epoch 6: Loss 2.5794


                                                                

Epoch 7: Loss 1.6307


                                                           

Epoch 8: Loss 1.3382


                                                            

Epoch 9: Loss 0.3821


                                                            

Epoch 10: Loss 0.0608


                                                            

Epoch 11: Loss 0.1368


                                                            

Epoch 12: Loss 0.0175


                                                            

Epoch 13: Loss 0.0095


                                                            

Epoch 14: Loss 0.0088


                                                            

Epoch 15: Loss 0.0085


                                                            

Epoch 16: Loss 0.0086


                                                            

Epoch 17: Loss 0.0085


                                                            

Epoch 18: Loss 0.0081


                                                            

Epoch 19: Loss 0.0081
Reducing exemplars to 10 per class...
Constructing 10 exemplars vectors per class number 15
Constructing 10 exemplars vectors per class number 16
Constructing 10 exemplars vectors per class number 17
Constructing 10 exemplars vectors per class number 18
Constructing 10 exemplars vectors per class number 19
Evaluating...


100%|██████████| 32/32 [01:00<00:00,  1.90s/it]


Task 4 Accuracy (NME): 72.65%
Model saved to WandB as artifact 'icarl-cifar100-checkpoints'.
3 Saved checkpoint model to WandB.



                                                           

Epoch 0: Loss 1397.1029


Epoch 2/20:  53%|█████▎    | 23/43 [02:03<01:50,  5.53s/it]