In [1]:
# utils.py
# !pip install tensorboard
import os
import random
import numpy as np
import torch

def set_state(seed=42069):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
    
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

In [2]:
import os
import json 

def read_data(train_data_dir, test_data_dir=None):
    '''parses data in given train and test data directories
    assumes:
    - the data in the input directories are .json files with 
        keys 'users' and 'user_data'
    - the set of train set users is the same as the set of test set users
    
    Return:
        clients: list of client ids
        groups: list of group ids; empty list if none found
        train_data: dictionary of train data
        test_data: dictionary of test data
    '''
    group_ids = []
    train_data = {}
    test_data = {}

    train_files = os.listdir(train_data_dir)
    if 'train.json' in train_files:
        with open(os.path.join(train_data_dir, 'train.json'), 'r') as f:
            cdata = json.load(f)
        if 'hierarchies' in cdata:
            group_ids.extend(cdata['hierarchies'])
        train_data.update(cdata['user_data'])
    else:
        train_files = [f for f in train_files if f.endswith('.json')]
        for f in train_files:
            file_path = os.path.join(train_data_dir,f)
            with open(file_path, 'r') as inf:
                cdata = json.load(inf)
            if 'hierarchies' in cdata:
                group_ids.extend(cdata['hierarchies'])
            train_data.update(cdata['user_data'])

    test_files = os.listdir(test_data_dir)
    if test_data_dir is not None:
        if 'test.json' in test_files:
            with open(os.path.join(test_data_dir, 'test.json'), 'r') as f:
                cdata = json.load(f)
            test_data.update(cdata['user_data'])
        else:
            test_files = [f for f in test_files if f.endswith('.json')]
            for f in test_files:
                file_path = os.path.join(test_data_dir, f)
                with open(file_path, 'r') as inf:
                    cdata = json.load(inf)
                test_data.update(cdata['user_data'])

    client_ids = list(sorted(train_data.keys()))
    return client_ids, group_ids, train_data, test_data

In [3]:
# config.py
from torchvision import datasets, transforms

dataset_to_transform = {
    'mnist': {
        'train': transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307), (0.3081))
        ]),
        'val': transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307), (0.3081))
        ]),
    },
    'cifar10': {
        'train': transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ]),
        'val': transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ]),
    },
    'cifar100': {
        'train': transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.507, 0.487, 0.441],
                                 std=[0.267, 0.256, 0.276])
        ]),
        'val': transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.507, 0.487, 0.441],
                                 std=[0.267, 0.256, 0.276])
        ]),
    }
}

train_ds = datasets.MNIST('data/', train=True, download=True, transform=dataset_to_transform['mnist']['train'])
test_ds = datasets.MNIST('data/', train=False, download=True, transform=dataset_to_transform['mnist']['val'])

from torch.utils.data import DataLoader
test_dl = DataLoader(test_ds, batch_size=32)

In [5]:
train_emnist = datasets.EMNIST('data/', split='balanced', train=True, download=True, transform=dataset_to_transform['mnist']['train'])
test_emnist = datasets.EMNIST('data/', split='balanced', train=False, download=True, transform=dataset_to_transform['mnist']['val'])
train_cifar = datasets.CIFAR10('data/', train=True, download=True, transform=dataset_to_transform['cifar10']['train'])
val_cifar = datasets.CIFAR10('data/', train=False, download=True, transform=dataset_to_transform['cifar10']['train'])

Downloading and extracting zip archive
Downloading http://www.itl.nist.gov/iaui/vip/cs_links/EMNIST/gzip.zip to data/EMNIST/raw/emnist.zip


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Extracting data/EMNIST/raw/emnist.zip to data/EMNIST/raw
Processing byclass


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


Processing bymerge
Processing balanced
Processing letters
Processing digits
Processing mnist
Done!
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Extracting data/cifar-10-python.tar.gz to data/
Files already downloaded and verified


In [5]:
# train_data_dir = '/workspace/leaf/FedProx/data/nist/data/train/'
# test_data_dir = '/workspace/leaf/FedProx/data/nist/data/test/'
# client_ids, group_ids, train_data, test_data = read_data(train_data_dir, test_data_dir)


In [38]:
import torch
from torch.utils.data import DataLoader


def get_client_data(train_data, 
                    test_data=None,
                    batch_size=10,
                    shuffle=True,
                    num_workers=0,
                    **kwargs):
    client_to_data = {}
    for client_id in train_data.keys():
        X_train = torch.Tensor(train_data[client_id]['x']).type(torch.float32)
        # X_train = torch.Tensor(train_data[client_id]['x']).view.(-1, num_channels, img_size, img_size).type(torch.float32)
        y_train = torch.Tensor(train_data[client_id]['y']).type(torch.int64)
        train_dataset = [(x, y) for x, y in zip(X_train, y_train)]
        client_to_data[client_id] = {}
        client_to_data[client_id]['train'] = DataLoader(
            train_dataset, 
            batch_size=batch_size,
            shuffle=shuffle,
            **kwargs
        )
        if test_data is not None:
            X_test = torch.Tensor(test_data[client_id]['x']).type(torch.float32)
            y_test = torch.Tensor(test_data[client_id]['y']).type(torch.int64)
            test_dataset = [(x, y) for x, y in zip(X_test, y_test)]
            client_to_data[client_id]['eval'] = DataLoader(
                test_dataset, 
                batch_size=batch_size,
                **kwargs
            )
        else:
            client_to_data[client_id]['eval'] = None
            
    return client_to_data


def get_clients(train_data,
                test_data=None,
                client_cls=None,
                client_params=None,
                dataloader_params=None):
    client_cls = client_cls or Client
    dataloader_params = dataloader_params or {}
    client_params = client_params or {}
#     if isinstance(client_params, dict):
#         client_params = [client_params for _ in range(len(train_data))]
    client_to_data = get_client_data(
        train_data,
        test_data, 
        **dataloader_params
    )
    clients = {}
    for client_id in client_to_data.keys():
        clients[client_id] = client_cls(client_id, 
                                     client_to_data[client_id]['train'],
                                     eval_loader=client_to_data[client_id].get('eval'),
                                     **client_params)
    return clients

In [39]:
# # sampling.py
# import numpy as np
# from sklearn.model_selection import train_test_split
# from torch.utils.data import DataLoader


# def get_iid_shards(labels,
#                    num_clients=100,
#                    client_ids=None,
#                    seed=None):
#     """Returns a homogeneous (IID) mapping of client ID's to sample indices of a dataset.
    
#     Parameters
#     ----------
#     labels : list, np.ndarray, torch.Tensor
#         List of class labels for a dataset.
#     num_clients : int (default=100)
#         Number of clients used for training.
#     client_ids : list, np.ndarray (default=None)
#         List of client ID's. If `None`, defaults to a `range(num_clients)`.
#     seed : int (default=None)
#         Random state.
       
#     Notes
#     -----
#     ```
#     num_samples_per_client = len(labels) // num_clients # 60000 / 100 = 600 for MNIST
#     ```
#     """
#     random_state = np.random.RandomState(seed)
#     client_ids = client_ids or list(range(num_clients))
    
#     # randomly shuffle sample indices to generate IID (homogeneous) clients
#     indices_shuffled = random_state.choice(range(len(labels)), len(labels), replace=False)
#     num_samples_per_client = len(indices_shuffled) // num_clients

#     # assign `num_samples_per_client` random samples to each client
#     for i, client_id in enumerate(client_ids):
#         client_indices = indices_shuffled[i * num_samples_per_client : (i + 1) * num_samples_per_client]
#         client_to_shard[client_id] = client_indices
    
#     return client_to_shard
    
    
# def get_client_shards(labels, 
#                       is_iid=True,
#                       num_clients=100,
#                       shard_size=300,
#                       client_ids=None,
#                       seed=None):
#     """Returns a mapping of client ID's to sample indices of a dataset.
    
#     Parameters
#     ----------
#     labels : list, np.ndarray, torch.Tensor
#         List of class labels for a dataset.
#     is_iid : bool
#         Boolean whether to partition devices into homogenous (IID) or
#         heterogeneous (non-IID) samples.
#     num_clients : int (default=100)
#         Number of clients used for training.
#     client_ids : list, np.ndarray (default=None)
#         List of client ID's. If `None`, defaults to a `range(num_clients)`.
#     shard_size : int (default=300)
#         Size of each shard to split labels by.
#     seed : int (default=None)
#         Random state.
       
#     Notes
#     -----
#     ```
#     num_shards = len(labels) // shard_size # 60000 / 300 = 200 for MNIST
#     num_shards_per_client = num_shards // num_clients # 200 / 100 = 2 for MNIST
#     ```
#     """

#     random_state = np.random.RandomState(seed)
#     client_ids = client_ids or list(range(num_clients))
#     client_to_shard = {c_id: [] for c_id in client_ids}
#     if is_iid:
#         # randomly shuffle sample indices if IID (homogoneous)
#         sample_indices = random_state.choice(range(len(labels)), len(labels), replace=False)
#     else:
#         # sort sample indices by by label if non-IID (heterogeneous)
#         sample_indices = np.argsort(labels).tolist()

#     num_shards = len(labels) // shard_size
#     num_shards_per_client = num_shards // num_clients # how many shards fit in each client
#     shard_indices = set(range(num_shards))

