In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.transforms import Compose, Normalize, ToTensor, Resize, CenterCrop
import matplotlib.pyplot as plt
from tqdm import tqdm
from typing import Optional, cast
import wandb
from models.dino_icarl import DinoIcarlModel
from utilities.wandb_utils import load_checkpoint_from_wandb, save_checkpoint_to_wandb

ENTITY = "aml-fl-project"
PROJECT = "fl-task-arithmetic"
GROUP = "centralized-dino-cifar100"

BATCH_SIZE = 64
LR  = 0.01           # Learning Rate (Tune this: 0.1, 0.01, 0.001)
MOMENTUM = 0.9
WEIGHT_DECAY = 5e-4
EPOCHS = 100         # Increase to 100+ for final results
DEVICE = torch.device("mps") # torch.device("cuda" if torch.cuda.is_available() else "cpu")
PATIENCE = 3

# Standard CIFAR-100 Normalization
stats = ((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))

# Transforms
transform_train = Compose([
    Resize(256), CenterCrop(224), # Required for DINO
    # transforms.RandomHorizontalFlip(), # Optional augmentation
    ToTensor(),
    Normalize(*stats),
])

transform_test = Compose([
    Resize(256), CenterCrop(224),
    ToTensor(),
    Normalize(*stats),
])

trainset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
testset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)

trainloader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
testloader = DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)


def train(lr, momentum, weight_decay, epochs, scheduler_name, scheduler_fn):
    best_accuracy = 0.0
    patience_counter = 0
    run_id = f"run-1-centralized-dino-icarl-cifar100-lr{lr}-mom{momentum}-wd{weight_decay}-sched-{scheduler_name}"
    run = wandb.init(
        entity=ENTITY,
        project=PROJECT,
        group=GROUP,
        name=f"centralized-dino-icarl-cifar100-lr{lr}-mom{momentum}-wd{weight_decay}-sched-{scheduler_name}",
        id=run_id,
        resume="allow",
        mode="online",
    )

    model = DinoIcarlModel(device=DEVICE).to(DEVICE)

    checkpoint = load_checkpoint_from_wandb(
        run,
        model,
        "model.pth"
    )
    start_epoch = 0
    if checkpoint is not None:
        checkpoint_dict, artifact = checkpoint
        model.load_state_dict(checkpoint_dict['model'])
        start_epoch = artifact.metadata["epoch"] + 1
        print(f"Resuming from epoch {start_epoch}")
    else:
        print("Starting from scratch")

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
    scheduler = scheduler_fn(optimizer)
    train_losses = []
    test_losses = []
    test_accs = []

    for epoch in range(start_epoch, epochs):
        model.train()
        running_loss = 0.0

        # Progress bar for training
        pbar = tqdm(trainloader, desc=f"Epoch {epoch+1}/{EPOCHS}")
        for images, labels in pbar:
            images, labels = images.to(DEVICE), labels.to(DEVICE)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            pbar.set_postfix({'loss': loss.item()})

        scheduler.step()

        avg_train_loss = running_loss / len(trainloader)
        train_losses.append(avg_train_loss)

        # 5. Evaluation
        model.eval()
        test_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in testloader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                outputs = model(images)
                loss = criterion(outputs, labels)
                test_loss += loss.item()
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()

        avg_test_loss = test_loss / len(testloader)
        acc = 100. * correct / total

        test_losses.append(avg_test_loss)
        test_accs.append(acc)

        print(f"Epoch {epoch+1} Results: Train Loss: {avg_train_loss:.4f} | Test Loss: {avg_test_loss:.4f} | Test Acc: {acc:.2f}%")

        if (acc > best_accuracy):
            best_accuracy = acc
            save_checkpoint_to_wandb(run, {
                'model': model.state_dict(),
            }, f"model.pth", {
                "task": model,
                "accuracy": acc,
            })
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter > PATIENCE:
                print("Early stopping triggered.")
                break

        print(epoch, "Saved checkpoint model to WandB.")






[34m[1mwandb[0m: Currently logged in as: [33madrientrahan[0m ([33maml-fl-project[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Using cache found in /Users/adrientrahan/.cache/torch/hub/facebookresearch_dino_main


[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


Using cache found in /Users/adrientrahan/.cache/torch/hub/facebookresearch_dino_main
Using cache found in /Users/adrientrahan/.cache/torch/hub/facebookresearch_dino_main
[34m[1mwandb[0m: Downloading large artifact 'icarl-cifar100-checkpoints:latest', 596.51MB. 1 files...
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 00:03:10.8 (3.1MB/s)


Loading model from: /Users/adrientrahan/Documents/ecole/AML/project/FL-task-arithmetic/notebooks/artifacts/icarl-cifar100-checkpoints:v14/model.pth
Successfully loaded model from: /Users/adrientrahan/Documents/ecole/AML/project/FL-task-arithmetic/notebooks/artifacts/icarl-cifar100-checkpoints:v14/model.pth
Run (run-1-centralized-dino-icarl-cifar100-lr0.01-mom0.9-wd0.0005-sched-CosineAnnealingLR) is finished. The call to `use_artifact` will be ignored. Please make sure that you are using an active run.
Model checkpoint not found on WandB. Run (run-1-centralized-dino-icarl-cifar100-lr0.01-mom0.9-wd0.0005-sched-CosineAnnealingLR) is finished. The call to `use_artifact` will be ignored. Please make sure that you are using an active run.
Starting from scratch


Epoch 1/100: 100%|██████████| 782/782 [10:35<00:00,  1.23it/s, loss=4.77]


Epoch 1 Results: Train Loss: 8.9377 | Test Loss: 5.7398 | Test Acc: 2.02%


AttributeError: 'collections.OrderedDict' object has no attribute 'cpu'

In [None]:

SCHEDULERS = [
    ("CosineAnnealingLR", lambda opt: torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=EPOCHS)),
    ("StepLR", lambda opt: torch.optim.lr_scheduler.StepLR(opt, step_size=30, gamma=0.1)),
    ("NoScheduler", lambda opt: torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda=lambda epoch: 1.0)) # No scheduling
]
    

for scheduler_name, scheduler_fn in SCHEDULERS:
    train(
        lr=LR,
        momentum=MOMENTUM,
        weight_decay=WEIGHT_DECAY,
        epochs=EPOCHS, 
        scheduler_name=scheduler_name, 
        scheduler_fn=scheduler_fn
    )


In [None]:
LEARNING_RATES = [0.1, 0.01, 0.001]

for lr in LEARNING_RATES:
    train(
        lr=lr,
        momentum=MOMENTUM,
        weight_decay=WEIGHT_DECAY,
        epochs=EPOCHS, 
        scheduler_name=SCHEDULERS[0][0], 
        scheduler_fn=SCHEDULERS[0][1]
    )
