In [57]:
#!/usr/bin/env python
import os
from math import ceil
from random import Random

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torch.optim as optim
import torchvision.transforms as transforms
import torch.distributed as dist
from torch.multiprocessing import Process

In [58]:
def run(rank,size):
    "Distributed function to be implemented later"
    tensor = torch.zeros(1)
    if rank == 0:
        tensor += 1
        dist.send(tensor=tensor,dst=1)
    else:
        dist.recv(tensor=tensor,src=0)
    print('Rank ',rank,' has data ',tensor[0])

def init_process(rank,size,fn,backend="gloo"):
    os.environ["MASTER_ADDR"] = '127.0.0.1'
    os.environ["MASTER_PORT"] = '29500'
    dist.init_process_group(backend,rank=rank,world_size=size)
    fn(rank,size)

In [59]:
def dummy(fn):
    size = 2
    processes = []
    for rank in range(size):
        p = Process(target=init_process, args=(rank,size,fn))
        p.start()
        processes.append(p)
        
    for p in processes:
        p.join()

In [60]:
# all-reduce
def run_reduce(rank,size):
    group = dist.new_group([0, 1])
    tensor = torch.ones(2)
    dist.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM, group=group)
    print('Rank ',rank,tensor)

In [61]:
dummy(run_reduce)

Rank  1 tensor([2., 2.])
Rank  0 tensor([2., 2.])


In [62]:
# Splitting Dataset - from PyTorch tutorial
""" Dataset partitioning helper """
class Partition(object):

    def __init__(self, data, index):
        self.data = data
        self.index = index

    def __len__(self):
        return len(self.index)

    def __getitem__(self, index):
        data_idx = self.index[index]
        return self.data[data_idx]


class DataPartitioner(object):

    def __init__(self, data, sizes=[0.7, 0.2, 0.1], seed=1234):
        self.data = data
        self.partitions = []
        rng = Random()
        rng.seed(seed)
        data_len = len(data)
        indexes = [x for x in range(0, data_len)]
        rng.shuffle(indexes)

        for frac in sizes:
            part_len = int(frac * data_len)
            self.partitions.append(indexes[0:part_len])
            indexes = indexes[part_len:]

    def use(self, partition):
        return Partition(self.data, self.partitions[partition])

In [63]:
""" Partitioning MNIST """
def partition_dataset():
    dataset = torchvision.datasets.MNIST('./data', train=True, download=True,
                             transform=transforms.Compose([
                                 transforms.ToTensor(),
                                 transforms.Normalize((0.1307,), (0.3081,))
                             ]))
    size = dist.get_world_size()
    bsz = int(128 / float(size))
    partition_sizes = [1.0 / size for _ in range(size)]
    partition = DataPartitioner(dataset, partition_sizes)
    partition = partition.use(dist.get_rank())
    train_set = torch.utils.data.DataLoader(partition,
                                         batch_size=bsz,
                                         shuffle=True)
    return train_set, int(bsz)

In [29]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

In [52]:
""" Gradient averaging. """
def average_gradients(model):
    size = float(dist.get_world_size())
    for param in model.parameters():
        dist.all_reduce(param.grad.data, op=dist.reduce_op.SUM)
        param.grad.data /= size

def distributed_training(rank,size):
    torch.manual_seed(1000)
    train_set,bsz = partition_dataset()
    model = Net()
    optimizer = optim.SGD(model.parameters(),lr=0.01,momentum=0.5)
    num_batches = ceil(len(train_set.dataset) / float(bsz))
    for epoch in range(10):
        epoch_loss = 0.0
        for data,target in train_set:
            optimizer.zero_grad()
            output = model(data)
            loss = F.nll_loss(output,target)
            loss.backward()
            epoch_loss += loss
            average_gradients(model)
            optimizer.step()
        print('Rank ',rank,' dist rank ',dist.get_rank(),' epoch ',
                 epoch,' epoch_loss ',epoch_loss/num_batches)

In [68]:
dummy(distributed_training)

Process Process-62:
Process Process-61:
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/gaurav/anaconda3/lib/python3.7/multiprocessing/process.py", line 297, in _bootstrap
    self.run()
  File "/home/gaurav/anaconda3/lib/python3.7/multiprocessing/process.py", line 297, in _bootstrap
    self.run()


KeyboardInterrupt: 

  File "/home/gaurav/anaconda3/lib/python3.7/multiprocessing/process.py", line 99, in run
    self._target(*self._args, **self._kwargs)
  File "/home/gaurav/anaconda3/lib/python3.7/multiprocessing/process.py", line 99, in run
    self._target(*self._args, **self._kwargs)
  File "<ipython-input-58-07c81a862869>", line 15, in init_process
    fn(rank,size)
  File "<ipython-input-58-07c81a862869>", line 15, in init_process
    fn(rank,size)
  File "<ipython-input-52-d409a395a41c>", line 17, in distributed_training
    optimizer.zero_grad()
  File "<ipython-input-52-d409a395a41c>", line 18, in distributed_training
    output = model(data)
  File "/home/gaurav/anaconda3/lib/python3.7/site-packages/torch/optim/optimizer.py", line 165, in zero_grad
    p.grad.zero_()
  File "/home/gaurav/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
KeyboardInterrupt
  File "<ipython-input-29-82b1960216b5>", line 14, in for