#     for i, client_id in enumerate(client_ids):
#         client_shard_indices = random_state.choice(list(shard_indices), 
#                                                    num_shards_per_client, 
#                                                    replace=False)
#         for shard_idx in client_shard_indices:
#             client_to_shard[client_id].extend(
#                 sample_indices[shard_idx*shard_size : (shard_idx+1)*shard_size]
#             )
#             shard_indices.remove(shard_idx)
            
#     return client_to_shard


# def get_client_data(dataset, 
#                     num_clients,
#                     client_ids=None,
#                     is_iid=True,
#                     shard_size=300,
#                     batch_size=32,
#                     shuffle=True,
#                     seed=None,
#                     dataloader_params=None):
#     """Returns a mapping of client ID's to their corresponding train & validation dataloaders.
    
#     Parameters
#     ----------
#     is_iid : bool (default=True)
#         Boolean whether data is IID (Homogeneous) or non-IID (Heterogeneous)
#     **kwargs
#         Additional parameters used when instantiating each clients dataloader 
        
#     """
#     dataloader_params = dataloader_params or {}
#     client_to_data = {}
#     labels = dataset.targets
#     client_to_shard = get_client_shards(
#         labels,
#         is_iid=is_iid,
#         num_clients=num_clients,
#         client_ids=client_ids,
#         shard_size=shard_size,
#         seed=seed
#     )
#     # iterate through each client and create train/val dataloaders using the shard indices
#     for k in client_to_shard.keys():
#         client_indices = client_to_shard[k]
#         client_dataset = SplitDataset(dataset, client_indices)
#         # reseed workers for reproducibility
# #         g = torch.Generator()
# #         g.manual_seed(0)
#         client_to_data[k] = DataLoader(
#             client_dataset, 
#             batch_size=batch_size,
#             shuffle=shuffle, 
# #             worker_init_fn=seed_worker,
# #             generator=g,
#             **dataloader_params
#         )
        
#     return client_to_data


# def get_clients(dataset,
#                 num_clients=100,
#                 client_cls=None,
#                 client_ids=None,
#                 is_iid=True,
#                 shard_size=300,
#                 batch_size=32,
#                 client_params=None,
#                 dataloader_params=None):
#     client_cls = client_cls or Client
#     client_params = client_params or {}
# #     if client_params is None:
# #         client_params = {}
# #     elif isinstance(client_params, dict):
# #         client_params = [client_params for _ in range(num_clients)]
#     client_to_data = get_client_data(
#         dataset,
#         num_clients=num_clients,
#         client_ids=client_ids,
#         is_iid=is_iid,
#         shard_size=shard_size,
#         batch_size=batch_size,
#         dataloader_params=dataloader_params
#     )
#     clients = {
#         k: client_cls(k, dl, **client_params)
#         for k, dl
#         in client_to_data.items()
#     }
#     return clients

In [55]:
# client.py


class Client:
    """Base client.
    
    Parameters
    ----------
    client_id : str
        Id of the client.
    dataloader : DataLoader
        Local dataset used for training on the client.
    """
    def __init__(self, client_id, train_loader, eval_loader=None, device='cpu'):
        self.client_id = client_id
        self.train_loader = train_loader
        self.eval_loader = eval_loader
        self.device = device
        self._model = None
        self._device = None
        self._optimizer = None
        self._scheduler = None
        
        self._local_steps = 0
        
    @property
    def model(self):
        return self._model
    
    @model.setter
    def model(self, model):
        self._model = model

    @property
    def optimizer(self):
        return self._optimizer
    
    @optimizer.setter
    def optimizer(self, optimizer):
        self._optimizer = optimizer
    
    @property
    def scheduler(self):
        return self._scheduler
    
    @scheduler.setter
    def scheduler(self, scheduler):
        self._scheduler = scheduler
    
    @property
    def local_steps(self):
        return self._local_steps
    
    @local_steps.setter
    def local_steps(self, v):
        self._local_steps = v
        
    def __len__(self):
        return len(self.train_loader.dataset)
    
    def _update(self, criterion, num_epochs=1):
        """Algorithm 1 (ClientUpdate).
        
        Parameters
        ----------
        num_epochs (E) : int
            Number of epochs.
        criterion : 
        """
        self.model.train()
        self.model.to(self.device)
        self.local_steps = 0
        
        total_loss = np.zeros(num_epochs, dtype=np.float32)
        total_correct = np.zeros(num_epochs, dtype=np.float32)
        for i in range(num_epochs):
            for x, y in self.train_loader:
                x = x.to(self.device)
                y = y.to(self.device)
                
                self.optimizer.zero_grad()
                
                logits = self.model(x)
                loss = criterion(logits, y)
                loss.backward()
                self.optimizer.step()
                
                total_loss[i] += loss.item()
                total_correct[i] += (logits.argmax(-1) == y).sum().item()
                self.local_steps += 1
                
            # set this to a function we can call that way inherritance is easier
            if self.scheduler is not None:
                self.scheduler.step()
        
        metrics = {
            'loss': total_loss / len(self.train_loader),
            'accuracy': total_correct / len(self)
        }
#         if self.eval_loader is not None:
#             val_metrics = self.validate(criterion)
#             for k, v in val_metrics:
#                 metrics[f'val_{k}'] = v
                
        # move model back to cpu
        self.model.to('cpu')
        return metrics
    
    def update(self, criterion, num_epochs=1):
        return self._update(criterion, num_epochs)

    def validate(self, criterion):
        self.model.eval()
        loss = 0 
        correct = 0
        with torch.no_grad():
            for x, y in self.eval_loader:
                x = x.to(self.device)
                y = y.to(self.device)
                logits = self.model(x)
                correct += (logits.argmax(-1) == y).sum().item()
                loss += criterion(logits, y).item()
                
        metrics = {
            'loss': loss / len(self.eval_loader),
            'accuracy': correct / len(self.eval_loader.dataset)
        }
        return metrics
    
    def get_gradients(self, criterion):
        return get_gradients(self.model, self.train_loader, criterion, device=self.device)
    
    
class SCAFFOLDClient:
    
    def __init__(self, client_id, train_loader, eval_loader=None, option='II', device='cpu'):
        super().__init__(client_id, train_loader, eval_loader, eval_loader, device)
        self.option = option
        self.control = None
        self.control_new = None
        self.control_delta = None
        self._control_server = None
        
    @property
    def control_server(self):
        return self._control_server
    
    @control_server.setter
    def control_server(self, control_server):
        self._control_server = control_server
        
    def set_control_variates(self):
        # set control variates if first update
        if self.control is None:
            self.control = [torch.zeros_like(p.data) for p in self.model]
        if self.control_new is None:
            self.control_new = [torch.zeros_like(p.data) for p in self.model]
        if self.contrl_delta is None:
            self.control_delta = [torch.zeros_like(p.data) for p in self.model]
            
    def update(self, criterion, num_epochs=1):
        self.set_control_variates()
        model_server = deepcopy(self.model)
        results = self._update(criterion, num_epochs=num_epochs)
        
        # (4) updates to the local control variate
        if self.option == 'I':
            # gradients of global model w.r.t local data
            grads = get_gradients(model_server, self.dataset, criterion, device=self.device)
            for d_p, ci_new in zip(grads, self.control_new):
                ci_new.data = d_p.data
        elif self.option == 'II':
            grads = [torch.zeros_like(p.data) for p in self.model.parameters()]
            for p_server, p_client, d_p in zip(model_server.parameters(), zip(self.model.parameters()), grads):
                d_p.data = p_client.data.detach() - p_server.data.detach()
                
            lr = self.optimizer.param_groups[0]['lr']
            for ci, ci_new, c, d_p in zip(self.control, self.control_new, control_server, grads):
                ci_new.data = ci - c + 1 / (self.local_steps * lr) * d_p.data
#                 ci_delta.data = - c + 1 / (self.local_steps * lr) * d_p.data
        
        # store the control correction used in (5) and update the local control variate
        for ci, ci_new, ci_delta in zip(self.control, self.control_new, self.control_delta):
            ci_delta.data = ci_new.data - ci.data
            ci.data = ci_new.data
            
        return results
        

In [41]:
def get_gradients(model, dataset, criterion, device='cpu'):
    """Returns a list of gradients of `model` w.r.t. `dataset`"""
    model.eval()
    model.to(device)
    # clear gradients
    for p in model.parameters():
        if p.grad is not None and p.requires_grad:
            p.grad.zero_()
    for x, y in dataloader:
        x = x.to(device)
        y = y.to(device)

        logits = model(x)
        loss = criterion(logits, y)

        # accumulate the average gradient of each batch
        loss.backward()

    # normalize the accumulated gradient across batches
    grads = []
    for p in model.parameters():
        # what to do when the model has layers that don't require gradients?
        grads.append(p.grad / len(dataloader))

    return grads

In [42]:
# dataset.py
from torch.utils.data import Dataset, DataLoader


class SplitDataset(Dataset):
    """Dataset for a client partitioned by a list of indices."""
    def __init__(self, dataset, indices):
        self.dataset = dataset
        self.indices = list(indices)
        
    def __len__(self):
        return len(self.indices)
    
    def __getitem__(self, index):
        return self.dataset[self.indices[index]]

In [43]:
# model.py
import torch.nn as nn


