# Tutorial LipiMap
## Biologically-Informed VAE

In [None]:
# Imports
import warnings
warnings.simplefilter(action='ignore')
%reload_ext autoreload
%autoreload 2

# import scanpy as sc
import torch
import pandas as pd
import numpy as np
import os

# sc.set_figure_params(frameon=False)
# sc.set_figure_params(dpi=200)
# sc.set_figure_params(figsize=(4, 4))
torch.set_printoptions(precision=3, sci_mode=False, edgeitems=7)

### Folder Configuration

In [None]:
project_name = "lipiMap"
project_path = ".."

src_path = os.path.join(project_path, project_name)

LBA_path = os.path.join(src_path, "data/LBA_Data")
LION_path = os.path.join(src_path, "data/LION_Data")
Linex_path = os.path.join(src_path, "data/Linex_Data")
masks_path = os.path.join(src_path, "data/masks")

results_path = os.path.join(project_path, "results")

directories = [src_path, results_path, LBA_path, LION_path, Linex_path,  masks_path]
for directory in directories:
    if not os.path.exists(directory):
        os.makedirs(directory)

In [None]:
dataset_name = '20241213_LBA_brain2'

# Depending on the dataset, the initial and final formats can be 'log', 'exp' or 'norm_exp'
initial_format = 'exp'
final_format = 'norm_exp'

## Load MALDI-MSI data
MALDI-MSI data are loaded and processed by the `LBADataHandler` class ([LBADataHandler docs](../lipiMap/data/data_handlers/LBA_README.md)).

In [None]:
# Input Data Loading and Processing
from lipiMap.data_handlers.LBA import LBADataHandler

# Load and process data
df = pd.read_hdf(os.path.join(LBA_path, f"{dataset_name}.h5ad"), key='table')

data_handler = LBADataHandler(df=df, 
                              masks_path=masks_path, 
                              initial_format=initial_format, 
                              final_format=final_format)

# Save the .lmt file containing the lipid family information
# This file will be used for creating the binary mask in the latent space referring to the lipid families
# Please be aware that the .lmt file is saved with the standard name 'LBA_lipid_families.lmt' in the masks_path directory
LBA_mask_name = 'lipid_families'
data_handler.create_family_lmt()

# Store the data into anndata object
adata = data_handler.to_anndata()

del df

adata

## Process the Lipid Ontology database
The Lipid Ontology database is processed and integrated with the input data with the `LIONDataHandler` class ([LIONDataHandler docs](../lipiMap/dataset/data_handlers/LION_README.md))

In [None]:
# LION Data
from lipiMap.data_handlers.LION import LIONDataHandler

LION_data_handler = LIONDataHandler(LION_path, masks_path)
data = LION_data_handler.tabular_LION_database()
cropped_data = LION_data_handler.filter_and_aggregate(data=data, refinement_level=4)

# Be sure that the LION mask name follows the actual categories included while filtering and aggregating the LION database.
# For example, if all macro_categories and micro_categories are included, the mask name should be 'full'.
# Otherwise, the LION_mask_name could be cellularcomponent, function, lipidclass, physicochemical, etc.
LION_mask_name = 'full'

# Please be aware that the .lmt file is saved with the prefix 'LION_<LION_mask_name>.lmt' in the masks_path directory
# If you want to create multiple masks, you can change the LION_mask_name and run the following line again
LION_data_handler.create_ontology_lmt(adata.var_names, cropped_data, save_name=f'{LION_mask_name}')

## Linear Decoder Mask Generation

The decoder mask is a binary matrix (containing only 0s and 1s) that maps the relationship between lipids and lipid programs.
The mask is generated from `.lmt` (Lipid Matrix Transposed) files, which store lipid program memberships.

For more details about .lmt files, refer to `LBADataHandler.create_family_lmt()`, `LIONDataHandler.create_ontology_lmt()`, and `LinexDataHandler.create_reactions_lmt()`

At the end of LBA data, LION data and Linex data pre-processing. the `mask_path` will contain multiple `.lmt` files describing independently the membership of lipids to different lipid programs (lipid families, categories in LION database and possible clusters based on biochemical reactions).

The final latent space can be contructed *concatenating* these files based on the user necessities.

In [None]:
from lipiMap.utils.annotations import add_annotations

# The user can select the masks to be included in the latent space
# Please make sure that the names should match the .lmt files present in the masks_path directory
masks_included = [
    'LBA_lipid_families', # LBA_mask_name
    'LION_full',          # LION_mask_name
    ]

