In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import Compose, Normalize, ToTensor, Resize, CenterCrop
from tqdm import tqdm
import wandb
from utilities.wandb_utils import load_checkpoint_from_wandb, save_checkpoint_to_wandb
from fl_task_arithmetic.model import CustomDino
import math

ENTITY = "aml-fl-project"
PROJECT = "fl-task-arithmetic"
GROUP = "centralized-dino-cifar100"

BATCH_SIZE = 64
LR  = 0.01           # Learning Rate (Tune this: 0.1, 0.01, 0.001)
MOMENTUM = 0.9
WEIGHT_DECAY = 5e-4
EPOCHS = 100         # Increase to 100+ for final results
DEVICE = torch.device("mps") # torch.device("cuda" if torch.cuda.is_available() else "cpu")
PATIENCE = 5

# Standard CIFAR-100 Normalization
stats = ((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))

# Transforms
transform_train = Compose([
    Resize(256), CenterCrop(224), # Required for DINO
    # transforms.RandomHorizontalFlip(), # Optional augmentation
    ToTensor(),
    Normalize(*stats),
])

transform_test = Compose([
    Resize(256), CenterCrop(224),
    ToTensor(),
    Normalize(*stats),
])

trainset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
testset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)

trainloader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
testloader = DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)


def train(run_name, lr, momentum, weight_decay, epochs, scheduler_name, scheduler_fn):
    best_loss = math.inf
    best_accuracy = 0.0
    patience_counter = 0
    run_id = f"{run_name}-centralized-dino-icarl-cifar100-lr{lr}-mom{momentum}-wd{weight_decay}-sched-{scheduler_name}"
    run = wandb.init(
        entity=ENTITY,
        project=PROJECT,
        group=GROUP,
        name=f"{run_name}-centralized-dino-icarl-cifar100-lr{lr}-mom{momentum}-wd{weight_decay}-sched-{scheduler_name}",
        id=run_id,
        resume="allow",
        mode="online",
    )

    model = CustomDino().to(DEVICE)

    checkpoint = load_checkpoint_from_wandb(
        run,
        model,
        f"model-{run.id}.pth"
)
    start_epoch = 0

    optimizer = optim.SGD(model.backbone.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
    scheduler = scheduler_fn(optimizer)

    if checkpoint is not None:
        print(checkpoint)
        checkpoint_dict, artifact = checkpoint
        model.load_state_dict(checkpoint_dict['model'])

        optimizer = optim.SGD(model.backbone.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
        scheduler = scheduler_fn(optimizer)
        
        optimizer.load_state_dict(checkpoint_dict["optimizer_state"])
        scheduler.load_state_dict(checkpoint_dict["scheduler_state"])


        start_epoch = checkpoint_dict["epoch"] + 1
        best_accuracy = checkpoint_dict["best_accuracy"]
        
        print(f"Resuming from epoch {start_epoch}")
    else:
        print("Starting from scratch")

    criterion = nn.CrossEntropyLoss()
    train_losses = []
    test_losses = []
    test_accs = []

    for epoch in range(start_epoch, epochs):
        model.train()
        running_loss = 0.0

        # Progress bar for training
        pbar = tqdm(trainloader, desc=f"Epoch {epoch+1}/{EPOCHS}")
        for images, labels in pbar:
            images, labels = images.to(DEVICE), labels.to(DEVICE)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            pbar.set_postfix({'loss': loss.item()})

        scheduler.step()

        avg_train_loss = running_loss / len(trainloader)
        train_losses.append(avg_train_loss)

        # 5. Evaluation
        model.eval()
        test_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in testloader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                outputs = model(images)
                loss = criterion(outputs, labels)
                test_loss += loss.item()
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()

        avg_test_loss = test_loss / len(testloader)
        acc = 100. * correct / total

        test_losses.append(avg_test_loss)
        test_accs.append(acc)

        wandb.log(
            {
                "epoch": epoch + 1,
                "train_loss": avg_train_loss,
                "test_loss": avg_test_loss,
                "test_accuracy": acc,
                "best_accuracy": best_accuracy,
                "learning_rate": optimizer.param_groups[0]["lr"],
            }
        )
        print(f"Epoch {epoch+1} Results: Train Loss: {avg_train_loss:.4f} | Test Loss: {avg_test_loss:.4f} | Test Acc: {acc:.2f}%")

        if (acc > best_accuracy):
            best_accuracy = acc

        if (epoch < 20 or avg_test_loss <= best_loss):
            best_loss = avg_test_loss
            save_checkpoint_to_wandb(run, {
                "epoch": epoch,
                'model': model.state_dict(),
                "optimizer_state": optimizer.state_dict(),
                "scheduler_state": scheduler.state_dict(),
                "best_loss": best_loss,
                "best_accuracy": best_accuracy,
                "patience_counter": patience_counter,
            }, f"model-{run.id}.pth", {
                "task": model,
                "accuracy": acc,
                "epoch": epoch
            })
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter > PATIENCE:
                print("Early stopping triggered.")
                break

        print(epoch, "Saved checkpoint model to WandB.")

In [None]:

SCHEDULERS = [
    ("CosineAnnealingLR", lambda opt: torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=10)),
    ("StepLR", lambda opt: torch.optim.lr_scheduler.StepLR(opt, step_size=30, gamma=0.1)),
    ("NoScheduler", lambda opt: torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda=lambda epoch: 1.0)) # No scheduling
]