class CNN(nn.Module):
    """CNN described in "Communication-Efficient Learning of Deep Networks
    from Decentralized Data" (https://arxiv.org/pdf/1602.05629.pdf).
    
    Parameters
    ----------
    in_features : int (default=1)
        Number of channels in the input image.
    num_classes : int (default=10)
        Number of class labels.
    """
    def __init__(self, in_features=1, num_classes=10):
        super(CNN, self).__init__()
        self._in_features = in_features
        self._num_classes = num_classes
        
        self.conv1 = nn.Conv2d(self._in_features, 32, 5)
        self.conv2 = nn.Conv2d(32, 64, 5)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, self._num_classes)
        
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = self.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = x.view(x.size(0), -1)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x
    
    
class MLP(nn.Module):
    """CNN described in "Communication-Efficient Learning of Deep Networks
    from Decentralized Data" (https://arxiv.org/pdf/1602.05629.pdf).
    
    Parameters
    ----------
    in_features : int (default=784)
        Number input features.
    hidden_dim : int (cdefault=200)
        Number of hidden units.
    num_classes : int (default=10)
        Number of class labels.
    """
    def __init__(self, in_features=784, hidden_dim=200, num_classes=10):
        super(MLP, self).__init__()
        self._in_features = in_features
        self._hidden_dim = hidden_dim
        self._num_classes = num_classes
        
        self.fc1 = nn.Linear(self._in_features, self._hidden_dim)
        self.fc2 = nn.Linear(self._hidden_dim, self._num_classes)
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x
    

class LogisticRegression(nn.Module):
    """Multinomial Logistic Regression"""
    def __init__(self, in_features, num_classes):
        super(LogisticRegression, self).__init__()
        self._in_features = in_features
        self._num_classes = num_classes
        self.fc = nn.Linear(self._in_features, self._num_classes)
        
    def forward(self, x):
        x = self.fc(x)
        return x
    
    
MODEL_MAP = {
    'cnn': CNN,
    'mlp': MLP,
    'lr': LogisticRegression
}

In [44]:
# servers (aggregates parameters from local solvers)
    # base.py
    # fedavg.py
    # fedprox.py
    # feddane.py
    # fednova.py
    # fedopt.py
    # scaffold.py
# optimizers (local solvers)
    # fedprox.py
    # feddane.py
    # fedopt.py
    # scaffold.py
    # fednova.py
# utils
    # client.py
    # sampling.py
# models
    # mnist
        # cnn.py
        # mlp.py

In [45]:
# create a server for each algo
# create a client for each algo
# create an optimizer for each algo

In [46]:
# optimizers/fedprox.py
from torch.optim.optimizer import Optimizer


class FedProxSolver(Optimizer):
    """Implements FedProx local solver.
    
    This adds a proximal term to any clients optimizer.
    
    This wrapper allows us to pass in any torch.optim.Optimizer for
    a given client, not limited to SGD as originally proposed.
    
    Args:
        optimizer (torch.optim.Optimizer): local optimizer.
        mu (float): proximal term weight (default: 0)

    __ https://arxiv.org/pdf/1812.06127.pdf
        
    Example:
        >>> # train a model locally for a client
        >>> client_optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
        >>> client_optimizer = FedProxLocal(client_optimizer, mu=0.1)
        >>> client_optimizer.zero_grad()
        >>> loss_fn(model(input), target).backward()
        >>> client_optimizer.step()
    
    """
    
    def __init__(self, 
                 optimizer,
                 mu=0):
        if mu < 0.0:
            raise ValueError(f'Invalid mu value: {mu}')
        self.optimizer = optimizer
        self.mu = mu
        self.param_groups = self.optimizer.param_groups
        self.state = self.optimizer.state
    
    def _update(self, group):
        """Applies a proximal update to a parameter group."""
        for p in group['params']:
            if p.grad is None:
                continue
            state = self.state[p]
            p.data.add_(state['proximal'], alpha=-group['lr'])
        
    def step(self, closure=None):
        """Performs a single optimization step.
        
        Parameters
        ----------
        closure : bool
            A closure that reevaluates the model and returns the loss.
        """
        # set the initial (global) weights and proximal term before we update the client optimizer
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                state = self.state[p]
                if 'initial_weights' not in state:
                    state['initial_weights'] = torch.clone(p.data).detach()
                state['proximal'] = self.mu * (p.data - state['initial_weights'])

        loss = self.optimizer.step(closure=closure)
        for group in self.param_groups:
            self._update(group)
        return loss
    
    
class FedDaneSolver(Optimizer):
    """Implements FedDane local solver.
    
    Args:
        optimizer (torch.optim.Optimizer): local optimizer.
        mu (float): proximal term weight (default: 0)
    
        
    Example:
        >>> # train a model locally for a client
        >>> client_optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
        >>> client_optimizer = FedDaneLocal(client_optimizer, average_gradients, mu=0.1)
        >>> client_optimizer.zero_grad()
        >>> loss_fn(model(input), target).backward()
        >>> client_optimizer.step()
    
    __ https://arxiv.org/pdf/2001.01920.pdf
    """
    
    def __init__(self, 
                 optimizer,
                 average_gradients,
                 mu=0):
        if mu < 0.0:
            raise ValueError(f'Invalid mu value: {mu}')
        self.optimizer = optimizer
        self.average_gradients = average_gradients
        self.mu = mu
        self.param_groups = self.optimizer.param_groups
        self.state = self.optimizer.state
    
    def _update(self, group):
        """Applies a proximal update to a parameter group."""
        for p in group['params']:
            if p.grad is None:
                continue
            state = self.state[p]
            p.data.add_(state['proximal'] + state['grad_delta'], alpha=-group['lr'])
        
    def step(self, closure=None):
        """Performs a single optimization step.
        
        Parameters
        ----------
        closure : bool
            A closure that reevaluates the model and returns the loss.
        """
        # set the initial (global) weights and proximal term before we update the client optimizer
        for group in self.param_groups:
            for i, p in enumerate(group['params']):
                if p.grad is None:
                    continue
                state = self.state[p]
                if 'initial_weights' not in state:
                    state['initial_weights'] = torch.clone(p.data).detach()
                state['proximal'] = self.mu * (p.data - state['initial_weights'])
                if 'average_gradient' not in state:
                    state['average_gradient'] = torch.clone(self.average_gradients[i]).detach()
                state['grad_delta'] = state['average_gradient'] - p.grad.data

        loss = self.optimizer.step(closure=closure)
        for group in self.param_groups:
            self._update(group)
        return loss
    
    
class FedNovaSolver(Optimizer):
    """Implements FedNova local solver.
    
    Args:
        optimizer (torch.optim.Optimizer): local optimizer.
        mu (float): proximal term weight (default: 0)

    __ https://arxiv.org/pdf/2007.07481.pdf
        
    Example:
        >>> # train a model locally for a client
        >>> client_optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
        >>> client_optimizer = FedNovaLocal(client_optimizer, mu=0.1)
        >>> client_optimizer.zero_grad()
        >>> loss_fn(model(input), target).backward()
        >>> client_optimizer.step()
    """
    
    def __init__(self, 
                 optimizer,
                 mu=0):
        if mu < 0.0:
            raise ValueError(f'Invalid mu value: {mu}')
        self.optimizer = optimizer
        self.mu = mu
        self.param_groups = self.optimizer.param_groups
        self.state = self.optimizer.state
    
    def _update(self, group):
        """Applies a proximal update to a parameter group."""
        for p in group['params']:
            if p.grad is None:
                continue
            state = self.state[p]
            p.data.add_(state['proximal'], alpha=-group['lr'])
        
    def step(self, closure=None):
        """Performs a single optimization step.
        
        Parameters
        ----------
        closure : bool
            A closure that reevaluates the model and returns the loss.
        """
        # set the initial (global) weights and proximal term before we update the client optimizer
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                state = self.state[p]
                if 'initial_weights' not in state:
                    state['initial_weights'] = torch.clone(p.data).detach()
                state['proximal'] = self.mu * (p.data - state['initial_weights'])

        # update the weights and gradient with the client optimizer
        loss = self.optimizer.step(closure=closure)
        
        # update the weights by adding the (negative) proximal term
        for group in self.param_groups:
            self._update(group)
        
        # accumualte gradients after calculating loss
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                
                momentum = group.get('momentum', 0)
                
                state = self.state[p]
                if 'local_step' not in state:
                    state['local_step'] = 0
                state['local_step'] += 1
                
                # momentum (1 - p^t) / (1 - p)
                a = (1 - momentum ** state['local_step']) / (1 - momentum)
                # proximal (1 - lr * mu)^t
                a *= (1 - group['lr'] * self.mu) ** (state['local_step']-1)
                # record the norm factor (a) to divide the l1-norm during aggregation
                if 'norm_factor' not in state:
                    state['norm_factor'] = []
                state['norm_factor'].append(a)
                
                if 'cgrad' not in state:
                    state['cgrad'] = torch.clone(p.grad.data).detach()
                    state['cgrad'].mul_(group['lr']) # do we need the lr ?
                    state['cgrad'].mul_(a) # G * a
                else:
                    state['cgrad'].add_(p.grad.data, alpha=group['lr'])
                    state['cgrad'].mul_(a) # G * a
                    
        return loss
    
    
