In [1]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from tqdm import tqdm
import time
from threading import Thread
from multiprocessing import Process, Manager
from time import sleep
import random

# Create the model that will be instantiated for the workers and the master

In [2]:
class MnistModel(nn.Module):
    def __init__(self):
        super(MnistModel, self).__init__()
        # input is 28x28
        # padding=2 for same padding
        self.conv1 = nn.Conv2d(1, 32, 5, padding=2)
        # feature map size is 14*14 by pooling
        # padding=2 for same padding
        self.conv2 = nn.Conv2d(32, 64, 5, padding=2)
        # feature map size is 7*7 by pooling
        self.fc1 = nn.Linear(64*7*7, 1024)
        self.fc2 = nn.Linear(1024, 10)
        
    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), 2)
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, 64*7*7)   # reshape Variable
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x)

# Define the master model, number of workers and instantiate them
Also we define the batch size and optimizers here

In [3]:
master_model = MnistModel()
optimizer = optim.SGD(master_model.parameters(), lr=0.01, momentum=0.9) # Defining optimizer for master model

worker_size = 3
workers = []
optimizers = []
for i in range(worker_size):
    workers.append(MnistModel())
    optimizers.append(optim.SGD(workers[i].parameters(), lr=0.01, momentum=0.9)) 

In [4]:
# Change the waiting to random.random() or 0 depending on Task 1 or Task 2
waiting_time = 0.0
waiting_time = random.random() #random between 0 and 1

In [5]:
batch_size = 32

# Get the MNIST Dataset

In [6]:
# !wget www.di.ens.fr/~lelarge/MNIST.tar.gz
# !tar -zxvf MNIST.tar.gz

from torchvision.datasets import MNIST
from torchvision import transforms
import numpy as np

In [7]:
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./', train=True, download=True, transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True)

In [8]:
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./', train=False, transform=transforms.ToTensor()),
    batch_size=batch_size)

# Worker Loop
For each worker, perform a forward pass over the batch and record the loss.
After getting the forward pass loss, calculate the gradients over the backward pass and store them.
Do the same steps for all the workers and sum the gradients of each layer of the model into one dictionary called params.
## The code is parallel since we're using a multiprocessing system to imitate several workers
This is achieved by using the multiprocessing library offered by python

In [9]:
manager = Manager()
params = manager.dict()

In [10]:
def worker_forward(worker,index,params):
    time.sleep(waiting_time+index) # Use in Task 2 (change param in beginning of notebook)
    for data, target in train_loader: # at every iteration it generates a new batch
        data, target = Variable(data), Variable(target)
        optimizers[index].zero_grad()
        output = worker(data)
        loss = F.nll_loss(output, target)
        loss.backward()    # calc gradients
        for param_index,param in enumerate(worker.parameters()):
            if param_index not in params:
                params[param_index] = param.grad
            else:
                params[param_index] += param.grad 
        break # We do it over 1 batch

In [11]:
def forward_pass_workers():
    for index, worker in enumerate(workers):
        if(index == 0):
            worker_forward (worker,index,params) #This is done because we need to initialize some params of the shared dictionary. #This adds a latency of only 1 device which is okay for the sake of demonstration.
        else:
            process = Process(target = worker_forward, args = (worker,index,params, ))
            process.start()

# Master Model Loop
Once the params of the nodes have been aggregated, we update the gradients of the worker models since we have added the gradients.

In [12]:
# We won't use this function since at the end we're performing inference on the worker
def update_master_gradients():
    for param_index,param in enumerate(master_model.parameters()):
        param.grad = params[param_index]

Once we update the master model gradients, we perform one step of the optimizer to update the weights

In [13]:
def update_master_weights():
    master_model.train()
    optimizer.step()   # update gradients
    

# Most Importantly update the worker weights by the aggregated gradient of the loss

In [14]:
def update_worker_weights():
    for index, optim in enumerate(optimizers):
        workers[index].train()
        optim.step() #do a step for each optimizer

# To do it properly we aggregate all the previous parts together and we do it over multiple iterations

In [15]:
manager = Manager()
params = manager.dict()

In [16]:
def clear_temp_params(): # Only do it after master gradient has been updated
    params = manager.dict()

In [17]:
start_time = time.time()

for epoch in tqdm(range(200)):
    forward_pass_workers()
    update_master_gradients()
    update_worker_weights()
    clear_temp_params()
    
print("--- %s seconds ---" % (time.time() - start_time))

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
100%|██████████| 200/200 [02:20<00:00,  1.42it/s]

--- 140.5531461238861 seconds ---





In [18]:
def test(worker):
    test_loss = 0
    correct = 0
    count = 0
    for data, target in test_loader:
        count += batch_size
        data, target = Variable(data), Variable(target)
        output = worker(data)
        test_loss += F.nll_loss(output, target, size_average=False).data # sum up batch loss
        pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
        correct += pred.eq(target.data.view_as(pred)).long().cpu().sum()
        

    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, count,
        100. * correct / count))

In [19]:
test(workers[0])




Test set: Average loss: 0.2250, Accuracy: 9282/10016 (93%)

