In [1]:
import os
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from math import ceil
from random import Random
from torch.multiprocessing import Process
from torch.autograd import Variable
from torchvision import datasets, transforms

In [55]:
class Partition(object):
    """ Dataset-like object, but only access a subset of it. """

    def __init__(self, data, index):
        self.data = data
        self.index = index

    def __len__(self):
        return len(self.index)

    def __getitem__(self, index):
        data_idx = self.index[index]
        return self.data[data_idx]


class DataPartitioner(object):
    """ Partitions a dataset into different chuncks. """

    def __init__(self, data, sizes=[0.7, 0.2, 0.1], seed=1234):
        self.data = data
        self.partitions = []
        rng = Random()
        rng.seed(seed)
        data_len = len(data)
        indexes = [x for x in range(0, data_len)]
        rng.shuffle(indexes)

        for frac in sizes:
            part_len = int(frac * data_len)
            self.partitions.append(indexes[0:part_len])
            indexes = indexes[part_len:]

    def use(self, partition):
        return Partition(self.data, self.partitions[partition])

class Net(nn.Module):
    """ Network architecture. """

    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x)

def partition_dataset():
    """ Partitioning MNIST """
    dataset = datasets.MNIST(
        root='data/',
        train=True,
        download=False,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ]))
    size = dist.get_world_size()
    bsz = 128 / float(size)
    partition_sizes = [1.0 / size for _ in range(size)]
    partition = DataPartitioner(dataset, partition_sizes)
    partition = partition.use(dist.get_rank())
    train_set = torch.utils.data.DataLoader(
        partition, batch_size=int(bsz), shuffle=True)
    return train_set, bsz

def average_gradients(model):
    """ Gradient averaging. """
    size = float(dist.get_world_size())
    for param in model.parameters():
        dist.all_reduce(param.grad.data, op=dist.reduce_op.SUM, group=0)
        param.grad.data /= size

def run(rank, size):
    """ Distributed Synchronous SGD Example """
    torch.manual_seed(1234)
    train_set, bsz = partition_dataset()
    model = Net()
    model = model
#    model = model.cuda(rank)
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

    num_batches = ceil(len(train_set.dataset) / float(bsz))
    for epoch in range(10):
        epoch_loss = 0.0
        for data, target in train_set:
            data, target = Variable(data), Variable(target)
#            data, target = Variable(data.cuda(rank)), Variable(target.cuda(rank))
            optimizer.zero_grad()
            output = model(data)
            loss = F.nll_loss(output, target)
            epoch_loss += loss.data.item() # tensor(4.5000)转成numpy的4.5
            loss.backward()
            average_gradients(model)
            optimizer.step()
        print('CPU ',
              dist.get_rank(), ', epoch ', epoch, ', ',
              'trian loss ', epoch_loss.data.numpy() / num_batches)

In [58]:
def init_processes(rank, size, fn, backend='tcp'):
    """ Initialize the distributed environment. """
    os.environ['MASTER_ADDR'] = '127.0.0.1'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(backend, rank=rank, world_size=size)
    fn(rank, size)

In [59]:
if __name__ == "__main__":
    size = 2
    processes = []
    for rank in range(size):
        p = Process(target=init_processes, args=(rank, size, run))
        p.start()
        processes.append(p)

    for p in processes:
        p.join()



CPU  0 , epoch  0 ,  trian loss  1.3098469878564765
CPU  1 , epoch  0 ,  trian loss  1.306282482675906
CPU  1 , epoch  1 ,  trian loss  0.5369825546167044
CPU  0 , epoch  1 ,  trian loss  0.5427322225021655
CPU  1 , epoch  2 ,  trian loss  0.42255146213686035
CPU  0 , epoch  2 ,  trian loss  0.43399929542785515
CPU  0 , epoch  3 ,  trian loss  0.36311379373709024
CPU  1 , epoch  3 ,  trian loss  0.3568102871177039
CPU  0 , epoch  4 ,  trian loss  0.31704591535556037
CPU  1 , epoch  4 ,  trian loss  0.3190948876744903
CPU  1 , epoch  5 ,  trian loss  0.2866672930686967
CPU  0 , epoch  5 ,  trian loss  0.2919410103673874
CPU  0 , epoch  6 ,  trian loss  0.2659025049921292
CPU  1 , epoch  6 ,  trian loss  0.2675688200667977
CPU  0 , epoch  7 ,  trian loss  0.2538931873053122
CPU  1 , epoch  7 ,  trian loss  0.2519818694352595
CPU  0 , epoch  8 ,  trian loss  0.24273240795013493
CPU  1 , epoch  8 ,  trian loss  0.23500510014450626
CPU  0 , epoch  9 ,  trian loss  0.23236726545321662
CPU  1