class SCAFFOLDSolver(Optimizer):
    """Implements SCAFFOLD local solver.
    
    Args:
        optimizer (torch.optim.Optimizer): local optimizer.

    __ https://arxiv.org/pdf/1910.06378.pdf
        
    Example:
        >>> # train a model locally for a client
        >>> client_optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
        >>> client_optimizer = SCAFFOLDSolver(client_optimizer, mu=0.1)
        >>> client_optimizer.zero_grad()
        >>> loss_fn(model(input), target).backward()
        >>> client_optimizer.step()
    """
    
    def __init__(self, optimizer, control_global, control_local):
        if isinstance(control_global, torch.Tensor):
            raise TypeError("control_global argument given to the optimizer should be "
                            "an iterable of Tensors or lists, but got " +
                            torch.typename(control_global))
        if isinstance(control_global[0], torch.Tensor):
            control_global = [control_global]
        if isinstance(control_local, torch.Tensor):
            raise TypeError("control_local argument given to the optimizer should be "
                            "an iterable of Tensors or lists, but got " +
                            torch.typename(control_local))
        if isinstance(control_local[0], torch.Tensor):
            control_local = [control_local]
            
        self.optimizer = optimizer
        self.control_global = control_global
        self.control_local = control_local
        self.param_groups = self.optimizer.param_groups
        self.state = self.optimizer.state
        
    def step(self, closure=None):
        """Performs a single optimization step.
        
        Parameters
        ----------
        closure : bool
            A closure that reevaluates the model and returns the loss.
        """
        # update the weights and gradient with the client optimizer
        loss = self.optimizer.step(closure=closure)
        
        # (3) update the weights by adding the control variate correction term
        for group, group_global, group_local in zip(self.param_groups, self.control_global, self.control_local):
            for p, c, ci in zip(group, group_global, group_local):
                if p.grad is None:
                    continue
                p.data.add_(c - ci, -group['lr'])
                    
        return loss

In [47]:

# servers/base.py
import time
from copy import deepcopy
from collections import defaultdict
from torch.utils.tensorboard.writer import SummaryWriter


class BaseFederater:
    """Base Federater.
    
    Parameters
    ----------
    model : nn.Module
    dataset : torch.utils.data.Dataset
    num_clients (K) : int (default=100)
        Number of clients to partition `dataset`.
    batch_size (B) : int, dict[str, int] (defualt=32)
        Number of samples per batch to load on each client. 
        Can be a dictionary mapping each client ID to it's corresponding batch size
        to allow for various batch sizes across clients.
    shard_size : int (default=300)
    is_iid : bool (default=False)
    drop_last : bool (default=True)
    num_workers : int (default=0)
    device : str (default='cpu')
    """
    def __init__(self, 
                 model,
                 clients,
                 client_optimizer_cls,
                 client_optimizer_params,
                 server_optimizer=None,
                 client_scheduler_cls=None,
                 client_scheduler_params=None,
                 server_scheduler=None,
                 seed=None,
                 writer=None):
        self.model = model
        self.clients = clients
        self.client_optimizer_cls = client_optimizer_cls
        self.client_optimizer_params = client_optimizer_params
        self.server_optimizer = server_optimizer
        self.client_scheduler_cls = client_scheduler_cls
        self.client_scheduler_params = client_scheduler_params
        self.server_scheduler = server_scheduler
        self.writer = writer or SummaryWriter()
        
        self.client_ids = list(self.clients.keys())
        self.num_clients = len(self.clients)
        self.num_samples = sum([len(c) for c in self.clients.values()]) # n
        self.client_weights = [len(c) / self.num_samples for c in self.clients.values()]
        
        self.device = next(self.model.parameters()).device
        self._global_round = 0
        self._random_state = np.random.RandomState(seed)
        
    @property
    def global_round(self):
        return self._global_round

    @global_round.setter
    def global_round(self, global_round):
        self._global_round = global_round
    
    def aggregate(self):
        raise NotImplementedError
        
    def update(self, 
               client_ids, 
               criterion,
               num_epochs, 
               straggler_rate=0):
        """Performs a full communication round.
        
        Parameters
        ----------
        client_ids (S_t): list, np.ndarray
            List of client ID's to train.
        criterion : nn.Module
            Loss function to optimize on each client.
        num_epochs (E): int
            Number of epochs to train on each client.
            
        Returns
        -------
        metrics_dict : dict
            Dictionary mapping each metric to the average score across `client_ids`
        """
        # send the global model parameters to each client
        self.send_model()
        
        metrics_dict = defaultdict(lambda: 0)
        for k in client_ids:
            start_time = time.time()
            # instantiate client optimizer and scheduler
            client = self.clients[k]
            client.optimizer = self.get_client_optimizer(client)
            client.scheduler = self.get_client_scheduler(client.optimizer)
            
            # for heterogeneity experiments we can train clients for varying epochs (stragglers)
            if self._random_state.random() < straggler_rate:
                client_epochs = self._random_state.choice(range(1, num_epochs+1))
            else:
                client_epochs = num_epochs
                
            # update the client weights and record the local training metrics
            client_metrics_dict = client.update(
                criterion,
                num_epochs=client_epochs,
            )
            
            # update the summary writer and record loss/acc from the client
            elapsed_time = time.time() - start_time
            self.writer.add_scalar(f'clients/{k}/elapsed_time', elapsed_time, self.global_round)
            for metric, values in client_metrics_dict.items():
                for i, value in enumerate(values):
                    self.writer.add_scalar(f'client/{k}/round_{self.global_round}/{metric}', 
                                           value,
                                           self.global_round)
                metrics_dict[metric] += values[-1] / len(client_ids)
        
        # aggregate the parameters of the local solvers
        self.aggregate()
        if self.server_scheduler is not None:
            self.server_scheduler.step()
        
        return metrics_dict
        
    def fit(self, 
            num_rounds,
            criterion, 
            num_epochs,
            val_dl=None,
            C=0.1,
            straggler_rate=0,
            eval_every_n=1):
        """Train loop."""
        start_time = time.time()
        # subset a sample of `m` clients each round
        m = max(int(np.ceil(self.num_clients * C)), 1)
        for t in range(num_rounds):
            self.global_round += 1
            
            # update a subset of clients with the local solver
            S = self._random_state.choice(self.client_ids, m, replace=False)
            train_metrics = self.update(client_ids=S, 
                                        criterion=criterion,
                                        num_epochs=num_epochs, 
                                        straggler_rate=straggler_rate)
            
            # log train summary metrics
            elapsed_time = round(time.time() - start_time)
            self.writer.add_scalar('train/elapsed_time', elapsed_time, self.global_round)
            template_str = f'round {self.global_round} - {elapsed_time}s'
            for metric, value in train_metrics.items():
                self.writer.add_scalar(f'train/{metric}', value, self.global_round)
                template_str += f' - train_{metric} : {value:0.4f}'
                
            # log validation summary metrics
            if eval_every_n is not None and t % eval_every_n == 0:
                val_metrics = self.validate(criterion)
                for metric, value in val_metrics.items():
                    self.writer.add_scalar(f'val/{metric}', value, self.global_round)
                    template_str += f' - val_{metric} : {value:0.4f}'
                
            print(template_str)
    
    def get_client_optimizer(self, client):
        """Returns a client optimizer (local solver).
        
        Parameters
        ----------
        params : iterable
            Client parameters to optimize
        optimizer_params : dict
            Client optimizer hyperparameters
        
        Returns
        -------
        torch.optim.Optimizer
        """
        optimizer_params = self.client_optimizer_params or {}
        return self.client_optimizer_cls(client.model.parameters(), **optimizer_params)
    
    def get_client_scheduler(self, optimizer):
        """Returns a LR scheduler for a client optimizer.
        
        Parameters
        ----------
        optimizer : torch.optim.Optimizer
            Client optimizer
        
        Returns
        -------
        torch.optim.lr_scheduler._LRScheduler or None
            Client LR scheduler, or None if not specified
        """
        if self.client_scheduler_cls is not None:
            scheduler_params = self.client_scheduler_params or {}
            return self.client_scheduler_cls(optimizer, **scheduler_params)
        else:
            return None
        
    def validate(self, criterion, client_ids=None):
        eval_metrics = defaultdict(lambda: 0)
        
        # send server model to each client for validation
        client_ids = client_ids or self.client_ids
        self.send_model(client_ids)
        
        # validate on each client
        for client_id in client_ids:
            client = self.clients[client_id]
            client_metrics = client.validate(criterion=criterion)
            for metric, value in client_metrics.items():
                eval_metrics[metric] += value / len(client_ids)
        return eval_metrics
    
#     def validate(self, val_dl, criterion):
# #         self.model.to(self.device)
#         self.model.eval()
#         loss = 0 
#         correct = 0
#         with torch.no_grad():
#             for x, y in val_dl:
#                 x = x.to(self.device)
#                 y = y.to(self.device)
#                 logits = self.model(x)
#                 correct += (logits.argmax(-1) == y).sum().item()
#                 loss += criterion(logits, y).item()
                
#         results = {
#             'loss': loss / len(val_dl),
#             'accuracy': correct / len(val_dl.dataset)
#         }
#         return results
    
    def send_model(self, client_ids=None):
        """Send the current state of the global model to each client."""
        if client_ids is None:
            client_ids = self.client_ids
        for client_id in client_ids:
            self.clients[client_id].model = deepcopy(self.model)
            
    def get_gradients(self, client_ids, criterion):
        self.send_model(client_ids)
        grads = []
        for k, client_id in enumerate(client_ids):
            client = self.clients[client_id]
            client_grads = client.get_gradients(criterion)
            grads.append(client_grads)
        return grads

