In [1024]:
################################################################################
from copy import deepcopy
import torch


class CentralServer:
    """Central Server"""
    def __init__(self, 
                 model, 
#                  dataloader,
                 save_every_n=None,
                 device='cpu'):
        self.model = model
#         self.dataloader = dataloader
        self.save_every_n = save_every_n
        self.device = device
        
    def aggregate(self):
        raise NotImplementedError
    
    def send_model(self):
        """Sends a copy of the global models parameters."""
        return deepcopy(self.model)
    
    def save(self, filepath):
        """Save model parameters to disk."""
        torch.save(self.model.state_dict(), filepath)
    
    def update(self, global_state):
        """Updates the parameters of the global model."""
        self.model.load_state_dict(global_state)
    
    def predict(self, dataloader):
        with torch.no_grad():
            pass
        
    def validate(self, dataloader, criterion):
        self.model.to(self.device)
        self.model.eval()
        loss = 0 
        correct = 0
        with torch.no_grad():
            for x, y in dataloader:
                x = x.to(self.device)
                y = y.to(self.device)
                logits = self.model(x)
                correct += (logits.argmax(-1) == y).sum().item()
                loss += criterion(logits, y).item()
        results = {
            'loss': loss / len(dataloader),
            'accuracy': correct / len(dataloader.dataset)
        }
        return results
        
        
class FedAvgServer(CentralServer):
    """FederatedAveraging (FedAvg) central server as proposed in "Communication-Efficient Learning 
    of Deep Networks from Decentralized Data" (https://arxiv.org/pdf/1602.05629.pdf)
    
    Parameters
    ----------
    model : 
    dataloader :
    device : str (default='cpu')
    """
    def __init__(self, model, dataloader=None, device='cpu'):
        super(FedAvgServer, self).__init__(model,
                                           dataloader,
                                           device=device)
        
    def aggregate(self, clients, client_weights):
        layer_names = self.model.state_dict().keys()
        global_state = {}
        for k, (client_id, client) in enumerate(clients.items()):
            local_state = client.model.state_dict()
            for layer_name in layer_names:
                if k == 0:
                    global_state[layer_name] = local_state[layer_name] * client_weights[k]
                else:
                    global_state[layer_name] += local_state[layer_name] * client_weights[k]

        self.update(global_state)
        

In [1180]:
# client.py


class Client:
    """Base client.
    
    Parameters
    ----------
    client_id : str
        Id of the client.
    dataloader : DataLoader
        Local dataset used for training on the client.
    device : str, torch.device (default='cpu')
        Device type.
    """
    def __init__(self, client_id, train_dl, val_dl=None, device='cpu'):
        self.client_id = client_id
        self.train_dl = train_dl
        self.val_dl = val_dl
        self.device = device
        self._model = None
        
    @property
    def model(self):
        return self._model
    
    @model.setter
    def model(self, model):
        self._model = model
    
    def __len__(self):
        return len(self.train_dl.dataset)
    
    def update(self, optimizer, criterion, num_epochs=1):
        """Algorithm 1 (ClientUpdate).
        
        Parameters
        ----------
        optim_cls : 
        optim_params :
        num_epochs (E) : int
            Number of epochs.
        criterion : 
        """
        self.model.train()
        self.model.to(self.device)
#         optim = optim_cls(self.model.parameters(), **optim_params)
        total_loss = 0
        total_correct = 0
        for i in range(num_epochs):
            for i, (x, y) in enumerate(self.train_dl):
                x = x.to(self.device)
                y = y.to(self.device)
                
                optimizer.zero_grad()
                
                logits = self.model(x)
                loss = criterion(logits, y)
                loss.backward()
                optimizer.step()
                
                total_loss += loss.item()
                total_correct += (logits.argmax(-1) == y).sum().item()
            
        self.model.to('cpu')
        results = {
            'loss': total_loss / len(self.train_dl),
            'accuracy': total_correct / len(self)
        }
        return results
        
    
class FedProxClient(Client):
    """Base client.
    
    Parameters
    ----------
    client_id : str
        Id of the client.
    dataloader : DataLoader
        Local dataset used for training on the client.
    device : str, torch.device (default='cpu')
        Device type.
    """
    def __init__(self, client_id, train_dl, **kwargs):
        super(FedProxClient, self).__init__(client_id, train_dl, **kwargs)
        
    def update(self, optimizer, criterion, num_epochs=1):
        """Algorithm 1 (ClientUpdate).
        
        Parameters
        ----------
        optimizer :
        num_epochs (E) : int
            Number of epochs.
        criterion : 
        """
        self.model.train()
        self.model.to(self.device)
#         optim = optim_cls(self.model.parameters(), **optim_params)
        total_loss = 0
        total_correct = 0
        total_xent_loss = 0
        total_weights_delta = 0
        for i in range(num_epochs):
            for i, (x, y) in enumerate(self.train_dl):
                x = x.to(self.device)
                y = y.to(self.device)
                
                optimizer.zero_grad()
                
                logits = self.model(x)
                loss, xent_loss, weights_delta = criterion(logits, y, self.model.state_dict())
                loss.backward()
                optimizer.step()
                
                total_loss += loss.item()
                total_correct += (logits.argmax(-1) == y).sum().item()
                total_xent_loss += xent_loss.item()
                total_weights_delta += weights_delta.item()
            
        self.model.to('cpu')
        results = {
            'loss': total_loss / len(self.train_dl),
            'accuracy': total_correct / len(self),
            'xent_loss': total_xent_loss / len(self.train_dl),
            'weights_delta': weights_delta / len(self.train_dl),
#             'mean_grad': 
        }
        return results

# class FedAvgClient(Client):
#     """FederatedAveraging (FedAvg) central server as proposed in "Communication-Efficient Learning 
#     of Deep Networks from Decentralized Data" (https://arxiv.org/pdf/1602.05629.pdf)
    
#     Parameters
#     ----------
#     client_id : str
#         Id of the client.
#     dataloader : DataLoader
#         Local dataset used for training on the client.
#     device : str, torch.device (default='cpu')
#         Device type.
#     """
#     def update(self, optim_cls, optim_params, num_epochs, criterion):
#         """Algorithm 1 (ClientUpdate).
        
#         Parameters
#         ----------
#         optim_cls : 
#         optim_params :
#         num_epochs : int
#             Number of epochs, referred to as E.
#         criterion : 
#         """
#         self.model.train()
#         self.model.to(self.device)
#         optim = optim_cls(self.model.parameters(), **optim_params)
        
#         for i in range(num_epochs):
#             for i, (x, y) in enumerate(self.dataloader):
#                 x = x.to(self.device)
#                 y = y.to(self.device)
                
#                 optim.zero_grad()
                
#                 logits = self.model(x)
#                 loss = criterion(logits, y)
#                 loss.backward()
#                 optim.step()
            
#         # why set to device ?
#         self.model.to('cpu')

In [843]:
# dataset.py
from torch.utils.data import Dataset, DataLoader


class SplitDataset(Dataset):
    """Dataset for a client partitioned by a list of indices."""
    def __init__(self, dataset, indices):
        self.dataset = dataset
        self.indices = list(indices)
        
    def __len__(self):
        return len(self.indices)
    
    def __getitem__(self, index):
        return self.dataset[self.indices[index]]

In [1181]:
# sampling.py
import numpy as np
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader


def get_iid_shards(labels,
                   num_clients=100,
                   client_ids=None,
                   seed=None):
    """Returns a homogeneous (IID) mapping of client ID's to sample indices of a dataset.
    
    Parameters
    ----------
    labels : list, np.ndarray, torch.Tensor
        List of class labels for a dataset.
    num_clients : int (default=100)
        Number of clients used for training.
    client_ids : list, np.ndarray (default=None)
        List of client ID's. If `None`, defaults to a `range(num_clients)`.
    seed : int (default=None)
        Random state.
       
    Notes
    -----
    ```
    num_samples_per_client = len(labels) // num_clients # 60000 / 100 = 600 for MNIST
    ```
    """
    random_state = np.random.RandomState(seed)
    client_ids = client_ids or list(range(num_clients))
    
    # randomly shuffle sample indices to generate IID (homogeneous) clients
    indices_shuffled = random_state.choice(range(len(labels)), len(labels), replace=False)
    num_samples_per_client = len(indices_shuffled) // num_clients

    # assign `num_samples_per_client` random samples to each client
    for i, client_id in enumerate(client_ids):
        client_indices = indices_shuffled[i * num_samples_per_client : (i + 1) * num_samples_per_client]
        client_to_shard[client_id] = client_indices
    
    return client_to_shard
    
    
