# Multi-Channel Variational Autoencoder

## Goal of this Tutorial
This tutorial serves as an example on how to train a model with a custom number of channels. More specifically, it uses a Multi-channel Variational Autoencoder to encode and decode medical data with 5 different modalities.

## VAE
The Variational Autoencoder is a latent variable model composed by one encoder and one decoder associated to a single channel.
The latent distribution and the decoding distribution are implemented as follows:

$$q(\mathbf{z|x}) = \mathcal{N}(\mathbf{z|\mu_x; \Sigma_x})$$

$$p(\mathbf{x|z}) = \mathcal{N}(\mathbf{x|\mu_z; \Sigma_z})$$

They are Gaussians with moments parametrized by Neural Networks (or a linear transformation layer in a simple case).

<img src="https://gitlab.inria.fr/epione/flhd/-/raw/master/heterogeneous_data/img/vae.svg" alt="img/vae.svg">

For the variance networks output, it is more common and convenient to use $\log{\sigma^2}$. This is due to the fact that neural networks can output any real number, while the variance is strictly positive (${\sigma^2}>0)$.

## MCVAE
The last part of this tutorial concerns the use of the *multi-channel variational autoencoder*, a more advanced method for the joint analysis and prediction of several modalities.

The MultiChannel VAE is built by stacking multiple VAEs and allowing the decoding distributions to be computed from every input channel.

The source code can be found in here: https://gitlab.inria.fr/epione_ML/mcvae

<img src="https://gitlab.inria.fr/epione/flhd/-/raw/master/heterogeneous_data/img/mcvae.svg" alt="img/mcvae.svg">


## Installing the Requirements

Below, we install the mcvae model, which is necessary for this tutorial. The seaborn library is optionally used for plotting and illustration purposes.

In [None]:
%pip install -q git+https://gitlab.inria.fr/epione_ML/mcvae.git
%pip install seaborn

In [None]:
import pandas as pd
import os
import torch

## Downloading the data

The data contains 5 different modalities which are: 

- Volume: Structural MRI Brain Volumes of the patient.
- Demographics: Age, sex and years of education of the patient.
- Cognition: Cognitive scores of the patient. It contains scores from Clinical Dementia Rating, Alzheimer's Disease Assessment Scale, Mini-Mental State Examination, Rey Auditory Verbal Learning Test and Functional Activities Questionnaire.
- Apoe (Genetic risk): The count of APOE ε4 alleles (0, 1 or 2) of the patient, where higher indicates more risk.
- Fluid: CerebroSpinal Fluid Biomarkers of the patient, where it contains the baseline values of Amyloid-beta 42 (ABETA), Tau (TAU) and Phospho-tau (PTAU) proteins the patient has.

In [None]:
adni = pd.read_csv('https://gitlab.inria.fr/ssilvari/flhd/-/raw/master/heterogeneous_data/pseudo_adni.csv?inline=false')

print(f'Loaded {len(adni)} samples.')

normalize = lambda x: (x - x.mean(0))/x.std(0)

volume_cols = ['WholeBrain.bl', 'Ventricles.bl', 'Hippocampus.bl', 'MidTemp.bl', 'Entorhinal.bl']
demog_cols = ['SEX', 'AGE', 'PTEDUCAT']
cognition_cols = ['CDRSB.bl', 'ADAS11.bl', 'MMSE.bl', 'RAVLT.immediate.bl', 'RAVLT.learning.bl', 'RAVLT.forgetting.bl', 'FAQ.bl']
apoe_cols = ['APOE4']
fluid_cols = ['ABETA.MEDIAN.bl', 'PTAU.MEDIAN.bl', 'TAU.MEDIAN.bl']

adni_cols = [volume_cols, demog_cols, cognition_cols, apoe_cols, fluid_cols]

for cols in adni_cols:
  adni[cols] = (adni[cols] - adni[cols].mean())/adni[cols].std()

# Creating a list with multimodal data
data_adni = [adni[cols].values for cols in adni_cols]

# Transform as a pytorch Tensor for compatibility
data_adni = [torch.Tensor(_) for _ in data_adni]

print(f'We have {len(data_adni)} channels in total as an input for the model')