In [48]:
class FedAvg(BaseFederater):
    """Federated Averaging (FedAvg)
    
    https://arxiv.org/pdf/1602.05629.pdf
    """
    def __init__(self, 
                 model,
                 clients,
                 client_optimizer_cls,
                 client_optimizer_params,
                 client_scheduler_cls=None,
                 client_scheduler_params=None,
                 seed=None,
                 writer=None):
        super().__init__(model,
                         clients,
                         client_optimizer_cls=client_optimizer_cls,
                         client_optimizer_params=client_optimizer_params,
                         client_scheduler_cls=client_scheduler_cls,
                         client_scheduler_params=client_scheduler_params,
                         seed=seed,
                         writer=writer)
    
    def aggregate(self):
        global_state = {} # self.model.state_dict()
        for k, (client_id, client) in enumerate(self.clients.items()):
            local_state = client.model.state_dict()
            for layer_name, param in local_state.items():
                if k == 0:
                    global_state[layer_name] = self.client_weights[k] * param
                else:
                    global_state[layer_name] += self.client_weights[k] * param

        self.model.load_state_dict(global_state)
                    

class FedProx(BaseFederater):
    """FedProx
    
    https://arxiv.org/pdf/1812.06127.pdf
    """
    def __init__(self, 
                 model,
                 clients,
                 client_optimizer_cls,
                 client_optimizer_params,
                 server_optimizer,
                 mu=0,
                 client_scheduler_cls=None,
                 client_scheduler_params=None,
                 server_scheduler=None,
                 seed=None,
                 writer=None):
        super().__init__(model,
                         clients,
                         client_optimizer_cls,
                         client_optimizer_params,
                         server_optimizer=server_optimizer,
                         server_scheduler=server_scheduler,
                         client_scheduler_cls=client_scheduler_cls,
                         client_scheduler_params=client_scheduler_params,
                         seed=seed,
                         writer=writer)
        self.mu = mu
            
    def get_client_optimizer(self, client):
        optimizer_params = self.client_optimizer_params or {}
        client_optimizer = self.client_optimizer_cls(client.model.parameters(), **optimizer_params)
        return FedProxSolver(client_optimizer, mu=self.mu)
    
    def aggregate(self):
        self.server_optimizer.zero_grad()
        for k, client in enumerate(self.clients.values()):
            for p_server, p_client in zip(self.model.parameters(), client.model.parameters()):
                if p_server.requires_grad:
                    if k == 0:
                        p_server.grad = self.client_weights[k] * (p_server.data - p_client.data)
                    else:
                        p_server.grad.data.add_(p_server.data - p_client.data, alpha=self.client_weights[k])
        
        self.server_optimizer.step()

    
class FedOpt(BaseFederater):
    
    def __init__(self,
                 model,
                 clients,
                 client_optimizer_cls,
                 client_optimizer_params,
                 server_optimizer,
                 server_scheduler=None,
                 client_scheduler_cls=None,
                 client_scheduler_params=None,
                 seed=None,
                 writer=None):
        super().__init__(model,
                         clients,
                         client_optimizer_cls,
                         client_optimizer_params,
                         server_optimizer=server_optimizer,
                         server_scheduler=server_scheduler,
                         client_scheduler_cls=client_scheduler_cls,
                         client_scheduler_params=client_scheduler_params,
                         seed=seed,
                         writer=writer)
    
    def aggregate(self):
        """
        
        """
        self.server_optimizer.zero_grad()
        # iterate through each client
        for k, client in enumerate(self.clients.values()):
            for p_server, p_client in zip(self.model.parameters(), client.model.parameters()):
                if p_server.requires_grad:
                    if k == 0:
                        p_server.grad = client_weights[k] * (p_server.data - p_client.data)
                    else:
                        p_server.grad.add_(p_server.data - p_client.data, alpha=self.client_weights[k])
        
        self.server_optimizer.step()
        
        
class FedNova(BaseFederater):
    """FedNova
    
    https://arxiv.org/pdf/2007.07481.pdf
    """
    def __init__(self,
                 model,
                 clients,
                 server_optimizer,
                 client_optimizer_cls,
                 client_optimizer_params,
                 mu=0,
                 server_scheduler=None,
                 client_scheduler_cls=None,
                 client_scheduler_params=None,
                 seed=None,
                 writer=None):
        super().__init__(model,
                         clients,
                         client_optimizer_cls=client_optimizer_cls,
                         client_optimizer_params=client_optimizer_params,
                         server_optimizer=server_optimizer,
                         server_scheduler=server_scheduler,
                         client_scheduler_cls=client_scheduler_cls,
                         client_scheduler_params=client_scheduler_params,
                         seed=seed,
                         writer=writer)
        self.mu = mu
        
    def get_client_optimizer(self, client):
        optimizer_params = self.client_optimizer_params or {}
        client_optimizer = self.client_optimizer_cls(client.model.parameters(), **optimizer_params)
        client_optimizer = FedNovaSolver(client_optimizer, mu=self.mu)
        return client_optimizer
    
    def aggregate(self):
        """ """
        
        self.server_optimizer.zero_grad()
        # iterate through each client and set gradients
        for k, client in enumerate(self.clients.values()):
            # skip clients with no optimizer
            # we may want to use the weights of the local model instead
            if client.optimizer is None:
                continue
            for group_server, group_client in zip(self.server_optimizer.param_groups, 
                                                  client.optimizer.param_groups):
                for p_server, p_client in zip(group_server['params'], group_client['params']):
                    if p_server.requires_grad:
                        state = client.optimizer.state[p_client]
                        w = self.client_weights[k]
                        G_a = state['cgrad']
                        a = torch.tensor(state['norm_factor'])
                        d = G_a / a.abs().sum()
                        tau_eff = client.local_steps
                        if p_server.grad is None:
                            p_server.grad = tau_eff * w * d  # need to take lr off of G ? jk lr is necessary for client (local)
                        else:
                            p_server.grad.data.add_(d, alpha=tau_eff * w)

        self.server_optimizer.step()
        
        
class FedDane(BaseFederater):
    def __init__(self, 
                 model,
                 clients,
                 client_optimizer_cls,
                 client_optimizer_params,
                 server_optimizer,
                 mu=0,
                 client_scheduler_cls=None,
                 client_scheduler_params=None,
                 server_scheduler=None,
                 seed=None,
                 writer=None):
        super().__init__(model,
                         clients,
                         client_optimizer_cls,
                         client_optimizer_params,
                         server_optimizer=server_optimizer,
                         server_scheduler=server_scheduler,
                         client_scheduler_cls=client_scheduler_cls,
                         client_scheduler_params=client_scheduler_params,
                         seed=seed,
                         writer=writer)
        self.mu = mu
        self.average_gradients = None
            
    def get_client_optimizer(self, client):
        optimizer_params = self.client_optimizer_params or {}
        client_optimizer = self.client_optimizer_cls(client.model.parameters(), **optimizer_params)
        client_optimizer = FedDaneSolver(client_optimizer, 
                                         average_gradients=self.average_gradients,
                                         mu=self.mu)
        return client_optimizer
    
    def fit(self, 
            num_rounds,
            criterion, 
            num_epochs,
            val_dl=None,
            C=0.1,
            straggler_rate=0,
            eval_every_n=1):
        # subset a sample of `m` clients each round
        m = max(int(np.ceil(self.num_clients * C)), 1)
        
        for t in range(num_rounds):
            self.global_round += 1
            
            # calculate the average gradient on a subset of clients
            S_grad = self._random_state.choice(self.client_ids, m, replace=False)
            self.set_average_gradients(S_grad, criterion)
            
            # update a subset of clients with the local solver
            S = self._random_state.choice(self.client_ids, m, replace=False)
            train_metrics = self.update(client_ids=S, 
                                        criterion=criterion,
                                        num_epochs=num_epochs, 
                                        straggler_rate=straggler_rate)
            
            if eval_every_n is not None and t % eval_every_n == 0:# and val_dl is not None:
                template_str = f'round {self.global_round}'
                val_metrics = self.validate(val_dl, criterion)
                for metric, value in train_metrics.items():
                    self.writer.add_scalar(f'train/{metric}', value, self.global_round)
                    template_str += f' - train_{metric} : {value:0.4f}'
                for metric, value in val_metrics.items():
                    self.writer.add_scalar(f'val/{metric}', value, self.global_round)
                    template_str += f' - val_{metric} : {value:0.4f}'
                
                print(template_str)
    
    def aggregate(self):
        self.server_optimizer.zero_grad()
        for k, client in enumerate(self.clients.values()):
            for p_server, p_client in zip(self.model.parameters(), client.model.parameters()):
                if p_server.requires_grad:
                    if k == 0:
                        p_server.grad = self.client_weights[k] * (p_server.data - p_client.data)
                    else:
                        p_server.grad.add_(p_server.data - p_client.data, alpha=self.client_weights[k])
        
        self.server_optimizer.step()
        
    def set_average_gradients(self, client_ids, criterion):
        grads = self.get_gradients(client_ids, criterion)
        average_gradients = [0] * len(grads[0])
        for client_grads in grads:
            for i, g in enumerate(client_grads):
                average_gradients[i] += g
        self.average_gradients = [g / len(grads) for g in average_gradients]
        
        
        
