Centralized Baseline

In order to quantify the performance of federated models, we will train a normal model on the CIPHER-100 so to have a banchmark for the ideal situation our models should reach.

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.transforms import Compose, Normalize, ToTensor, Resize, CenterCrop
import matplotlib.pyplot as plt
from tqdm import tqdm

# Import your models from the uploaded file
from fl_task_arithmetic.task import Net, CustomDino

# --- Hyperparameters ---
BATCH_SIZE = 128
LR = 0.1  # Initial Learning Rate (Search: try 0.1, 0.01, 0.05)
MOMENTUM = 0.9
WEIGHT_DECAY = 5e-4
EPOCHS = 100  # Try 100, 150, 200
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def get_cifar100_loaders():
    # Standard CIFAR-100 Normalization
    stats = ((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
    
    transform_train = Compose([
        Resize(256), CenterCrop(224), # If using DINO, otherwise standard crop
        # transforms.RandomCrop(32, padding=4), # For Net (CNN)
        # transforms.RandomHorizontalFlip(),    # For Net (CNN)
        ToTensor(),
        Normalize(*stats),
    ])
    
    transform_test = Compose([
        Resize(256), CenterCrop(224), # If using DINO
        # ToTensor(),                 # For Net
        ToTensor(),
        Normalize(*stats),
    ])

    trainset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
    testset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)

    trainloader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
    testloader = DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
    
    return trainloader, testloader

def train_centralized():
    trainloader, testloader = get_cifar100_loaders()
    
    # Initialize Model (Choose Net() or CustomDino())
    # model = Net().to(DEVICE) 
    model = CustomDino(num_classes=100).to(DEVICE) # Requires the DINO transform above

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=LR, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
    
    # Scheduler: Cosine Annealing
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

    train_losses, test_losses, test_accs = [], [], []

    for epoch in range(EPOCHS):
        model.train()
        running_loss = 0.0
        for images, labels in tqdm(trainloader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
        
        scheduler.step()
        avg_train_loss = running_loss / len(trainloader)
        train_losses.append(avg_train_loss)

        # Evaluation
        model.eval()
        test_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in testloader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                outputs = model(images)
                loss = criterion(outputs, labels)
                test_loss += loss.item()
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()
        
        avg_test_loss = test_loss / len(testloader)
        acc = 100. * correct / total
        test_losses.append(avg_test_loss)
        test_accs.append(acc)

        print(f"Train Loss: {avg_train_loss:.4f} | Test Loss: {avg_test_loss:.4f} | Test Acc: {acc:.2f}%")

    # Plotting
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(test_losses, label='Test Loss')
    plt.title('Loss vs Epochs')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(test_accs, label='Test Accuracy')
    plt.title('Accuracy vs Epochs')
    plt.legend()
    plt.show()

train_centralized()

Using cache found in C:\Users\marco/.cache\torch\hub\facebookresearch_dino_main
Epoch 1/100:   3%|â–Ž         | 12/391 [03:00<1:35:04, 15.05s/it]


KeyboardInterrupt: 