In [0]:
import os

In [0]:
if not os.path.isdir("./data"):
    os.mkdir("data")

###DATA UTILS

In [0]:
import os
import random

import numpy as np
import torch
from torch.utils.data import TensorDataset, DataLoader, Subset
from torchvision import datasets


def get_data_loaders(batch_size, num_clients, iid_split=True, percentage_val=0.2, full=False,
                     non_iid_mix=0):
    val_loader = None
    train_input, train_target, test_input, test_target = load_data(flatten=False, full=full)
    train_dataset = TensorDataset(train_input, train_target)

    # If validation set is needed randomly split training set
    if percentage_val:
        val_dataset, train_dataset = torch.utils.data.random_split(train_dataset,
                                                                   (int(percentage_val * len(train_dataset)),
                                                                    int((1 - percentage_val) * len(train_dataset)))
                                                                   )
        val_loader = DataLoader(dataset=val_dataset,
                                batch_size=batch_size,
                                shuffle=True)
    # Split data for each client
    if iid_split:
        # Random IID data split
        client_datasets = torch.utils.data.random_split(train_dataset, np.tile(int(len(train_dataset) / num_clients),
                                                                               num_clients).tolist())
    else:
        if non_iid_mix:
            non_iid_part, iid_part = get_non_iid_split(train_dataset, non_iid_mix)
            client_datasets = get_non_iid_datasets(num_clients, non_iid_part)  # make client_datasets with non_iid_part
            for client_nr, client_dataset in enumerate(client_datasets):
                chunk_size = int(len(iid_part) / num_clients)
                client_dataset.indices.\
                    extend(iid_part.indices[chunk_size*client_nr: chunk_size*(1+client_nr)])
        else:
            # Each client has different set of non overlapping digits
            client_datasets = get_non_iid_datasets(num_clients, train_dataset)
    random.shuffle(client_datasets)
    train_loaders = []
    for train_dataset in client_datasets:
        train_loader = DataLoader(dataset=train_dataset,
                                  batch_size=batch_size,
                                  shuffle=True)
        train_loaders.append(train_loader)

    test_loader = DataLoader(dataset=TensorDataset(test_input, test_target),
                             batch_size=batch_size)

    return train_loaders, val_loader, test_loader


def get_non_iid_split(train_dataset, non_iid_mix_p):
    # split train_dataset into a non-iid and iid part
    iid_part, non_iid_part = torch.utils.data.random_split(train_dataset, [round(non_iid_mix_p * len(train_dataset)),
                                                                           round((1 - non_iid_mix_p) * len(
                                                                               train_dataset))])
    if isinstance(train_dataset, Subset):
        iid_part.dataset = iid_part.dataset.dataset
        non_iid_part.dataset = non_iid_part.dataset.dataset
    return non_iid_part, iid_part


def load_data(cifar=False, one_hot_labels=False, normalize=False, flatten=False, full=False):
    data_dir = './data'

    if cifar:
        print('* Using CIFAR')
        cifar_train_set = datasets.CIFAR10(data_dir + '/cifar10/', train=True, download=True)
        cifar_test_set = datasets.CIFAR10(data_dir + '/cifar10/', train=False, download=True)

        train_input = torch.from_numpy(cifar_train_set.data)
        train_input = train_input.transpose(3, 1).transpose(2, 3).float()
        train_target = torch.tensor(cifar_train_set.targets, dtype=torch.int64)

        test_input = torch.from_numpy(cifar_test_set.data).float()
        test_input = test_input.transpose(3, 1).transpose(2, 3).float()
        test_target = torch.tensor(cifar_test_set.targets, dtype=torch.int64)

    else:
        print('* Using MNIST')
        mnist_train_set = datasets.MNIST(data_dir + '/mnist/', train=True, download=True)
        mnist_test_set = datasets.MNIST(data_dir + '/mnist/', train=False, download=True)

        train_input = mnist_train_set.data.view(-1, 1, 28, 28).float()
        train_target = mnist_train_set.targets
        test_input = mnist_test_set.data.view(-1, 1, 28, 28).float()
        test_target = mnist_test_set.targets

    if flatten:
        train_input = train_input.clone().reshape(train_input.size(0), -1)
        test_input = test_input.clone().reshape(test_input.size(0), -1)

    if not full:
        print('** Reducing the data-set, (use --full for the full thing)')
        train_input = train_input.narrow(0, 0, 5000)
        train_target = train_target.narrow(0, 0, 5000)
        test_input = test_input.narrow(0, 0, 5000)
        test_target = test_target.narrow(0, 0, 5000)

    print('** Use {:d} train and {:d} test samples'.format(train_input.size(0), test_input.size(0)))

    if one_hot_labels:
        train_target = convert_to_one_hot_labels(train_input, train_target)
        test_target = convert_to_one_hot_labels(test_input, test_target)

    if normalize:
        mu, std = train_input.mean(), train_input.std()
        train_input.sub_(mu).div_(std)
        test_input.sub_(mu).div_(std)

    return train_input, train_target, test_input, test_target