# scaffold.py
class SCAFFOLD(BaseFederater):
    """SCAFFOLD
    
    https://arxiv.org/pdf/1910.06378.pdf
    """
    def __init__(self, 
                 model,
                 clients,
                 client_optimizer_cls,
                 client_optimizer_params,
                 server_optimizer,
                 option='II',
                 client_scheduler_cls=None,
                 client_scheduler_params=None,
                 server_scheduler=None,
                 seed=None,
                 writer=None):
        super().__init__(model,
                         clients,
                         client_optimizer_cls,
                         client_optimizer_params,
                         server_optimizer=server_optimizer,
                         server_scheduler=server_scheduler,
                         client_scheduler_cls=client_scheduler_cls,
                         client_scheduler_params=client_scheduler_params,
                         seed=seed,
                         writer=writer)
        self.option = option
        self.control_server = [torch.zeros_like(p.data) for p in model.parameters()]

    def send_model(self, client_ids=None):
        """Send the current state of the global model to each client."""
        if client_ids is None:
            client_ids = self.client_ids
        for client_id in client_ids:
            self.clients[client_id].model = deepcopy(self.model)
            self.clients[client_id].control_server = self.control_server
            
    def get_client_optimizer(self, client):
        optimizer_params = self.client_optimizer_params or {}
        client_optimizer = self.client_optimizer_cls(client.model.parameters(), **optimizer_params)
        client_optimizer = SCAFFOLDSolver(client_optimizer, 
                                          control_global=self.control_server, 
                                          control_local=client.control)
        return client_optimizer
    
    def aggregate(self):
        # (5) update global parameters
        self.server_optimizer.zero_grad()
        for k, client in enumerate(self.clients.values()):
            for p_server, p_client, c_client in zip(self.model.parameters(), client.model.parameters(), client.control):
                if p_server.requires_grad:
                    if k == 0:
                        p_server.grad = self.client_weights[k] * (p_server.data - p_client.data)
                    else:
                        p_server.grad.data.add_(p_server.data - p_client.data, alpha=self.client_weights[k])
        
        self.server_optimizer.step()
        
        # (5) update global control variate
        for client in self.clients.values():
            for c, ci_delta in zip(self.control_server, client.control_delta):
                c.data.add_(ci_delta, 1/self.num_clients)

In [56]:
class Config:
    
    fed_avg = {
        'mnist': {
            'clients': {
                'num_clients': 100,
                'shard_size': 300,
                'batch_size': 10,
                'is_iid': False,
            },
            'client_optimizer': 'SGD',
            'client_optimizer_params': {
                'lr': 0.1,
            },
            'federater': {
                'C': 0.1
            },
            'fit': {
                'num_rounds': 240,
                'num_epochs': 20
            }
        }
    }
    fed_prox = {
        'mnist': {
            'clients': {
                'num_clients': 100,
                'shard_size': 300,
                'batch_size': 10,
                'is_iid': False,
            },
            'client_optimizer': 'SGD',
            'client_optimizer_params': {
                'lr': 0.03,
            },
            'federater': {
                'C': 0.1,
                'mu': 0.1
            },
            'fit': {
                'num_rounds': 1,
                'num_epochs': 1,
#                 'straggler_rate': 0.5,
            }
        }
    }

# config = Config()
config = {
    'fedavg': {
        'mnist': {
            'clients': {
                'num_clients': 100,
                'shard_size': 300,
                'batch_size': 10,
                'is_iid': True,
            },
            'client_optimizer': 'SGD',
            'client_optimizer_params': {
                'lr': 0.1,
            },
            'federater': {
                'C': 0.1
            },
            'fit': {
                'num_rounds': 50,
                'num_epochs': 20
            }
        }
    },
    'fedprox': {
        'femnist': {
            'input': {
                'train': '/workspace/leaf/FedProx/data/nist/data/train/train.json',
                'test': '/workspace/leaf/FedProx/data/nist/data/test/test.json',
            },
            'data': {
                'batch_size': 10,
                'num_workers': 0,
            },
            'client': {
                'device': 'cpu'
            },
            'model': {
                'name': 'lr',
                'params': {
                    'in_features': 784,
                    'num_classes': 10
                }
            },
            'server_optimizer': 'SGD',
            'server_optimizer_params': {
                'lr': 1
            },
            'client_optimizer': 'SGD',
            'client_optimizer_params': {
                'lr': 0.003,
            },
            'federater': {
                'mu': 0.1
            },
            'fit': {
                'num_rounds': 100,
                'num_epochs': 20,
                'C': 0.1,
#                 'straggler_rate': 0.5,
            }
        },
        'mnist': {
            'clients': {
                'num_clients': 100,
                'shard_size': 300,
                'batch_size': 10,
                'is_iid': False,
            },
            'server_optimizer': 'SGD',
            'server_optimizer_params': {
                'lr': 1
            },
            'client_optimizer': 'SGD',
            'client_optimizer_params': {
                'lr': 0.03,
            },
            'federater': {
                'C': 0.1,
                'mu': 0.1
            },
            'fit': {
                'num_rounds': 100,
                'num_epochs': 20,
#                 'straggler_rate': 0.5,
            }
        }
    },
    'fedadam': {
        'mnist': {
            'clients': {
                'num_clients': 100,
                'shard_size': 300,
                'batch_size': 10,
                'is_iid': False,
            },
            'client_optimizer': 'SGD',
            'client_optimizer_params': {
                'lr': 0.01,
            },
            'server_optimizer': 'adam',
            'server_optimizer_params': {
                'lr': 1,
            },
            'federater': {
                'C': 0.1
            },
            'fit': {
                'num_rounds': 240,
                'num_epochs': 20
            }
        }
    },
}

In [57]:
# mnist = datasets.fetch_openml('mnist_784', data_home='tmp')
# from tqdm import trange
# import numpy as np
# import random
# import json
# import os

# mu = np.mean(mnist.data.astype(np.float32), 0)
# sigma = np.std(mnist.data.astype(np.float32), 0)
# mnist.data = (mnist.data.astype(np.float32) - mu)/(sigma+0.001)
# mnist.target = mnist.target.astype(np.int32)
# mnist_data = []
# for i in trange(10):
#     idx = mnist.target==i
#     mnist_data.append(mnist.data[idx])

# print([len(v) for v in mnist_data])

# ###### CREATE USER DATA SPLIT #######
# # Assign 10 samples to each user
# X = [[] for _ in range(1000)]
# y = [[] for _ in range(1000)]
# idx = np.zeros(10, dtype=np.int64)
# for user in range(1000):
#     for j in range(2):
#         l = (user+j)%10
#         X[user] += mnist_data[l][idx[l]:idx[l]+5].tolist()
#         y[user] += (l*np.ones(5)).tolist()
#         idx[l] += 5
# print(idx)

In [58]:
import datetime


def create_experiment_name(method, dataset, params=None):
    params = params or {}
    experiment_name = f'{method}_{dataset}'
    if params.get('clients'):
        if params['clients'].get('is_iid'):
            if params['clients']['is_iid']:
                experiment_name += "_iid"
            else:
                experiment_name += "_noniid"
        if params['clients'].get('num_clients'):
            experiment_name += f"_K={params['clients']['num_clients']}"
        if params['clients'].get('batch_size'):
            experiment_name += f"_B={params['clients']['batch_size']}"
    if params.get('fit'):
        if params['fit'].get('num_rounds'):
            experiment_name += f"_T={params['fit']['num_rounds']}"
        if params['fit'].get('num_epochs'):
            experiment_name += f"_E={params['fit']['num_epochs']}"
    if params.get('server_optimizer'):
            experiment_name += f"_SOPT={params['server_optimizer']}"
    if params.get('client_optimizer'):
        experiment_name += f"_COPT={params['client_optimizer']}"
    
    experiment_name += f"_{datetime.datetime.now().strftime('%Y%m%d_%H%M')}"
    return experiment_name

In [59]:
def load_data(train_path, test_path=None):
    with open(train_path, 'r') as f:
        cdata = json.load(f)
    train_data = cdata['user_data']
    if test_path is not None:
        with open(test_path, 'r') as f:
            cdata = json.load(f)
        test_data = cdata['user_data']
    else:
        test_data = None
    return train_data, test_data

In [63]:
# Fedprox baseline
method = 'fedprox'
dataset = 'femnist'
seed = 42069
device = 'cpu'
num_workers = 0

experiment_config = config[method][dataset]

# load train, test data
train_data, test_data = load_data(experiment_config['input']['train'], experiment_config['input'].get('test'))

# setup clients
client_params = config[method][dataset]['client']
data_params = experiment_config.get('data', {})
data_params['num_workers'] = num_workers
clients = get_clients(
    train_data,
    test_data=test_data,
    dataloader_params=data_params,
    client_params=client_params
)
client_ids = list(clients.keys())

# model
model_cls = MODEL_MAP[experiment_config['model']['name']]
model_params = experiment_config['model'].get('params', {})
model = model_cls(**model_params)
# model = MODEL_MAP[experiment_config['model']['name']](**experiment_config['model'].get('params', {}))

# local and global optimizers
client_optimizer_cls = getattr(torch.optim, experiment_config['client_optimizer'])
client_optimizer_params = experiment_config['client_optimizer_params']
server_optimizer = getattr(torch.optim, experiment_config['server_optimizer'])
server_optimizer = server_optimizer(model.parameters(), **experiment_config['server_optimizer_params'])
criterion = nn.CrossEntropyLoss()

fed_params = experiment_config['federater']
fed_params['seed'] = seed
num_rounds = experiment_config['fit']['num_rounds']
num_epochs = experiment_config['fit']['num_epochs']

federater = FedProx(model,
                   clients=clients,
                    server_optimizer=server_optimizer,
                   client_optimizer_cls=client_optimizer_cls,
                   client_optimizer_params=client_optimizer_params,
                   **fed_params)