# The add_annotations function will create a binary mask based on the selected masks_included
# Both the single masks for each entry of masks_included and the combined mask will be saved in the adata.varm attribute
# Moreover, the function adds the names of the lipid programs to the adata.uns attribute
add_annotations(adata=adata, 
                masks_path=masks_path,
                masks=masks_included,
                min_lipids=2,)

adata

Storing both the single binary matrices and the merged one is obviously redundant, but it is done for future scenarios where "dynamical mask-wise training" could possibly be implemented.

Once all the binary masks are stored in the adata object (`adata.varm`), the user selects the actual `mask_framework` that is about to be used to design the latent space.

In [None]:
mask_framework = 'all' # all blocks included in the latent space

# Set the minimum number of lipid programs per lipid 
# and remove lipids that do not meet this criterion
min_lipid_programs = 2
adata._inplace_subset_var(adata.varm[mask_framework].sum(1)>=min_lipid_programs)

print(f"Number of lipids after filtering: {adata.varm[mask_framework].shape[0]}")
print(f"Number of lipid programs after filtering: {adata.varm[mask_framework].shape[1]}")
print(f"Shape of the binary mask: {adata.varm[mask_framework].shape}")

adata

### Additional processing of the latent space

In [None]:
from lipiMap.utils.annotations import manual_programs_removal

# Based on the biological context, the user can select manually the lipid programs to be removed
programs_to_remove = [
            'C16:0',
            'C18:0',
            'C18:1',
            'C18:2',
            'C20:4',
            'fatty acid with 16 carbons',
            'fatty acid with 18 carbons',
            'fatty acid with 20 carbons',
            'fatty acid with 22 carbons',
            'saturated fatty acid',
            'monounsaturated fatty_acid',
            'fatty acid with 2 double bonds',
            'fatty acid with 3 double bonds',
            'fatty acid with 4 double bonds',
            'fatty acid with 5 double bonds',
            'fatty acid with 6 double bonds',
            'membrane component',
            ]

manual_programs_removal(adata, mask_framework, programs_to_remove)

In [None]:
# sections_to_remove = [6, 10, 13, 20, 25]
# manual_sections_removal(adata, sections_to_remove)

In [None]:
from lipiMap.utils.annotations import remove_latent_collinearity

# Depending on the amount of annotated lipids in input data, different lipid programs may end up having the very same lipids. 
# For this reason, to avoid latent space collinearity, it is possible to merge all the LPs that are indistinguishable.
# Two LPs are considered to be collinear when their Hamming distance is zero.

remove_latent_collinearity(adata, mask_framework)

print(f"Number of lipid programs after collinearity removal: {adata.varm[mask_framework].shape[1]}")
print(f"Shape of the binary mask: {adata.varm[mask_framework].shape}")

adata

### Representation Assessment Analysis

We assess lipid representation by comparing the annotated lipids in our dataset against the comprehensive LION database. This evaluation is quantified through an **overall importance score**, which is composed of three distinct weighted components:

- `representation_score` (75% weight): Measures the ratio of annotated lipids in a Lipid Program (LP) to the total lipids in that LP within the LION database. This score directly assesses how well each LP is represented in our dataset. Scores stored in `adata.uns['representation_score']`.

- `exclusivity_score` (15% weight): Indicates the significance of an LP in the reconstruction process of its lipids. LPs associated with fewer lipids receive higher scores, emphasizing their unique contributions to lipid reconstruction. Scores stored in `adata.uns['exclusivity_score']`.

- `density_score` (10% weight): Penalizes deviations from the ideal LP density, preset at 30%. This metric ensures that the contributions of LPs to lipid reconstruction are neither too sparse nor too concentrated, promoting an optimal spread across the dataset. Scores stored in `adata.uns['density_score']`.

In [None]:
from lipiMap.utils.annotations import compute_representation_score

compute_representation_score(adata, LION_data_handler.lipid_programs, data, mask_framework)
# adata.uns['representation_score']

In [None]:
from lipiMap.utils.annotations import compute_exclusivity_score

compute_exclusivity_score(adata, mask_framework)
# adata.uns['exclusivity_score']

In [None]:
from lipiMap.utils.annotations import compute_density_score

compute_density_score(adata, mask_framework)
# adata.uns['density_score']

In [None]:
from lipiMap.utils.annotations import rank_LPs

rank_LPs(adata)
# adata.uns['final_score']

Although reasonable, these scores are entirely empirical. To come up with a more statistical analysis, we exploit bootstrap replicas and distributions of `representation_scores` to assess whether to keep or filter out LPs in the dataset.

In [None]:
from lipiMap.utils.annotations import representation_analysis