def get_non_iid_shards(labels, 
                       num_clients=100,
                       client_ids=None,
                       shard_size=300,
                       drop_last=False,
                       seed=None):
    """Returns a heterogeneous (non-IID) mapping of client ID's to sample indices of a dataset.
    
    Parameters
    ----------
    labels : list, np.ndarray, torch.Tensor
        List of class labels for a dataset.
    num_clients : int (default=100)
        Number of clients used for training.
    client_ids : list, np.ndarray (default=None)
        List of client ID's. If `None`, defaults to a `range(num_clients)`.
    shard_size : int (default=300)
        Size of each shard to split labels by.
    seed : int (default=None)
        Random state.
       
    Notes
    -----
    ```
    num_shards = len(labels) // shard_size # 60000 / 300 = 200 for MNIST
    num_shards_per_client = num_shards // num_clients # 200 / 100 = 2 for MNIST
    ```
    """
    random_state = np.random.RandomState(seed)
    client_ids = client_ids or list(range(num_clients))
    classes = np.unique(labels)
    num_classes = len(classes)

    label_to_indices = {}
    client_to_shard = {}
    shards = []
    shard_labels = []
    # map each class to the sample indices in the dataset
    for label in classes:
        label_to_indices[label] = np.where(labels == label)[0]

    # split each class labels indices to shards of size `shard_size`
    for label in classes:
        indices = label_to_indices[label]
        num_extra = len(indices) % shard_size
        # (num_shards_per_label, shard_size)
        label_shards = indices[:-num_extra].reshape(-1, shard_size).tolist()

        if num_extra > 0 and not drop_last:
            label_shards.append(indices[-num_extra:])

        # store the shards in a list to resample and assign to a client
        shard_labels.extend([label] * len(label_shards))
        shards.extend(label_shards)

    num_shards = len(shard_labels) # number of shards of size `shard_size`
    num_shards_per_client = num_shards // num_clients # how many shards fit in each client

    # shuffle the shards and assign `num_shards_per_client` to each client
    # each client should ideally have `num_shards_per_client` from a different class lablel
    shard_indices_shuffled = random_state.choice(range(num_shards), 
                                                 num_shards,
                                                 replace=False)
    for i, client_id in enumerate(client_ids):
        client_to_shard[client_id] = []
        shard_indices = shard_indices_shuffled[i * num_shards_per_client : (i + 1) * num_shards_per_client]
        for idx in shard_indices:
            client_to_shard[client_id].extend(shards[idx])
            
    return client_to_shard


def get_client_data(dataset, 
                    num_clients,
                    client_ids=None,
                    is_iid=True,
                    shard_size=300,
                    drop_last=True,
                    seed=None,
                    **kwargs):
    """Returns a mapping of client ID's to their corresponding train & validation dataloaders.
    
    Parameters
    ----------
    is_iid : bool (default=True)
        Boolean whether data is IID (Homogeneous) or non-IID (Heterogeneous)
    **kwargs
        Additional parameters used when instantiating each clients dataloader 
        
    """
    client_to_data = {}
    labels = dataset.targets
    if is_iid:
        client_to_shard = get_iid_shards(
            labels,
            num_clients=num_clients,
            client_ids=client_ids,
            seed=seed
        )
    else: # non-iid
        client_to_shard = get_non_iid_shards(
            labels,
            num_clients=num_clients,
            client_ids=client_ids,
            shard_size=shard_size,
            drop_last=drop_last,
            seed=seed
        )
        
    # iterate through each client and create train/val dataloaders using the shard indices
    for k in client_to_shard.keys():
        client_indices = client_to_shard[k]
#         if val_size > 0:
#             train_indices, val_indices = train_test_split(client_indices, test_size=val_size,  random_state=seed)
#             train_dataset = SplitDataset(dataset, train_indices)
#             val_dataset = SplitDataset(dataset, val_indices)
#         else:
#             train_dataset = SplitDataset(dataset, client_indices)
#             val_dataset = None
        client_dataset = SplitDataset(dataset, client_indices)
        client_to_data[k] = DataLoader(client_dataset, drop_last=drop_last, **kwargs)
#             'val': DataLoader(train_dataset, drop_last=drop_last, **kwargs)
        
    return client_to_data


def get_clients(dataset,
                method='fedavg',
                num_clients=100,
                client_ids=None,
                is_iid=True,
                shard_size=300,
                drop_last=True,
                batch_size=32,
                num_workers=0,
                device='cpu',
                **kwargs):
    if method == 'fedavg':
        client_cls = Client
    elif method == 'fedprox':
        client_cls = FedProxClient
    else:
        raise ValueError(f'Unrecognized `method` {method}')
    client_to_data = get_client_data(
        dataset,
        num_clients=num_clients,
        client_ids=client_ids,
        is_iid=is_iid,
        shard_size=shard_size,
        drop_last=drop_last,
        batch_size=batch_size,
        num_workers=num_workers,
        **kwargs
    )
    clients = {
        k: client_cls(k, dl, device=device)
        for k, dl
        in client_to_data.items()
    }
    return clients

In [1140]:
from collections import defaultdict
from torch.utils.tensorboard.writer import SummaryWriter


class BaseFederater:
    """Base Federater.
    
    Parameters
    ----------
    model : nn.Module
    dataset : torch.utils.data.Dataset
    criterion : nn.Module
        Loss function to optimize on each client.
    num_clients (K) : int (default=100)
        Number of clients to partition `dataset`.
    batch_size (B) : int, dict[str, int] (defualt=32)
        Number of samples per batch to load on each client. 
        Can be a dictionary mapping each client ID to it's corresponding batch size
        to allow for various batch sizes across clients.
    shard_size : int (default=300)
    is_iid : bool (default=False)
    drop_last : bool (default=True)
    num_workers : int (default=0)
    device : str (default='cpu')
    """
    def __init__(self, 
                 model,
                 clients,
                 criterion,
                 optim_cls,
                 optim_params=None,
                 C=0.1,
                 eval_every_n=1,
                 num_workers=0,
                 device='cpu',
                 output_dir=None):
        self.model = model
        self.clients = clients
        self.criterion = criterion
        self.optim_cls = optim_cls
        self.optim_params = optim_params
        self.C = C
        self.eval_every_n = eval_every_n
        self.num_workers = num_workers
        self.device = device
        self.output_dir = output_dir
        self.writer = SummaryWriter(self.output_dir)
        self.server = None
        
        self.num_clients = len(self.clients)
        self.num_samples = sum([len(c) for c in self.clients.values()]) # n
        self.client_ids = list(self.clients.keys())
        
        self._global_round = 0
        
    @property
    def global_round(self):
        return self._global_round

    @global_round.setter
    def global_round(self, global_round):
        self._global_round = global_round
    
    def update(self, client_ids, num_epochs, *args, **kwargs):
        raise NotImplementedError
        
    def get_optimizer(self, model, optim_params):
        raise NotImplementedError
        
    def fit(self, num_rounds, num_epochs, val_dl=None):
        m = max(int(np.ceil(self.num_clients * self.C)), 1)
        for t in range(num_rounds):
            self.global_round += 1
            S = np.random.choice(self.client_ids, m, replace=False)
            train_metrics = self.update(client_ids=S, num_epochs=num_epochs)
            
            if self.eval_every_n is not None and t % self.eval_every_n == 0:
                template_str = f'round {self.global_round}'
                val_metrics = self.validate(val_dl)
                for metric, value in train_metrics.items():
                    self.writer.add_scalar(f'train/{metric}', value, self.global_round)
                    template_str += f' - train_{metric} : {value:0.4f}'
                for metric, value in val_metrics.items():
                    self.writer.add_scalar(f'val/{metric}', value, self.global_round)
                    template_str += f' - val_{metric} : {value:0.4f}'
                
                print(template_str)
        
    def validate(self, val_dl):
        return self.server.validate(val_dl, self.criterion)
        
    def send_model(self):
        """Send the current state of the global model to each client."""
        for client_id, client in self.clients.items():
            client.model = self.server.send_model()
            
    def get_gradients(self):
        pass
            
            