def convert_to_one_hot_labels(input, target):
    tmp = input.new_zeros(target.size(0), target.max() + 1)
    tmp.scatter_(1, target.view(-1, 1), 1.0)
    return tmp


def get_non_iid_datasets(num_clients, train_dataset):
    """
    This function divides samples in a way that
    each client has non-overlapping classes,
    e.g client 1 has only digits 0 and 1 while client 2 has only digits 2 and 3.
    To achieve this we perform binary search on labels tensor
    to divide initial dataset
    """
    client_datasets = []
    # if we have validation set then train is a Subset type
    if isinstance(train_dataset, Subset):
        labels = train_dataset.dataset.tensors[1][train_dataset.indices]
    else:
        labels = train_dataset.tensors[1]
    labels, sorted_indices = torch.sort(labels)
    digits_per_client = 10 // num_clients
    digit = 0
    for client in range(num_clients):
        first_idx = first_index(labels, 0, len(labels), digit)
        if client == num_clients - 1:
            last_idx = len(labels) - 1
        else:
            last_idx = last_index(labels, 0, len(labels), digit + (digits_per_client - 1))
        if isinstance(train_dataset, Subset):
            new_indices = np.array(train_dataset.indices)[sorted_indices[first_idx: last_idx + 1].numpy()]
            client_dataset = Subset(train_dataset.dataset, new_indices.tolist())
        else:
            client_dataset = Subset(train_dataset, sorted_indices[first_idx: last_idx + 1].tolist())
        client_datasets.append(client_dataset)
        digit += digits_per_client
    return client_datasets


# binary search functions to retrieve first and last index of label in sorted labels array
def first_index(array, low, high, item):
    if high >= low:
        mid = low + (high - low) // 2
        if (mid == 0 or item > array[mid - 1]) and array[mid] == item:
            return mid
        elif item > array[mid]:
            return first_index(array, (mid + 1), high, item)
        else:
            return first_index(array, low, (mid - 1), item)
    print(f"This label {item} was not found")
    return -1


def last_index(array, low, high, item):
    if high >= low:
        mid = low + (high - low) // 2
        if (mid == len(array) - 1 or item < array[mid + 1]) and array[mid] == item:
            return mid
        elif item < array[mid]:
            return last_index(array, low, (mid - 1), item)
        else:
            return last_index(array, (mid + 1), high, item)
    print(f"This label {item} was not found")
    return -1


def get_model_bits(state_dict):
    """
    :param state_dict: model object for which we want to get size in bits
    :return: model_size - number of bites for all model's parameters
    """
    torch.save(state_dict, "temp.p")
    # Multiply by 8 to go from bytes to bits
    model_size = os.path.getsize("temp.p") * 8
    os.remove('temp.p')
    return model_size


###QUANTIZATION

In [0]:
import torch


def quantize_float16(model_dict):
    """
    :param model: Model's state dict with default 32-bit float parameters
    :return: model's state dict with 16-bit float parameters
    """
    for name, param in model_dict.items():
        model_dict[name] = param.half()
    return model_dict


def quantize_int8(model_dict):
    # Find maximum parameter
    max_param = 0
    for name, param in model_dict.items():
        new_max = param.abs().max()
        if new_max > max_param:
            max_param = new_max
    # Scale the maximum value to the max of an int8
    multiplier = 127 / max_param
    for name, param in model_dict.items():
        model_dict[name] = (param * multiplier).to(torch.int8)
    return model_dict, multiplier


def decode_quantized_model_int8(model_dict, multiplier):
    for name, param in model_dict.items():
        model_dict[name] = param.to(torch.float32) / multiplier
    return model_dict


def no_quantization(model_dict):
    return model_dict


