# Basic Example: Train a MOFA model

This notebook is a good introduction how to use PRISMO by training a MOFA model [1] with additional sparsity priors. We use a chronic lymphocytic leukaemia (CLL) data set, which combines ex vivo drug response measurements with somatic mutation status, transcriptome profiling and DNA methylation assays [2].

We reproduce results from the paper [1] and explain in detail how you can extend the model and generate meaningful interpretations of the results.

[1] Multi-Omics Factor Analysis-a framework for unsupervised integration of multi-omics data sets by Argelaguet, R. et al. (2018)  
[2] Drug-perturbation-based stratification of blood cancer by Dietrich et al. (2018)

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import scanpy as sc
from data_loader import load_CLL

from famo.core import CORE
from famo.plotting import (
    plot_factor,
    plot_factor_correlation,
    plot_top_weights,
    plot_training_curve,
    plot_variance_explained,
    plot_weights,
)

## Load and Preprocess Data

In [3]:
# Load CLL data
mdata = load_CLL()

# Normalize and log transform mRNA counts
sc.pp.normalize_total(mdata["mrna"], target_sum=1e4)
sc.pp.log1p(mdata["mrna"])


## Example 1: Train a Factor Analysis Model

In the first experiment, we will use all four views and train a factor analysis model. Each model training consists of two necessary steps:

1) Create a new model instance: `model = CORE()`
2) Train the model: `model.fit(data=mdata, n_factors=15)`

To simplify your life, we only require you to specify a number of factors and provide the training data. In Example 2, we will show you how to deviate from the default parameters and customize your model. For simplicity, we will model all data using a Gaussian noise model.
In contrast to the experiments shown in the original MOFA paper, we place a Horseshoe sparsity prior [3] on the weights due to its superior computational properties over the Spike-and-Slab prior.

[3] Carvalho, Carlos M., Nicholas G. Polson, and James G. Scott. "Handling sparsity via the horseshoe." Artificial intelligence and statistics. PMLR, 2009.

In [4]:
model = CORE(device="cuda")
model.fit(
    data=mdata,                # Our training data
    n_factors=20,              # number of factors
    likelihoods="Normal",      # Likelihood for all views
    factor_prior="Normal",
    weight_prior="Horseshoe",  # Sparsity prior for the weights
    lr=0.005,
    early_stopper_patience=1000,
)

2024-10-15 23:27:30 | famo.core | Setting up device...
2024-10-15 23:27:31 | famo.core | - No device id given. Using default device: 0
2024-10-15 23:27:31 | famo.core | - Running all computations on `cuda:0`
2024-10-15 23:27:31 | famo.core | - Using provided likelihood for all views.
2024-10-15 23:27:31 | famo.core |   - drugs: Normal
2024-10-15 23:27:31 | famo.core |   - methylation: Normal
2024-10-15 23:27:31 | famo.core |   - mrna: Normal
2024-10-15 23:27:31 | famo.core |   - mutations: Normal


2024-10-15 23:27:31 | famo.core | Initializing factors using `random` method...
2024-10-15 23:27:32 | famo.core | Decaying learning rate over 10000 iterations.
2024-10-15 23:27:32 | famo.core | Setting training seed to `2410152327`.
2024-10-15 23:27:32 | famo.core | Cleaning parameter store.
2024-10-15 23:27:33 | root | Guessed max_plate_nesting = 3
2024-10-15 23:27:33 | famo.core | Epoch:       0 | Time:       1.42s | Loss:   13775.04
2024-10-15 23:27:38 | famo.core | Epoch:     100 | Time:       6.02s | Loss:   10433.65
2024-10-15 23:27:42 | famo.core | Epoch:     200 | Time:      10.45s | Loss:    8912.48
2024-10-15 23:27:47 | famo.core | Epoch:     300 | Time:      14.98s | Loss:    8109.54
2024-10-15 23:27:51 | famo.core | Epoch:     400 | Time:      19.41s | Loss:    7595.90
2024-10-15 23:27:56 | famo.core | Epoch:     500 | Time:      23.95s | Loss:    7267.36
2024-10-15 23:28:00 | famo.core | Epoch:     600 | Time:      28.37s | Loss:    7067.02
2024-10-15 23:28:05 | famo.core 