GROUP = "adrien-centralized-dino-cifar100"

train(
        run_name="adrien-9",
        lr=0.0001, # 10^-4
        momentum=MOMENTUM,
        weight_decay=WEIGHT_DECAY,
        epochs=EPOCHS,
        scheduler_name=SCHEDULERS[0][0],
        scheduler_fn=SCHEDULERS[0][1]
    )


[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
best_accuracy,▁█████████████████████████
epoch,▁▁▂▂▂▂▃▃▃▄▄▄▄▅▅▅▅▆▆▆▇▇▇▇██
learning_rate,█▇▇▆▅▃▂▂▁▁▁▂▂▃▅▆▇▇███▇▇▆▅▃
test_accuracy,█▃▂▁▁▂▂▂▂▂▂▂▂▂▃▂▃▄▃▃▄▅█▆▇█
test_loss,█▆▅▄▄▄▄▃▃▃▃▃▃▃▃▃▃▂▂▂▂▂▁▁▁▁
train_loss,█▆▅▄▄▃▃▃▃▃▃▃▃▃▃▃▃▂▂▂▂▂▁▁▁▁

0,1
best_accuracy,51.79
epoch,26.0
learning_rate,3e-05
test_accuracy,51.64
test_loss,4.23341
train_loss,4.22155


Using cache found in /Users/adrientrahan/.cache/torch/hub/facebookresearch_dino_main


Loading iCaRL classifier from local cache: /Users/adrientrahan/Documents/ecole/AML/project/FL-task-arithmetic/utilities/trained/nearest_centroid_classifier.pth


[34m[1mwandb[0m: Downloading large artifact 'adrien-centralized-dino-cifar100-checkpoints:latest', 165.58MB. 1 files...
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 00:00:00.3 (544.0MB/s)


Loading model from: /Users/adrientrahan/Documents/ecole/AML/project/FL-task-arithmetic/utilities/artifacts/adrien-centralized-dino-cifar100-checkpoints:v24/model-adrien-9-centralized-dino-icarl-cifar100-lr0.0001-mom0.9-wd0.0005-sched-CosineAnnealingLR.pth
Successfully loaded model from: /Users/adrientrahan/Documents/ecole/AML/project/FL-task-arithmetic/utilities/artifacts/adrien-centralized-dino-cifar100-checkpoints:v24/model-adrien-9-centralized-dino-icarl-cifar100-lr0.0001-mom0.9-wd0.0005-sched-CosineAnnealingLR.pth
({'epoch': 19, 'model': OrderedDict([('backbone.cls_token', tensor([[[ 2.9728e-02,  1.0140e-03, -3.6106e-04, -2.8377e-03,  1.3994e-03,
          -5.9266e-03, -6.1358e-03, -1.0877e-02,  2.6753e-03,  9.5717e-03,
          -1.4855e-03,  1.0759e-03, -1.2553e-02, -1.0791e-02,  3.0013e-02,
           1.5472e-03, -2.3848e-02, -9.9610e-03, -1.2490e-02, -1.0605e-02,
           2.7004e-02,  5.1257e-04, -1.1869e-03, -5.0511e-03, -6.4173e-03,
          -1.0725e-03, -2.2295e-03,  1.00

Epoch 21/100: 100%|██████████| 782/782 [10:38<00:00,  1.22it/s, loss=4.24]


Epoch 21 Results: Train Loss: 4.2361 | Test Loss: 4.2427 | Test Acc: 50.35%
Model saved to WandB as artifact 'adrien-centralized-dino-cifar100-checkpoints'.
20 Saved checkpoint model to WandB.


Epoch 22/100: 100%|██████████| 782/782 [10:42<00:00,  1.22it/s, loss=4.24]


Epoch 22 Results: Train Loss: 4.2325 | Test Loss: 4.2401 | Test Acc: 51.44%
Model saved to WandB as artifact 'adrien-centralized-dino-cifar100-checkpoints'.
21 Saved checkpoint model to WandB.


Epoch 23/100: 100%|██████████| 782/782 [10:37<00:00,  1.23it/s, loss=4.27]


Epoch 23 Results: Train Loss: 4.2292 | Test Loss: 4.2375 | Test Acc: 51.54%
Model saved to WandB as artifact 'adrien-centralized-dino-cifar100-checkpoints'.
22 Saved checkpoint model to WandB.


Epoch 24/100:  96%|█████████▋| 754/782 [10:05<00:22,  1.25it/s, loss=4.23]