# **Federated Variational Autoencoder**

### **i. Imports and subfunctions**

In [1]:
# Imports
import copy
import torch
from tqdm.auto import tqdm
from torch.utils.data import Subset, DataLoader, random_split
from torchvision import datasets, transforms
from mcvae.models import Mcvae, ThreeLayersVAE, VAE

# Subfunctions
def split_iid(dataset, n_centers):
    """
    Split PyTorch dataset randomly into n_centers
    """
    n_obs_per_center = [len(dataset) // n_centers for _ in range(n_centers)]
    return random_split(dataset, n_obs_per_center)

def federated_averaging(models, n_obs_per_client):
    """
    Perform federated averaging.
    """
    # Error check inputs
    assert len(models) > 0
    assert len(n_obs_per_client) == len(models)

    # Compute proportions
    n_obs = sum(n_obs_per_client)
    proportions = [n_k / n_obs for n_k in n_obs_per_client]

    # Empty model parameter dictionary
    avg_params = models[0].state_dict()
    for key, val in avg_params.items():
        avg_params[key] = torch.zeros_like(val)

    # Compute average
    for model, proportion in zip(models, proportions):
        for key in avg_params.keys():
            avg_params[key] += proportion * model.state_dict()[key]

    # Copy one of the models and load trained params
    avg_model = copy.deepcopy(models[0])
    avg_model.load_state_dict(avg_params)

    return avg_model

def get_data(subset, shuffle=True):
    """
    Extracts data from a Subset torch dataset in the form of a tensor.
    """
    loader = DataLoader(subset, batch_size=len(subset), shuffle=shuffle)
    return next(iter(loader)) 

  warn(f"Failed to load image Python extension: {e}")


                                           uuid                     name  \
index                                                                      
0      GPU-2a01c044-2e5a-72b5-e75b-f0c9cfaed71d  NVIDIA GeForce RTX 3080   

      temperature.gpu utilization.gpu memory.used memory.total  
index                                                           
0                  58              34         786        10240  


### **ii. General setup**

In [3]:
# Define federated parameters
N_CENTERS = 4
N_ROUNDS = 10

# Define learning parameters
N_EPOCHS = 15
BATCH_SIZE = 48
LR = 1e-3

# Define device to use for torch
if torch.cuda.is_available():
    use_cuda = True
else:
    use_cuda = False
use_cuda = False

### **1. Load dataset**

In [4]:
# Define data transforms and download
transform = transforms.Compose([transforms.ToTensor(),
                                 transforms.Normalize((0,), (1,))])
dataset = datasets.MNIST('~/data/', train=True, download=True, transform=transform)

# Federate data
federated_dataset = split_iid(dataset, n_centers=N_CENTERS)
print('Number of centers:', len(federated_dataset))

Number of centers: 4


### **2. Create models**

In [5]:
# Feature dimensions and dummy data
N_FEATURES = 784
dummy_data = [torch.zeros(1, N_FEATURES)]

# Model architecture
lat_dim = 3
vae_class = ThreeLayersVAE

In [6]:
# VAE models
model = Mcvae(data=dummy_data, lat_dim=lat_dim, vaeclass=vae_class)
model.optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-3)
model.init_loss()
if use_cuda:
    model.cuda()
models = [copy.deepcopy(model) for _ in range(N_CENTERS)]
n_obs_per_client = [len(client_data) for client_data in federated_dataset]

In [31]:
def get_data(subset, shuffle=True):
    """ Extracts data from a Subset torch dataset in the form of a tensor"""
    loader = DataLoader(subset, batch_size=len(subset), shuffle=shuffle)
    return next(iter(loader))

In [17]:
# Initialise paramters
init_params = model.state_dict()

# Loop over training rounds and clients
for round_i in range(N_ROUNDS):
    for client_dataset, client_model in zip(federated_dataset, models):
        # Load client data in the form of a tensor
        X, y = get_data(client_dataset)
        if use_cuda:
            X.cuda()
            y.cuda()
        client_model.data = X.view(-1, N_FEATURES)  # Set data attribute in client's model (list wraps the number of channels)
        print(client_model)

        # Load client's model parameters and train
        client_model.load_state_dict(init_params)
        client_model.optimize(epochs=N_EPOCHS, data=client_model.data)
        
    # Aggregate models using federated averaging
    trained_model = federated_averaging(models, n_obs_per_client)
    init_params = trained_model.state_dict()

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])


RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument mat2 in method wrapper_mm)