###MODEL

In [0]:
from torch import nn as nn
from torch.nn import functional as F


# https://arxiv.org/pdf/1602.05629.pdf
# A CNN with two 5x5 convolution layers (the first with
# 32 channels, the second with 64, each followed with 2x2
# max pooling), a fully connected layer with 512 units and
# ReLu activation, and a final softmax output layer (1,663,370
# total parameters).

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 5)
        self.conv2 = nn.Conv2d(32, 64, 5)
        self.fc1 = nn.Linear(4 * 4 * 64, 512)
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4 * 4 * 64)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


###TRAINING


In [0]:
import torch
import torch.optim as optim


class Client:
    def __init__(self, data_loader, epochs=5):
        self.data_loader = data_loader
        self.epochs = epochs
        self.lr = 0.001
        self.log_interval = 5
        self.seed = 42
        torch.manual_seed(self.seed)
        self.save_model = False
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = CNN().to(self.device)
        self.optimizer = optim.SGD(self.model.parameters(), lr=self.lr)
        self.gradient_compression = None
        self.criterion = torch.nn.CrossEntropyLoss()
        self.model_name = "mnist_cnn"


def train(client, epoch, logging=True):
    # put model in train mode, we need gradients
    client.model.train()
    train_loader = client.data_loader
    train_loss = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.cuda(), target.cuda()
        client.optimizer.zero_grad()
        output = client.model(data)
        # get the basic loss for our main task
        total_loss = client.criterion(output, target)
        total_loss.backward()
        train_loss += total_loss.item()
        client.optimizer.step()
    _, train_accuracy = test(client, logging=False)
    if logging:
        print(f'Train Epoch: {epoch} Loss: {total_loss.item():.6f}, Train accuracy: {train_accuracy}')
    return train_loss, train_accuracy


