<a href="https://colab.research.google.com/github/lucavgn/AML_Project5/blob/main/Distributed_LocalSGD_command_line.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%%writefile localSGD.py
'''
prompt: Perform an iid sharding of the training set in K={2,4,8} chunks,
and train using the LocalSGD algorithm. Explore performing multiple local
steps J={4,8,16,32,64}, scaling accordingly the number of iterations
(total number of batches processed).
I have to use cifar-100 dataset
For the scheduler:
- learning rate scaled by k
- warmup epochs (5)
'''

import os
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import argparse
from torch.utils.data import random_split, Subset, DataLoader
from google.colab import files

# Define hyperparameters
# K_values = [2, 4, 8]
J_values = [4, 8, 16, 32, 64]
batch_size = 32
num_epochs = 150
base_learning_rate = 0.01
momentum = 0.9
weight_decay = 4e-4
warmup_epochs = 5

def compute_mean_std(dataset):
    """Compute the mean and std of CIFAR-100 dataset.

    Args:
        dataset: A dataset derived from `torch.utils.data.Dataset`,
                 such as `cifar100_training_dataset` or `cifar100_test_dataset`.

    Returns:
        A tuple containing (mean, std) for the entire dataset.
    """

    # Extract images and labels
    data_r = np.stack([np.array(dataset[i][0])[:, :, 0] for i in range(len(dataset))])
    data_g = np.stack([np.array(dataset[i][0])[:, :, 1] for i in range(len(dataset))])
    data_b = np.stack([np.array(dataset[i][0])[:, :, 2] for i in range(len(dataset))])

    # Compute mean and std
    mean = np.mean(data_r), np.mean(data_g), np.mean(data_b)
    std = np.std(data_r), np.std(data_g), np.std(data_b)

    return mean, std

# Define LeNet-5 architecture
class LeNet5(nn.Module):
    def __init__(self):
        super(LeNet5, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 5)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(64, 64, 5)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 5 * 5, 384)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(384, 192)
        self.relu4 = nn.ReLU()
        self.fc3 = nn.Linear(192, 100)

    def forward(self, x):
        x = self.pool1(self.relu1(self.conv1(x)))
        x = self.pool2(self.relu2(self.conv2(x)))
        x = x.view(-1, 64 * 5 * 5)
        x = self.relu3(self.fc1(x))
        x = self.relu4(self.fc2(x))
        x = self.fc3(x)
        return x

# Command-line arguments
parser = argparse.ArgumentParser(description='Train with Local SGD')
parser.add_argument('--k', type=int, default=2, help='choose a K value for local SGD')
args = parser.parse_args()
K = args.k

# Device setup
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

torch.manual_seed(42) # Set the seed for reproducibility
torch.cuda.manual_seed_all(42) # Set the seed for reproducibility on GPU

# use the same mean and std to add consistency to all datasets
data = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transforms.ToTensor())
mean, std = compute_mean_std(data)

# Load and split CIFAR-100 dataset
train_transform = transforms.Compose([
  transforms.RandomCrop(32, padding=4),
  transforms.RandomHorizontalFlip(),
  transforms.ToTensor(),
  transforms.Normalize(mean, std)
])

val_transform = transforms.Compose([
    transforms.ToTensor(), # Convert into tensor
    transforms.Normalize(mean, std)  # Normalization
])

test_transform = transforms.Compose([
    transforms.ToTensor(),  # Convert to PyTorch tensor
    transforms.Normalize(mean, std)
])

trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=train_transform)
valset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=val_transform)
indices = torch.randperm(len(trainset))
val_size = int(0.2*len(trainset))
trainset = torch.utils.data.Subset(trainset, indices[:-val_size])
valset = torch.utils.data.Subset(valset, indices[-val_size:])

trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size, shuffle=False, num_workers=2)

testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=test_transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)


# Shard the training set
shard_size = len(trainset) // K
shards = [Subset(trainset, range(i * shard_size, (i + 1) * shard_size)) for i in range(K)]

# Initialize models for each worker
worker_models = [LeNet5().to(device) for _ in range(K)]  # Create a model for each worker

# Training loop with LocalSGD
for J in J_values:

    print(f"Training with K={K}, J={J}")

    # Initialize model and optimizer
    net = LeNet5().to(device)

    # Scale the learning rate by k
    learning_rate = base_learning_rate * K / warmup_epochs
    print(f"Base learning rate: {learning_rate:.5f}")

    optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay)
    criterion = nn.CrossEntropyLoss()

    best_val_accuracy = 0  # Track the best validation accuracy

    # Calculate warm-up step size
    warmup_step = learning_rate / warmup_epochs

    for epoch in range(num_epochs):
        # Learning rate warm-up phase for first 5 epochs
        if epoch < warmup_epochs:
            for param_group in optimizer.param_groups:
                param_group['lr'] += warmup_step
        print(f"Learning rate at epoch {epoch + 1}: {param_group['lr']:.5f}")

        net.train()  # Set the model to training mode
        for k in range(K):
            shard_loader = DataLoader(shards[k], batch_size=batch_size, shuffle=True, num_workers=2)

            for j in range(J):
                for images, labels in shard_loader:
                    images, labels = images.to(device), labels.to(device)  # Move to device
                    optimizer.zero_grad()
                    outputs = worker_models[k](images)
                    loss = criterion(outputs, labels)
                    loss.backward()
                    optimizer.step()

        # Perform model synchronization
        with torch.no_grad():
            global_parameters = [param.clone() for param in worker_models[0].parameters()]  # Start with the first worker
            for k in range(1, K):
                for global_param, worker_param in zip(global_parameters, worker_models[k].parameters()):
                    global_param += worker_param
            for global_param in global_parameters:
                global_param /= K

            # Update each worker model with the global parameters
            for worker_model in worker_models:
                for worker_param, global_param in zip(worker_model.parameters(), global_parameters):
                    worker_param.data.copy_(global_param)


        # Perform validation after processing all shards in the current epoch
        net.eval()  # Set the model to evaluation mode
        correct, total = 0, 0
        with torch.no_grad():  # No need to track gradients during validation
            for images, labels in valloader:
                images, labels = images.to(device), labels.to(device)
                outputs = net(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        val_accuracy = 100 * correct / total
        print(f"Validation Accuracy after epoch {epoch + 1}: {val_accuracy:.2f}%")

        # Save the model if validation accuracy improves
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            # Optionally, save the model checkpoint here
            #torch.save(net.state_dict(), f"best_model_K{K}_J{J}.pth")

    # Final Evaluation on Test Set
    net.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(device), labels.to(device)  # Move to device
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    test_accuracy = 100 * correct / total
    print(f"Test Accuracy for K={K}, J={J}: {test_accuracy:.2f}%")


Writing localSGD.py


In [2]:
%run localSGD.py --k 2

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./data/cifar-100-python.tar.gz


100%|██████████| 169M/169M [00:04<00:00, 40.2MB/s]


Extracting ./data/cifar-100-python.tar.gz to ./data
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Training with K=2, J=4
Base learning rate: 0.00400
Learning rate at epoch 1: 0.00480


KeyboardInterrupt: 