class FedAvg(BaseFederater):
    """FederatedAveraging (FedAvg) as proposed in "Communication-Efficient Learning 
    of Deep Networks from Decentralized Data" (https://arxiv.org/pdf/1602.05629.pdf)
    
    Parameters
    ----------
    
    """
    def __init__(self,
                 model,
                 clients,
                 criterion,
                 optim_cls,
                 optim_params=None,
                 C=0.1,
                 eval_every_n=1,
                 device='cpu',
                 num_workers=0,
                 output_dir=None):
        super(FedAvg, self).__init__(model, 
                                     clients,
                                     criterion,
                                     optim_cls=optim_cls,
                                     optim_params=optim_params,
                                     C=C,
                                     eval_every_n=eval_every_n,
                                     num_workers=num_workers,
                                     device=device,
                                     output_dir=output_dir)
        
        self.server = FedAvgServer(self.model, device=self.device)
        self.client_weights = [len(c) / self.num_samples for c in self.clients.values()] # n_k / n
        
    def get_client_optimizer(self, model, optim_params):
        optim_params = optim_params or {}
        return self.optim_cls(model.parameters(), **optim_params)
        
    def update(self, client_ids, num_epochs):
        """Performs a full round of training on each client for E epochs.
        
        Parameters
        ----------
        client_ids : list, np.ndarray
            List of client ID's to train on the current round (S_t).
        E : int
            Number of epochs to train on each client.
            
        Returns
        -------
        metrics_dict : dict
            Dictionary mapping each metric to the average score across `client_ids`
        """
        # send the global model parameters to each client
        self.send_model()
        # instantiate the client optimizer
        optimizer = self.get_client_optimizer(self.clients[k].model, self.optim_params)
        
        metrics_dict = defaultdict(lambda: 0)
        for k in client_ids:
            
            # update the client weights and record the local training metrics
            client_metrics_dict = self.clients[k].update(
                optimizer,
                criterion=self.criterion,
                num_epochs=num_epochs
            )
            
            # update the summary writer and record loss/acc from the client
            for metric, value in client_metrics_dict.items():
                self.writer.add_scalar(f'client/{k}/{metric}', value, self.global_round)
                metrics_dict[metric] += value / len(client_ids)
        
        # note that we use all clients as shown in the "Server executes" section of "Algorithm 1"
        # when aggregating client weights, not just the `m` clients sampled for the current round
        self.server.aggregate(self.clients, self.client_weights)
        
        return metrics_dict
    
    
class FedProx(BaseFederater):
    """FederatedAveraging (FedAvg) as proposed in "Communication-Efficient Learning 
    of Deep Networks from Decentralized Data" (https://arxiv.org/pdf/1602.05629.pdf)
    
    Parameters
    ----------
    
    """
    def __init__(self,
                 model,
                 clients,
                 criterion,
                 optim_cls,
                 C,
                 mu=0,
                 optim_params=None,
                 eval_every_n=1,
                 device='cpu',
                 num_workers=0,
                 output_dir=None):
        super(FedProx, self).__init__(model, 
                                     clients,
                                     criterion,
                                      optim_cls,
                                     C=C,
                                      optim_params=optim_params,
                                     eval_every_n=eval_every_n,
                                     num_workers=num_workers,
                                     device=device,
                                     output_dir=output_dir)
        self.mu = mu
        self.server = FedAvgServer(self.model, device=self.device)
        self.client_weights = [len(c) / self.num_samples for c in self.clients.values()] # n_k / n
        
    def get_client_optimizer(self, model, optim_params):
        optim_params = optim_params or {}
        return self.optim_cls(model.parameters(), **optim_params)
        
    def update(self, client_ids, num_epochs):
        """Performs a full round of training on each client for E epochs.
        
        Parameters
        ----------
        client_ids : list, np.ndarray
            List of client ID's to train on the current round (S_t).
        E : int
            Number of epochs to train on each client.
            
        Returns
        -------
        metrics_dict : dict
            Dictionary mapping each metric to the average score across `client_ids`
        """
        # send the global model parameters to each client
        self.send_model()
        
        metrics_dict = defaultdict(lambda: 0)
        for k in client_ids:
            # instantiate a client optimizer
            client_optimizer = self.get_client_optimizer(self.clients[k].model, self.optim_params)
            criterion = FedProxLoss(self.mu, self.model.state_dict())
            # update the client weights and record the local training metrics
            client_metrics_dict = self.clients[k].update(
                client_optimizer,
                criterion=criterion,
                num_epochs=num_epochs
            )
            
            # update the summary writer and record loss/acc from the client
            for metric, value in client_metrics_dict.items():
                self.writer.add_scalar(f'client/{k}/{metric}', value, self.global_round)
                metrics_dict[metric] += value / len(client_ids)
        
        # note that we use all clients as shown in the "Server executes" section of "Algorithm 1"
        # when aggregating client weights, not just the `m` clients sampled for the current round
        self.server.aggregate(self.clients, self.client_weights)
        
        return metrics_dict

In [1194]:
import torch.nn.functional as F


class FedProxLoss(nn.Module):
    """
    Empirically, we observe that increasing µ leads to smaller dissimilarity among local functions
    Fk, and that the dissimilarity metric is consistent with the
    training loss.
    """
    def __init__(self, mu, weights_initial):
        super(FedProxLoss, self).__init__()
        self.mu = mu
        self.weights_initial = weights_initial
        self._layer_names = list(weights_initial.keys())
        
        
    def forward(self, pred, target, weights_new):
        loss = 0
        xent_loss = F.cross_entropy(pred, target)
#         loss = xent_loss
        weights_delta = []
        for layer in self._layer_names:
            layer_delta = torch.sum(torch.pow(self.weights_initial[layer] - weights_new[layer], 2))
            loss += self.mu / 2 * layer_delta
            weights_delta.append(layer_delta)
#             weights_delta.append(
#                 torch.sum(torch.pow(self.weights_initial[layer] - weights_new[layer], 2))
#             )
        weights_delta = torch.sum(torch.stack(weights_delta))
        loss += xent_loss
#         print(xent_loss)
#         loss = xent_loss + self.mu / 2. * weights_delta
        return loss, xent_loss, weights_delta
#         loss += 

In [1075]:
# model.py
import torch.nn as nn


class CNN(nn.Module):
    """CNN described in "Communication-Efficient Learning of Deep Networks
    from Decentralized Data" (https://arxiv.org/pdf/1602.05629.pdf).
    
    Parameters
    ----------
    in_features : int (default=1)
        Number of channels in the input image.
    num_classes : int (default=10)
        Number of class labels.
    """
    def __init__(self, in_features=1, num_classes=10):
        super(CNN, self).__init__()
        self._in_features = in_features
        self._num_classes = num_classes
        
        self.conv1 = nn.Conv2d(self._in_features,
                               32,
                               kernel_size=5)
        self.conv2 = nn.Conv2d(32,
                               64,
                               kernel_size=5)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, self._num_classes)
        
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=(2,2))
        
    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.maxpool(x)
        x = self.relu(self.conv2(x))
        x = self.maxpool(x)
        x = x.reshape(len(x), -1)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x
    
    
class MLP(nn.Module):
    """CNN described in "Communication-Efficient Learning of Deep Networks
    from Decentralized Data" (https://arxiv.org/pdf/1602.05629.pdf).
    
    Parameters
    ----------
    in_features : int (default=784)
        Number input features.
    hidden_dim : int (default=200)
        Number of hidden units.
    num_classes : int (default=10)
        Number of class labels.
    """
    def __init__(self, in_features=784, hidden_dim=200, num_classes=10):
        super(MLP, self).__init__()
        self._in_features = in_features
        self._hidden_dim = hidden_dim
        self._num_classes = num_classes
        
        self.fc1 = nn.Linear(self._in_features, self._hidden_dim)
        self.fc2 = nn.Linear(self._hidden_dim, self._num_classes)
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [898]:
test_dl = DataLoader(test_ds, batch_size=32, shuffle=False)

