# Federated Variational Autoencoders
We are going to study an example of federated latent variable modeling using federated learning and Variational autoencoders. In this example we will illustrate an iid scenario.

In [1]:
import copy
from tqdm.auto import tqdm

import torch
from torch.utils.data import Subset, DataLoader, random_split
from torchvision import datasets, transforms

In [2]:
N_CENTERS = 4
N_ROUNDS = 10   # Number of iterations between all the centers training and the aggregation process.

N_EPOCHS = 15   # Number of epochs before aggregating
BATCH_SIZE = 48
LR = 1e-3       # Learning rate

We define a set of functions to distribute our dataset across multiple centers (`split_iid`) and for computing the federated averaging (`federated_averaging`).

In [3]:
def split_iid(dataset, n_centers):
    """ Split PyTorch dataset randomly into n_centers """
    n_obs_per_center = [len(dataset) // n_centers for _ in range(n_centers)]
    return random_split(dataset, n_obs_per_center)

In [4]:
def federated_averaging(models, n_obs_per_client):
    assert len(models) > 0, 'An empty list of models was passed.'
    assert len(n_obs_per_client) == len(models), 'List with number of observations must have ' \
                                                 'the same number of elements that list of models.'

    # Compute proportions
    n_obs = sum(n_obs_per_client)
    proportions = [n_k / n_obs for n_k in n_obs_per_client]

    # Empty model parameter dictionary
    avg_params = models[0].state_dict()
    for key, val in avg_params.items():
        avg_params[key] = torch.zeros_like(val)

    # Compute average
    for model, proportion in zip(models, proportions):
        for key in avg_params.keys():
            avg_params[key] += proportion * model.state_dict()[key]

    # Copy one of the models and load trained params
    avg_model = copy.deepcopy(models[0])
    avg_model.load_state_dict(avg_params)

    return avg_model

## Federating dataset

In [5]:
transform = transforms.Compose([transforms.ToTensor(),
                                 transforms.Normalize((0,), (1,))])
dataset = datasets.MNIST('~/data/', train=True, download=True, transform=transform)

Now, `federated_dataset` is a list of subsets of the main dataset.

In [6]:
federated_dataset = split_iid(dataset, n_centers=N_CENTERS)
print('Number of centers:', len(federated_dataset))

Number of centers: 4


## Defining and distributing a model: Variational Autoencoder
In this excercise we will use the Multi-channel Variational Autoencoder proposed by _Antelmi et al (ICML 2019)_.

In [7]:
!pip install -q git+https://gitlab.inria.fr/epione_ML/mcvae.git

In [8]:
from mcvae.models import Mcvae, ThreeLayersVAE, VAE

First, it is necessary to define a model.

In [9]:
N_FEATURES = 784  # Number of pixels in MNIST

In [10]:
dummy_data = [torch.zeros(1, N_FEATURES)]  # Dummy data to initialize the input layer size
lat_dim = 3  # Size of the latent space for this autoencoder
vae_class = ThreeLayersVAE  # Architecture of the autoencoder (VAE: Single layer)

In [11]:
model = Mcvae(data=dummy_data, lat_dim=lat_dim, vaeclass=vae_class)
model.optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-3)
model.init_loss()

Now replicate a copy of the models across different centers.

In [12]:
models = [copy.deepcopy(model) for _ in range(N_CENTERS)]
n_obs_per_client = [len(client_data) for client_data in federated_dataset]

Train in a federated fashion

In [13]:
def get_data(subset, shuffle=True):
    """ Extracts data from a Subset torch dataset in the form of a tensor"""
    loader = DataLoader(subset, batch_size=len(subset), shuffle=shuffle)
    return iter(loader).next()

In [14]:
init_params = model.state_dict()
for round_i in range(N_ROUNDS):
    for client_dataset, client_model in zip(federated_dataset, models):
        # Load client data in the form of a tensor
        X, y = get_data(client_dataset)
        client_model.data = [X.view(-1, N_FEATURES)]  # Set data attribute in client's model (list wraps the number of channels)

        # Load client's model parameters and train
        client_model.load_state_dict(init_params)
        client_model.optimize(epochs=N_EPOCHS, data=client_model.data)
        
    # Aggregate models using federated averaging
    trained_model = federated_averaging(models, n_obs_per_client)
    init_params = trained_model.state_dict()

====> Epoch:    0/15 (0%)	Loss: 549.0060	LL: -549.0015	KL: 0.0044	LL/KL: -123752.7716
====> Epoch:   10/15 (67%)	Loss: 82.7334	LL: -78.5644	KL: 4.1690	LL/KL: -18.8448
====> Epoch:    0/15 (0%)	Loss: 551.0641	LL: -551.0596	KL: 0.0045	LL/KL: -123185.0004
====> Epoch:   10/15 (67%)	Loss: 83.1423	LL: -78.9854	KL: 4.1570	LL/KL: -19.0006
====> Epoch:    0/15 (0%)	Loss: 553.0543	LL: -553.0498	KL: 0.0045	LL/KL: -123165.6171
====> Epoch:   10/15 (67%)	Loss: 83.6635	LL: -79.2979	KL: 4.3656	LL/KL: -18.1644
====> Epoch:    0/15 (0%)	Loss: 551.7684	LL: -551.7640	KL: 0.0045	LL/KL: -123816.6904
====> Epoch:   10/15 (67%)	Loss: 82.1891	LL: -77.8627	KL: 4.3264	LL/KL: -17.9972
====> Epoch:   20/30 (67%)	Loss: 52.9532	LL: -41.7303	KL: 11.2229	LL/KL: -3.7183
====> Epoch:   20/30 (67%)	Loss: 53.4088	LL: -42.1911	KL: 11.2177	LL/KL: -3.7611
====> Epoch:   20/30 (67%)	Loss: 54.1244	LL: -43.0145	KL: 11.1099	LL/KL: -3.8717
====> Epoch:   20/30 (67%)	Loss: 52.5903	LL: -41.4866	KL: 11.1037	LL/KL: -3.7363
====> Ep

## Results visualization
Using the final parameters we can evaluate the performance of the model by visualizing the testing set onto the latent space.

In [15]:
import numpy as np
import pandas as pd

import seaborn as sns
import matplotlib.pyplot as plt

sns.set()

In [16]:
dataset_test = datasets.MNIST('~/data/', train=False, download=True)
X_test, y_test = [dataset_test.data.view(-1, N_FEATURES).float()], dataset_test.targets

In [17]:
Z_test = np.hstack([z.loc.detach().numpy() for z in trained_model.encode(X_test)])

### EXERCISE

In [18]:
### CREATE A PANDA DATAFRAME WITH FEATURES: LATENT COORDINATES + LABEL

In [19]:
###PLOT THE LATENT SPACE WITH A PAIRPLOT USING THE CREATED DATAFRAME