# Compute the distribution of the representation scores after n_replicas bootstrap sampling
distribution = representation_analysis(adata, 
                                       LION_data_handler,
                                       data, 
                                       mask_framework,
                                       n_replicas=10000, # ~ 33 mins
                                       seed=0)

In [None]:
from lipiMap.utils.annotations import representation_filtering

# Select the lipid programs to keep based on the distribution of the representation scores
# Selection criterion: the representation score should be higher than the median (ub=0.5)
# of the corresponding distribution
to_keep = representation_filtering(adata, 
                                   mask_framework, 
                                   distribution,
                                   lb=40,
                                   ub=50,
                                   plot=True)

In [None]:
# Manually remove the programs that did not pass the representation filtering
to_remove = [program for program in adata.uns[mask_framework] if program not in to_keep]
manual_programs_removal(adata, mask_framework, to_remove)

### Enrichment Scores
Once the latent space is complete, the user can add LPs enrichment scores:
```
lp_enrichment = mean((X_lp - mean(X_lp)) / std(X_lp))
```

where `X_lp` represents the expression values of lipids associated with the lipid program, `mean(X_lp)` is the mean expression value of those lipids, and `std(X_lp)` is the standard deviation of those expression values.

In [None]:
from lipiMap.utils.annotations import add_lipid_program_enrichment

add_lipid_program_enrichment(adata, mask_framework)

In [None]:
print('--- LATENT SPACE DESIGN DONE!! Ready to train ---')
print(adata)
print(adata.varm[mask_framework].shape)

## Training LipiMap

In [None]:
# Define the architecture parameters
percentage = 1                                  # Full brain
condition_key = None                            # In case of conditional VAE
mask_key = mask_framework
soft_mask = False                               # Allow soft membership of lipids to lipid programs
hidden_layer_sizes = [256, 256, 256, 128]       # Define the architecture of the encoder (the decoeder is linear)
use_bn = False                                  # Batch normalization
use_ln = True                                   # Layer normalization

In [None]:
# Define the training parameters
k = 1                                           # Weighted MSE loss parameter
q = 0.1                                         # Weighted MSE loss parameter
n_epochs = 500                                  # Number of epochs
batch_size = 256                                # Batch size
dr_rate = 0.01                                  # Dropout rate
initial_lr = 0.01                               # Initial learning rate
weight_decay = 0.                               # Weight decay     
alpha_kl = 0.7                                  # KL divergence weight
alpha_gl = 0.1                                  # Group Lasso weight ***
alpha_l1 = None                                 # L1 regularization weight (in case of soft membership)

initialization = 'pca'                          # Initialization of the decoder weights
clipping_decoder_weights = True                 # Clip the decoder weights between 0 and 1 to ensure interpretability

# *** Set the `alpha_gl` hyperparameter. This regulates the strength of group lasso regularization of LPs. 
# *** Higher value means that a larger number of LPs will be deactivated 
# *** during training depending on their contribution to the reconstruction loss.
# *** Group Lasso regularization is indeed meaningful when the number of LPs is large (not the case of lipiMap).

In [None]:
# Set up the model to prepare for training

from lipiMap.models import LIPIMAP

intr_cvae = LIPIMAP(
    adata=adata,                                        # Anndata object (input data)
    lipids_format=final_format,                         # Final format of the lipidomics data (for saving and loading purposes)
    mask_key=mask_key,                                  # Mask key (name of the binary mask)
    soft_mask=soft_mask,                                # Soft mask
    condition_key=condition_key,                        # Condition key (in case of conditional VAE)
    hidden_layer_sizes=hidden_layer_sizes,              # Hidden layer sizes
    use_bn=use_bn,                                      # Batch normalization
    use_ln=use_ln,                                      # Layer normalization

    recon_loss='mse',                                   # Reconstruction loss
    dr_rate=dr_rate,                                    # Dropout rate

    clipping_decoder_weights=clipping_decoder_weights,
)

In [None]:
# Set early stopping parameters to prevent overfitting
# Check the early stopping class for more details in ./lipiMap/utils/monitor.py

early_stopping_kwargs = {
    "early_stopping_metric": "val_unweighted_loss", # val_unweighted_loss
    "threshold": 0,
    "patience": 100,
    "reduce_lr": True,
    "lr_patience": 13,
    "lr_factor": 0.1,
}

In [None]:
# Train the model

