# Simple decentralized training

We train a CNN (LeNet) on decntralized MNIST datset. 

The data is split either randomly (iid) or according to label (non-iid). The communication graph can be also arbitarily set. Each node can perform multiple updates before communicating.

To try on colab: https://colab.research.google.com/drive/1DT4EaeEk9AuaFWMRaNkFEyQ9e-SphlUG?usp=sharing

In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np
from tqdm import tqdm


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
        self.fc1 = nn.Linear(2048, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.conv3(x)
        x = F.relu(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output


def client_update(client_model, optimizer, train_loader, epoch=5):
    """Train a client_model on the train_loder data."""
    client_model.train()
    for e in range(epoch):
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.cuda(), target.cuda()
            optimizer.zero_grad()
            output = client_model(data)
            loss = F.nll_loss(output, target)
            loss.backward()
            optimizer.step()
    return loss.item()


def diffuse_params(client_models, communication_matrix):
    """Diffuse the models with their neighbors."""
    if client_models:
      client_state_dicts = [model.state_dict() for model in client_models]
      keys = client_state_dicts[0].keys()
    for model, weights in zip(client_models, communication_matrix):
        neighbors = np.nonzero(weights)[0]
        model.load_state_dict(
            {
                key: torch.stack(
                    [weights[j]*client_state_dicts[j][key] for j in neighbors],
                    dim=0,
                ).sum(0) / weights.sum() 
                for key in keys
            }
        )


def average_models(global_model, client_models):
    """Average models across all clients."""
    global_dict = global_model.state_dict()
    for k in global_dict.keys():
        global_dict[k] = torch.stack([client_models[i].state_dict()[k] for i in range(len(client_models))], 0).mean(0)
    global_model.load_state_dict(global_dict)


def evaluate_model(model, data_loader):
    """Compute loss and accuracy of a single model on a data_loader."""
    model.eval()
    loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in data_loader:
            data, target = data.cuda(), target.cuda()
            output = model(data)
            loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    loss /= len(data_loader.dataset)
    acc = correct / len(data_loader.dataset)

    return loss, acc

def evaluate_many_models(models, data_loader):
  """Compute average loss and accuracy of multiple models on a data_loader."""
  num_nodes = len(models)
  losses = np.zeros(num_nodes)
  accuracies = np.zeros(num_nodes)
  for i in range(num_nodes):
    losses[i], accuracies[i] = evaluate_model(models[i], data_loader)
  return losses, accuracies

In [12]:
# IID case: all the clients have images of all the classes

# Hyperparameters

num_clients = 5
num_rounds = 5
epochs = 1
batch_size = 32

# Communication matrix
# For now restricted to doubly stochastic matrices.
comm_matrix = np.ones((num_clients, num_clients)) / num_clients
# comm_matrix = np.eye(num_clients)

# Creating decentralized datasets

traindata = datasets.MNIST('./data', train=True, download=True,
                       transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])
                       )
traindata_split = torch.utils.data.random_split(traindata, [int(traindata.data.shape[0] / num_clients) for _ in range(num_clients)])
train_loader = [torch.utils.data.DataLoader(x, batch_size=batch_size, shuffle=True) for x in traindata_split]

test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('./data', train=False, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])
        ), batch_size=10*batch_size, shuffle=True)

# Instantiate models and optimizers

global_model = Net().cuda()
client_models = [Net().cuda() for _ in range(num_clients)]
for model in client_models:
    model.load_state_dict(global_model.state_dict())

opt = [optim.SGD(model.parameters(), lr=0.1) for model in client_models]

# Runnining Decentralized training

for r in range(num_rounds):
    # client update
    loss = 0
    for i in range(num_clients):
        loss += client_update(client_models[i], opt[i], train_loader[i], epoch=epochs)
    
    # diffuse params across neighbors
    diffuse_params(client_models, comm_matrix)

    # evaluate
    test_losses, accuracies = evaluate_many_models(client_models, test_loader)
    
    print('%d-th round' % r)
    print('average train loss %0.3g | average test loss %0.3g | average test acc: %0.3f' % (loss / num_clients, test_losses.mean(), accuracies.mean()))

0-th round
average train loss 0.12 | average test loss 0.145 | average test acc: 0.957
1-th round
average train loss 0.0785 | average test loss 0.0896 | average test acc: 0.972
2-th round
average train loss 0.155 | average test loss 0.0651 | average test acc: 0.978
3-th round
average train loss 0.0995 | average test loss 0.0568 | average test acc: 0.981
4-th round
average train loss 0.0717 | average test loss 0.05 | average test acc: 0.984


In [13]:
# NON-IID case: every client has images of two categories chosen from [0, 1], [2, 3], [4, 5], [6, 7], or [8, 9].

# Hyperparameters

num_clients = 5
num_rounds = 5
epochs = 1
batch_size = 32


# Communication matrix
# For now restricted to doubly stochastic matrices.
comm_matrix = np.ones((num_clients, num_clients)) / num_clients
# comm_matrix = np.eye(num_clients)

# Creating decentralized datasets

traindata = datasets.MNIST('./data', train=True, download=True,
                       transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])
                       )
target_labels = torch.stack([traindata.targets == i for i in range(10)])
target_labels_split = []
for i in range(5):
    target_labels_split += torch.split(torch.where(target_labels[(2 * i):(2 * (i + 1))].sum(0))[0], int(60000 / num_clients))
traindata_split = [torch.utils.data.Subset(traindata, tl) for tl in target_labels_split]
train_loader = [torch.utils.data.DataLoader(x, batch_size=batch_size, shuffle=True) for x in traindata_split]

test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('./data', train=False, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])
        ), batch_size=batch_size, shuffle=True)

# Instantiate models and optimizers

global_model = Net().cuda()
client_models = [Net().cuda() for _ in range(num_clients)]
for model in client_models:
    model.load_state_dict(global_model.state_dict())

opt = [optim.SGD(model.parameters(), lr=0.1) for model in client_models]

# Runnining Decentralized training

for r in range(num_rounds):
    # client update
    loss = 0
    for i in range(num_clients):
        loss += client_update(client_models[i], opt[i], train_loader[i], epoch=epochs)   
    
    # diffuse params across neighbors
    diffuse_params(client_models, comm_matrix)

    # evaluate
    test_losses, accuracies = evaluate_many_models(client_models, test_loader)
    
    print('%d-th round' % r)
    print('average train loss %0.3g | average test loss %0.3g | average test acc: %0.3f' % (loss / num_clients, test_losses.mean(), accuracies.mean()))

0-th round
average train loss 0.465 | average test loss 2.27 | average test acc: 0.233
1-th round
average train loss 0.121 | average test loss 2.55 | average test acc: 0.263
2-th round
average train loss 0.0327 | average test loss 2.7 | average test acc: 0.353
3-th round
average train loss 0.0271 | average test loss 2.71 | average test acc: 0.418
4-th round
average train loss 0.0517 | average test loss 2.76 | average test acc: 0.427