In [1191]:
device = 'cpu'
optim_cls = torch.optim.SGD
optim_params = {
    'lr': 0.01,
    'momentum': 0.9,
    'weight_decay': 1e-5
}

criterion = nn.CrossEntropyLoss()

In [1197]:
for x,y in train_dl:
    break

In [1200]:
p = model(x[None,...])

In [1233]:
from torch.optim.optimizer import Optimizer, required

class FedProx(Optimizer):
    r"""Implements FedAvg and FedProx. Local Solver can have momentum.
    Nesterov momentum is based on the formula from
    `On the importance of initialization and momentum in deep learning`__.
    Args:
        params (iterable): iterable of parameters to optimize or dicts defining
            parameter groups
        ratio (float): relative sample size of client
        gmf (float): global/server/slow momentum factor
        mu (float): parameter for proximal local SGD
        lr (float): learning rate
        momentum (float, optional): momentum factor (default: 0)
        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
        dampening (float, optional): dampening for momentum (default: 0)
        nesterov (bool, optional): enables Nesterov momentum (default: False)
    Example:
        >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
        >>> optimizer.zero_grad()
        >>> loss_fn(model(input), target).backward()
        >>> optimizer.step()
    __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf
    .. note::
        The implementation of SGD with Momentum/Nesterov subtly differs from
        Sutskever et. al. and implementations in some other frameworks.
        Considering the specific case of Momentum, the update can be written as
        .. math::
                  v = \rho * v + g \\
                  p = p - lr * v
        where p, g, v and :math:`\rho` denote the parameters, gradient,
        velocity, and momentum respectively.
        This is in contrast to Sutskever et. al. and
        other frameworks which employ an update of the form
        .. math::
             v = \rho * v + lr * g \\
             p = p - v
        The Nesterov version is analogously modified.
    """

    def __init__(self, params, ratio, gmf, lr=required, momentum=0, dampening=0,
                 weight_decay=0, nesterov=False, variance=0, mu=0):
        
        self.gmf = gmf
        self.ratio = ratio
        self.itr = 0
        self.a_sum = 0
        self.mu = mu


        if lr is not required and lr < 0.0:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if momentum < 0.0:
            raise ValueError("Invalid momentum value: {}".format(momentum))
        if weight_decay < 0.0:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))

        defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
                        weight_decay=weight_decay, nesterov=nesterov, variance=variance)
        if nesterov and (momentum <= 0 or dampening != 0):
            raise ValueError("Nesterov momentum requires a momentum and zero dampening")
        super(FedProx, self).__init__(params, defaults)


    def __setstate__(self, state):
        super(FedProx, self).__setstate__(state)
        for group in self.param_groups:
            group.setdefault('nesterov', False)

    def step(self, closure=None):
        """Performs a single optimization step.
        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """

        loss = None
        if closure is not None:
            loss = closure()

        print(self.state)
        for group in self.param_groups:
            weight_decay = group['weight_decay']
            momentum = group['momentum']
            dampening = group['dampening']
            nesterov = group['nesterov']
            

            for p in group['params']:
                if p.grad is None:
                    continue
                d_p = p.grad.data

                if weight_decay != 0:
                    d_p.add_(weight_decay, p.data)
                
                param_state = self.state[p]
                print('param_state')
                if 'old_init' not in param_state:
                    print('here')
                    param_state['old_init'] = torch.clone(p.data).detach()

                if momentum != 0:
                    if 'momentum_buffer' not in param_state:
                        buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()
                    else:
                        buf = param_state['momentum_buffer']
                        buf.mul_(momentum).add_(1 - dampening, d_p)
                    if nesterov:
                        d_p = d_p.add(momentum, buf)
                    else:
                        d_p = buf

                # apply proximal update
                d_p.add_(self.mu, p.data - param_state['old_init'])
                p.data.add_(-group['lr'], d_p)

        return loss

    def average(self):
        param_list = []
        for group in self.param_groups:
            for p in group['params']:
                p.data.mul_(self.ratio)
                param_list.append(p.data)

        communicate(param_list, dist.all_reduce)

        for group in self.param_groups:
            for p in group['params']:
                param_state = self.state[p]
                param_state['old_init'] = torch.clone(p.data).detach()
                # Reinitialize momentum buffer
                if 'momentum_buffer' in param_state:
                    param_state['momentum_buffer'].zero_()

