# Steps
1. Define Model
2. Pass copies of this Model parameters to Clients
3. Clients run x epochs on this model on their data
4. Send back the model parameters
5. Aggregate (FedAVG) the model parameters
6. Repeat from step 2 for y rounds.

## Importing Libraries

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models, datasets, transforms
from torch.utils.tensorboard import SummaryWriter

import numpy as np
import copy
from numpy.random import default_rng

from collections import OrderedDict, defaultdict

## Basic Model

In [2]:
class BasicNet(nn.Module):
    def __init__(self):
        super(BasicNet, self).__init__()
        # conv layers 1,2
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
        # max pooling layer
        self.pool = nn.MaxPool2d(2, 2)
        # FC layers 1, 2, after Max pooling applied 3 times the size will be 4x4x128
        self.fc1 = nn.Linear(4*4*128, 500)
        self.fc2 = nn.Linear(500, 10)
        # drop out layer with p=0.2
        self.dropout = nn.Dropout(0.2)
    
    def forward(self, x):
        # passing through Convolution and max pooling
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        # flattening the image
        x = x.view(-1, 4*4*128)
        # drop out layer
        x = self.dropout(x)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        # final class scores are sent as it is
        x = self.fc2(x)
        return x
        

## ResNet Model Pre-trained
We'll freeze other layer except the last 2

In [3]:
resnet18_model_pre = models.resnet18(pretrained=True)

In [4]:
resnet18_model_pre

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [5]:
# first freezing all params
for param in resnet18_model_pre.parameters():
    param.requires_grad = False

# replacing all last fc layer with the classifier, which
# itself consist of 2 layers, 'final_layer' is our final layer now
classifier = nn.Sequential(OrderedDict({
    'fc1' : nn.Linear(512, 256),
    'relu1': nn.ReLU(),
    'final_layer' : nn.Linear(256, 10)
}))

resnet18_model_pre.fc = classifier



### Resnet model Raw
Without Pre-trained params

In [6]:
resnet18_model = models.resnet18(pretrained=False)
resnet18_model.fc = nn.Sequential(OrderedDict({'final_layer' : nn.Linear(512, 10)}))

## Simulation of Client Training and Server Model Aggregation


### Defining Client Dataset class

In [7]:
class ClientDataset(Dataset):
    def __init__(self, img_tensors, lbl_tensors, transform=None):
        self.img_tensors = img_tensors
        self.lbl_tensors = lbl_tensors
        self.transform = transform
    
    def __len__(self):
        return self.lbl_tensors.shape[0]
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        return self.img_tensors[idx], self.lbl_tensors[idx] 
    
def create_client_data_loaders(total_clients, data_folder, batch_size, random_mode=False):
    data_loaders = []
    for idx in range(total_clients):
        # loading data to tensors
        img_tensor_file = data_folder + f'client_{idx}_img.pt'
        lbl_tensor_file = data_folder + f'client_{idx}_lbl.pt'
        img_tensors = torch.load(img_tensor_file) # this contains 494 images, currently 76
        lbl_tensors = torch.load(lbl_tensor_file)

        # creating a dataset which can be fed to dataloader
        client_dataset = ClientDataset(img_tensors, lbl_tensors)
        data_loaders.append(DataLoader(client_dataset, batch_size=batch_size, shuffle=random_mode))
    return data_loaders

### Client Training Logic

In [8]:
def train_on_client(idx, model, data_loader, optimizer, loss_fn, local_epochs, device):
    model.train()
    client_training_losses = []
    for epoch in range(local_epochs):
        train_loss = 0.0
        for data, target in data_loader:
            # move tensors to GPU device if CUDA is available
            data, target = data.to(device), target.to(device)
            # clear the gradients of all optimized variables
            optimizer.zero_grad()
            # forward pass: compute predicted outputs by passing inputs to the model
            output = model(data)
            # calculate the batch loss
            loss = loss_fn(output, target)
            # backward pass: compute gradient of the loss with respect to model parameters
            loss.backward()
            # perform a single optimization step (parameter update)
            optimizer.step()
            # update training loss
            train_loss += loss.item()
        
        epoch_train_loss = train_loss/len(data_loader)
        client_training_losses.append(epoch_train_loss)
        print('Client: {}\t Epoch: {} \tTraining Loss: {:.6f}'.format(idx, epoch, epoch_train_loss))
    return client_training_losses

