# Basic Example: Train a MOFA model

This notebook is a good introduction how to use PRISMO by training a MOFA model [1] with additional sparsity priors. We use a chronic lymphocytic leukaemia (CLL) data set, which combines ex vivo drug response measurements with somatic mutation status, transcriptome profiling and DNA methylation assays [2].

We reproduce results from the paper [1] and explain in detail how you can extend the model and generate meaningful interpretations of the results.

[1] Multi-Omics Factor Analysis-a framework for unsupervised integration of multi-omics data sets by Argelaguet, R. et al. (2018)  
[2] Drug-perturbation-based stratification of blood cancer by Dietrich et al. (2018)

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import scanpy as sc
from data_loader import load_CLL

from famo.core import CORE
from famo.plotting import (
    plot_factor,
    plot_factor_correlation,
    plot_top_weights,
    plot_training_curve,
    plot_variance_explained,
    plot_weights,
)

## Load and Preprocess Data

In [3]:
# Load CLL data
mdata = load_CLL()

# Normalize and log transform mRNA counts
sc.pp.normalize_total(mdata["mrna"], target_sum=1e4)
sc.pp.log1p(mdata["mrna"])


## Example 1: Train a Factor Analysis Model

In the first experiment, we will use all four views and train a factor analysis model. Each model training consists of two necessary steps:

1) Create a new model instance: `model = CORE()`
2) Train the model: `model.fit(data=mdata, n_factors=15)`

To simplify your life, we only require you to specify a number of factors and provide the training data. In Example 2, we will show you how to deviate from the default parameters and customize your model. For simplicity, we will model all data using a Gaussian noise model.
In contrast to the experiments shown in the original MOFA paper, we place a Horseshoe sparsity prior [3] on the weights due to its superior computational properties over the Spike-and-Slab prior.

[3] Carvalho, Carlos M., Nicholas G. Polson, and James G. Scott. "Handling sparsity via the horseshoe." Artificial intelligence and statistics. PMLR, 2009.

In [4]:
model = CORE(device="cuda")
model.fit(
    data=mdata,                # Our training data
    n_factors=20,              # number of factors
    likelihoods="Normal",      # Likelihood for all views
    factor_prior="Normal",
    weight_prior="Horseshoe",  # Sparsity prior for the weights
    lr=0.005,
    early_stopper_patience=500,
)

2024-10-15 15:59:17 | famo.core | Setting up device...
2024-10-15 15:59:18 | famo.core | - No device id given. Using default device: 0
2024-10-15 15:59:18 | famo.core | - Running all computations on `cuda:0`
2024-10-15 15:59:18 | famo.core | - Using provided likelihood for all views.
2024-10-15 15:59:18 | famo.core |   - drugs: Normal
2024-10-15 15:59:18 | famo.core |   - methylation: Normal
2024-10-15 15:59:18 | famo.core |   - mrna: Normal
2024-10-15 15:59:18 | famo.core |   - mutations: Normal


2024-10-15 15:59:18 | famo.core | Initializing factors using `random` method...
2024-10-15 15:59:19 | famo.core | Decaying learning rate over 10000 iterations.
2024-10-15 15:59:19 | famo.core | Setting training seed to `2410151559`.
2024-10-15 15:59:19 | famo.core | Cleaning parameter store.
2024-10-15 15:59:20 | root | Guessed max_plate_nesting = 3
2024-10-15 15:59:20 | famo.core | Epoch:       0 | Time:       1.43s | Loss:   13753.84
2024-10-15 15:59:25 | famo.core | Epoch:     100 | Time:       6.10s | Loss:   10439.43
2024-10-15 15:59:30 | famo.core | Epoch:     200 | Time:      10.62s | Loss:    8908.00
2024-10-15 15:59:34 | famo.core | Epoch:     300 | Time:      15.30s | Loss:    8116.89
2024-10-15 15:59:39 | famo.core | Epoch:     400 | Time:      19.81s | Loss:    7597.86
2024-10-15 15:59:43 | famo.core | Epoch:     500 | Time:      24.42s | Loss:    7271.91
2024-10-15 15:59:48 | famo.core | Epoch:     600 | Time:      28.94s | Loss:    7078.71
2024-10-15 15:59:53 | famo.core 

Looking at the logs, we can see the following:
- for computational efficiency, we try to train the model on a GPU (if available)
- by default, we get an overview of the missing sample/features in our training data (Missing Data Overview)
- by default, we use a PCA-based initialization strategy for the factors
- the model is automatically stored in the current working directory.

## Plot Results

After training, we can inspect all learned parameters, visualize the factors/weights, etc.


### ELBO

If the learningcurve flattens out, we know that our model converged.