In [64]:
federater.fit(criterion=criterion, **experiment_config['fit'])

round 1 - 2s - train_loss : 0.9055 - train_accuracy : 0.6258 - val_loss : 2.3749 - val_accuracy : 0.0942
round 2 - 4s - train_loss : 0.9401 - train_accuracy : 0.6233 - val_loss : 2.3417 - val_accuracy : 0.0942
round 3 - 6s - train_loss : 0.9697 - train_accuracy : 0.5578 - val_loss : 2.3286 - val_accuracy : 0.0942
round 4 - 7s - train_loss : 0.9068 - train_accuracy : 0.6143 - val_loss : 2.3082 - val_accuracy : 0.0942
round 5 - 9s - train_loss : 0.9643 - train_accuracy : 0.5669 - val_loss : 2.2878 - val_accuracy : 0.0942
round 6 - 11s - train_loss : 0.9021 - train_accuracy : 0.6570 - val_loss : 2.2488 - val_accuracy : 0.1047
round 7 - 13s - train_loss : 0.8931 - train_accuracy : 0.6553 - val_loss : 2.2412 - val_accuracy : 0.1011
round 8 - 14s - train_loss : 0.9007 - train_accuracy : 0.6290 - val_loss : 2.2362 - val_accuracy : 0.1360
round 9 - 15s - train_loss : 0.9278 - train_accuracy : 0.6106 - val_loss : 2.2199 - val_accuracy : 0.1467
round 10 - 17s - train_loss : 0.8751 - train_accura

round 78 - 125s - train_loss : 0.6233 - train_accuracy : 0.8489 - val_loss : 1.7860 - val_accuracy : 0.3670
round 79 - 126s - train_loss : 0.6170 - train_accuracy : 0.8196 - val_loss : 1.7646 - val_accuracy : 0.4316
round 80 - 127s - train_loss : 0.5955 - train_accuracy : 0.8890 - val_loss : 1.7808 - val_accuracy : 0.3480
round 81 - 129s - train_loss : 0.6225 - train_accuracy : 0.8641 - val_loss : 1.7519 - val_accuracy : 0.3818
round 82 - 131s - train_loss : 0.5551 - train_accuracy : 0.8749 - val_loss : 1.7238 - val_accuracy : 0.5116
round 83 - 133s - train_loss : 0.6195 - train_accuracy : 0.8415 - val_loss : 1.7247 - val_accuracy : 0.4604
round 84 - 134s - train_loss : 0.5894 - train_accuracy : 0.8436 - val_loss : 1.7186 - val_accuracy : 0.4660
round 85 - 136s - train_loss : 0.5940 - train_accuracy : 0.8680 - val_loss : 1.7182 - val_accuracy : 0.4800
round 86 - 137s - train_loss : 0.6244 - train_accuracy : 0.8406 - val_loss : 1.7088 - val_accuracy : 0.5405
round 87 - 139s - train_loss

In [65]:
federater.fit(criterion=criterion, **experiment_config['fit'])

round 101 - 1s - train_loss : 0.5449 - train_accuracy : 0.8582 - val_loss : 1.6673 - val_accuracy : 0.4689
round 102 - 3s - train_loss : 0.5924 - train_accuracy : 0.8196 - val_loss : 1.6721 - val_accuracy : 0.4950
round 103 - 5s - train_loss : 0.5501 - train_accuracy : 0.8806 - val_loss : 1.6710 - val_accuracy : 0.4557
round 104 - 6s - train_loss : 0.6096 - train_accuracy : 0.8425 - val_loss : 1.6386 - val_accuracy : 0.5502
round 105 - 8s - train_loss : 0.6015 - train_accuracy : 0.8306 - val_loss : 1.6216 - val_accuracy : 0.6038
round 106 - 10s - train_loss : 0.5222 - train_accuracy : 0.8915 - val_loss : 1.6298 - val_accuracy : 0.5561
round 107 - 12s - train_loss : 0.5766 - train_accuracy : 0.8553 - val_loss : 1.6109 - val_accuracy : 0.6095
round 108 - 13s - train_loss : 0.5466 - train_accuracy : 0.8688 - val_loss : 1.6053 - val_accuracy : 0.6026
round 109 - 15s - train_loss : 0.5252 - train_accuracy : 0.8836 - val_loss : 1.6141 - val_accuracy : 0.5506
round 110 - 16s - train_loss : 0.

round 177 - 124s - train_loss : 0.4672 - train_accuracy : 0.8706 - val_loss : 1.4157 - val_accuracy : 0.6344
round 178 - 125s - train_loss : 0.4340 - train_accuracy : 0.9003 - val_loss : 1.4183 - val_accuracy : 0.6065
round 179 - 127s - train_loss : 0.4423 - train_accuracy : 0.9041 - val_loss : 1.4109 - val_accuracy : 0.6416
round 180 - 128s - train_loss : 0.4482 - train_accuracy : 0.8942 - val_loss : 1.4088 - val_accuracy : 0.6482
round 181 - 130s - train_loss : 0.4544 - train_accuracy : 0.9050 - val_loss : 1.4189 - val_accuracy : 0.6317
round 182 - 131s - train_loss : 0.4611 - train_accuracy : 0.8915 - val_loss : 1.4138 - val_accuracy : 0.6486
round 183 - 133s - train_loss : 0.4256 - train_accuracy : 0.9108 - val_loss : 1.4036 - val_accuracy : 0.6434
round 184 - 135s - train_loss : 0.4633 - train_accuracy : 0.8869 - val_loss : 1.4102 - val_accuracy : 0.6365
round 185 - 136s - train_loss : 0.4486 - train_accuracy : 0.9046 - val_loss : 1.4067 - val_accuracy : 0.6600
round 186 - 138s - 

In [None]:
federater.fit(num_rounds=num_rounds, criterion=criterion, num_epochs=num_epochs, val_dl=test_dl)

In [185]:
model = MODEL_MAP[experiment_config['model']['name']](**experiment_config['model'].get('params', {}))
client_optimizer_cls = getattr(torch.optim, experiment_config['client_optimizer'])
client_optimizer_params = experiment_config['client_optimizer_params']
server_optimizer = getattr(torch.optim, experiment_config['server_optimizer'])
server_optimizer = server_optimizer(model.parameters(), **experiment_config['server_optimizer_params'])
criterion = nn.CrossEntropyLoss()

fed_params = experiment_config['federater']
fed_params['seed'] = seed
num_rounds = experiment_config['fit']['num_rounds']
num_epochs = experiment_config['fit']['num_epochs']

federater = FedProx(model,
                   clients=clients,
                    server_optimizer=server_optimizer,
                   client_optimizer_cls=client_optimizer_cls,
                   client_optimizer_params=client_optimizer_params,
                   **fed_params)
federater.fit(num_rounds=num_rounds, criterion=criterion, num_epochs=num_epochs, val_dl=test_dl)c

__main__.LogisticRegression

In [None]:
model = CNN()
client_optimizer_cls = getattr(torch.optim, config[method][dataset]['client_optimizer'])
client_optimizer_params = config[method][dataset]['client_optimizer_params']
server_optimizer = getattr(torch.optim, config[method][dataset]['server_optimizer'])
server_optimizer = server_optimizer(model.parameters(), **config[method][dataset]['server_optimizer_params'])
criterion = nn.CrossEntropyLoss()

fed_params = config[method][dataset]['federater']
fed_params['seed'] = seed
num_rounds = config[method][dataset]['fit']['num_rounds']
num_epochs = config[method][dataset]['fit']['num_epochs']

federater = FedProx(model,
                   clients=clients,
                    server_optimizer=server_optimizer,
                   client_optimizer_cls=client_optimizer_cls,
                   client_optimizer_params=client_optimizer_params,
                   **fed_params)
federater.fit(num_rounds=num_rounds, criterion=criterion, num_epochs=num_epochs, val_dl=test_dl)c

In [132]:
# Fedprox baseline
method = 'fedprox'
dataset = 'femnist'
seed = 42069
device = 'cpu'
num_workers = 0

# experiment_name = create_experiment_name(method, dataset, config[method][dataset])
experiment_name = 'tmp_prox'
print(f'Experiment : {experiment_name}')

# writer = SummaryWriter(os.path.join('logs', experiment_name))
writer = SummaryWriter()

set_state(seed)
0/0
client_params = config[method][dataset]['clients']
clients = get_clients(
    train_ds, 
    num_workers=num_workers,
    seed=seed,
    device=device,
    **client_params,
)

model = CNN()
client_optimizer_cls = getattr(torch.optim, config[method][dataset]['client_optimizer'])
client_optimizer_params = config[method][dataset]['client_optimizer_params']
server_optimizer = getattr(torch.optim, config[method][dataset]['server_optimizer'])
server_optimizer = server_optimizer(model.parameters(), **config[method][dataset]['server_optimizer_params'])
criterion = nn.CrossEntropyLoss()

fed_params = config[method][dataset]['federater']
fed_params['seed'] = seed
num_rounds = config[method][dataset]['fit']['num_rounds']
num_epochs = config[method][dataset]['fit']['num_epochs']

federater = FedProx(model,
                   clients=clients,
                    server_optimizer=server_optimizer,
                   client_optimizer_cls=client_optimizer_cls,
                   client_optimizer_params=client_optimizer_params,
                   **fed_params)