def test(client, logging=True):
    # put model in eval mode, disable dropout etc.
    client.model.eval()
    test_loss = 0
    correct = 0
    test_loader = client.data_loader
    # disable grad to perform testing quicker
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.cuda(), target.cuda()
            data, target = data.to(client.device), target.to(client.device)
            output = client.model(data)
            test_loss += client.criterion(output, target).item()
            # prediction is an output with maximal probability
            pred = output.argmax(1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    test_accuracy = 100. * correct / len(test_loader.dataset)
    if logging:
        print(f'Test set: Average loss: {test_loss:.4f}, '
              f'Test accuracy: {correct} / {len(test_loader.dataset)} '
              f'({test_accuracy:.0f}%)\n')
    return test_loss, test_accuracy


def average_client_models(clients_dicts):
    """
    :param clients_dicts: list of clients state dicts
    :return: state_dict of averaged parameters
    """
    # To perform averaging we need to go back to float32 cause summing is not supported for float16
    for client in clients_dicts:
        for name, param in client.items():
            client[name] = param.float()
    dict_keys = clients_dicts[0].keys()
    final_dict = dict.fromkeys(dict_keys)
    for key in dict_keys:
        # Average model parameters
        final_dict[key] = torch.cat([dictionary[key].unsqueeze(0) for dictionary in clients_dicts], dim=0).sum(0).div(
            len(clients_dicts))
    return final_dict


### EXPERIMENTS

In [0]:
# Create outputs directory when running in colab for the first time
if not os.path.isdir("./outputs"):
    os.mkdir("./outputs")

In [None]:
import os
import pickle
from functools import reduce


##############################
# Configure here to get a specific experiment
batch_size = 25
num_clients = 5
target_accuracy = 93
iid_split = True
non_iid_mix = 0
# validation set
percentage_val = 0
# if use full MNIST with 60000 train and 10000 test
full = True
# default setup is 5 epochs per client,
# here we have five clients therefore  we need [5, 5, 5, 5, 5]
# change the list accordingly to get variable
# number of epochs for different clients
client_epochs = 5
epochs_per_client = num_clients * [client_epochs]
quantization = quantize_int8
##############################


# Load data
train_loaders, _, test_loader = get_data_loaders(batch_size, num_clients, non_iid_mix=non_iid_mix,
                                                 percentage_val=percentage_val, iid_split=iid_split, full=full)

# Initialize all clients
clients = [Client(train_loader, epochs) for train_loader, epochs in zip(train_loaders, epochs_per_client)]

# Set seed for the script
torch.manual_seed(clients[0].seed)

testing_accuracy = 0
num_rounds = 1

central_server = Client(test_loader)

filename = f"iid_split_{iid_split}_quantization_{quantization.__name__}_num_epochs_{client_epochs}_mix_{non_iid_mix}.pkl"
experiment_state = {"num_rounds": 0,
                    "test_accuracies": [],
                    "conserved_bits_from_server": [],
                    "conserved_bits_from_clients": [],
                    "transferred_bits_from_server": [],
                    "transferred_bits_from_clients": [],
                    "original_bits_from_server": [],
                    "original_bits_from_clients": []
                    }

# Multiplier for int8 quantization
multiplier = 0

while testing_accuracy < target_accuracy:
    print("Communication Round {0}".format(num_rounds))

    if num_rounds > 1:
        # Load server weights onto clients
        total_bits_conserved = 0
        total_bits_transferred = 0
        total_float_model_bits = 0
        for client in clients:
            with torch.no_grad():
                # Calculate number of bits in full server model
                float_model_bits = get_model_bits(central_server.model.state_dict())
                # Quantize server's model
                if quantization == quantize_int8:
                    quantized_model, multiplier = quantization(central_server.model.state_dict())
                else:
                    quantized_model = quantization(central_server.model.state_dict())
                bits_transferred = get_model_bits(quantized_model)
                # Calculate how many bits we saved
                bits_conserved = (float_model_bits - bits_transferred)
                # If quantization method is int8, decode the weights
                if quantization == quantize_int8:
                    quantized_model = decode_quantized_model_int8(quantized_model, multiplier)
                # Distribute quantized model on clients
                client.model.load_state_dict(quantized_model)
                # Update summary values
                total_bits_conserved += bits_conserved
                total_bits_transferred += bits_transferred
                total_float_model_bits += float_model_bits

        # Add to our summary
        experiment_state["conserved_bits_from_server"].append(total_bits_conserved // num_clients)
        experiment_state["transferred_bits_from_server"].append(total_bits_transferred // num_clients)
        experiment_state["original_bits_from_server"].append(total_float_model_bits // num_clients)

    # Perform E local training steps for each client
    for client_idx, client in enumerate(clients):
        print("Training client {0}".format(client_idx))
        for epoch in range(1, client.epochs + 1):
            train(client, epoch)

    with torch.no_grad():
        # Get number of bits in all clients' models before quantization
        clients_bits = reduce((lambda x, y: x + y), [get_model_bits(client.model.state_dict()) for client in clients])
        # Quantize clients models
        if quantization == quantize_int8:
            quantized_clients_models = []
            multipliers = []
            for client in clients:
                client_model, multiplier = quantization(client.model.state_dict())
                quantized_clients_models.append(client_model)
                multipliers.append(multiplier)
        else:
            quantized_clients_models = [quantization(client.model.state_dict()) for client in clients]
        quantized_clients_bits = reduce((lambda x, y: x + y),
                                        [get_model_bits(client) for client in quantized_clients_models])
        bits_conserved = (clients_bits - quantized_clients_bits) // num_clients
        # Add to summary
        experiment_state["conserved_bits_from_clients"].append(bits_conserved)
        experiment_state["transferred_bits_from_clients"].append(quantized_clients_bits // num_clients)
        experiment_state["original_bits_from_clients"].append(clients_bits // num_clients)
        # Decode bits on central server side:
        if quantization == quantize_int8:
            new_client_models = []
            for client, multiplier in zip(quantized_clients_models, multipliers):
                new_client = decode_quantized_model_int8(client, multiplier)
                new_client_models.append(new_client)
            quantized_clients_models = new_client_models
        # Send quantized models to server and average them
        averaged_model = average_client_models(quantized_clients_models)
        central_server.model.load_state_dict(averaged_model)
    # We have to convert back to float32 otherwise there is a mismatch with input dtype
    central_server.model.to(torch.float32)
    # Test the aggregated model
    test_loss, testing_accuracy = test(central_server)
    experiment_state['test_accuracies'].append(testing_accuracy)
    experiment_state['num_rounds'] = num_rounds + 1
    
    # Save experiment states
    with open(os.path.join("./outputs", filename), "wb") as f:
        pickle.dump(experiment_state, f)
    
    num_rounds += 1
    

# Save model
if central_server.save_model:
    torch.save(central_server.model.state_dict(), f"{central_server.model_name}.pt")
