# **Federated Variational Autoencoder**

In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""

environ({'ELECTRON_RUN_AS_NODE': '1', 'GJS_DEBUG_TOPICS': 'JS ERROR;JS LOG', 'LESSOPEN': '| /usr/bin/lesspipe %s', 'CONDA_PROMPT_MODIFIER': '(base) ', 'LANGUAGE': 'en_GB:en', 'USER': 'ic-ai4health-fri', 'SSH_AGENT_PID': '3686', 'XDG_SESSION_TYPE': 'x11', 'SHLVL': '1', 'HOME': '/home/ic-ai4health-fri', 'CONDA_SHLVL': '1', 'OLDPWD': '/home/ic-ai4health-fri/Projects/DL_CW_1_hd119', 'DESKTOP_SESSION': 'ubuntu', 'GNOME_SHELL_SESSION_MODE': 'ubuntu', 'GTK_MODULES': 'gail:atk-bridge', 'MANAGERPID': '3479', 'DBUS_SESSION_BUS_ADDRESS': 'unix:path=/run/user/1000/bus', 'COLORTERM': 'truecolor', '_CE_M': '', 'IM_CONFIG_PHASE': '1', 'LOGNAME': 'ic-ai4health-fri', 'JOURNAL_STREAM': '8:190685', '_': '/usr/bin/code', 'XDG_SESSION_CLASS': 'user', 'USERNAME': 'ic-ai4health-fri', 'TERM': 'xterm-color', 'GNOME_DESKTOP_SESSION_ID': 'this-is-deprecated', '_CE_CONDA': '', 'WINDOWPATH': '2', 'PATH': '/home/ic-ai4health-fri/miniconda3/envs/pytorch/bin:/home/ic-ai4health-fri/miniconda3/bin:/home/ic-ai4health-fr

### **i. Imports and subfunctions**

In [4]:
# Imports
import copy
import torch
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from torch.utils.data import Subset, DataLoader, random_split
from torchvision import datasets, transforms
from mcvae.models import Mcvae, ThreeLayersVAE, VAE

# Subfunctions
def split_iid(dataset, n_centers):
    """
    Split PyTorch dataset randomly into n_centers.
    """
    n_obs_per_center = [len(dataset) // n_centers for _ in range(n_centers)]
    return random_split(dataset, n_obs_per_center)

def federated_averaging(models, n_obs_per_client):
    """
    Perform federated averaging.
    """
    # Error check inputs
    assert len(models) > 0
    assert len(n_obs_per_client) == len(models)

    # Compute proportions
    n_obs = sum(n_obs_per_client)
    proportions = [n_k / n_obs for n_k in n_obs_per_client]

    # Empty model parameter dictionary
    avg_params = models[0].state_dict()
    for key, val in avg_params.items():
        avg_params[key] = torch.zeros_like(val)

    # Compute average
    for model, proportion in zip(models, proportions):
        for key in avg_params.keys():
            avg_params[key] += proportion * model.state_dict()[key]

    # Copy one of the models and load trained params
    avg_model = copy.deepcopy(models[0])
    avg_model.load_state_dict(avg_params)

    return avg_model

def get_data(subset, shuffle=True):
    """
    Extracts data from a Subset torch dataset in the form of a tensor.
    """
    loader = DataLoader(subset, batch_size=len(subset), shuffle=shuffle)
    return next(iter(loader)) 

  warn(f"Failed to load image Python extension: {e}")


### **ii. General setup**

In [7]:
# Define federated parameters
N_CENTERS = 4
N_ROUNDS = 10

# Define learning parameters
N_EPOCHS = 15
BATCH_SIZE = 48
LR = 1e-3

# # Define GPU
# USE_GPU = True
# dtype = torch.float32 
# if USE_GPU and torch.cuda.is_available():
#     device = torch.device('cuda:0')
# else:
#     device = torch.device('cpu')

### **1. Load dataset**

In [8]:
# Define data transforms and download
transform = transforms.Compose([transforms.ToTensor(),
                                 transforms.Normalize((0,), (1,))])
dataset = datasets.MNIST('~/data/', train=True, download=True, transform=transform)

# Federate data
federated_dataset = split_iid(dataset, n_centers=N_CENTERS)
print('Number of centers:', len(federated_dataset))

Number of centers: 4


### **2. Create model and train**

In [9]:
# Feature dimensions and dummy data
N_FEATURES = 784
dummy_data = [torch.zeros(1, N_FEATURES)]

# Model architecture
lat_dim = 3
vae_class = ThreeLayersVAE

In [10]:
# VAE models
model = Mcvae(data=dummy_data, lat_dim=lat_dim, vaeclass=vae_class)
model.optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-3)
model.init_loss()
# model = model.to(device=device)
models = [copy.deepcopy(model) for _ in range(N_CENTERS)]
n_obs_per_client = [len(client_data) for client_data in federated_dataset]

In [11]:
# Initialise paramters
init_params = model.state_dict()

# Loop over training rounds and clients
for round_i in range(N_ROUNDS):
    for client_dataset, client_model in zip(federated_dataset, models):
        # Load client data in the form of a tensor
        X, y = get_data(client_dataset)
        # X = X.to(device=device)
        # y = y.to(device=device)
        client_model.data = [X.view(-1, N_FEATURES)]  # Set data attribute in client's model (list wraps the number of channels)

        # Load client's model parameters and train
        client_model.load_state_dict(init_params)
        client_model.optimize(epochs=N_EPOCHS, data=client_model.data)
        
    # Aggregate models using federated averaging
    trained_model = federated_averaging(models, n_obs_per_client)
    init_params = trained_model.state_dict()

====> Epoch:    0/15 (0%)	Loss: 542.5519	LL: -542.5479	KL: 0.0041	LL/KL: -133920.3389
====> Epoch:   10/15 (67%)	Loss: 74.6212	LL: -70.1510	KL: 4.4702	LL/KL: -15.6930
====> Epoch:    0/15 (0%)	Loss: 548.2899	LL: -548.2858	KL: 0.0040	LL/KL: -135640.7928
====> Epoch:   10/15 (67%)	Loss: 76.3847	LL: -71.9886	KL: 4.3960	LL/KL: -16.3758
====> Epoch:    0/15 (0%)	Loss: 541.0313	LL: -541.0273	KL: 0.0041	LL/KL: -133540.9558
====> Epoch:   10/15 (67%)	Loss: 73.7192	LL: -69.3528	KL: 4.3664	LL/KL: -15.8833
====> Epoch:    0/15 (0%)	Loss: 545.7291	LL: -545.7250	KL: 0.0041	LL/KL: -133587.2160
====> Epoch:   10/15 (67%)	Loss: 76.5488	LL: -72.1437	KL: 4.4051	LL/KL: -16.3773
====> Epoch:   20/30 (67%)	Loss: 45.8481	LL: -33.8719	KL: 11.9762	LL/KL: -2.8283
====> Epoch:   20/30 (67%)	Loss: 47.9345	LL: -35.9531	KL: 11.9814	LL/KL: -3.0007
====> Epoch:   20/30 (67%)	Loss: 45.4779	LL: -33.4726	KL: 12.0053	LL/KL: -2.7882
====> Epoch:   20/30 (67%)	Loss: 48.0424	LL: -36.1164	KL: 11.9260	LL/KL: -3.0284
====> Ep

### **3. Visualise results**