<a href="https://colab.research.google.com/github/kavyakvk/CS242-FederatedLearning/blob/master/Eric_Kavya_Jazz_CS242_Assignment_2_copy.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# CS242: Assignment 2





> Harvard CS 242: Computing at Scale (Spring 2020)
> 
> Instructor: Professor HT Kung


### **Assignment Instructions**

Read the following instructions carefully before starting the assignment and again before submitting your work:

* This programming assignment must be completed with the same group used in Assignment 1.  **If you have any issues with this arrangement, please email Marcus.**
* We expect this assignment to take more time than Assignment 1 (there is a more significant programming element as well as training time required for the models). Again, we suggest that you start your effort right away.
* The assignment consists of two files: this Google Colab file (an .ipynb file) and a latex answer template ([download here](https://drive.google.com/open?id=1I-D2CCBZRICzCnEB881Ublp3tORN6Dq5)).
* The Google Colab contains all assignment instructions and *Code Cells* that you will use to implement the programming components of the assignment (in Python).
* We provide a significant amount of the code to make it easier to get started. In the *Code Cells*, please add comments to explain the purpose of each line of code in your implementation. **You will not receive credit for implementations that are not well documented.**
* <font color='red'>**Deliverables highlighted in red**</font> are given in this Google Colab file. Use the latex answer template to write down answers for these deliverables.
* Each group will submit both a PDF of your latex answers and your Google Colab file (.ipynb file) containing all completed *Code Cells* to "Programming Assignment 2" on Canvas. Only one submission per group. Check your .ipynb file using this [tool](https://htmtopdf.herokuapp.com/ipynbviewer/) before submitting to ensure that you completed all *Code Cells* (including detailed comments).
* The assignment is due on April 13, 2020 at Noon EST.
* Each part that you are asked to implement is relatively small in isolation, and should be easy to test.  We strongly recommend that you test each of these parts before training the large models so that you do not waste time training models when your implemention may have bugs.  For example, you should ensure that your sampling is being done correctly, as if you do incorrect sampling the model will still train, but your results will not be correct.  For a number of sections, we have provided checks you can run to ensure correctness prior to training the large model.

-----
The outline of this assignment with point values and training estimates is given below. Note that these training estimates represent a minimum running time which assume that your implementation is correct.

1. Exploring Federated Learning (FL) [25 points] [Training Estimate: 2 hours]

2. Non-IID Federated Learning and Fairness [30 points] [Training Estimate: 3 hours]

3. Quantization of Local Models for Reduced Communication Cost [25 points] [Training Estimate: 3 hours]

4. Extreme Anomaly Detection [20 points] [Training Estimate: 1.5 hours]


---

### **1. Exploring Federated Learning (FL)**

---
For consistency, the exact same dataset (CIFAR-10) and CNN model (`ConvNet`) will be used in this assignment as in Assignment 1. *Code Cell 1.1* creates the CIFAR-10 training and testing datasets. Additionally, it also contains the CNN (`ConvNet`)  that will be used throughout the assignment.

In [0]:
## Code Cell 1.1

import time
import copy
import sys
from collections import OrderedDict

import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import numpy as np
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt


# Using CIFAR-10 again as in Assignment 1
# Load training data
transform_train = transforms.Compose([                                   
    transforms.RandomCrop(32, padding=4),                                       
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, 
                                        download=True,
                                        transform=transform_train)

# Load testing data
transform_test = transforms.Compose([                                           
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True,
                                       transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False,
                                         num_workers=2)


# Using same ConvNet as in Assignment 1
def conv_block(in_channels, out_channels, kernel_size=3, stride=1,
               padding=1):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding,
                  bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True)
        )

class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.model = nn.Sequential(
            conv_block(3, 32),
            conv_block(32, 32),
            conv_block(32, 64, stride=2),
            conv_block(64, 64),
            conv_block(64, 64),
            conv_block(64, 128, stride=2),
            conv_block(128, 128),
            conv_block(128, 256),
            conv_block(256, 256),
            nn.AdaptiveAvgPool2d(1)
            )

        self.classifier = nn.Linear(256, 10)

    def forward(self, x):
        h = self.model(x)
        B, C, _, _ = h.shape
        h = h.view(B, C)
        return self.classifier(h)


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


**Federated Learning Overview**

Federated Learning (FL) distributes the task of training a deep neural network (such as our CNN `ConvNet`) across multiple clients (each with a device). Each client has their own private data that they do not want to share with a central server. Therefore, instead of transmitting data, clients perform training locally and send the updated model parameters (e.g., convolutional weights) to the server. The server averages these parameters across multiple clients to update the centralized model. Finally, after the centralized model has been updated, the server sends the new version of the model to all clients.

The figure below depicts this Federated Learning paradigm (taken from [Towards Federated Learning at Scale: System Design](https://arxiv.org/pdf/1902.01046.pdf)). At the beginning of a training round in the selection phase, a percentage of devices (i.e., clients) agree to participate.  By agreeing to participate, the client agrees to perform local training on its own dataset that resides on the device. During the configuration phase, the up-to-date centralized model is sent to the participating clients, which then perform local training. In the reporting phase, each client sends their updated model (trained on local data) to the server for aggregation. Note that in the figure, one of the clients fails to report back to the central server (either due to device or network failure). In this programming assignment, for simplicity we will assume that this type of device/network failure is not possible. 

<figure>
<center>
<img src='https://drive.google.com/uc?id=1y8HAIxtNaZVLWetXHEzJ4UXWJ0yX_Jo0' />
</figure>


**Simulating Federated Learning**

In this assignment, we will simulate this distributed Federated Learning environment on a single machine (a Colab instance). Each `device` will own a non-overlapping subset (or partition) of the dataset (e.g., 10% of the CIFAR-10 training set) and use it to train a local version of the model. The main difference between this simulated environment and a real system is the lack of networking between devices.

You will use the `DatasetSplit` class in *Code Cell 1.2* to create subsets of the full training dataset. The `create_device` function creates a unique instance of `ConvNet`, an instance of the `DatasetSplit` dataloader, and an optimizer and scheduler for training. This function will be called multiple times (once per device) in order to create all the required device instances used for Federated Learning. The `train` and `test` functions are a modified version from Assignment 1 that take a device argument (the output from `create_device`).  The batch size during training is set to 128 throughout the assignment. This is passed into the `create_device` function as a default parameter value (i.e., `batch_size=128`).

In [0]:
## Code Cell 1.2

class DatasetSplit(torch.utils.data.Dataset):
    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = [int(i) for i in idxs]

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        image, label = self.dataset[self.idxs[item]]
        return image, torch.tensor(label)
  
def create_device(net, device_id, trainset, data_idxs, lr=0.1,
                  milestones=None, batch_size=128):
    if milestones == None:
        milestones = [25, 50, 75]

    device_net = copy.deepcopy(net)
    optimizer = torch.optim.SGD(device_net.parameters(), lr=lr, momentum=0.9,
                                weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=milestones,
                                                     gamma=0.1)
    device_trainset = DatasetSplit(trainset, data_idxs)
    device_trainloader = torch.utils.data.DataLoader(device_trainset,
                                                     batch_size=batch_size,
                                                     shuffle=True,
                                                     num_workers=2)
    return {
        'net': device_net,
        'id': device_id,
        'dataloader': device_trainloader, 
        'optimizer': optimizer,
        'scheduler': scheduler,
        'train_loss_tracker': [],
        'train_acc_tracker': [],
        'test_loss_tracker': [],
        'test_acc_tracker': [],
        }
  
def train(epoch, device):
    net.train()
    train_loss, correct, total = 0, 0, 0
    for batch_idx, (inputs, targets) in enumerate(device['dataloader']):
        inputs, targets = inputs.cuda(), targets.cuda()
        device['optimizer'].zero_grad()
        outputs = device['net'](inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        device['optimizer'].step()
        train_loss += loss.item()
        device['train_loss_tracker'].append(loss.item())
        loss = train_loss / (batch_idx + 1)
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        acc = 100. * correct / total
        dev_id = device['id']
        sys.stdout.write(f'\r(Device {dev_id}/Epoch {epoch}) ' + 
                         f'Train Loss: {loss:.3f} | Train Acc: {acc:.3f}')
        sys.stdout.flush()
    device['train_acc_tracker'].append(acc)
    sys.stdout.flush()

def test(epoch, device):
    net.eval()
    test_loss, correct, total = 0, 0, 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.cuda(), targets.cuda()
            outputs = device['net'](inputs)
            loss = criterion(outputs, targets)
            test_loss += loss.item()
            device['test_loss_tracker'].append(loss.item())
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            loss = test_loss / (batch_idx + 1)
            acc = 100.* correct / total
    sys.stdout.write(f' | Test Loss: {loss:.3f} | Test Acc: {acc:.3f}\n')
    sys.stdout.flush()  
    acc = 100.*correct/total
    device['test_acc_tracker'].append(acc)

**Single Device Scenario**

Before implementing Federated Learning, we will train a model for a single client device using only local data without sending updates to a central server. By doing this, the device is only able to look at a small percentage of the CIFAR-10 training set (10% in this case), and should perform poorly.

---
<font color='red'>**PART 1.1:**</font> [5 points]

<font color='red'>**Deliverables**</font>
1. In *Code Cell 1.3*, implement the `iid_sampler` function to generate **iid** (independent and identically distributed) samples from the CIFAR-10 training set. We will use this function to generate training subsets for multiple devices in PART 1.2.
2. In *Code Cell 1.4*, create a single device using the `create_device` function. This device should have 10% of the CIFAR-10 training set using the `iid_sampler` function.
3. Train the model for 1000 epochs using the specified parameters in *Code Cell 1.4* (similar to Assignment 1). The number of epochs is 10x greater due to the single device having 10x less data. Plot the test accuracy (`device['test_acc_tracker']`) and comment on the classification accuracy compared to using 100% of the dataset as in Assignment 1.
---

In [0]:
## Code Cell 1.3
import random #to use the random.sample method

def iid_sampler(dataset, num_devices, data_pct):
    '''
    dataset: PyTorch Dataset (e.g., CIFAR-10 training set)
    num_devices: integer number of devices to create subsets for
    data_pct: percentage of training samples to give each device
              e.g., 0.1 represents 10%

    return: a dictionary of the following format:
      {
        0: [3, 65, 2233, ..., 22] // device 0 sample indexes
        1: [0, 2, 4, ..., 583] // device 1 sample indexes
        ...
      }

    iid (independent and identically distributed) means that the indexes
    should be drawn independently in a uniformly random fashion.
    '''

    # total number of samples in the dataset
    total_samples = len(dataset)

    # Part 1.1: Implement!
    arr = [i for i in range(total_samples)] #create an arrray of length total_samples
    d = {} #initialize the dictonary
    for i in range(num_devices): #for every device
        d[i] = random.sample(list(arr), k=round(data_pct*total_samples)) #select data_pct*total_samples from the array, without replacement
    return d #return the dictionary

Now, perform training using a single device on a subset of the training dataset using your `iid_sampler`.

In [0]:
## Code Cell 1.4

data_pct = 0.1
epochs = 1000
num_devices = 1
device_pct = 0.1
net = ConvNet().cuda()
criterion = nn.CrossEntropyLoss()

# Part 1.1: Implement cifar_idd to generate data_idxs for create_device
data_idxs = iid_sampler(trainset, num_devices, data_pct)

# Part 1.1: Create the device
device = create_device(net, 0, trainset, data_idxs[0],
                       milestones=[250, 500, 750])

# Part 1.1: Train the device model for 100 epochs and plot the result
# Standard Training Loop
start_time = time.time()
for epoch in range(epochs):
    train(epoch, device)
    # To speed up running time, only evaluate the test set every 10 epochs
    if epoch > 0 and epoch % 10 == 0:
        test(epoch, device)
    device['scheduler'].step()


total_time = time.time() - start_time
print('Total training time: {} seconds'.format(total_time))

(Device 0/Epoch 5) Train Loss: 1.664 | Train Acc: 36.719

KeyboardInterrupt: ignored

**Implementing Components for Federated Learning**

In PART 1.1, you implemented an `iid_sampler`, created a 10% subset of the CIFAR-10 training set, and used it to train a single client device model. Since the client only had a 10% subset of the full CIFAR-10 training set, it performed significantly worse than the same model trained on 100% of the training set.  

Federated Learning aims to improve the performance of these client devices by averaging the updates from multiple clients over the course of training. This way, a centralized server is able to be updated using the training data stored on local devices without having access to the training data.  By using more client devices, you will be able to get access to the original entire dataset.  In a way, this is simulating the traditional gradient descent, but with additional epochs performed on each client before averaging, where each epoch uses minibatches of size 128.  An additional benefit of federated learning is that the centralized server does not require large compute resources, as most of the training computation is performed on local devices. This makes the training computation "free" for the centralized server, as the clients are paying for the compute cost on their local devices.

---
<font color='red'>**PART 1.2:**</font> [10 points]

Before implementing Federated Learning, you must implement two functions which will be used during the training process.

The `average_weights` function takes in multiple device models, and computes the average for each model parameter across all models. This function will be called by the centralized server to aggregate the training performed by the end user devices.  This averaging is done in 32-bit floating point.

The `get_devices_for_round` function will be used to simulate the device rejection phase shown earlier in the figure in the **Federated Learning Overview** section. This function will select a percentage of devices to participate in each training round.

<font color='red'>**Deliverables**</font>
1. In *Code Cell 1.5*, implement the `average_weights` function. We have provided test code which you can use to validate your implementation. This test code will also be useful for the full implementation of Federated Learning in PART 1.3.
2. In *Code Cell 1.5*, implement the `get_devices_for_round` function. Try multiple `device_pct` settings to ensure that it is working properly.
---

In [0]:
## Code Cell 1.5


def average_weights(devices):
    '''
    devices: a list of devices generated by create_devices
    Returns an the average of the weights.
    '''
    # Part 1.2: Implement!
    # Hint: device['net'].state_dict() will return an OrderedDict of all
    #       tensors in the model. Return the average of each tensor using
    #       and OrderedDict so that you can update the global model using
    #       device['net'].load_state_dict(w_avg), where w_avg is the 
    #       averaged OrderedDict over all devices
    global_tensors = copy.deepcopy(devices[0]['net'].state_dict()) #initialize a global tensor with the weights of the first device
    for i in range(1, len(devices)):#iterate over the remaining devices
        #for easy/ less complicated referencing, store the device and the state_dict
        d = devices[i]
        d_tensors = d['net'].state_dict()
        for j in global_tensors.keys(): #add the tensors together by the key they are indexed by
            global_tensors[j] += d_tensors[j]
    for j in global_tensors.keys(): #average each tensor by the number of devices
        global_tensors[j] = global_tensors[j]/len(devices)
    return global_tensors #return the averaged weights


def get_devices_for_round(devices, device_pct):
    '''
    This function will select a percentage of devices to participate in each training round.
    '''
    # Part 1.2: Implement!
    #randomly choose device_pct*len(devices) devices from the devices array without replacement
    arr = random.sample(devices, k=round(device_pct*len(devices)))
    return arr

# Test code for average_weights
# Hint: This test may be useful for Part 1.3!
class TestNetwork(nn.Module):
    '''
    A simple 2 layer MLP used for testing your average_weights implementation.
    '''
    def __init__(self):
        super(TestNetwork, self).__init__()
        self.layer1 = nn.Linear(2, 2)
        self.layer2 = nn.Linear(2, 4)
    
    def forward(self, x):
        h = F.relu(self.layer1(x))
        return self.layer2(h)

data_pct = 0.05
num_devices = 2
net = TestNetwork()
data_idxs = iid_sampler(trainset, num_devices, data_pct)
devices = [create_device(net, i, trainset, data_idxs[i])
           for i in range(num_devices)]

# Fixed seeding to compare against precomputed correct_weight_averages below
torch.manual_seed(0)
devices[0]['net'].layer1.weight.data.normal_()
devices[0]['net'].layer1.bias.data.normal_()
devices[0]['net'].layer2.weight.data.normal_()
devices[0]['net'].layer2.bias.data.normal_()
devices[1]['net'].layer1.weight.data.normal_()
devices[1]['net'].layer1.bias.data.normal_()
devices[1]['net'].layer2.weight.data.normal_()
devices[1]['net'].layer2.bias.data.normal_()

# Precomputed correct averages
correct_weight_averages = OrderedDict(
    [('layer1.weight', torch.tensor([[ 0.3245, -0.9013], [-0.9042,  1.0125]])),
     ('layer1.bias', torch.tensor([-0.0724, -0.3119])),
     ('layer2.weight', torch.tensor([[0.2976,  1.0509], [-1.0048, -0.5972],
                                     [-0.3088, -0.2682], [-0.1690, -0.1060]])),
     ('layer2.bias', torch.tensor([-0.4396,  0.3327, -1.3925,  0.3160]))
    ])

# Computed weight averages
computed_weight_averages = average_weights(devices)

mismatch_found = False
for correct, computed in zip(correct_weight_averages.items(),
                             computed_weight_averages.items()):
    if not torch.allclose(correct[1], computed[1], atol=1e-2):
        mismatch_found = True
        print('Mismatch in tensor:', correct[0])
        print(correct[1], computed[1]) #our debugging

if not mismatch_found:
    print('Implementation output matches!')

Implementation output matches!


**Federated Learning Training**

---
<font color='red'>**PART 1.3:**</font> [10 points]

We will now run the federated learning in the iid setting using the functions you wrote previously in this section.  The parameters are given to you in the code.  You will use 100 rounds of federated learning updates.  For each round, each device that participates in a given round will complete four epochs of local training.  10% of devices should participate in each round selected from the `get_devices_for_round` function you wrote previously.  Note that we use a static initialization for the models between all parts of the assignment.

<font color='red'>**Deliverables**</font>
1. In *Code Cell 1.6*, train a global model via federated learning.  Much of the code has been given to you, but you will need to fill in the parts using calls to the functions you wrote above.  

2. Graph the accuracy of the global model over 100 rounds.  Discuss the accuracy difference between the global model trained here and the individual local model you trained in PART 1.1.  

In [0]:
## Code Cell 1.6

# use these parameters
rounds = 100
local_epochs = 4
num_devices = 50
device_pct = 0.1
data_pct = 0.1
net = ConvNet().cuda()
criterion = nn.CrossEntropyLoss()

data_idxs = iid_sampler(trainset, num_devices, data_pct)

# Part 1.3: Implement device creation here
devices = [create_device(net, i, trainset, data_idxs[i])
           for i in range(num_devices)]

## IID Federated Learning
start_time = time.time()
for round_num in range(rounds):
  
    # Part 1.3: Implement getting devices for each round here
    round_devices = get_devices_for_round(devices, device_pct)

    print('Round: ', round_num)
    for device in round_devices:
        for local_epoch in range(local_epochs):
            train(local_epoch, device)

    # Part 1.3: Implement weight averaging here
    w_avg = average_weights(round_devices)

    for device in devices:
        device['net'].load_state_dict(w_avg)
        device['optimizer'].zero_grad()
        device['optimizer'].step()
        device['scheduler'].step()

    # test accuracy after aggregation
    test(round_num, devices[0])


total_time = time.time() - start_time
print('Total training time: {} seconds'.format(total_time))

Round:  0
(Device 11/Epoch 1) Train Loss: 2.122 | Train Acc: 21.094

KeyboardInterrupt: ignored

---

### **2. Non-IID Federated Learning and Fairness**

---
**Overview**

In PART 1, you implemented a Federated Learning pipeline that operated over iid data.  While this iid assumption may hold in some applications, it does not hold in many other settings.  For example, a group of similar users may have data that is fundamentally different from that of another group of users.  As a result, the data as a whole that federated learning operates over will be non-iid in nature.

In PART 2 of the assignment, you will explore using Federated Learning in this non-iid setting.  In this part of the assignment, you will create groups of devices such that the inter-group data is non-iid and the intra-group data is iid.  To do this, you will re-implement many of the functions you implemented in PART 1 for this group, non-iid setting.

For all experiments in this section, we fix that there are three groups.  Each group is assigned a different subset of classes in the dataset (Group 0 is assigned data from classes 0-3, Group 1 is assigned data from classes 4-6, and Group 2 is assigned data from classes 7-9).  We also fix that each group has 20 devices, although you will vary per-group participation rates in each round in PART 2.4.


---
**Non-iid Sampling**

<font color='red'>**PART 2.1:**</font> [10 points]

We will first start by implementing `noniid_group_sampler`, a new, non-iid group version of the `iid_sampler` you implemented in PART 1.  We will use this function to generate training subsets for multiple devices in PART 2.4.  

<font color='red'>**Deliverables**</font>
1. In *Code Cell 2.1*, implement the `noniid_group_sampler` function to generate **non-iid** samples from the CIFAR-10 training set.  As input, the function should take the dataset and number of training samples each device should be assigned.  As was the case in the `iid_sampler` you implemented previously, the function should return a dictionary where each key is the ID number of the particular device and each value is a `set` of the indexes of training samples in the dataset assigned to the device.  You may want to have the function return other data as well, depending on how you implement functions in later parts of the assignment.  We have provided the mapping stating which classes are mapped to each group.  Within a given group, you should sample the data in an iid fashion.
---

In [0]:
## Code Cell 2.1

# creates noniid TRAINING datasets for each group
def noniid_group_sampler(dataset, num_items_per_device):
    '''
    dataset: PyTorch Dataset (e.g., CIFAR-10 training set)
    num_devices: integer number of devices to create subsets for
    num_items_per_device: how many samples to assign to each device

    return: a dictionary of the following format:
      {
        0: [3, 65, 2233, ..., 22] // device 0 sample indexes
        1: [0, 2, 4, ..., 583] // device 1 sample indexes
        ...
      }

    '''

    # how many devices per non-iid group
    devices_per_group = [20, 20, 20]

    # label assignment per group
    dict_group_classes = {}
    dict_group_classes[0] = [0,1,2,3]
    dict_group_classes[1] = [4,5,6]
    dict_group_classes[2] = [7,8,9]

    # Part 2.1: Implement!
    #label dict stores the indexes of the dataset examples that fall into the ith group
    label_dict = {}
    label_dict[0] = []
    label_dict[1] = []
    label_dict[2] = []
    for i in range(len(dataset)):
        label = dataset[i][1]
        #find which group the label belongs to and append the index of the example to the correct array
        if(label in dict_group_classes[0]):
            label_dict[0].append(i)
        elif(label in dict_group_classes[1]):
            label_dict[1].append(i)
        else:
            label_dict[2].append(i)
    
    num_devices = sum(devices_per_group)

    final_dict = {} #final dict is to be returned
    #device_group[i] returns the group of the ith device
    device_group = [0 for i in range(devices_per_group[0])]+[1 for i in range(devices_per_group[1])]+[2 for i in range(devices_per_group[2])]
    for i in range(num_devices):
        group = device_group[i] #determine the group of the ith device
        final_dict[i] = random.sample(label_dict[group], num_items_per_device) #randomly sample num_items_per_device examples without replacement
    return final_dict

---
**Group-based Device Rejection**

<font color='red'>**PART 2.2:**</font> [5 points]

We will now implement `get_devices_for_round_GROUP`, a new group-based version of the `get_devices_for_round` you implemented in PART 1.  We will use this function in PART 2.4 to simulate the device rejection phase shown earlier on a per-group basis.  

<font color='red'>**Deliverables**</font>
1. In *Code Cell 2.2*, implement the `get_devices_for_round_GROUP` function to generate a list of devices that will participate in each round of federated learning.  The function should take as input 1) the list of all devices, and 2) how many devices from each group should participate in a given round.  It should return a list of devices that will participate in a given round.  You may want to add additional parameters to the function definition based on your implementation strategy.  

In [0]:
## Code Cell 2.2

# get which devices in each group should participate in a current round
# by explicitly saying number of each devices desired for each group 
def get_devices_for_round_GROUP(devices, device_nums, user_group_idxs=[[i for i in range(0,20)],[i for i in range(20,40)],[i for i in range(40,60)]]):
    #  Part 2.2: Implement!

    #Note: We assume devices 0-19 are group 0, 20-39 are group 1, and 40-59 are group 2
    arr = [] #initialize the return array
    for group in range(len(device_nums)): #iterate over all the groups
        #randomly sample the indexes of a certain number of devices from a certain group
        choices = random.sample(user_group_idxs[group], device_nums[group]) 
        #print(choices)
        for i in choices:
            arr.append(devices[i]) #append the chosen devices to the return array
    return arr

---
**Group-based Testing**

<font color='red'>**PART 2.3:**</font> [5 points]

We will now implement the testing functions needed to evaluate the global model learned via Federated Learning on a per-group basis.  This will require two functions.

The `cifar_noniid_group_test` function divides the test dataset into three sub-datasets, one for each group.  

The `test_group` function gets per-group classification accuracy for the global model.  You will likely want to start with the `test` function in Code Cell 1.2 and then modify it to work on a per-group basis. 

<font color='red'>**Deliverables**</font>
1. In *Code Cell 2.3*, implement the `cifar_noniid_group_test` function to create a test dataset for each group.  It should take the full CIFAR-10 test dataset as input, and return a dictionary where each key is a group ID, and each value is a `set` of the indexes for all test samples for that group. 

2.  In *Code Cell 2.3*, implement the `test_group` function to output the per-group classification accuracy of the global model.



In [0]:
## Code Cell 2.3

# creates noniid TEST datasets for each group
def cifar_noniid_group_test(dataset):

    dict_group_classes = {}
    dict_group_classes[0] = [0,1,2,3]
    dict_group_classes[1] = [4,5,6]
    dict_group_classes[2] = [7,8,9]

    # Part 2.3: Implement!
    label_dict = {}
    label_dict[0] = []
    label_dict[1] = []
    label_dict[2] = []
    for i in range(len(dataset)): 
        label = dataset[i][1] #determine the label of the ith data point
        #determine which group the label is in using dict_group_classes
        #and append it to the correct array
        if(label in dict_group_classes[0]):
            label_dict[0].append(i)
        elif(label in dict_group_classes[1]):
            label_dict[1].append(i)
        else:
            label_dict[2].append(i)
    return label_dict

# gets per-group accuracy of global model
def test_group(epoch, device, group_idxs_dict):
    # Part 2.3: Implement!
    # Hint: refer to test function in PART 1
    net.eval() #turn the net into evaluaton mode
    sys.stdout.write(' | Test accuracy: ')
    with torch.no_grad():
        # Iterate through the groups
        for group in range (0, len(group_idxs_dict.keys())): #iterate over all of the groups
            # Initialize the counters at the beginning of each group
            test_loss, correct, total = 0, 0, 0
            # Use DatasetSplit to generate new testloader for each group
            new_testset = DatasetSplit(testset, group_idxs_dict[group])
            testloader = torch.utils.data.DataLoader(new_testset, batch_size=128, shuffle=False,
                                            num_workers=2)
            # Code from PART 1
            for batch_idx, (inputs, targets) in enumerate(testloader):
                inputs, targets = inputs.cuda(), targets.cuda()
                outputs = device['net'](inputs)
                loss = criterion(outputs, targets)
                test_loss += loss.item()
                device['test_loss_tracker'].append(loss.item())
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
            # Compute and print loss and accuracy at the end of the group
            loss = test_loss / (batch_idx + 1)
            acc = 100.* correct / total
            sys.stdout.write(f'{acc:.3f} | ')
    sys.stdout.write('\n')
    sys.stdout.flush()  


---
**Federated Learning Results in Non-IID Setting**

<font color='red'>**PART 2.4:**</font> [10 points]

We will now run federated learning in the non-iid setting using the functions you wrote previously in this section.  We will examine two different settings.

**Fair Device Participation:** Run federated learning on the CIFAR-10 dataset with the three groups.  Each group should have exactly one device participate in each round.  

**Unfair Device Participation:** Run federated learning on the CIFAR-10 dataset with the three groups.  Group 0 should have five devices participate in each round, and Groups 1 and 2 should each have one device participate in each round.

<font color='red'>**Deliverables**</font>
1. In *Code Cell 2.4*, train a global model via federated learning for the group-based non-iid setting.  Much of the code has been given to you, but you will need to fill in the parts using calls to the group-based, non-iid functions you wrote above.  You will likely be able to re-use parts of the code you wrote in Part 1.3.

2. Graph the per-group test accuracy over 100 rounds in the Fair Device Participation setting scenario.  Each group should have its own line in the graph.

3.  Graph the per-group test accuracy over 100 rounds in the Unfair Device Participation setting scenario.  Each group should have its own line in the graph.

4.  Describe the differences you see between the two scenarios.  How can you explain what you are seeing?

In [0]:
## Code Cell 2.4

rounds = 100
local_epochs = 4
num_items_per_device = 5000
device_nums = [5,1,1]
net = ConvNet().cuda()
criterion = nn.CrossEntropyLoss()

# Part 2.4: Implement non-iid sampling
data_idxs = noniid_group_sampler(trainset, num_items_per_device)

# Part 2.4: Implement device creation here
num_devices = 60
devices = [create_device(net, i, trainset, data_idxs[i])
           for i in range(num_devices)]

## Non-IID Federated Learning
start_time = time.time()
for round_num in range(rounds):

    # Part 2.4: Implement getting devices for each round here
    round_devices = get_devices_for_round_GROUP(devices, device_nums)

    print('Round: ', round_num)
    for device in round_devices:
        for local_epoch in range(local_epochs):
            train(local_epoch, device)

    # Part 2.4: Implement weight averaging here
    w_avg = average_weights(round_devices)

    for device in devices:
        device['net'].load_state_dict(w_avg)
        device['optimizer'].zero_grad()
        device['optimizer'].step()
        device['scheduler'].step()

    # Part 2.4: Implement test accuracy here
    group_idxs_dict = cifar_noniid_group_test(testset)
    test_group(round_num, devices[0], group_idxs_dict)

total_time = time.time() - start_time
print('Total training time: {} seconds'.format(total_time))

Round:  0
(Device 42/Epoch 3) Train Loss: 0.544 | Train Acc: 79.520 | Test accuracy: 59.400 | 0.000 | 0.000 | 
Round:  1
(Device 50/Epoch 3) Train Loss: 0.586 | Train Acc: 75.380 | Test accuracy: 40.675 | 0.000 | 0.000 | 
Round:  2
(Device 53/Epoch 3) Train Loss: 0.406 | Train Acc: 83.940 | Test accuracy: 68.300 | 0.000 | 

Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f6570616630>>
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 961, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 941, in _shutdown_workers
    w.join()
  File "/usr/lib/python3.6/multiprocessing/process.py", line 122, in join
    assert self._parent_pid == os.getpid(), 'can only join a child process'
AssertionError: can only join a child process


0.000 | 
Round:  3
(Device 13/Epoch 0) Train Loss: 0.777 | Train Acc: 72.266

Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f6570616630>>
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 961, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 941, in _shutdown_workers
    w.join()
  File "/usr/lib/python3.6/multiprocessing/process.py", line 122, in join
    assert self._parent_pid == os.getpid(), 'can only join a child process'
AssertionError: can only join a child process


(Device 13/Epoch 0) Train Loss: 0.754 | Train Acc: 73.438

Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f6570616630>>
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 961, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 941, in _shutdown_workers
    w.join()
  File "/usr/lib/python3.6/multiprocessing/process.py", line 122, in join
    assert self._parent_pid == os.getpid(), 'can only join a child process'
AssertionError: can only join a child process


(Device 13/Epoch 1) Train Loss: 0.516 | Train Acc: 82.031

Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f6570616630>>
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 961, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 941, in _shutdown_workers
    w.join()
  File "/usr/lib/python3.6/multiprocessing/process.py", line 122, in join
Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f6570616630>>
    assert self._parent_pid == os.getpid(), 'can only join a child process'
Traceback (most recent call last):
AssertionError: can only join a child process
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 961, in __del__
    self._shutdown_workers()
  File "/usr/local/l

(Device 13/Epoch 1) Train Loss: 0.598 | Train Acc: 76.953

    assert self._parent_pid == os.getpid(), 'can only join a child process'
AssertionError: can only join a child process


(Device 13/Epoch 2) Train Loss: 0.478 | Train Acc: 82.812

Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f6570616630>>
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 961, in __del__
    self._shutdown_workers()


(Device 13/Epoch 2) Train Loss: 0.492 | Train Acc: 81.250

  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 941, in _shutdown_workers
    w.join()
  File "/usr/lib/python3.6/multiprocessing/process.py", line 122, in join
    assert self._parent_pid == os.getpid(), 'can only join a child process'
AssertionError: can only join a child process


(Device 13/Epoch 2) Train Loss: 0.592 | Train Acc: 76.719

Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f6570616630>>
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 961, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 941, in _shutdown_workers
    w.join()
  File "/usr/lib/python3.6/multiprocessing/process.py", line 122, in join
    assert self._parent_pid == os.getpid(), 'can only join a child process'
AssertionError: can only join a child process


(Device 13/Epoch 2) Train Loss: 0.590 | Train Acc: 76.580

Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f6570616630>>
Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f6570616630>>
Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 961, in __del__
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 961, in __del__
    self._shutdown_workers()
    self._shutdown_workers()
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 941, in _shutdown_workers
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 941, in _shutdown_workers
    w.join()
    w.join()
  File "/usr/lib/python3.6/multiprocessing/process.py", line 122, in join
    assert self._p

(Device 8/Epoch 0) Train Loss: 0.812 | Train Acc: 68.750

Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f6570616630>>
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 961, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 941, in _shutdown_workers
    w.join()
Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f6570616630>>
  File "/usr/lib/python3.6/multiprocessing/process.py", line 122, in join
    assert self._parent_pid == os.getpid(), 'can only join a child process'
AssertionError: can only join a child process
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 961, in __del__
    self._shutdown_workers()
  File "/usr/local/l

(Device 8/Epoch 1) Train Loss: 0.604 | Train Acc: 74.862

Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f6570616630>>
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 961, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 941, in _shutdown_workers
    w.join()
  File "/usr/lib/python3.6/multiprocessing/process.py", line 122, in join
    assert self._parent_pid == os.getpid(), 'can only join a child process'
AssertionError: can only join a child process


(Device 8/Epoch 1) Train Loss: 0.631 | Train Acc: 74.438

Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f6570616630>>
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 961, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 941, in _shutdown_workers
    w.join()
  File "/usr/lib/python3.6/multiprocessing/process.py", line 122, in join
    assert self._parent_pid == os.getpid(), 'can only join a child process'
AssertionError: can only join a child process


(Device 8/Epoch 1) Train Loss: 0.639 | Train Acc: 74.120

Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f6570616630>>
Traceback (most recent call last):
Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f6570616630>>
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 961, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 941, in _shutdown_workers
    w.join()
  File "/usr/lib/python3.6/multiprocessing/process.py", line 122, in join
Traceback (most recent call last):
    assert self._parent_pid == os.getpid(), 'can only join a child process'
AssertionError: can only join a child process
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 961, in __del__
    self._shutdown_workers()
  File "/usr/local/l

(Device 8/Epoch 3) Train Loss: 0.654 | Train Acc: 71.484

Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f6570616630>>
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 961, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 941, in _shutdown_workers
    w.join()
Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f6570616630>>
  File "/usr/lib/python3.6/multiprocessing/process.py", line 122, in join
Traceback (most recent call last):
    assert self._parent_pid == os.getpid(), 'can only join a child process'
AssertionError: can only join a child process
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 961, in __del__
    self._shutdown_workers()
  File "/usr/local/l

(Device 8/Epoch 3) Train Loss: 0.669 | Train Acc: 71.094

    w.join()
  File "/usr/lib/python3.6/multiprocessing/process.py", line 122, in join
    assert self._parent_pid == os.getpid(), 'can only join a child process'
AssertionError: can only join a child process


(Device 8/Epoch 3) Train Loss: 0.592 | Train Acc: 76.300

Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f6570616630>>
Traceback (most recent call last):
Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f6570616630>>
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 961, in __del__
Traceback (most recent call last):
    self._shutdown_workers()
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 961, in __del__
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 941, in _shutdown_workers
    self._shutdown_workers()
    w.join()
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 941, in _shutdown_workers
  File "/usr/lib/python3.6/multiprocessing/process.py", line 122, in join
    w.join()
    assert self._p

(Device 16/Epoch 0) Train Loss: 0.648 | Train Acc: 74.540

Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f6570616630>>
Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f6570616630>>
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 961, in __del__
Traceback (most recent call last):
    self._shutdown_workers()
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 961, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 941, in _shutdown_workers
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 941, in _shutdown_workers
    w.join()
  File "/usr/lib/python3.6/multiprocessing/process.py", line 122, in join
    assert self._parent_pid == 

(Device 16/Epoch 2) Train Loss: 0.518 | Train Acc: 79.688

Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f6570616630>>
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 961, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 941, in _shutdown_workers
Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f6570616630>>
    w.join()
Traceback (most recent call last):
  File "/usr/lib/python3.6/multiprocessing/process.py", line 122, in join
    assert self._parent_pid == os.getpid(), 'can only join a child process'
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 961, in __del__
    self._shutdown_workers()
AssertionError: can only join a child process
  File "/usr/local/l

(Device 16/Epoch 2) Train Loss: 0.603 | Train Acc: 76.220

Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f6570616630>>
Traceback (most recent call last):
Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f6570616630>>
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 961, in __del__
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 961, in __del__
    self._shutdown_workers()
    self._shutdown_workers()
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 941, in _shutdown_workers
    w.join()
  File "/usr/lib/python3.6/multiprocessing/process.py", line 122, in join
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 941, in _shutdown_workers
    w.join()
  File "/usr/lib/p

(Device 16/Epoch 3) Train Loss: 0.592 | Train Acc: 76.300

Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f6570616630>>
Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f6570616630>>
Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 961, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 961, in __del__
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 941, in _shutdown_workers
    self._shutdown_workers()
    w.join()
  File "/usr/lib/python3.6/multiprocessing/process.py", line 122, in join
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 941, in _shutdown_workers
    assert self._parent_pid == 

(Device 19/Epoch 1) Train Loss: 0.657 | Train Acc: 75.781

Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f6570616630>>
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 961, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 941, in _shutdown_workers
    w.join()
  File "/usr/lib/python3.6/multiprocessing/process.py", line 122, in join
    assert self._parent_pid == os.getpid(), 'can only join a child process'
AssertionError: can only join a child process
Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f6570616630>>
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 961, in __del__
    self._shutdown_workers()
  File "/usr/local/l

(Device 19/Epoch 1) Train Loss: 0.634 | Train Acc: 78.125

    assert self._parent_pid == os.getpid(), 'can only join a child process'
AssertionError: can only join a child process


(Device 19/Epoch 1) Train Loss: 0.598 | Train Acc: 76.000

Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f6570616630>>
Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f6570616630>>
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 961, in __del__
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 961, in __del__
    self._shutdown_workers()
    self._shutdown_workers()
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 941, in _shutdown_workers
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 941, in _shutdown_workers
    w.join()
    w.join()
  File "/usr/lib/python3.6/multiprocessing/process.py", line 122, in join
  File "/usr/lib/p

(Device 19/Epoch 2) Train Loss: 0.612 | Train Acc: 75.860

Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f6570616630>>
Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f6570616630>>
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 961, in __del__
    self._shutdown_workers()
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 941, in _shutdown_workers
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 961, in __del__
    w.join()
  File "/usr/lib/python3.6/multiprocessing/process.py", line 122, in join
    self._shutdown_workers()
    assert self._parent_pid == os.getpid(), 'can only join a child process'
AssertionError: can only join a child process
  File "/usr/local/l

(Device 5/Epoch 0) Train Loss: 0.744 | Train Acc: 69.922

Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f6570616630>>
Traceback (most recent call last):
Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f6570616630>>
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 961, in __del__
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 961, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 941, in _shutdown_workers
    w.join()
  File "/usr/lib/python3.6/multiprocessing/process.py", line 122, in join
    assert self._parent_pid == os.getpid(), 'can only join a child process'
AssertionError: can only join a child process
    self._shutdown_workers()
  File "/usr/local/l

(Device 5/Epoch 0) Train Loss: 0.733 | Train Acc: 70.833

    assert self._parent_pid == os.getpid(), 'can only join a child process'
AssertionError: can only join a child process


(Device 5/Epoch 0) Train Loss: 0.670 | Train Acc: 73.740

Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f6570616630>>
Traceback (most recent call last):
Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f6570616630>>
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 961, in __del__
Traceback (most recent call last):
    self._shutdown_workers()
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 961, in __del__
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 941, in _shutdown_workers
    self._shutdown_workers()
    w.join()
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 941, in _shutdown_workers
  File "/usr/lib/python3.6/multiprocessing/process.py", line 122, in join
    w.join()
    assert self._p

(Device 5/Epoch 1) Train Loss: 0.682 | Train Acc: 72.240

Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f6570616630>>
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 961, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 941, in _shutdown_workers
Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f6570616630>>
    w.join()
  File "/usr/lib/python3.6/multiprocessing/process.py", line 122, in join
    assert self._parent_pid == os.getpid(), 'can only join a child process'
AssertionError: can only join a child process
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 961, in __del__
    self._shutdown_workers()
  File "/usr/local/l

(Device 5/Epoch 3) Train Loss: 0.684 | Train Acc: 73.438

Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f6570616630>>
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 961, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 941, in _shutdown_workers
    w.join()
  File "/usr/lib/python3.6/multiprocessing/process.py", line 122, in join
    assert self._parent_pid == os.getpid(), 'can only join a child process'
AssertionError: can only join a child process
Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f6570616630>>
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 961, in __del__


(Device 5/Epoch 3) Train Loss: 0.745 | Train Acc: 71.615

    self._shutdown_workers()
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 941, in _shutdown_workers
    w.join()
  File "/usr/lib/python3.6/multiprocessing/process.py", line 122, in join
    assert self._parent_pid == os.getpid(), 'can only join a child process'
AssertionError: can only join a child process


(Device 5/Epoch 3) Train Loss: 0.644 | Train Acc: 75.360

Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f6570616630>>
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 961, in __del__
Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f6570616630>>
    self._shutdown_workers()
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 941, in _shutdown_workers
    w.join()
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 961, in __del__
  File "/usr/lib/python3.6/multiprocessing/process.py", line 122, in join
    self._shutdown_workers()
    assert self._parent_pid == os.getpid(), 'can only join a child process'
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/da

(Device 38/Epoch 0) Train Loss: 1.234 | Train Acc: 49.120

Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f6570616630>>
Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f6570616630>>
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 961, in __del__
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 961, in __del__
    self._shutdown_workers()
    self._shutdown_workers()
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 941, in _shutdown_workers
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 941, in _shutdown_workers
    w.join()
    w.join()
  File "/usr/lib/python3.6/multiprocessing/process.py", line 122, in join
    assert self._p

(Device 38/Epoch 1) Train Loss: 0.767 | Train Acc: 67.220

Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f6570616630>>
Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f6570616630>>
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 961, in __del__
    self._shutdown_workers()
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 941, in _shutdown_workers
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 961, in __del__
    w.join()
    self._shutdown_workers()
  File "/usr/lib/python3.6/multiprocessing/process.py", line 122, in join
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 941, in _shutdown_workers
    w.join()
  File "/usr/lib/p

(Device 38/Epoch 2) Train Loss: 0.701 | Train Acc: 70.920

Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f6570616630>>
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 961, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 941, in _shutdown_workers
    w.join()
  File "/usr/lib/python3.6/multiprocessing/process.py", line 122, in join
    assert self._parent_pid == os.getpid(), 'can only join a child process'
AssertionError: can only join a child process
Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f6570616630>>
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 961, in __del__
    self._shutdown_workers()
  File "/usr/local/l

(Device 38/Epoch 3) Train Loss: 0.644 | Train Acc: 73.420

Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f6570616630>>
Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f6570616630>>
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 961, in __del__
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 961, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 941, in _shutdown_workers
    self._shutdown_workers()
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 941, in _shutdown_workers
    w.join()
  File "/usr/lib/python3.6/multiprocessing/process.py", line 122, in join
    assert self._parent_pid == 

(Device 49/Epoch 0) Train Loss: 0.747 | Train Acc: 75.060

Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f6570616630>>
Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f6570616630>>
Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 961, in __del__
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 961, in __del__
    self._shutdown_workers()
    self._shutdown_workers()
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 941, in _shutdown_workers
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 941, in _shutdown_workers
    w.join()
    w.join()
  File "/usr/lib/python3.6/multiprocessing/process.py", line 122, in join
  File "/usr/lib/p

(Device 49/Epoch 1) Train Loss: 0.390 | Train Acc: 84.600

Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f6570616630>>
Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f6570616630>>
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 961, in __del__
Traceback (most recent call last):
    self._shutdown_workers()
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 941, in _shutdown_workers
    w.join()
  File "/usr/lib/python3.6/multiprocessing/process.py", line 122, in join
    assert self._parent_pid == os.getpid(), 'can only join a child process'
AssertionError: can only join a child process
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 961, in __del__
    self._shutdown_workers()
  File "/usr/local/l

(Device 49/Epoch 2) Train Loss: 0.335 | Train Acc: 87.460

Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f6570616630>>
Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f6570616630>>
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 961, in __del__
    self._shutdown_workers()
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 941, in _shutdown_workers
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 961, in __del__
    w.join()
    self._shutdown_workers()
  File "/usr/lib/python3.6/multiprocessing/process.py", line 122, in join
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 941, in _shutdown_workers
    assert self._parent_pid == 

(Device 49/Epoch 3) Train Loss: 0.331 | Train Acc: 87.240 | Test accuracy: 71.700 | 0.000 | 0.000 | 
Round:  4
(Device 43/Epoch 3) Train Loss: 0.380 | Train Acc: 85.780 | Test accuracy: 74.550 | 0.000 | 0.000 | 
Round:  5
(Device 44/Epoch 3) Train Loss: 0.336 | Train Acc: 87.120 | Test accuracy: 73.050 | 0.000 | 0.000 | 
Round:  6
(Device 57/Epoch 3) Train Loss: 0.254 | Train Acc: 90.440 | Test accuracy: 74.625 | 0.000 | 0.000 | 
Round:  7
(Device 44/Epoch 3) Train Loss: 0.238 | Train Acc: 90.820 | Test accuracy: 77.000 | 0.000 | 0.000 | 
Round:  8
(Device 53/Epoch 3) Train Loss: 0.239 | Train Acc: 92.300 | Test accuracy: 79.100 | 0.000 | 0.000 | 
Round:  9
(Device 42/Epoch 3) Train Loss: 0.232 | Train Acc: 92.340 | Test accuracy: 79.725 | 0.100 | 0.000 | 
Round:  10
(Device 48/Epoch 3) Train Loss: 0.144 | Train Acc: 94.520 | Test accuracy: 83.775 | 0.000 | 0.000 | 
Round:  11
(Device 54/Epoch 3) Train Loss: 0.154 | Train Acc: 94.560 | Test accuracy: 82.700 | 0.167 | 0.033 | 
Round:  1

----
### **3. Quantization of Local Models for Reduced Communication Cost**
-----
Quantization refers to the process of reducing the number of bits used to represent a number. In the context of deep learning, the predominant numerical format used in research and deployment has been full precision (32-bit floating-point, [IEEE 754 Format](https://en.wikipedia.org/wiki/Single-precision_floating-point_format)). However, the desire for reduced model size and computation has led to research on using fewer bits to represent numbers in deep learning models.  This can impact a number of aspects of the pipeline, including computation, communication, and storage requirements.  For example, in the context of federated learning, quantizing a client model from full precision to 8-bit precision will reduce the model size by ~ 4×.  This also reduces the storage requirements. Further, because the model size is reduced, the communication required for uploading a client model is also reduced by ~ 4× as well.

However, this quantization comes with trade-offs.  To see this, consider a full precision representation (32-bit floating-point).  This representation has a large dynamic range (from $-3.4\times 10^{38}$ to $+3.4\times10^{38}$) and high precision (about $7$ decimal digits). As a result, a full precision number can be seen as continuous data. In contrast, a n-bit fixed-point representation only has $2^n$ discrete values. As a result, n-bit representations of numbers can only be one of these $2^n$ values. n-bit quantization generally refers to projecting a full precision weight to one of these $2^n$ discrete values by finding its nearest neighbor.  

--------
<font color='red'>**PART 3.1:**</font> [5 points]

In this part, we will write a function to project full precision numbers into n-bit fixed-point numbers.  For example, suppose we want to project full precision numbers in the range of $[0, 1]$ into an 8-bit fixed point representation, $\frac{1}{2^8-1}\times(0, 1, 2, 3, \dots, 253, 254,255)$, where $\frac{1}{2^8-1}$ is the **scale factor** of the 8-bit fixed point representation. 

<font color='red'>**Deliverables**</font>
1. In *Code Cell 3.1*, implement a function that projects full precision numbers in the range of [0, 1] into n-bit fixed-point numbers. If your implementation is correct, it should return *'Output of Quantization Matches!'*.

In [0]:
## Code Cell 3.1

def quantizer(input, nbit):
    '''
    input: full precision tensor in the range [0, 1]
    return: quantized tensor
    '''
    scale_factor = 1 / (2**nbit -  1)

    # scale input by inverse of scale_factor and round to nearest integer
    output = input / scale_factor
    output = torch.round(output)

    # scale rounded output back and return
    output *= scale_factor
    return output

# Test Code
test_data = torch.tensor([i/11 for i in range(11)])

# ground truth results of 4-bit quantization
ground_truth = torch.tensor([0.0000, 0.0667, 0.2000, 0.2667, 0.3333, 0.4667,
                             0.5333, 0.6667, 0.7333, 0.8000, 0.9333])

# output of your quantization function
quantizer_output = quantizer(test_data, 4)

if torch.allclose(quantizer_output, ground_truth, atol=1e-04):
    print('Output of Quantization Matches!')
else:
    print('Output of Quantization DOES NOT Match!')

Output of Quantization Matches!


**Quantize Weights of Neural Networks**

The quantizer in PART 3.1 will quantize any full precision number in the range of $[0, 1]$ into an n-bit fixed-point number. However, weights of neural networks, $w$, are not necessarily in the range of $[0, 1]$. 

To use the quantizer in PART 3.1, we will first use a scaling function to transform weights into the range of [0 ,1]:
$$\tilde{w} = \frac{w}{2 \max(|w|)} + \frac{1}{2}$$ 
where $2 \times\max(|w|)$ is the **adaptive scale**.

Then, we quantize the transformed weights:
$$\hat{w} = \text{quantizer}_{\text{n-bit}}(\tilde{w})$$
After quantization, a reverse scaling function is applied on $\hat{w}$ to recover weights' original scale:
$$w_q = 2\max(|w|)\times(\hat{w}-\frac{1}{2})$$

Combining these three equations, the expression we will use to get the quantized weights $w_q$ is as follows:
$$w_q = 2\max(|w|)\times[\text{quantizer}_{\text{n-bit}}(\frac{w}{2\max(|w|)} + \frac{1}{2}) - \frac{1}{2}]$$

This equation is the **deterministic quantization function**. 

Following the method proposed by [DoReFa-Net](https://arxiv.org/abs/1606.06160), we enable *stochastic quantization* by adding extra noise $N(n) = \frac{\sigma}{2^n-1}$ to the transformed weights $\tilde{w}$, where $\sigma \sim \text{Uniform}(-0.5, 0.5)$ and $n$ is the number of bits. Generally, including such extra noise will encourage the model to explore more of the unexplored area in the loss surface, help the model escape local minima, and improve the model's generalizability.  

The final **stochastic quantization function** we will use to quantize layers of local models is:

$$w_q = 2\max(|w|)\times[\text{quantizer}_{\text{n-bit}}[\frac{w}{2 \times\max(|w|)} + \frac{1}{2} + N(n)] - \frac{1}{2}]$$


<font color='red'>**PART 3.2:**</font> [10 points]

<font color='red'>**Deliverables**</font>
1. In *Code Cell 3.2*, implement `dorefa_g(w, nbit, adaptive_scale=None)` with the formulation of the **stochastic quantization function** shown above. Again, if your implementation is correct, it should return *'Output of Quantization Matches!'*.


In [0]:
## Code Cell 3.2

def quantize_model(model, nbit):
    '''
    Used in Code Cell 3.3 to quantize the ConvNet model
    '''
    for m in model.modules():
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
            m.weight.data, m.adaptive_scale = dorefa_g(m.weight, nbit)
            if m.bias is not None:
                m.bias.data,_ = dorefa_g(m.bias, nbit, m.adaptive_scale)

def dorefa_g(w, nbit, adaptive_scale=None):
    '''
    w: a floating-point weight tensor to quantize
    nbit: the number of bits in the quantized representation
    adaptive_scale: the maximum scale value. if None, it is set to be the
                    absolute maximum value in w.
    '''
    if adaptive_scale is None:
        adaptive_scale = torch.max(torch.abs(w))
    
    # follows equations above
    sigma = torch.rand(w.shape) - 0.5
    noise = sigma / (2**nbit - 1)
    # avoid type errors
    noise = noise.type(w.type())
    inp = w / (2*adaptive_scale) + 0.5 + noise
    w_q = 2*adaptive_scale * (quantizer(inp, nbit) - 0.5)

    return w_q, adaptive_scale


# Test Code
test_data = torch.tensor([i/11 for i in range(11)])

# ground truth results of 4-bit quantization
ground_truth = torch.tensor([-0.0606, 0.0606, 0.1818, 0.3030, 0.3030, 0.4242,
                             0.5455, 0.5455, 0.7879, 0.7879, 0.9091])

# output of your quantization function
torch.manual_seed(43)
quantizer_output, adaptive_scale = dorefa_g(test_data, 4)

if torch.allclose(quantizer_output, ground_truth, atol=1e-04):
    print('Output of Quantization Matches!')
else:
    print('Output of Quantization DOES NOT Match!')

Output of Quantization Matches!


**Reduce the Communication Overhead with Quantization**

We will now explore the performance impact on federated learning using quantization.  We will use the iid-setting from PART 1. You will run the same federated learning code, but will quantize each local model with the `quantize_model` function you wrote above before uploading to the central server (*Line 27, Code Cell 3.3*).

<font color='red'>**PART 3.3:**</font> [10 points]

<font color='red'>**Deliverables**</font>
1. In *Code Cell 3.3*, run federated learning with the following two different settings of quantization: `nbit = 16, 4`. Graph the accuracy of the global models over 100 rounds under different settings of bitwidth, 32 (the full precision baseline you ran previously), 16, and 4.  
2. Discuss the accuracy difference between the global models under the three different settings of bit-width (32, 16, and 4). 

In [0]:
## Code Cell 3.3

# Part 3.2: Train two settings with nbit=16 and nbit=4.
#           Compare against the floating-point performance
#           of the final FL model trained in Part 1.3.
nbit = 16

rounds = 100
local_epochs = 4
num_devices = 50
device_pct = 0.1
data_pct = 0.1
net = ConvNet().cuda()
criterion = nn.CrossEntropyLoss()

data_idxs = iid_sampler(trainset, num_devices, data_pct)
devices = [create_device(net, i, trainset, data_idxs[i])
           for i in range(num_devices)]

## IID Federated Learning
start_time = time.time()
for round_num in range(rounds):
    # Part 3.3: Implement!
    # Hint: you can use your federated learning code from PART 1

    # following PART 1
    round_devices = get_devices_for_round(devices, device_pct)

    print('Round: ', round_num)

    # iterate over all devices
    for device in round_devices:
        for local_epoch in range(local_epochs):
            train(local_epoch, device)
        # after training, quantize the learned model
        quantize_model(device['net'], nbit)

    # take average over devices with already quantized models
    w_avg = average_weights(round_devices)

    for device in devices:
        device['net'].load_state_dict(w_avg)
        device['optimizer'].zero_grad()
        device['optimizer'].step()
        device['scheduler'].step()

    # test accuracy after aggregation
    test(round_num, devices[0])


total_time = time.time() - start_time
print('Total training time: {} seconds'.format(total_time))

Round:  0
(Device 43/Epoch 1) Train Loss: 2.232 | Train Acc: 21.429

KeyboardInterrupt: ignored

---

### **4. Extreme Anomaly Detection**

In this part, you will explore the impact of malicious clients on Federated Learning. These malicious clients send fake weight updates to the centralized server in order to seriously degrade the classification accuracy of the global model and make convergence difficult to achieve during training. To mitigate the impact of these malicious clients, you will implement a secure version of the weight averaging scheme you implemented in PART 1 called `secure_average_weights`.

In PART 4.1, you will simulate a group of malicious clients which participate in FL with the regular non-malicious clients. You will implement the fake weight updates (`generate_fake_weights`) used by these malicious clients to make FL unstable. Finally, you will observe the impact of these malicious clients on the classification accuracy of the global model.

In PART 4.2, you will implement a secure weight update scheme (`secure_average_weights`) at the central server to detect the fake weights sent by malicious clients. Using this secure scheme, the malicious clients will be identified so that their weights are not used in the averaging step of the global model. Ultimately, this makes it so that these malicious clients do not impact the performance of the global model.


**Simulating Malicious Devices**

In this part, you will write code to simulate malicious clients. We have slightly modify the FL training loop from PART 1 to support two different types of devices: normal (i.e., non-malicious) devices and malicious devices. We provide a function to select which devices are malicious (`get_malicious_devices`). When a malicious device is selected for a round of training, it will use fake weights drawn from a Gaussian distribution instead of actually performing local training.

---
<font color='red'>**PART 4.1:**</font> [10 points]

Implement the `generate_fake_weights` function to sample fake weight values from a Gaussian distribution with a mean of zero and a standard deviation of 0.5. When a malicious device is selected for a round of training, it will use this function to send fake weights to the central server.

<font color='red'>**Deliverables**</font>

1. In *Code Cell 4.1*, implement the `generate_fake_weights` to generate fake weights for each layer in `ConvNet` which are used by malicious clients. 
2. Train FL using the same federated learning training settings as in PART 1.3, but with 1 out of 5 client devices used in each round being malicious.
3. Graph the accuracy of the global model over 50 rounds (instead of 100 as in prior parts).  Discuss the accuracy difference between this model trained with the presense of malicious clients and the model from PART 1.3.  

---

In [0]:
## Code Cell 4.1
import torch.distributions as tdist

def get_malicious_devices(devices):
    # Creates a group of malicious devices
    num_malicious_devices = 1
    device_idxs = np.random.permutation(len(devices))[:num_malicious_devices]
    return [devices[i] for i in device_idxs]

def generate_fake_weights(device): 
    #define an empty dictionary to fill with the random model state tensors
    w_mal = {}
    for i in device['net'].state_dict().keys(): #iterate over every key in the model state dictionary for the device
        temp = torch.ones(device['net'].state_dict()[i].shape, dtype=torch.float64) #create a temporary tensor of the correct size
        temp.normal_(0, 0.5) #fill the tensor with data from the normal distribution, mean = 0 and sd = 0.5
        w_mal[i] = temp.cuda() #make the tensor cuda so that it is in the right format
    device['net'].load_state_dict(w_mal) #load the randomly generated state dict into the device
    return device #return the device with the malicious random weights


rounds = 50
local_epochs = 4
num_devices = 50
device_pct = 0.1
net = ConvNet().cuda()
criterion = nn.CrossEntropyLoss()

data_idxs = iid_sampler(trainset, num_devices, data_pct)
devices = [create_device(net, i, trainset, data_idxs[i])
           for i in range(num_devices)]

# Part 4.1: Implement!
# Hint: base this off of your federated learning code in PART 1
# modified for the malicious case here.
# We give you part of the code below, but you will likely need to plug
# in code throughout below.

start_time = time.time()
for round_num in range(rounds):
    
    # 1/5 of the devices in each round are malicious
    round_devices = get_devices_for_round(devices, device_pct)
    malicious_devices = get_malicious_devices(round_devices)

    print('Round: ', round_num)    
    for device in round_devices:
        if device not in malicious_devices:
            for local_epoch in range(local_epochs):
                train(local_epoch, device)
        else:   # the device is malicious
            # instead of training, get fake weights
            device = generate_fake_weights(device) 
    
    w_avg = average_weights(round_devices)

    for device in devices:
        device['net'].load_state_dict(w_avg)
        device['optimizer'].zero_grad()
        device['optimizer'].step()
        device['scheduler'].step()

    # test accuracy after aggregation
    test(round_num, devices[0])

total_time = time.time() - start_time
print('Total training time: {} seconds'.format(total_time))

Round:  0
(Device 23/Epoch 0) Train Loss: 2.107 | Train Acc: 21.245

KeyboardInterrupt: ignored

**Securing the Server against Malicious Clients**

To prevent the malicious clients from degrading the global model performance observed in PART 4.1, you will implement a secure weight update scheme (`secure_average_weights`) used by the central server to detect the fake weights sent by malicious clients. Then, you will ensure that these fake weights are not used in the averaging step of the global model. This approach significantly improves the performance of the global model in the presense of malicious clients.


---
<font color='red'>**PART 4.2:**</font> [10 points]

In *Code Cell 4.2*, implement `secure_average_weights`, which is a modified version of the `average_weights` function you implemented earlier in *Code Cell 1.5*. This `secure_average_weights` function has a simple anomaly detection algorithm using the $l_{2}$ distance between the individual weights reported by each client and the average weights over all clients. At each training round, the central server will receive a set of updated weights from $5$ clients. Let $w_{i}^{l}$ denote the layer $l$ weight tensor for client $i$. The detailed steps to implement `secure_average_weights` are as follows:


1. Compute the normalized $l_{2}$ distance between the weights of each device ($w_i$) and the average weights across all devices ($w_{avg}$). This distance is computed across all layers $$a_{i} = \sqrt{\frac{\sum_{l} ||w_{i}^{l} - w_{avg}^{l}||^{2}}{N}}$$ for each device $i$, where $N$ is the total number of parameters in the CNN, and $w_{avg}$ is defined as the output of the `average_weights` function you implemented in PART 1.2. Note that the resulting distance for a single device ($a_{i}$) is a scalar value.  
2. Compute the average of the device weight distances $$a_{avg} = \frac{\sum_{i=1}^5 a_{i}}{5}$$
3.  If $a_{i} > a_{avg} + \epsilon$, where $\epsilon$ is set to 0.3, then mark device $i$ as malicious, otherwise mark device $i$ as non-malicious. 
We observe through experimentation that devices with a $l_{2}$ distance greater than $a_{avg} + \epsilon$ $(\epsilon=0.3)$ are likely to be malicious.
4. Recompute the average weights by including only the non-malicious devices.

<font color='red'>**Deliverables**</font>

1. In *Code Cell 4.2*, implement the `secure_average_weights` function using the detailed steps described above. 
2. Train FL using the same settings as in Part 4.1, but with `secure_average_weights` instead of `average_weights`.
3. Graph the accuracy of the global model over 50 rounds. Compare against the model performance in PART 4.1.

---

In [0]:
## Code Cell 4.2
import math
def secure_average_weights(devices):
    """
    Returns the average of the weights.
    """
    # Part 4.2: Implement!
    num_devices = len(devices)

    avg_weight = average_weights(devices)
    a = [0 for i in range(num_devices)] 
    #compute L2 distance from device weights to average weight
    for i in range(num_devices):
        temp_counter = 0
        device_state_dict = devices[i]['net'].state_dict()
        N = 0
        for k in avg_weight.keys():
            #only use the parts of the state dict with floating point precision
            if(device_state_dict[k].dtype == torch.float32): 
                N += 1
                #add the L2 distance between the device's weight and the average weight for a particular key
                temp_counter += torch.dist(device_state_dict[k], avg_weight[k], 2).item()
        #average and square root as per the formula
        temp_counter = temp_counter/N
        temp_counter = math.sqrt(temp_counter)
        a[i] = temp_counter

    #compute the average of the device weight distances
    a_avg = sum(a)/len(a)

    new_devices = []
    #Check if above the threshold
    for i in range(num_devices):
        #if not, append to the accepted devices
        if(a[i] <= a_avg+0.3):
            new_devices.append(devices[i])
    assert (len(new_devices) == num_devices-1) #for debugging purposes, this should always be true

    #return the averaged weights of the accepted/ non-malicious devices
    return average_weights(new_devices)

rounds = 50
local_epochs = 4
num_devices = 50
device_pct = 0.1
net = ConvNet().cuda()
criterion = nn.CrossEntropyLoss()

# Part 4.2: Implement!
# You can reuse the training code you wrote in PART 4.1 but with replacing
# average_weights with secure_average_weights

data_idxs = iid_sampler(trainset, num_devices, data_pct)
devices = [create_device(net, i, trainset, data_idxs[i])
           for i in range(num_devices)]

start_time = time.time()
for round_num in range(rounds):
    
    # 1/5 of the devices in each round are malicious
    round_devices = get_devices_for_round(devices, device_pct)
    malicious_devices = get_malicious_devices(round_devices)

    print('Round: ', round_num)    
    for device in round_devices:
        if device not in malicious_devices:
            for local_epoch in range(local_epochs):
                train(local_epoch, device)
        else:   # the device is malicious
            # instead of training, get fake weights
            device = generate_fake_weights(device) 
    
    # use new function secure_average_weights
    malic = [0 if x in malicious_devices else 1 for x in round_devices]
    #print (malic)
    w_avg = secure_average_weights(round_devices)

    for device in devices:
        device['net'].load_state_dict(w_avg)
        device['optimizer'].zero_grad()
        device['optimizer'].step()
        device['scheduler'].step()

    # test accuracy after aggregation
    test(round_num, devices[0])

total_time = time.time() - start_time
print('Total training time: {} seconds'.format(total_time))

Round:  0
(Device 20/Epoch 1) Train Loss: 1.871 | Train Acc: 29.120

KeyboardInterrupt: ignored