### Server Aggregation Logic

In [9]:
def fed_avg(server_model, client_models, client_weights):
    # Safety lock, to not update model params accidentally
    with torch.no_grad():
        # need to take avg key-wise
        for key in server_model.state_dict().keys():
            temp = torch.zeros_like(server_model.state_dict()[key], dtype=torch.float32)
            for idx in range(len(client_weights)):
                temp += client_weights[idx]*client_models[idx].state_dict()[key]
            # update the new value of this key in the server model
            server_model.state_dict()[key].data.copy_(temp)
            # update this key value in all the client models as well
            for idx in range(len(client_weights)):
                client_models[idx].state_dict()[key].data.copy_(temp)
    return server_model, client_models
    

### Testing the server model on Test Set

In [28]:
def run_test(model, test_loader, loss_fn, device):
    test_loss = 0.0
    class_correct = list(0. for i in range(10))
    class_total = list(0. for i in range(10))
    # specify the image classes
    classes = ['airplane', 'automobile', 'bird', 'cat', 'deer',
               'dog', 'frog', 'horse', 'ship', 'truck']

    model.eval()
    # iterate over test data
    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        # forward pass: compute predicted outputs by passing inputs to the model
        output = model(data)
        # calculate the batch loss
        loss = loss_fn(output, target)
        # update test loss 
        test_loss += loss.item()
        # convert output probabilities to predicted class
        top_p, pred_class = torch.max(output, 1)    
        # compare predictions to true label
        correct_tensor = pred_class.eq(target.data.view_as(pred_class))
        correct = np.squeeze(correct_tensor.numpy()) if not torch.cuda.is_available() else np.squeeze(correct_tensor.cpu().numpy())
        # calculate test accuracy for each object class
        for i in range(len(target)):
            label = target.data[i]
            class_correct[label] += correct[i].item()
            class_total[label] += 1
    # average test loss
    test_loss = test_loss/len(test_loader.dataset)
    print('Test Loss: {:.6f}\n'.format(test_loss))
    
    for i in range(10):
        if class_total[i] > 0:
            print('Test Accuracy of %5s: %2d%% (%2d/%2d)' % (
                classes[i], (100 * class_correct[i]) / class_total[i],
                np.sum(class_correct[i]), np.sum(class_total[i])))
        else:
            print('Test Accuracy of %5s: N/A (no training examples)' % (classes[i]))

    print('\nTest Accuracy (Overall): %2d%% (%2d/%2d)' % (
        100. * np.sum(class_correct) / np.sum(class_total),
        np.sum(class_correct), np.sum(class_total)))
    test_acc = np.sum(class_correct) / np.sum(class_total)
    return test_loss, test_acc

# run_test(server_model, test_loader, loss_fn, device)

## Main Configuration Cell

In [29]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Computing Device:{device}")
seed = 42
rng = default_rng(seed)
np.random.seed(seed)
torch.manual_seed(seed)     
torch.cuda.manual_seed_all(seed)

# specify which model to use
server_model = resnet18_model_pre
# server_model = resnet18_model
# server_model = BasicNet()

# using gpu for computations if available
server_model = server_model.to(device)

# specify loss function (categorical cross-entropy)
loss_fn = nn.CrossEntropyLoss()

# choose how many clients you want send model to
client_frac = 0.3
total_clients = 10
client_training_losses = [ [] for i in range(total_clients)]
client_training_accs = [ [] for i in range(total_clients)]
client_weights = [1/total_clients for i in range(total_clients)] # need to check about this
client_models = [copy.deepcopy(server_model).to(device) for idx in range(total_clients)]