Looking at the logs, we can see the following:
- for computational efficiency, we try to train the model on a GPU (if available)
- by default, we get an overview of the missing sample/features in our training data (Missing Data Overview)
- by default, we use a PCA-based initialization strategy for the factors
- the model is automatically stored in the current working directory.

## Plot Results

After training, we can inspect all learned parameters, visualize the factors/weights, etc.


### ELBO

If the learningcurve flattens out, we know that our model converged.

In [5]:
plot_training_curve(model)

### Plot a heatmap of the weights

To perform downstream analysis, we need to extract the factor loadings and weights from the model. We can do this by calling the `get_factors()` and `get_weights()` method. This returns a dictionary with the weights for each modality and factors (for each group) respectively.


In [6]:
factors = model.get_factors()
weights = model.get_weights()

However, we also provide functionality to plot the weights directly.

### Factor Correlation Matrix
The correlation matrix shows us that our factors are mostly uncorrelated, which is an indicator that our training worked well.

In [7]:
plot_factor_correlation(model)

### Variance Explained Plot
The Variance Explained plot allows us to identify which factors in which views explain variation in our input data and might be of greater interest for a more detailed analysis.

Here we can see that factor 1 explains variation in all views, where factor 2 and 3 are mostly driving variation in the drug view, etc.

In [8]:
plot_variance_explained(model)

### Factor Values

If we are interested in the value distribution of a specific factor across all samples `plot_factor` gives as an intuitive visualisation for this. We can see that the factor values are rarely close to zero.

In [9]:
plot_factor(model, factor=1)

### Top weights per factor

The `plot_weights` functionality allows us to inspect the weights for a specific view and simultaneously labels the the top-n values (measured in magnitude).
Here, we can see that the IGHV mutation has a large weight indicating that patients with a high value in factor one, express this mutation more likely.

In [10]:
plot_weights(model, "mutations", factor=1)

`plot_top_weights` provides a similar visualization, yet it allows to combine multiple factors into a single plot. 

In [11]:
plot_top_weights(model, views=["mutations"], factors=[1, 2, 3])

### Gene Set Enrichment Analysis

Finally, we can use the weights to run a gene set enrichment analysis - helping us identify potential underlying gene programs.

In [12]:
import altair as alt
import numpy as np
import pandas as pd

from famo import feature_sets as fs
from famo.utils_downstream import test

In [13]:
gene_sets = fs.from_gmt("c2.cp.reactome.v2023.2.Hs.symbols.gmt")
gene_sets = gene_sets.filter(
    model.feature_names["mrna"], min_fraction=0.1, min_count=15
)
gene_set_mask = gene_sets.to_mask(model.feature_names["mrna"])

test_results_pos = test(model, "mrna", feature_sets=gene_set_mask, sign="pos")
test_results_neg = test(model, "mrna", feature_sets=gene_set_mask, sign="neg")

test_df = test_results_pos["p_adj"].copy()
threshold_line = -np.log10(0.05)

for k in range(model.n_factors):
    programs = sorted(test_df.columns[test_df.iloc[k, :] < 0.05])
    if len(programs) > 0:
        filtered_series = test_df.iloc[k]
        neg_log10_values = -np.log10(filtered_series)

        # Filter the series, compute negative log10, sort, and select top 10
        filtered_series = test_df.iloc[k]
        neg_log10_values = -np.log10(filtered_series)
        top_10_values = neg_log10_values.sort_values(ascending=False).head(10)
        top_10_values.index = top_10_values.index.str.replace("REACTOME_", "", regex=False)

        data = pd.DataFrame({
            'index': top_10_values.index,
            'values': top_10_values.values
        })

        bar_chart = alt.Chart(data).mark_bar().encode(
            y=alt.Y('index:O', title='Index', sort='-x', axis=alt.Axis(labelLimit=600, title=None)),
            x=alt.X('values:Q', title='Negative Log10 Values')
        )

        line = alt.Chart(pd.DataFrame({'x': [threshold_line]})).mark_rule(color='red', strokeWidth=2).encode(
            x='x',
        )

        (bar_chart + line).properties(
            title='Top 10 -Log10 Values',
            width=300,
            height=300
        ).configure_view(
            strokeWidth=0
        ).configure_axis(
            labelFontSize=12,
            titleFontSize=14
        ).display()

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 307/307 [00:01<00:00, 229.60it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 307/307 [00:01<00:00, 231.28it/s]