In [5]:
plot_training_curve(model)

### Plot a heatmap of the weights

To perform downstream analysis, we need to extract the factor loadings and weights from the model. We can do this by calling the `get_factors()` and `get_weights()` method. This returns a dictionary with the weights for each modality and factors (for each group) respectively.


In [6]:
factors = model.get_factors()
weights = model.get_weights()

However, we also provide functionality to plot the weights directly.

### Factor Correlation Matrix
The correlation matrix shows us that our factors are mostly uncorrelated, which is an indicator that our training worked well.

In [7]:
plot_factor_correlation(model)

### Variance Explained Plot
The Variance Explained plot allows us to identify which factors in which views explain variation in our input data and might be of greater interest for a more detailed analysis.

Here we can see that factor 1 explains variation in all views, where factor 2 and 3 are mostly driving variation in the drug view, etc.

In [8]:
plot_variance_explained(model)

### Factor Values

In [49]:
plot_factor(model, factor=1)

### Top weights per factor

In [None]:
plot_weights(model, "mutations", factor=1)

In [None]:
plot_weights(model, "mutations", factor=5)

In [None]:
plot_top_weights(model, view="mutations", factor=[1, 2, 3, 4, 5], orientation="horizontal")

In [None]:
import famo

In [None]:
from matplotlib import pyplot as plt

In [None]:
model._cache["factors"]["group_1"].obs_names

In [None]:
mdata.obs_names

In [None]:
model._cache["factors"]["group_1"].obs = mdata.obs.loc[model._cache["factors"]["group_1"].obs_names, :].copy()
model._cache["factors"]["group_1"].obs["IGHV"] = model._cache["factors"]["group_1"].obs["IGHV"].astype(str).astype("category")
model._cache["factors"]["group_1"].obs["trisomy12"] = model._cache["factors"]["group_1"].obs["trisomy12"].astype(str).astype("category")
model._cache["factors"]["group_1"].obs.keys()

In [None]:
model.factor_names

In [None]:
famo.plotting.violinplot(model, factor_idx=1, groupby="IGHV")

In [None]:
famo.plotting.violinplot(model, factor_idx=2, groupby="trisomy12")

In [None]:
famo.plotting.violinplot(model, factor_idx=7, groupby="trisomy12")

In [None]:
from famo.utils_downstream import test

In [None]:
from famo import feature_sets as fs

In [None]:
gene_sets = fs.from_gmt("c2.cp.reactome.v2023.2.Hs.symbols.gmt")
gene_sets

In [None]:
gene_sets = gene_sets.filter(
    model.feature_names["mrna"], min_fraction=0.1, min_count=15
)
gene_sets

In [None]:
gene_set_mask = gene_sets.to_mask(model.feature_names["mrna"])
gene_set_mask

In [None]:
test_results_pos = test(model, "mrna", feature_sets=gene_set_mask, sign="pos")
test_results_neg = test(model, "mrna", feature_sets=gene_set_mask, sign="neg")

In [None]:
test_df = test_results_pos["p_adj"].copy()

In [None]:
for k in range(model.n_factors):
    print(model.factor_names[k])
    print(test_df.columns[test_df.iloc[k, :] < 0.05])

In [None]:
import gseapy as gp
from gseapy import barplot

In [None]:
relevant_factors = model.factor_names.tolist()[:5]

In [None]:
model._cache["weights"]["mrna"].to_df().head()

In [None]:
top = 200
# TODO: direction important if vanilla muvi
direction = "pos"
for factor_idx in relevant_factors:
    if factor_idx not in model.factor_names:
        continue
    gl = model._cache["weights"]["mrna"].to_df().loc[factor_idx, :].sort_values(ascending=False).index.tolist()
    if top is not None:
        gl = gl[:top]
    enr = gp.enrichr(
        gene_list=gl,
        gene_sets=[
            # "MSigDB_Hallmark_2020",
            # "KEGG_2021_Human",
            "Reactome_2022",
            # "GO_Biological_Process_2023",
        ],
        organism="human",
        outdir=None,
    )

    try:
        # categorical scatterplot
        ax = barplot(
            enr.results,
            column="Adjusted P-value",
            group="Gene_set",  # set group, so you could do a multi-sample/library comparsion
            size=10,
            top_term=5,
            figsize=(3, 5),
            color=["#1b9e77", "#d95f02", "#7570b3", "#e7298a"],  # set colors for group
            # color = {'MSigDB_Hallmark_2020':'blue', 'KEGG_2021_Human': 'salmon', 'Reactome_2022':'red'}
        )
        plt.title(f"{factor_idx} (top {top} loadings)")
        plt.show()
    except ValueError as e:
        print(e)