In [None]:
class FedProx(Optimizer):
    def __init__(self, 
                 params,
                 lr=required,
                 mu=0,
                 momentum=0,
                 dampening=0,
                 weight_decay=0,
                 nesterov=False):
        self.mu = mu
        
        if lr is not required and lr < 0.0:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if momentum < 0.0:
            raise ValueError("Invalid momentum value: {}".format(momentum))
        if weight_decay < 0.0:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))

        defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
                        weight_decay=weight_decay, nesterov=nesterov)
        if nesterov and (momentum <= 0 or dampening != 0):
            raise ValueError("Nesterov momentum requires a momentum and zero dampening")
        super(FedProx, self).__init__(params, defaults)

    def __setstate__(self, state):
        super(FedProx, self).__setstate__(state)
        for group in self.param_groups:
            group.setdefault('nesterov', False)

    @torch.no_grad()
    def step(self, closure=None):
        """Performs a single optimization step.
        
        Parameters
        ----------
        closure : callable (default=NOne)
            A closure that reevaluates the model and returns the loss.
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            params_with_grad = []
            d_p_list = []
            momentum_buffer_list = []
            weight_decay = group['weight_decay']
            momentum = group['momentum']
            dampening = group['dampening']
            nesterov = group['nesterov']
            lr = group['lr']

            for p in group['params']:
                if p.grad is not None:
                    params_with_grad.append(p)
                    d_p_list.append(p.grad)

                    state = self.state[p]
                    if 'momentum_buffer' not in state:
                        momentum_buffer_list.append(None)
                    else:
                        momentum_buffer_list.append(state['momentum_buffer'])
                    
#                     # add the initial (global) weights on the first step call
#                     if 'weights_initial' not in state:
#                         state['weights_initial'] = p.data.clone().detach()

            for i, param in enumerate(params_with_grad):
                d_p = d_p_list[i]
                if weight_decay != 0:
                    d_p = d_p.add(param, alpha=weight_decay)

                if momentum != 0:
                    buf = momentum_buffer_list[i]
                    if buf is None:
                        buf = torch.clone(d_p).detach()
                        momentum_buffer_list[i] = buf
                    else:
                        buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
                    if nesterov:
                        d_p = d_p.add(buf, alpha=momentum)
                    else:
                        d_p = buf
                
                # add the proximal term shown in equation (2)
                d_p.add_(p.data - param['initial_weights'], alpha=self.mu)
                # gradient descent
                param.add_(d_p, alpha=-lr)

            # update momentum_buffers in state
            for p, momentum_buffer in zip(params_with_grad, momentum_buffer_list):
                state = self.state[p]
                state['momentum_buffer'] = momentum_buffer

        return loss
    
    def set_initial_weights(self, params):
        for p in params:
            for p in group['params']:
                # add the initial (global) weights to each parameter state
                state = self.state[p]
                if 'initial_weights' not in state:
                    state['initial_weights'] = p.data.clone().detach()


In [1314]:
p = tmp.param_groups[0]['params'][0]#[0]
p

Parameter containing:
tensor([[[[ 0.0135,  0.0834, -0.1635, -0.1240,  0.0150],
          [ 0.1076,  0.0894,  0.1221, -0.0690, -0.1478],
          [-0.1232,  0.1202,  0.1382,  0.1460,  0.0215],
          [ 0.1577,  0.0469,  0.0594,  0.0215,  0.1838],
          [ 0.1121, -0.2022, -0.1584,  0.0095,  0.0925]]],


        [[[-0.0625,  0.1570,  0.0061,  0.1393,  0.1369],
          [-0.1410,  0.1502,  0.0581,  0.1548, -0.0272],
          [-0.1242, -0.0803, -0.0600, -0.1837, -0.0414],
          [ 0.0703, -0.0023,  0.0040, -0.1584, -0.1531],
          [ 0.0853,  0.2344,  0.2251,  0.0915,  0.1444]]],


        [[[ 0.0254,  0.0159, -0.1718,  0.0243,  0.0359],
          [ 0.1028,  0.0580,  0.1281, -0.1112,  0.0933],
          [ 0.0104,  0.1602, -0.0586,  0.0608, -0.0146],
          [-0.0350, -0.0496,  0.1493, -0.1740, -0.0495],
          [-0.0526, -0.0541, -0.1318,  0.1012,  0.0460]]],


        [[[ 0.0523, -0.1491, -0.1451, -0.1036,  0.1483],
          [ 0.1593, -0.1866,  0.0807, -0.1690,  0.1838

In [1315]:
tmp.state[p]

{'old_init': tensor([[[[-0.0152,  0.0386, -0.1857, -0.1259,  0.0449],
           [ 0.0759,  0.0275,  0.0646, -0.0870, -0.1584],
           [-0.1323,  0.0762,  0.0831,  0.1156,  0.0071],
           [ 0.1702,  0.0632,  0.0493,  0.0004,  0.1970],
           [ 0.1274, -0.1625, -0.1157,  0.0324,  0.1151]]],
 
 
         [[[-0.0372,  0.1566,  0.0055,  0.1552,  0.1695],
           [-0.1298,  0.1493,  0.0488,  0.1698, -0.0087],
           [-0.1463, -0.0848, -0.0599, -0.1671,  0.0008],
           [ 0.0395, -0.0326, -0.0025, -0.1419, -0.1266],
           [ 0.0562,  0.1891,  0.1971,  0.0958,  0.1444]]],
 
 
         [[[-0.0024, -0.0018, -0.1555,  0.0271,  0.0251],
           [ 0.0881,  0.0476,  0.1273, -0.0854,  0.1086],
           [ 0.0024,  0.1491, -0.0637,  0.0970,  0.0140],
           [-0.0280, -0.0347,  0.1469, -0.1742, -0.0478],
           [-0.0275, -0.0190, -0.0903,  0.1244,  0.0652]]],
 
 
         [[[ 0.0477, -0.1364, -0.1304, -0.0934,  0.1813],
           [ 0.1439, -0.1863,  0.0850, -0.

In [1312]:
tmp.state[p]

{}

In [1302]:
tmp.param_groups[0]['params'][0][0]

tensor([[[ 0.0135,  0.0834, -0.1635, -0.1240,  0.0150],
         [ 0.1076,  0.0894,  0.1221, -0.0690, -0.1478],
         [-0.1232,  0.1202,  0.1382,  0.1460,  0.0215],
         [ 0.1577,  0.0469,  0.0594,  0.0215,  0.1838],
         [ 0.1121, -0.2022, -0.1584,  0.0095,  0.0925]]],
       grad_fn=<SelectBackward>)

In [1304]:
for k,v in tmp.state.items():
    break

In [1308]:
k

Parameter containing:
tensor([[[[ 0.0135,  0.0834, -0.1635, -0.1240,  0.0150],
          [ 0.1076,  0.0894,  0.1221, -0.0690, -0.1478],
          [-0.1232,  0.1202,  0.1382,  0.1460,  0.0215],
          [ 0.1577,  0.0469,  0.0594,  0.0215,  0.1838],
          [ 0.1121, -0.2022, -0.1584,  0.0095,  0.0925]]],


        [[[-0.0625,  0.1570,  0.0061,  0.1393,  0.1369],
          [-0.1410,  0.1502,  0.0581,  0.1548, -0.0272],
          [-0.1242, -0.0803, -0.0600, -0.1837, -0.0414],
          [ 0.0703, -0.0023,  0.0040, -0.1584, -0.1531],
          [ 0.0853,  0.2344,  0.2251,  0.0915,  0.1444]]],


        [[[ 0.0254,  0.0159, -0.1718,  0.0243,  0.0359],
          [ 0.1028,  0.0580,  0.1281, -0.1112,  0.0933],
          [ 0.0104,  0.1602, -0.0586,  0.0608, -0.0146],
          [-0.0350, -0.0496,  0.1493, -0.1740, -0.0495],
          [-0.0526, -0.0541, -0.1318,  0.1012,  0.0460]]],


        [[[ 0.0523, -0.1491, -0.1451, -0.1036,  0.1483],
          [ 0.1593, -0.1866,  0.0807, -0.1690,  0.1838

In [1301]:
tmp.param_groups[0]['params'][0][0] in tmp.state

False

In [1290]:
tmp.param_groups[0]

{'params': [Parameter containing:
  tensor([[[[ 0.0135,  0.0834, -0.1635, -0.1240,  0.0150],
            [ 0.1076,  0.0894,  0.1221, -0.0690, -0.1478],
            [-0.1232,  0.1202,  0.1382,  0.1460,  0.0215],
            [ 0.1577,  0.0469,  0.0594,  0.0215,  0.1838],
            [ 0.1121, -0.2022, -0.1584,  0.0095,  0.0925]]],
  
  
          [[[-0.0625,  0.1570,  0.0061,  0.1393,  0.1369],
            [-0.1410,  0.1502,  0.0581,  0.1548, -0.0272],
            [-0.1242, -0.0803, -0.0600, -0.1837, -0.0414],
            [ 0.0703, -0.0023,  0.0040, -0.1584, -0.1531],
            [ 0.0853,  0.2344,  0.2251,  0.0915,  0.1444]]],
  
  
          [[[ 0.0254,  0.0159, -0.1718,  0.0243,  0.0359],
            [ 0.1028,  0.0580,  0.1281, -0.1112,  0.0933],
            [ 0.0104,  0.1602, -0.0586,  0.0608, -0.0146],
            [-0.0350, -0.0496,  0.1493, -0.1740, -0.0495],
            [-0.0526, -0.0541, -0.1318,  0.1012,  0.0460]]],
  
  
          [[[ 0.0523, -0.1491, -0.1451, -0.1036,  0.1483]

In [1280]:
aa = torch.tensor(3)

In [1285]:
aa.add()

tensor(15)

In [1273]:
torch.clone(pp[0][0]) == pp[0][0].clone()

tensor([[[True, True, True, True, True],
         [True, True, True, True, True],
         [True, True, True, True, True],
         [True, True, True, True, True],
         [True, True, True, True, True]]])

In [1234]:
tmp = FedProx(model.parameters(), 1, 1, 1)

In [1247]:
p = model(x[None,...])
loss = criterion(p, torch.tensor([y]))

In [1248]:
loss.backward()

In [1246]:
tmp.zero_grad()

In [1260]:
for group in tmp.param_groups:
    break

In [1261]:
for p in group['params']:
    break

In [1249]:
tmp.step()

defaultdict(<class 'dict'>, {Parameter containing:
tensor([[[[ 0.0135,  0.0834, -0.1635, -0.1240,  0.0150],
          [ 0.1076,  0.0894,  0.1221, -0.0690, -0.1478],
          [-0.1232,  0.1202,  0.1382,  0.1460,  0.0215],
          [ 0.1577,  0.0469,  0.0594,  0.0215,  0.1838],
          [ 0.1121, -0.2022, -0.1584,  0.0095,  0.0925]]],


        [[[-0.0625,  0.1570,  0.0061,  0.1393,  0.1369],
          [-0.1410,  0.1502,  0.0581,  0.1548, -0.0272],
          [-0.1242, -0.0803, -0.0600, -0.1837, -0.0414],
          [ 0.0703, -0.0023,  0.0040, -0.1584, -0.1531],
          [ 0.0853,  0.2344,  0.2251,  0.0915,  0.1444]]],


        [[[ 0.0254,  0.0159, -0.1718,  0.0243,  0.0359],
          [ 0.1028,  0.0580,  0.1281, -0.1112,  0.0933],
          [ 0.0104,  0.1602, -0.0586,  0.0608, -0.0146],
          [-0.0350, -0.0496,  0.1493, -0.1740, -0.0495],
          [-0.0526, -0.0541, -0.1318,  0.1012,  0.0460]]],


        [[[ 0.0523, -0.1491, -0.1451, -0.1036,  0.1483],
          [ 0.1593, -0.18

###### a = defaultdict(dict)

In [1250]:
tmp.state

defaultdict(dict,
            {Parameter containing:
             tensor([[[[ 0.0135,  0.0834, -0.1635, -0.1240,  0.0150],
                       [ 0.1076,  0.0894,  0.1221, -0.0690, -0.1478],
                       [-0.1232,  0.1202,  0.1382,  0.1460,  0.0215],
                       [ 0.1577,  0.0469,  0.0594,  0.0215,  0.1838],
                       [ 0.1121, -0.2022, -0.1584,  0.0095,  0.0925]]],
             
             
                     [[[-0.0625,  0.1570,  0.0061,  0.1393,  0.1369],
                       [-0.1410,  0.1502,  0.0581,  0.1548, -0.0272],
                       [-0.1242, -0.0803, -0.0600, -0.1837, -0.0414],
                       [ 0.0703, -0.0023,  0.0040, -0.1584, -0.1531],
                       [ 0.0853,  0.2344,  0.2251,  0.0915,  0.1444]]],
             
             
                     [[[ 0.0254,  0.0159, -0.1718,  0.0243,  0.0359],
                       [ 0.1028,  0.0580,  0.1281, -0.1112,  0.0933],
                       [ 0.0104,  0.1602, -0.05

In [1244]:
tmp.step()

defaultdict(<class 'dict'>, {})
param_state
here
param_state
here
param_state
here
param_state
here
param_state
here
param_state
here
param_state
here
param_state
here


	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha) (Triggered internally at  /opt/conda/conda-bld/pytorch_1607370172916/work/torch/csrc/utils/python_arg_parser.cpp:882.)
  d_p.add_(self.mu, p.data - param_state['old_init'])


In [1245]:
tmp.state[]

{}

In [1243]:
tmp.param_groups[0]['params'][0][0].grad

  tmp.param_groups[0]['params'][0][0].grad


In [1227]:
pp[0][0].grad

  pp[0][0].grad


In [None]:
class SGD(Optimizer):
    r"""Implements stochastic gradient descent (optionally with momentum).
    Nesterov momentum is based on the formula from
    `On the importance of initialization and momentum in deep learning`__.
    Args:
        params (iterable): iterable of parameters to optimize or dicts defining
            parameter groups
        lr (float): learning rate
        momentum (float, optional): momentum factor (default: 0)
        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
        dampening (float, optional): dampening for momentum (default: 0)
        nesterov (bool, optional): enables Nesterov momentum (default: False)
    Example:
        >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
        >>> optimizer.zero_grad()
        >>> loss_fn(model(input), target).backward()
        >>> optimizer.step()
    __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf
    .. note::
        The implementation of SGD with Momentum/Nesterov subtly differs from
        Sutskever et. al. and implementations in some other frameworks.
        Considering the specific case of Momentum, the update can be written as
        .. math::
            \begin{aligned}
                v_{t+1} & = \mu * v_{t} + g_{t+1}, \\
                p_{t+1} & = p_{t} - \text{lr} * v_{t+1},
            \end{aligned}
        where :math:`p`, :math:`g`, :math:`v` and :math:`\mu` denote the
        parameters, gradient, velocity, and momentum respectively.
        This is in contrast to Sutskever et. al. and
        other frameworks which employ an update of the form
        .. math::
            \begin{aligned}
                v_{t+1} & = \mu * v_{t} + \text{lr} * g_{t+1}, \\
                p_{t+1} & = p_{t} - v_{t+1}.
            \end{aligned}
        The Nesterov version is analogously modified.
    """

    def __init__(self, params, lr=required, momentum=0, dampening=0,
                 weight_decay=0, nesterov=False):
        if lr is not required and lr < 0.0:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if momentum < 0.0:
            raise ValueError("Invalid momentum value: {}".format(momentum))
        if weight_decay < 0.0:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))

        defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
                        weight_decay=weight_decay, nesterov=nesterov)
        if nesterov and (momentum <= 0 or dampening != 0):
            raise ValueError("Nesterov momentum requires a momentum and zero dampening")
        super(SGD, self).__init__(params, defaults)

    def __setstate__(self, state):
        super(SGD, self).__setstate__(state)
        for group in self.param_groups:
            group.setdefault('nesterov', False)

    @torch.no_grad()
    def step(self, closure=None):
        """Performs a single optimization step.
        Args:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            params_with_grad = []
            d_p_list = []
            momentum_buffer_list = []
            weight_decay = group['weight_decay']
            momentum = group['momentum']
            dampening = group['dampening']
            nesterov = group['nesterov']
            lr = group['lr']

            for p in group['params']:
                if p.grad is not None:
                    params_with_grad.append(p)
                    d_p_list.append(p.grad)

                    state = self.state[p]
                    if 'momentum_buffer' not in state:
                        momentum_buffer_list.append(None)
                    else:
                        momentum_buffer_list.append(state['momentum_buffer'])

            F.sgd(params_with_grad,
                  d_p_list,
                  momentum_buffer_list,
                  weight_decay=weight_decay,
                  momentum=momentum,
                  lr=lr,
                  dampening=dampening,
                  nesterov=nesterov)

            # update momentum_buffers in state
            for p, momentum_buffer in zip(params_with_grad, momentum_buffer_list):
                state = self.state[p]
                state['momentum_buffer'] = momentum_buffer

        return loss