Utility function to divide the data into the n number of data centers (hospitals) and leave a certain ratio for each center as holdout for later validation.

In [None]:
train_data_path = f'./data/train'
holdout_data_path = f'./data/holdout'

def prepare_data_nth_center(n: int, offset: int, n_samples_train: int, n_samples_holdout):
  os.makedirs(train_data_path, exist_ok=True)
  os.makedirs(holdout_data_path, exist_ok=True)
  train_data_df = adni.iloc[offset:offset+n_samples_train,:]
  train_data_df.to_csv(train_data_path + '/dataset.csv')
  test_data_df = adni.iloc[offset+n_samples_train:offset+n_samples_train+n_samples_holdout,:]
  test_data_df.to_csv(holdout_data_path + 'dataset.csv')

# Number of centers to divide the data
n_centers = 2
n_samples_total = len(adni)
n_samples_per_center = n_samples_total // n_centers

# Holdout ratio
holdout_ratio = 0.1

n_holdout_samples_per_center = int(n_samples_per_center*holdout_ratio)
n_train_samples_per_center = n_samples_per_center - n_holdout_samples_per_center
last_offset = 0
for i in range(n_centers-1):
  prepare_data_nth_center(n=i,
                          offset=last_offset,
                          n_samples_train=n_train_samples_per_center,
                          n_samples_holdout=n_holdout_samples_per_center)
  last_offset += n_train_samples_per_center+n_holdout_samples_per_center
  print(f'Center {i}: {n_train_samples_per_center} train samples')

prepare_data_nth_center(n=n_centers-1,
                        offset=last_offset,
                        n_samples_train=n_samples_total - last_offset - n_holdout_samples_per_center,
                        n_samples_holdout=n_holdout_samples_per_center)
print(f'Center {i}: {n_train_samples_per_center} train samples')

Add a dataset to the first node (hospital) with the following command
```shell
fedbiomed node -p CUSTOM/PATH/TO/NODE dataset add
```

When prompted for data type, select 1) csv
```shell
Please select the data type that you're configuring:
        1) csv
        2) default
        3) mednist
        4) images
        5) medical-folder
        6) flamby
select: 1
```

For name and description you may input whatever you want.

For `tags` it is **VERY important** to input `adni-train`
The Experiment will later search for the available data using the tag(s) provided.

For the path of the file, input

```shell
/PATH/TO/NODE/data/train/dataset.csv
```

Likewise, return all the same steps for the N number of nodes that you want to add.

```shell
fedbiomed node -p CUSTOM/PATH/TO/NODE_N dataset add
```
```shell
/PATH/TO/NODE_N/data/train/dataset.csv
```

Finally. start the nodes using the command:

```shell
fedbiomed node -p CUSTOM/PATH/TO/NODE start
```

## Creating the Training Plan

To train our custom mcvae, we should initialize the model in the init_model function, and use a Dataset class wrapper around our data in the training_data function. To do both, we define an auxiliary function get_channels to customize and specify the channels our data has. 

Next, we define our second helper function to create the 5 channels we have as Torch Tensors. We initialize the model and it's parameters. We create a dummy data, again with 5 channels to initialize the dimensionality of our model.

For the training_data function, we inherit from the Dataset class and create our own Dataset class. This is especially done to override the __getitem__ function which is fundamental for our training plan. It defines what data item would be retrieved at each training step to train one sample during the training loop. These samples are then batched according to the batch_size parameter.

Finally, the training step computes the loss by using:

- q: The approximate posterior value $q(z|x)$ the encoder calculates from the (generally Gaussian) distribution of the data over the latent variable z.
- x: The input data in tensor format.
- p: The likelihood value $p(x|z)$ the decoder calculates by reconstructing from z.
- KL: KL divergence
- LL: Log likelihood

The loss is calculated as the difference of kl to ll. Their formulas can be seen below:

$$\mathcal{L}_{\text{KL}} = \frac{1}{2} \sum_{i=1}^c \left( \mu_i^2 + \sigma_i^2 - \ln \sigma_i^2 - 1 \right)$$

$$\mathcal{L}_{\text{LL}} = -\frac{1}{2\sigma^2} \| x - \hat{x} \|^2 + \text{const}$$