fed_rounds = 12
local_epochs = 3
batch_size = 32

# Save the tensor of images and labels for clients
username = 'fnx11'
data_folder = f'/home/{username}/thesis/codes/Playground/data/fed_data/'
logs_folder = f'/home/{username}/thesis/codes/Playground/logs/'
writer = SummaryWriter(logs_folder+'fed_cifar10_experiment')

# specify learning rate to be used
learning_rate = 0.01 # change this according to our model, tranfer learning use 0.001, basic model use 0.01
optimizers = [optim.SGD(params=client_models[idx].parameters(), lr=learning_rate) for idx in range(total_clients)]
client_data_loaders = create_client_data_loaders(total_clients, data_folder, batch_size)

# Define a transform to normalize the data
data_transforms = transforms.Compose([transforms.ToTensor(),
                                       transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])])
testset = datasets.CIFAR10('~/.pytorch/CIFAR10_data/', train=False, download=True, transform=data_transforms)
test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size)

Computing Device:cuda:0
Files already downloaded and verified


In [31]:
###Training and testing model every kth round
k = 4
testing_losses = []
testing_accs = []
for i in range(fed_rounds):
    clients_selected = rng.choice(total_clients, size=int(total_clients*client_frac), replace=False)
    print(clients_selected)
    for j in clients_selected:
        training_losses = train_on_client(j, client_models[j], client_data_loaders[j], optimizers[j], loss_fn, local_epochs, device)
        client_training_losses[j].extend(training_losses)
#         client_training_accs[j].extend(training_accs)
    # aggregate to update server_model and client_models
    print(f"Round {i} complete")
    server_model, client_models = fed_avg(server_model, client_models, client_weights)
    # Testing Model every kth round
    if (i+1)%k==0:
        test_loss, test_acc = run_test(server_model, test_loader, loss_fn, device)
        testing_losses.append(test_loss)
        testing_accs.append(test_acc)

###Logging of losses
for i in range(total_clients):
    for j in range(len(client_training_losses[i])):
        writer.add_scalar(f'training_loss_client_{i}', client_training_losses[i][j], j+1)

###Logging of Testing losses and testing accs
for e in range(len(testing_losses)):
    writer.add_scalar('testing loss', testing_losses[e], e+1)
    writer.add_scalar('model accuracy', testing_accs[e], e+1)


[0 9 6]
Client: 0	 Epoch: 0 	Training Loss: 1.811408
Client: 0	 Epoch: 1 	Training Loss: 1.793063
Client: 0	 Epoch: 2 	Training Loss: 1.770396
Client: 9	 Epoch: 0 	Training Loss: 1.818543
Client: 9	 Epoch: 1 	Training Loss: 1.797690
Client: 9	 Epoch: 2 	Training Loss: 1.776153
Client: 6	 Epoch: 0 	Training Loss: 1.802555
Client: 6	 Epoch: 1 	Training Loss: 1.777556
Client: 6	 Epoch: 2 	Training Loss: 1.754153
Round 0 complete
[4 8 7]
Client: 4	 Epoch: 0 	Training Loss: 1.816725
Client: 4	 Epoch: 1 	Training Loss: 1.795713
Client: 4	 Epoch: 2 	Training Loss: 1.772180
Client: 8	 Epoch: 0 	Training Loss: 1.816946
Client: 8	 Epoch: 1 	Training Loss: 1.798205
Client: 8	 Epoch: 2 	Training Loss: 1.776107
Client: 7	 Epoch: 0 	Training Loss: 1.796422
Client: 7	 Epoch: 1 	Training Loss: 1.770936
Client: 7	 Epoch: 2 	Training Loss: 1.747796
Round 1 complete
[4 6 1]
Client: 4	 Epoch: 0 	Training Loss: 1.809625
Client: 4	 Epoch: 1 	Training Loss: 1.788172
Client: 4	 Epoch: 2 	Training Loss: 1.7649