<a href="https://colab.research.google.com/github/lucavgn/Project-5-Distributed-Learning/blob/main/code/Personal_contribution/DynamicSGD.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%%writefile DynamicSGD.py

import os
import math
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import argparse
from torch.utils.data import random_split, Subset, DataLoader
from google.colab import files
import torch.nn.functional as F
import time

total_computation_time = 0
total_communication_time = 0

# Define hyperparameters
# K_values = [2, 4, 8]
# J_values = [4, 8, 16, 32, 64]
iteration_per_epoch = 782 # number of local step performed in one epoch
num_epochs = 150
batch_size = 64
base_learning_rate = 1e-3
momentum = 0.9
weight_decay = 4e-4

# Hyperparameters for dynamic adjustment
H_min, H_max = 4, 64  # Minimum and maximum allowed local steps
alpha, epsilon = 1000, 1e-6  # Scaling factor and small constant

def compute_mean_std(dataset):
    # Extract images and labels
    data_r = np.stack([np.array(dataset[i][0])[:, :, 0] for i in range(len(dataset))])
    data_g = np.stack([np.array(dataset[i][0])[:, :, 1] for i in range(len(dataset))])
    data_b = np.stack([np.array(dataset[i][0])[:, :, 2] for i in range(len(dataset))])

    # Compute mean and std
    mean = np.mean(data_r), np.mean(data_g), np.mean(data_b)
    std = np.std(data_r), np.std(data_g), np.std(data_b)

    return mean, std


# Define LeNet-5 architecture
class LeNet5(nn.Module):
    def __init__(self):
        super(LeNet5, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 5)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(64, 64, 5)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 5 * 5, 384)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(384, 192)
        self.relu4 = nn.ReLU()
        self.fc3 = nn.Linear(192, 100)

    def forward(self, x):
        x = self.pool1(self.relu1(self.conv1(x)))
        x = self.pool2(self.relu2(self.conv2(x)))
        x = x.view(-1, 64 * 5 * 5)
        x = self.relu3(self.fc1(x))
        x = self.relu4(self.fc2(x))
        x = self.fc3(x)
        return x

# Command-line arguments
parser = argparse.ArgumentParser(description='Train with Local SGD')
parser.add_argument('--k', type=int, default=2, help='choose a K value for local SGD: [2, 4, 8]')
parser.add_argument('--j', type=int, default=4, help='choose a J value for local SGD: [4, 8, 16, 32, 64]')
args = parser.parse_args()
K = args.k
J = args.j
print(f"Training with K={K}, J={J}")


# Device setup
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

torch.manual_seed(42) # Set the seed for reproducibility
torch.cuda.manual_seed_all(42) # Set the seed for reproducibility on GPU

# use the same mean and std to add consistency to all datasets
data = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transforms.ToTensor())
mean, std = compute_mean_std(data)

# Load and split CIFAR-100 dataset
train_transform = transforms.Compose([
  transforms.RandomCrop(32, padding=4),
  transforms.RandomHorizontalFlip(),
  transforms.ToTensor(),
  transforms.Normalize(mean, std)
])

test_transform = transforms.Compose([
    transforms.ToTensor(),  # Convert to PyTorch tensor
    transforms.Normalize(mean, std)
])

trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=train_transform)

# Shard the training set
shard_size = len(trainset) // K
shards = [Subset(trainset, range(i * shard_size, (i + 1) * shard_size)) for i in range(K)]
# After creating the DataLoader for each worker, create an iterator
shard_iterators = [iter(DataLoader(shards[k], batch_size=batch_size, shuffle=True)) for k in range(K)]


# Test set
testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=test_transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

# Scale the learning rate by k
learning_rate = base_learning_rate * K
print(f"Base learning rate: {learning_rate:.5f}")

# Initialize local models, optimizers and scheduler
nets = [LeNet5().to(device) for _ in range(K)]

with torch.no_grad():
    for net in nets[1:]:  # Skip nets[0] because it's the reference model
        for param_target, param_source in zip(net.parameters(), nets[0].parameters()):
            param_target.data.copy_(param_source.data)

optimizers = [optim.SGD(net.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay) for net in nets]
schedulers = [optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=num_epochs) for optimizer in optimizers]

criterion = nn.CrossEntropyLoss()

# Initialize storage for train losses and accuracies for each worker
train_losses_per_worker = [[] for _ in range(K)]
train_accuracies_per_worker = [[] for _ in range(K)]

# Initialize test losses and accuracies for test set
test_losses = []
test_accuracies = []

# Compute gradient magnitude and variance
def compute_gradient_stats(nets, optimizers):
    # Get gradients from each worker
    gradients = []
    for net, optimizer in zip(nets, optimizers):
        grad_vector = []
        for param in net.parameters():
            if param.grad is not None:
                grad_vector.append(param.grad.view(-1))
        grad_vector = torch.cat(grad_vector)  # Flatten all gradients into a vector
        gradients.append(grad_vector)

    # Compute average gradient (g_t)
    avg_gradient = torch.stack(gradients).mean(dim=0)

    # Compute squared norms of individual gradients
    squared_norms = [torch.norm(g).item() ** 2 for g in gradients]

    # Compute squared norm of the average gradient
    avg_squared_norm = torch.norm(avg_gradient).item() ** 2

    # Compute gradient variance (sigma_t^2) using the provided formula
    sigma_t_squared = (sum(squared_norms) / (K - 1)) - (K / (K - 1)) * avg_squared_norm

    # Compute gradient magnitude (G_t) using the provided formula
    G_t = avg_squared_norm - (1 / K) * sigma_t_squared

    return G_t, sigma_t_squared