intr_cvae.train(
    q = q,
    k = k,
    n_epochs=n_epochs,
    alpha_kl=alpha_kl,
    alpha_gl=alpha_gl,
    alpha_epoch_anneal=None,
    # alpha_l1=0.,

    lr=initial_lr,
    
    # train_frac=0.9, 
    batch_size=batch_size,
    weight_decay=weight_decay, 
    use_early_stopping=True,
    early_stopping_kwargs=early_stopping_kwargs,

    monitor_only_val=False,
    seed=2024,
    save_logs=True,
    # print_stats=True,

    initialization=initialization,
)

In [None]:
# Save the model
save_path = LIPIMAP.create_model_directory(
                    dataset=dataset_name,
                    lipids_format=final_format,
                    mask_key=mask_key,
                    soft_mask=soft_mask,
                    condition_key=condition_key,
                    hidden_layer_sizes=hidden_layer_sizes,
                    use_bn=use_bn,
                    
                    percentage=percentage, 
                    n_epochs=n_epochs,
                    batch_size=batch_size,
                    dr_rate=dr_rate,
                    initial_lr=initial_lr,
                    weight_decay=weight_decay,
                    alpha_kl=alpha_kl,
                    alpha_gl=alpha_gl,
                    alpha_l1=alpha_l1,
                    initialization=initialization,
                    clipping_decoder_weights=clipping_decoder_weights,)

# add manually additional labels to the folder if needed
# save_path = save_path + '_additional_labels'
save_path = os.path.join(results_path, save_path)

if not os.path.exists(save_path):
    os.makedirs(save_path)

intr_cvae.save(save_path, overwrite=True)
print(f"Model saved in {save_path}")

In [None]:
# Load the model
from lipiMap.models import LIPIMAP

load_path = LIPIMAP.create_model_directory(
                    dataset=dataset_name,
                    lipids_format=final_format,
                    mask_key=mask_framework,
                    soft_mask=soft_mask,
                    condition_key=condition_key,
                    hidden_layer_sizes=hidden_layer_sizes,
                    use_bn=use_bn,
                    
                    percentage=percentage, 
                    n_epochs=n_epochs,
                    batch_size=batch_size,
                    dr_rate=dr_rate,
                    initial_lr=initial_lr,
                    weight_decay=weight_decay,
                    alpha_kl=alpha_kl,
                    alpha_gl=alpha_gl,
                    alpha_l1=alpha_l1,
                    initialization=initialization,
                    clipping_decoder_weights=clipping_decoder_weights,)

path = os.path.join(results_path, load_path)
# path = '../results/norm_exp_lipids/mask_all/hard/no/encoder_l_256_256_256_128/
#           layer_norm/LBA_Brain2_20240611_100percent_epochs500_bs256_dr0.01_
#               lr0.01_wd0.0_kl0.7_gl0.1_l1None_initpca_clipTrue'
model = LIPIMAP.load(dir_path=path, adata=adata)

## Evaluating LipiMap

In [None]:
from lipiMap.plotting.lipimap_eval import LIPIMAP_EVAL
lipimap_eval = LIPIMAP_EVAL(model=model, 
                         adata=model.adata) # adata = model.adata to be sure? check

### Input Data Visualization

In [None]:
lipimap_eval.plot_spatial(name='LPC 18:1',
                       space='input',
                       cmap='inferno',
                       sym_colorscale=False,
                       savepath=None)
# lipimap_eval.to_pdf(space='input',
#                  savepath='input.pdf',
#                  cmap='inferno',
#                  sym_colorscale=False)

### Output Data Visualization (Reconstruction and Residuals)

In [None]:
lipimap_eval.plot_spatial(name='LPC 18:1',
                       space='output',
                       cmap='inferno',
                       sym_colorscale=False,
                       savepath=None)
# lipimap_eval.to_pdf(space='output',
#                  savepath='output.pdf',
#                  cmap='inferno',
#                  sym_colorscale=False)

In [None]:
lipimap_eval.plot_spatial(name='LPC 18:1',
                       space='residual',
                       cmap='RdBu',
                       sym_colorscale=True,
                       savepath=None)
# lipimap_eval.to_pdf(space='residual',
#                  savepath='residual.pdf',
#                  cmap='RdBu',
#                  sym_colorscale=True)

### Latent Space Visualization

In [None]:
lipimap_eval.plot_spatial(name='mitochondrion',
                       space='latent',
                       cmap='PuOr',
                       sym_colorscale=False,
                       savepath=None)
# lipimap_eval.to_pdf(space='latent',
#                  savepath='latent.pdf',
#                  cmap='PuOr',
#                  sym_colorscale=False)

### Decoder Weights

In [None]:
lipimap_eval.plot_decoder_weights()

In [None]:
lipimap_eval.plot_LPs_vs_feature(feature='level_6')