### Important Warning

The mcvae module tries to detect and utilize the gpu in the system by default. If that is not preferred, the DEVICE variable can be set to cpu as seen below.

In [None]:
from mcvae.gpu import DEVICE
print(DEVICE)

# DEVICE = torch.device('cpu')	

In [None]:
from fedbiomed.common.training_plans import TorchTrainingPlan

class MCVAETrainingPlan(TorchTrainingPlan):

    @staticmethod
    def get_channels():
      channel_1 = ['WholeBrain.bl', 'Ventricles.bl', 'Hippocampus.bl', 'MidTemp.bl', 'Entorhinal.bl']
      channel_2 = ['SEX', 'AGE', 'PTEDUCAT']
      channel_3 = ['CDRSB.bl', 'ADAS11.bl', 'MMSE.bl', 'RAVLT.immediate.bl', 'RAVLT.learning.bl', 'RAVLT.forgetting.bl', 'FAQ.bl']
      channel_4 = ['APOE4']
      channel_5 = ['ABETA.MEDIAN.bl', 'PTAU.MEDIAN.bl', 'TAU.MEDIAN.bl']
      return channel_1, channel_2, channel_3, channel_4, channel_5

    @staticmethod
    def get_data_as_multichannel_tensor_dataset(df):
      """Takes a dataframe, splits it into multiple channels and parse each channel as a tensor"""
      channel_1, channel_2, channel_3, channel_4, channel_5 = MCVAETrainingPlan.get_channels()

      df = (df - df.mean())/df.std()
      def as_tensor(cols):
          tensor = torch.tensor(df[cols].values).float()
          return tensor

      return [as_tensor(channel_1), as_tensor(channel_2), as_tensor(channel_3), as_tensor(channel_4), as_tensor(channel_5)]

    def init_model(self, model_args):
      channels = MCVAETrainingPlan.get_channels()
      dummy_data = [torch.zeros((1, len(ch))).to('cpu') for ch in channels]
      # print(dummy_data[0].device)
      vaeclass = VAE
      return Mcvae(data=dummy_data,
                   lat_dim=model_args.get('lat_dim', 1),
                   vaeclass=vaeclass,
                   sparse=model_args.get('sparse', False))

    def init_optimizer(self, optimizer_args):
        optimizer = Adam(self.model().parameters(), lr=optimizer_args.get('lr', 0.001))
        return optimizer

    def init_dependencies(self):
        deps = [
            'from mcvae.models import Mcvae, ThreeLayersVAE, VAE',
            'from torch.optim import Adam',
            'from torchvision import datasets, transforms',
            'from torch.utils.data import Dataset',
            'from fedbiomed.common.logger import logger',
            'import numpy as np',
            'import pandas as pd']
        return deps

    def training_data(self):
        df = pd.read_csv(self.dataset_path)
        class myDataset(Dataset):
          def __init__(self, data):
            self._data = data
          def __len__(self):
            return len(self._data)
          def __getitem__(self, idx):
            df_ = self._data.iloc[idx,:]
            return MCVAETrainingPlan.get_data_as_multichannel_tensor_dataset(df_), []
        return DataManager(myDataset(df))

    def training_step(self, data, target):
      output = self.model().forward(data)
      q = output['q']
      x = output['x']
      p = output['p']
      
      kl = self.model().compute_kl(q)
      kl *= self.model().beta
      ll = self.model().compute_ll(p=p, x=x)
      
      return kl - ll


We initialize the model arguments for MCVAE and the training arguments

In [None]:
model_args = {
    'lat_dim': 1,
    'sparse': False
}

training_args = {
    'loader_args': { 'batch_size': 64, },
    'optimizer_args': {'lr': 1e-4},
    'num_updates': 50,
    'log_interval': 25,
    'test_ratio': 0.0,
    'test_on_global_updates': False,
    'test_on_local_updates': False,
    'random_seed': 424242,
}

We create an experiment. We select Federated Averaging as aggregator method and use the tags that we initially used on our dataset.

In [None]:
from fedbiomed.researcher.federated_workflows import Experiment
from fedbiomed.researcher.aggregators.fedavg import FedAverage