In [1230]:
optimizer.state['the']

{}

In [None]:
for i, param in enumerate(params):

    d_p = d_p_list[i]
    if weight_decay != 0:
        d_p = d_p.add(param, alpha=weight_decay)

    if momentum != 0:
        buf = momentum_buffer_list[i]

        if buf is None:
            buf = torch.clone(d_p).detach()
            momentum_buffer_list[i] = buf
        else:
            buf.mul_(momentum).add_(d_p, alpha=1 - dampening)

        if nesterov:
            d_p = d_p.add(buf, alpha=momentum)
        else:
            d_p = buf

    param.add_(d_p, alpha=-lr)

In [1207]:
pp = list(model.parameters())

In [1212]:
pp[0]

torch.nn.parameter.Parameter

In [1221]:
param_group = {'params': pp[0]}
params = param_group['params']
if isinstance(params, torch.Tensor):
    print('here')
    param_group['params'] = [params]

here


In [None]:
param_group['']

In [1217]:
param_group['params']

IndexError: too many indices for tensor of dimension 4

In [None]:
if not isinstance(param_groups[0], dict):
    param_groups = [{'params': param_groups}]

In [1195]:
model = CNN()
clients = get_clients(
    train_ds, 
    method='fedprox',
    num_clients=100,
    is_iid=False,
    shard_size=300,
    batch_size=50,
    num_workers=0,
    device=device
)
fed_cls = FedProx
federater = fed_cls(model, 
                   clients=clients,
                   optim_cls=optim_cls,
                   optim_params=optim_params,
                   criterion=criterion, 
                   C=0.1,
                    mu=1000,
                   output_dir='mu1'
                  )

In [1196]:
# 1
federater.fit(num_rounds=10, num_epochs=1, val_dl=test_dl)

round 1 - train_loss : 78.3555 - train_accuracy : 0.8820 - train_xent_loss : 0.9399 - train_weights_delta : 0.0808 - val_loss : 2.3053 - val_accuracy : 0.1034
round 2 - train_loss : 81.7870 - train_accuracy : 0.8213 - train_xent_loss : 0.9839 - train_weights_delta : 0.0840 - val_loss : 2.3001 - val_accuracy : 0.1042
round 3 - train_loss : 93.3500 - train_accuracy : 0.8510 - train_xent_loss : 0.8035 - train_weights_delta : 0.0922 - val_loss : 2.2993 - val_accuracy : 0.1529
round 4 - train_loss : 101.7852 - train_accuracy : 0.8333 - train_xent_loss : 0.7750 - train_weights_delta : 0.0989 - val_loss : 2.3005 - val_accuracy : 0.1253
round 5 - train_loss : 106.8303 - train_accuracy : 0.8857 - train_xent_loss : 0.6514 - train_weights_delta : 0.1007 - val_loss : 2.3055 - val_accuracy : 0.1460
round 6 - train_loss : 114.3917 - train_accuracy : 0.8423 - train_xent_loss : 0.6932 - train_weights_delta : 0.1076 - val_loss : 2.2933 - val_accuracy : 0.1507
round 7 - train_loss : 126.1735 - train_acc