# Update H dynamically based on gradient stats
def update_H(G_t, sigma_t_squared, H_min, H_max, alpha, epsilon):
    H_t = min(H_max, max(H_min, alpha * G_t / (sigma_t_squared + epsilon)))
    return int(H_t)  # Ensure H_t is an integer

# Training loop


H_t = J  # Initialize H_t with the default J value


for epoch in range(num_epochs):
    train_loss_for_iteration = [0.0 for _ in range(K)]
    correct_train_for_iteration = [0 for _ in range(K)]
    total_train_for_iteration = [0 for _ in range(K)]

    iteration_per_epoch = 782
    iteration = 0 # Number of global iteration performed
    while iteration_per_epoch > 0:
        computation_start_time = time.time()
        # Train each worker for H_t steps
        for k in range(K):
            nets[k].train()
            correct_train, total_train, train_loss = 0, 0, 0.0
            for _ in range(H_t):  # Perform H_t local steps
                try:
                    images, labels = next(shard_iterators[k])
                except StopIteration:
                    shard_iterators[k] = iter(DataLoader(shards[k], batch_size=batch_size, shuffle=True))
                    images, labels = next(shard_iterators[k])
                images, labels = images.to(device), labels.to(device)

                optimizers[k].zero_grad()
                outputs = nets[k](images)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizers[k].step()

                # Calculate train loss and accuracy
                train_loss += loss.item() * labels.size(0)
                _, predicted = outputs.max(1)
                total_train += labels.size(0)
                correct_train += predicted.eq(labels).sum().item()

            train_loss_for_iteration[k] += train_loss
            correct_train_for_iteration[k] += correct_train
            total_train_for_iteration[k] += total_train

        computation_end_time = time.time()
        total_computation_time += (computation_end_time - computation_start_time)

        communication_start_time = time.time()
        # Synchronize models
        with torch.no_grad():
            global_parameters = [torch.zeros_like(param) for param in nets[0].parameters()]
            for net in nets:
                for global_param, local_param in zip(global_parameters, net.parameters()):
                    global_param += local_param
            for global_param in global_parameters:
                global_param /= K
            for net in nets:
                for local_param, global_param in zip(net.parameters(), global_parameters):
                    local_param.data.copy_(global_param)

        communication_end_time = time.time()
        total_communication_time += (communication_end_time - communication_start_time)

        # update tot_num_local_step and number of global iteration
        iteration_per_epoch -= (H_t * K)
        iteration += 1
        # Compute gradient stats and update H_t
        G_t, sigma_t_squared = compute_gradient_stats(nets, optimizers)
        H_t = update_H(G_t, sigma_t_squared, H_min, H_max, alpha, epsilon)
        print(f"Epoch {epoch+1}, Iteration {iteration}: H_t = {H_t}, G_t = {G_t:.4f}, σ_t^2 = {sigma_t_squared:.4f}")

    # Store training metrics for all workers after each epoch
    for k in range(K):
        # Store train loss and accuracy for this worker
        train_loss_for_iteration[k] /= total_train_for_iteration[k]
        train_losses_per_worker[k].append(train_loss_for_iteration[k])
        tot_accuracy = 100. * correct_train_for_iteration[k] / total_train_for_iteration[k]
        train_accuracies_per_worker[k].append(tot_accuracy)
        print(f"    Worker {k}: Train Loss: {train_losses_per_worker[k][-1]:.4f}, Train Accuracy: {train_accuracies_per_worker[k][-1]:.2f}%")

    # Update learning rate
    for scheduler in schedulers:
        scheduler.step()

    # Evaluate on test set
    global_model = nets[0]
    global_model.eval()
    correct_test, total_test, test_loss = 0, 0, 0.0
    with torch.no_grad():
        for inputs, labels in testloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = global_model(inputs)
            loss = criterion(outputs, labels)
            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total_test += labels.size(0)
            correct_test += predicted.eq(labels).sum().item()

    test_loss /= len(testloader)
    test_losses.append(test_loss)
    test_accuracy = 100. * correct_test / total_test
    test_accuracies.append(test_accuracy)
    print(f"Epoch {epoch + 1}: Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%")

    print(f"Total Computation Time: {total_computation_time:.2f} seconds")
    print(f"Total Communication Time: {total_communication_time:.2f} seconds")



# Plot Train Loss for Each Worker
plt.figure(figsize=(8, 6))
for k in range(K):
    plt.plot(train_losses_per_worker[k], label=f'Worker {k} Train Loss')
plt.title('Train Loss per Worker')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.tight_layout()
plt.savefig('train_loss_per_worker.png')
plt.show()

# Plot Train Accuracy for Each Worker
plt.figure(figsize=(8, 6))
for k in range(K):
    plt.plot(train_accuracies_per_worker[k], label=f'Worker {k} Train Accuracy')
plt.title('Train Accuracy per Worker')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.tight_layout()
plt.savefig('train_accuracy_per_worker.png')
plt.show()

# Plot Test Loss
plt.figure(figsize=(6, 4))
plt.plot(test_losses, label='Test Loss')
plt.title('Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.tight_layout()
plt.savefig('test_loss.png')
plt.show()

# Plot Test Accuracy
plt.figure(figsize=(6, 4))
plt.plot(test_accuracies, label='Test Accuracy')
plt.title('Test Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.tight_layout()
plt.savefig('test_accuracy.png')
plt.show()


In [None]:
%run DynamicSGD.py --k 4 --j 32