tags =  ['adni-train']
num_rounds = 50

exp = Experiment(tags=tags,
                 model_args=model_args,
                 training_plan_class=MCVAETrainingPlan,
                 training_args=training_args,
                 round_limit=num_rounds,
                 aggregator=FedAverage(),
                 tensorboard=True
                )


In [None]:
%load_ext tensorboard

from fedbiomed.researcher.config import config
tensorboard_dir = './tensorboard_results'

%tensorboard --logdir "$tensorboard_dir"


In [None]:
exp.run()

We import some additional libraries for plotting

In [None]:
import numpy as np
import pandas as pd

import seaborn as sns
import matplotlib.pyplot as plt

sns.set()

In [None]:
decoding_weights_dict = {k: w.detach().numpy() for k, w in aggregated_model.state_dict().items() if 'W_out.weight' in k}

We plot the Z values for Volume biomarkers

In [None]:
lat_dim_names = [f'$Z_{{{i}}}$' for i in range(model_args['lat_dim'])]
col_names = lat_dim_names + ["biomarker"]
weights = pd.DataFrame()

channels = MCVAETrainingPlan.get_channels()

for channel_i, weights_i in enumerate(decoding_weights_dict.values()):

    channel_df = pd.DataFrame(np.concatenate((weights_i, np.array(channels[channel_i]).reshape(-1, 1)), axis=1),
        columns=lat_dim_names + ["biomarker"])
    channel_df['channel'] = channel_i + 1


    weights = pd.concat((weights, channel_df))

weights["$Z_{0}$"] = weights["$Z_{0}$"].astype('float32')
weights.head()

In [None]:
weights_melt = weights.melt(id_vars=['biomarker', 'channel'], var_name='latent_var')
weights_melt.sample()

In [None]:
sns.catplot(data=weights_melt, x='biomarker', y='value', hue='latent_var', kind='bar', col='channel', col_wrap=1, aspect=2.5, sharex=False, palette='Blues_r')
plt.show()

We present two alternative methods for prediction.

The first one is to predict a channel/modality from the whole data.

The second is to predict a channel from a specific channel.

In [None]:
# Predict volumes (channel 0) from cognition (channel 2)

# Solution 1

with torch.no_grad():
  # Encode everything
  q = aggregated_model.encode(data_adni)
  # Take the mean of every encoded distribution q
  z = [qi.loc for qi in q]
  # Decode all
  p = aggregated_model.decode(z)
  # Extract what you need: p(x|z) or p[x][z] or p[decoder output channel][encoder input chanenl]
  decoding_volume_from_cognition = p[0][2].loc.data.numpy()

In [None]:
plt.figure(figsize=(12, 28))

for i in range(len(volume_cols)):
    plt.subplot(5,1,i+1)
    plt.scatter(decoding_volume_from_cognition[:,i], data_adni[0][:,i])
    plt.title('reconstruction ' + volume_cols[i])
    plt.xlabel('predicted')
    plt.ylabel('target')
plt.show()

Predict the Volume from Cognition.

In [None]:
# Solution 2

# Encode the cognition (ch 2)
q2 = aggregated_model.vae[2].encode(data_adni[2])
# Take the mean of q (location in pytorch jargon)
z2 = q2.loc
# Decode through the brain volumes decoder (ch 0)
p0 = aggregated_model.vae[0].decode(z2)
# Take the mean
decoding_volume_from_cognition = p0.loc.data.numpy()

In [None]:
plt.figure(figsize=(12, 28))

for i in range(len(volume_cols)):
    plt.subplot(5,1,i+1)
    plt.scatter(decoding_volume_from_cognition[:,i], data_adni[0][:,i])
    plt.title('reconstruction ' + volume_cols[i])
    plt.xlabel('predicted')
    plt.ylabel('target')
plt.show()

Save the model.

In [None]:
torch.save({
    'model_state_dict': model.state_dict(),
    'training_args': {
        'num_rounds': training_args['num_rounds'],
        'num_updates': training_args['num_updates'],
        'loader_args': training_args['loader_args'],
        'optimizer_args': training_args['optimizer_args'], 
        'model_args': training_args['model_args'],
    }
}, "model.pth")