In [1163]:
# 1
federater.fit(num_rounds=50, num_epochs=1, val_dl=test_dl)

round 1 - train_loss : 1.0243 - train_accuracy : 0.8240 - train_xent_loss : 0.9379 - train_weights_delta : 0.0888 - val_loss : 2.3261 - val_accuracy : 0.1138
round 2 - train_loss : 0.9458 - train_accuracy : 0.8500 - train_xent_loss : 0.8507 - train_weights_delta : 0.0953 - val_loss : 2.3223 - val_accuracy : 0.1138
round 3 - train_loss : 0.8591 - train_accuracy : 0.8500 - train_xent_loss : 0.7551 - train_weights_delta : 0.1012 - val_loss : 2.3194 - val_accuracy : 0.1162
round 4 - train_loss : 0.7957 - train_accuracy : 0.8503 - train_xent_loss : 0.6841 - train_weights_delta : 0.1058 - val_loss : 2.3144 - val_accuracy : 0.1673
round 5 - train_loss : 0.7910 - train_accuracy : 0.8477 - train_xent_loss : 0.6673 - train_weights_delta : 0.1152 - val_loss : 2.3000 - val_accuracy : 0.1278
round 6 - train_loss : 0.7412 - train_accuracy : 0.8677 - train_xent_loss : 0.6166 - train_weights_delta : 0.1148 - val_loss : 2.2938 - val_accuracy : 0.1147
round 7 - train_loss : 0.6914 - train_accuracy : 0.8

In [1153]:
# 0.1
federater.fit(num_rounds=50, num_epochs=1, val_dl=test_dl)

round 1 - train_loss : 1.1173 - train_accuracy : 0.7953 - train_xent_loss : 1.1106 - train_weights_delta : 0.0723 - val_loss : 2.3110 - val_accuracy : 0.0953
round 2 - train_loss : 1.0041 - train_accuracy : 0.8437 - train_xent_loss : 0.9962 - train_weights_delta : 0.0832 - val_loss : 2.3044 - val_accuracy : 0.1025
round 3 - train_loss : 0.8556 - train_accuracy : 0.8340 - train_xent_loss : 0.8465 - train_weights_delta : 0.0915 - val_loss : 2.2992 - val_accuracy : 0.0984
round 4 - train_loss : 0.8237 - train_accuracy : 0.8620 - train_xent_loss : 0.8142 - train_weights_delta : 0.0939 - val_loss : 2.2952 - val_accuracy : 0.1003
round 5 - train_loss : 0.7290 - train_accuracy : 0.8500 - train_xent_loss : 0.7186 - train_weights_delta : 0.0998 - val_loss : 2.2968 - val_accuracy : 0.1227
round 6 - train_loss : 0.6908 - train_accuracy : 0.8523 - train_xent_loss : 0.6795 - train_weights_delta : 0.1068 - val_loss : 2.2913 - val_accuracy : 0.1278
round 7 - train_loss : 0.6580 - train_accuracy : 0.8

In [1134]:
federater.fit(num_rounds=10, num_epochs=1, val_dl=test_dl)

round 1 - train_loss : 0.8697 - train_accuracy : 0.8507 - val_loss : 2.3002 - val_accuracy : 0.1090
round 2 - train_loss : 0.8123 - train_accuracy : 0.8583 - val_loss : 2.2975 - val_accuracy : 0.1299
round 3 - train_loss : 0.7712 - train_accuracy : 0.8477 - val_loss : 2.2965 - val_accuracy : 0.1081
round 4 - train_loss : 0.6790 - train_accuracy : 0.8703 - val_loss : 2.2955 - val_accuracy : 0.1208
round 5 - train_loss : 0.6348 - train_accuracy : 0.8413 - val_loss : 2.2920 - val_accuracy : 0.1671
round 6 - train_loss : 0.6197 - train_accuracy : 0.8447 - val_loss : 2.2789 - val_accuracy : 0.2053
round 7 - train_loss : 0.5574 - train_accuracy : 0.8797 - val_loss : 2.2858 - val_accuracy : 0.1840
round 8 - train_loss : 0.5512 - train_accuracy : 0.8767 - val_loss : 2.2874 - val_accuracy : 0.1009
round 9 - train_loss : 0.4433 - train_accuracy : 0.9333 - val_loss : 2.3899 - val_accuracy : 0.1009
round 10 - train_loss : 0.5090 - train_accuracy : 0.8667 - val_loss : 2.3651 - val_accuracy : 0.1009

In [1096]:
federater.fit(num_rounds=10, num_epochs=1, val_dl=test_dl)

round 1 - train_loss : 0.8465 - train_accuracy : 0.8743 - val_loss : 2.2897 - val_accuracy : 0.2361
round 2 - train_loss : 0.7888 - train_accuracy : 0.8497 - val_loss : 2.2847 - val_accuracy : 0.1830
round 3 - train_loss : 0.7610 - train_accuracy : 0.8487 - val_loss : 2.2787 - val_accuracy : 0.1567
round 4 - train_loss : 0.6402 - train_accuracy : 0.9013 - val_loss : 2.2836 - val_accuracy : 0.1010
round 5 - train_loss : 0.6472 - train_accuracy : 0.8333 - val_loss : 2.2681 - val_accuracy : 0.1503
round 6 - train_loss : 0.5997 - train_accuracy : 0.8433 - val_loss : 2.2516 - val_accuracy : 0.2674
round 7 - train_loss : 0.5411 - train_accuracy : 0.8953 - val_loss : 2.2446 - val_accuracy : 0.1829
round 8 - train_loss : 0.5276 - train_accuracy : 0.8637 - val_loss : 2.2356 - val_accuracy : 0.1947
round 9 - train_loss : 0.4706 - train_accuracy : 0.9150 - val_loss : 2.2518 - val_accuracy : 0.0958
round 10 - train_loss : 0.4918 - train_accuracy : 0.8500 - val_loss : 2.2306 - val_accuracy : 0.1027

In [1093]:
federater.fit(num_rounds=10, num_epochs=1, val_dl=test_dl)

round 1 - train_loss : 0.5057 - train_accuracy : 0.8840 - val_loss : 2.3208 - val_accuracy : 0.0936
round 2 - train_loss : 0.4780 - train_accuracy : 0.8343 - val_loss : 2.3035 - val_accuracy : 0.2034
round 3 - train_loss : 0.4342 - train_accuracy : 0.8767 - val_loss : 2.3096 - val_accuracy : 0.2068
round 4 - train_loss : 0.4736 - train_accuracy : 0.8670 - val_loss : 2.3189 - val_accuracy : 0.3255
round 5 - train_loss : 0.4287 - train_accuracy : 0.9023 - val_loss : 2.3270 - val_accuracy : 0.2267
round 6 - train_loss : 0.4493 - train_accuracy : 0.8537 - val_loss : 2.2895 - val_accuracy : 0.3200
round 7 - train_loss : 0.4289 - train_accuracy : 0.8857 - val_loss : 2.2831 - val_accuracy : 0.2452
round 8 - train_loss : 0.4404 - train_accuracy : 0.8850 - val_loss : 2.2737 - val_accuracy : 0.2514
round 9 - train_loss : 0.4510 - train_accuracy : 0.8660 - val_loss : 2.2353 - val_accuracy : 0.3152
round 10 - train_loss : 0.4347 - train_accuracy : 0.8877 - val_loss : 2.2077 - val_accuracy : 0.2248

In [1082]:
federater.fit(num_rounds=10, num_epochs=1, val_dl=test_dl)

round 1 - train_loss : 1.0563 - train_accuracy : 0.8353 - val_loss : 2.2984 - val_accuracy : 0.1088
round 2 - train_loss : 0.9254 - train_accuracy : 0.8500 - val_loss : 2.2924 - val_accuracy : 0.1308
round 3 - train_loss : 0.8273 - train_accuracy : 0.8697 - val_loss : 2.2908 - val_accuracy : 0.1035
round 4 - train_loss : 0.7899 - train_accuracy : 0.8347 - val_loss : 2.2940 - val_accuracy : 0.1139
round 5 - train_loss : 0.6584 - train_accuracy : 0.8687 - val_loss : 2.2940 - val_accuracy : 0.1279
round 6 - train_loss : 0.6520 - train_accuracy : 0.8427 - val_loss : 2.3053 - val_accuracy : 0.1472
round 7 - train_loss : 0.5746 - train_accuracy : 0.8800 - val_loss : 2.3093 - val_accuracy : 0.1566
round 8 - train_loss : 0.5822 - train_accuracy : 0.8597 - val_loss : 2.3179 - val_accuracy : 0.0980
round 9 - train_loss : 0.5821 - train_accuracy : 0.8333 - val_loss : 2.3079 - val_accuracy : 0.1019
round 10 - train_loss : 0.5713 - train_accuracy : 0.8333 - val_loss : 2.2820 - val_accuracy : 0.1566