federater.fit(num_rounds=num_rounds, criterion=criterion, num_epochs=num_epochs, val_dl=test_dl)

Experiment : tmp_prox


ZeroDivisionError: division by zero

In [132]:
# Fed avg baseline
method = 'fedavg'
dataset = 'mnist'
seed = 42069
device = 'cuda:0'
num_workers = 0

experiment_name = create_experiment_name(method, dataset, config[method][dataset])
experiment_name = 'tmp'
print(f'Experiment : {experiment_name}')

writer = SummaryWriter(os.path.join('logs', experiment_name))

set_state(seed)
client_params = config[method][dataset]['clients']
clients = get_clients(
    train_ds, 
    num_workers=num_workers,
    seed=seed,
    device=device,
    **client_params,
)

model = CNN()
# client optimizer
# client_optimizer_cls = torch.optim.SGD
# client_optimizer_params = {
#     'lr': 0.1,
# }
client_optimizer_cls = getattr(torch.optim, config[method][dataset]['client_optimizer'])
client_optimizer_params = config[method][dataset]['client_optimizer_params']
criterion = nn.CrossEntropyLoss()

fed_params = config[method][dataset]['federater']
fed_params['seed'] = seed
# fed_params = {
#     'seed': seed,
#     'C': 0.1,
# }
num_rounds = 1#config[method][dataset]['fit']['num_rounds']
num_epochs = 1#config[method][dataset]['fit']['num_epochs']

federater = FedAvg(model,
                    clients=clients,
                    client_optimizer_cls=client_optimizer_cls,
                    client_optimizer_params=client_optimizer_params,
                    **fed_params)
federater.fit(num_rounds=num_rounds, criterion=criterion, num_epochs=num_epochs, val_dl=test_dl)

Experiment : tmp


RuntimeError: Given groups=1, weight of size [32, 784, 5, 5], expected input[10, 1, 28, 28] to have 784 channels, but got 1 channels instead

In [128]:
# Fedprox baseline
method = 'fedprox'
dataset = 'mnist'
seed = 42069
device = 'cuda:0'
num_workers = 0

experiment_name = create_experiment_name(method, dataset, config[method][dataset])
experiment_name = 'tmp_prox'
print(f'Experiment : {experiment_name}')

writer = SummaryWriter(os.path.join('logs', experiment_name))

set_state(seed)
client_params = config[method][dataset]['clients']
clients = get_clients(
    train_ds, 
    num_workers=num_workers,
    seed=seed,
    device=device,
    **client_params,
)

model = CNN()
client_optimizer_cls = getattr(torch.optim, config[method][dataset]['client_optimizer'])
client_optimizer_params = config[method][dataset]['client_optimizer_params']
server_optimizer = getattr(torch.optim, config[method][dataset]['server_optimizer'])
server_optimizer = server_optimizer(model.parameters(), **config[method][dataset]['server_optimizer_params'])
criterion = nn.CrossEntropyLoss()

fed_params = config[method][dataset]['federater']
fed_params['seed'] = seed
num_rounds = config[method][dataset]['fit']['num_rounds']
num_epochs = config[method][dataset]['fit']['num_epochs']

federater = FedProx(model,
                   clients=clients,
                    server_optimizer=server_optimizer,
                   client_optimizer_cls=client_optimizer_cls,
                   client_optimizer_params=client_optimizer_params,
                   **fed_params)
federater.fit(num_rounds=num_rounds, criterion=criterion, num_epochs=num_epochs, val_dl=test_dl)

In [None]:
# Fed avg baseline
method = 'fedavg'
dataset = 'mnist'
seed = 42069
device = 'cuda:0'
num_workers = 0

experiment_name = create_experiment_name(method, dataset, config[method][dataset])
experiment_name = 'tmp'
print(f'Experiment : {experiment_name}')

writer = SummaryWriter(os.path.join('logs', experiment_name))

set_state(seed)
client_params = config[method][dataset]['clients']
clients = get_clients(
    train_ds, 
    num_workers=num_workers,
    seed=seed,
    device=device,
    **client_params,
)

model = CNN()
# client optimizer
# client_optimizer_cls = torch.optim.SGD
# client_optimizer_params = {
#     'lr': 0.1,
# }
client_optimizer_cls = getattr(torch.optim, config[method][dataset]['client_optimizer'])
client_optimizer_params = config[method][dataset]['client_optimizer_params']
criterion = nn.CrossEntropyLoss()

fed_params = config[method][dataset]['federater']
fed_params['seed'] = seed
# fed_params = {
#     'seed': seed,
#     'C': 0.1,
# }
num_rounds = 1#config[method][dataset]['fit']['num_rounds']
num_epochs = 1#config[method][dataset]['fit']['num_epochs']

federater = FedAvg(model,
                    clients=clients,
                    client_optimizer_cls=client_optimizer_cls,
                    client_optimizer_params=client_optimizer_params,
                    **fed_params)
federater.fit(num_rounds=num_rounds, criterion=criterion, num_epochs=num_epochs, val_dl=test_dl)

In [162]:
import emnist

In [178]:
import emnist
import numpy as np
from tqdm import trange
import random
import json
import os
import argparse
from os.path import dirname

similarity = 1
num_of_users = 100
samples_num = 20
dataset = 'balanced'
images, train_labels = emnist.extract_training_samples(dataset)  # TODO: add test samples
images = np.reshape(images, (images.shape[0], -1))
images = images.astype(np.float32)
train_labels = train_labels.astype(np.int)
num_of_labels = len(set(train_labels))

emnist_data = []
for i in range(min(train_labels), num_of_labels + min(train_labels)):
    idx = train_labels == i
    emnist_data.append(images[idx])

iid_samples = int(similarity * samples_num)
X = [[] for _ in range(num_of_users)]
y = [[] for _ in range(num_of_users)]
idx = np.zeros(num_of_labels, dtype=np.int64)

# create %similarity of iid data
for user in range(num_of_users):
    labels = np.random.randint(0, num_of_labels, iid_samples)
    for label in labels:
        X[user].append(emnist_data[label][idx[label]].tolist())
        y[user] += (label * np.ones(1)).tolist()
        idx[label] += 1

print(idx)

# fill remaining data
for user in range(num_of_users):
    label = user % num_of_labels
    X[user] += emnist_data[label][idx[label]:idx[label] + samples_num - iid_samples].tolist()
    y[user] += (label * np.ones(samples_num - iid_samples)).tolist()
    idx[label] += samples_num - iid_samples

print(idx)

train_data = {'users': [], 'user_data': {}, 'num_samples': []}
test_data = {'users': [], 'user_data': {}, 'num_samples': []}

for i in trange(num_of_users, ncols=120):
    uname = 'f_{0:05d}'.format(i)

    combined = list(zip(X[i], y[i]))
    random.shuffle(combined)
    X[i][:], y[i][:] = zip(*combined)
    num_samples = len(X[i])
    train_len = int(0.9 * num_samples)
    test_len = num_samples - train_len

    train_data['users'].append(uname)
    train_data['user_data'][uname] = {'x': X[i][:train_len], 'y': y[i][:train_len]}
    train_data['num_samples'].append(train_len)
    test_data['users'].append(uname)
    test_data['user_data'][uname] = {'x': X[i][train_len:], 'y': y[i][train_len:]}
    test_data['num_samples'].append(test_len)

100%|██████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 73403.99it/s]

[57 46 35 28 40 41 44 40 40 47 35 40 39 36 33 50 40 42 41 41 45 54 51 45
 41 50 35 32 49 43 44 42 39 55 46 46 53 38 41 50 28 39 46 46 46 45 36]
[57 46 35 28 40 41 44 40 40 47 35 40 39 36 33 50 40 42 41 41 45 54 51 45
 41 50 35 32 49 43 44 42 39 55 46 46 53 38 41 50 28 39 46 46 46 45 36]





In [186]:
np.unique(train_labels).shape

(47,)

In [181]:
len(train_data['user_data']['f_00000']['y'])

18

In [None]:
# Fedprox baseline
method = 'fedopt'
dataset = 'mnist'
seed = 42069
device = 'cuda:0'
num_workers = 0

experiment_name = create_experiment_name(method, dataset, config[method][dataset])
experiment_name = 'tmp'
print(f'Experiment : {experiment_name}')

writer = SummaryWriter(os.path.join('logs', experiment_name))

set_state(seed)
client_params = config[method][dataset]['clients']
clients = get_clients(
    train_ds, 
    num_workers=num_workers,
    seed=seed,
    device=device,
    **client_params,
)

model = CNN()
client_optimizer_cls = getattr(torch.optim, config[method][dataset]['client_optimizer'])
client_optimizer_params = config[method][dataset]['client_optimizer_params']
server_optimizer = getattr(torch.optim, config[method][dataset]['server_optimizer'])
server_optimizer = server_optimizer(model.parameters(), **config[method][dataset]['server_optimizer_params'])
criterion = nn.CrossEntropyLoss()

fed_params = config[method][dataset]['federater']
fed_params['seed'] = seed
num_rounds = config[method][dataset]['fit']['num_rounds']
num_epochs = config[method][dataset]['fit']['num_epochs']

federater = FedProx(model,
                   clients=clients,
                    server_optimizer=server_optimizer,
                   client_optimizer_cls=client_optimizer_cls,
                   client_optimizer_params=client_optimizer_params,
                   **fed_params)
federater.fit(num_rounds=num_rounds, criterion=criterion, num_epochs=num_epochs, val_dl=test_dl)