In [1000]:
federater.fit(num_rounds=50, num_epochs=1, val_dl=test_dl)

round 1 - train_loss : 0.1825 - train_accuracy : 0.9713 - val_loss : 2.2786 - val_accuracy : 0.1963
round 2 - train_loss : 0.1276 - train_accuracy : 0.9763 - val_loss : 2.2947 - val_accuracy : 0.1642
round 3 - train_loss : 0.1219 - train_accuracy : 0.9753 - val_loss : 2.2815 - val_accuracy : 0.0967
round 4 - train_loss : 0.1064 - train_accuracy : 0.9667 - val_loss : 2.3481 - val_accuracy : 0.1166
round 5 - train_loss : 0.0966 - train_accuracy : 0.9700 - val_loss : 2.2797 - val_accuracy : 0.2585
round 6 - train_loss : 0.1048 - train_accuracy : 0.9667 - val_loss : 2.2043 - val_accuracy : 0.1887
round 7 - train_loss : 0.0805 - train_accuracy : 0.9743 - val_loss : 2.3510 - val_accuracy : 0.1010
round 8 - train_loss : 0.0762 - train_accuracy : 0.9733 - val_loss : 2.2810 - val_accuracy : 0.2122
round 9 - train_loss : 0.0737 - train_accuracy : 0.9820 - val_loss : 2.2808 - val_accuracy : 0.1925
round 10 - train_loss : 0.0878 - train_accuracy : 0.9667 - val_loss : 2.2817 - val_accuracy : 0.2073

In [902]:
federater.fit(num_rounds=20, num_epochs=1, val_dl=test_dl)

round 2 - train_loss : 0.0716 - train_accuracy : 0.0871 - val_loss : 2.3012 - val_accuracy : 0.1034
round 3 - train_loss : 0.0684 - train_accuracy : 0.0834 - val_loss : 2.3043 - val_accuracy : 0.1396
round 4 - train_loss : 0.0649 - train_accuracy : 0.0868 - val_loss : 2.2986 - val_accuracy : 0.1518
round 5 - train_loss : 0.0638 - train_accuracy : 0.0863 - val_loss : 2.2956 - val_accuracy : 0.2013
round 6 - train_loss : 0.0578 - train_accuracy : 0.0880 - val_loss : 2.2905 - val_accuracy : 0.2041
round 7 - train_loss : 0.0524 - train_accuracy : 0.0866 - val_loss : 2.2830 - val_accuracy : 0.2260
round 8 - train_loss : 0.0504 - train_accuracy : 0.0857 - val_loss : 2.2845 - val_accuracy : 0.3227
round 9 - train_loss : 0.0526 - train_accuracy : 0.0879 - val_loss : 2.2593 - val_accuracy : 0.3334
round 10 - train_loss : 0.0469 - train_accuracy : 0.0882 - val_loss : 2.2494 - val_accuracy : 0.3007
round 11 - train_loss : 0.0439 - train_accuracy : 0.0894 - val_loss : 2.2341 - val_accuracy : 0.317

In [903]:
federater.fit(num_rounds=20, num_epochs=1, val_dl=test_dl)

round 22 - train_loss : 0.0373 - train_accuracy : 0.0887 - val_loss : 2.1633 - val_accuracy : 0.4132
round 23 - train_loss : 0.0390 - train_accuracy : 0.0909 - val_loss : 2.1414 - val_accuracy : 0.3046
round 24 - train_loss : 0.0336 - train_accuracy : 0.0916 - val_loss : 2.1610 - val_accuracy : 0.2358
round 25 - train_loss : 0.0373 - train_accuracy : 0.0882 - val_loss : 2.1217 - val_accuracy : 0.3317
round 26 - train_loss : 0.0323 - train_accuracy : 0.0925 - val_loss : 2.1613 - val_accuracy : 0.2572
round 27 - train_loss : 0.0400 - train_accuracy : 0.0866 - val_loss : 2.1418 - val_accuracy : 0.3099
round 28 - train_loss : 0.0343 - train_accuracy : 0.0908 - val_loss : 2.1510 - val_accuracy : 0.2642
round 29 - train_loss : 0.0366 - train_accuracy : 0.0862 - val_loss : 2.1381 - val_accuracy : 0.2546
round 30 - train_loss : 0.0349 - train_accuracy : 0.0878 - val_loss : 2.1111 - val_accuracy : 0.2814
round 31 - train_loss : 0.0339 - train_accuracy : 0.0880 - val_loss : 2.1247 - val_accuracy

In [890]:
0 % 1

0

In [866]:
from collections import defaultdict

In [876]:
tmp = defaultdict(lambda: 0)

In [880]:
tmp['a'] += 3

In [881]:
tmp

defaultdict(<function __main__.<lambda>()>, {'a': 6})

In [882]:
for k, v in tmp.items():
    print(k,v)

a 6


In [838]:
federater.writer.

<torch.utils.tensorboard.writer.SummaryWriter at 0x7f0610288580>

In [749]:
type(federater.clients[0].dataloader)

__main__.ClientDataset

In [672]:
a,b = federater.validate(test_dl)

tensor([ 0.0078, -0.1566,  0.0084,  0.0535, -0.1603, -0.1524, -0.1051, -0.0275,
         0.0413, -0.1712])

In [682]:
criterion(torch.stack(a), b)

tensor(2.2917)

In [None]:
fe

In [657]:
for x,y in test_dl:
    pass

In [575]:
len(dl.dataset)

60000

In [573]:
dl.dataset[299]

(tensor([[[-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
           -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
           -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
           -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
          [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
           -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
           -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
           -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
          [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
           -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
           -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
           -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
          [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
           -0.4242, -0.4242, -0.424

In [561]:
dl = federater.clients[0].dataloader

In [562]:
len(dl)

1200

In [563]:
i = 0
for x,y in dl:
    print(i)
    i+= 1

0
1
2
3
4
5


IndexError: list index out of range

In [547]:

for x, y in federater.clients[0].dataloader:
    break

In [518]:
federater.clients[0].dataloader

<__main__.ClientDataset at 0x7f06671b4df0>

In [517]:
x

torch.Size([1, 28, 28])

In [502]:
model.train()

CNN(
  (conv1): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=1024, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=10, bias=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
)

In [413]:
x = torch.zeros((1, 1, 28, 28))

In [425]:
torch.flatten(p)

tensor([0.0948, 0.0948, 0.0948,  ..., 0.0884, 0.0884, 0.0884])

In [423]:
p.flatten()

tensor([0.0948, 0.0948, 0.0948,  ..., 0.0884, 0.0884, 0.0884])

In [414]:
model = CNN(1, 10)

In [415]:
with torch.no_grad():
    p = model(x)

In [67]:
from torchvision import datasets, transforms

dataset_to_transform = {
    'mnist': {
        'train': transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307), (0.3081))
        ]),
        'val': transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307), (0.3081))
        ]),
    },
    'cifar10': {
        'train': transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ]),
        'val': transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ]),
    },
    'cifar100': {
        'train': transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.507, 0.487, 0.441],
                                 std=[0.267, 0.256, 0.276])
        ]),
        'val': transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.507, 0.487, 0.441],
                                 std=[0.267, 0.256, 0.276])
        ]),
    }
}

train_ds = datasets.MNIST('data/', train=True, download=True, transform=dataset_to_transform['mnist']['train'])
test_ds = datasets.MNIST('data/', train=False, download=True, transform=dataset_to_transform['mnist']['val'])

In [72]:
local_train_ds = ClientDataset(train_ds, indices=[0,3])
