# Basic Example: Train a MOFA model

This notebook shows a simple way to train a MOFA [1] model with additional sparsity priors. We use a chronic lymphocytic leukaemia (CLL) data set, which combined ex vivo drug response measurements with somatic mutation status, transcriptome profiling and DNA methylation assays [2].

[1] Multi-Omics Factor Analysis-a framework for unsupervised integration of multi-omics data sets by Argelaguet, R. et al. (2018)  
[2] Drug-perturbation-based stratification of blood cancer by Dietrich et al. (2018)

In [1]:
import scanpy as sc
from data_loader import load_CLL

from famo.core import CORE
from famo.plotting import (
    plot_all_weights,
    plot_factor,
    plot_factor_correlation,
    plot_top_weights,
    plot_training_curve,
    plot_variance_explained,
    plot_weights,
)

  from .autonotebook import tqdm as notebook_tqdm


## Load and Preprocess Data

In [2]:
# Load CLL data
mdata = load_CLL()

# Normalize and log transform mRNA counts
sc.pp.normalize_total(mdata["mrna"], target_sum=1e4)
sc.pp.log1p(mdata["mrna"])




## Example 1: Run Factor Analysis

In a first experiment, we will just use all four views and run the factorization. Each experiment consists of two necessary steps:

1) Create a new model instance: `model = CORE()`
3) Train the model: `model.fit(n_factors=15, data=mdata)`

To make your life simple, we only require you to specify the number of factors. In Example 2, we will show you how to deviate from the default parameters and customize your model. For simplicity we will model all data with a Gaussian noise model.
In contrast to the original MOFA definition, we place a Horseshoe sparsity prior [3] on the weights.

[3] Carvalho, Carlos M., Nicholas G. Polson, and James G. Scott. "Handling sparsity via the horseshoe." Artificial intelligence and statistics. PMLR, 2009.

In [3]:
# Run model with MuData object
model = CORE(device="cuda")
model.fit(
    data=mdata,
    n_factors=15,
    likelihoods={
        "mrna": "Normal",
        "drugs": "Normal",
        "mutations": "Normal",
        "methylation": "Normal",
    },
    factor_prior="Normal",
    weight_prior="Horseshoe",
    max_epochs=10000,
    lr=0.025,
)

Setting up device...
- `cuda` not available...
- Running all computations on `cpu`
Fitting model...
- Checking compatibility of provided likelihoods with data.
  - mrna: Normal
  - drugs: Normal
  - mutations: Normal
  - methylation: Normal
- Centering group_1/drugs...
- Centering group_1/methylation...
- Centering group_1/mrna...
- Centering group_1/mutations...


Initializing factors using `random` method...
Epoch:       0 | Time:       0.84s | Loss: 3495839.20
Epoch:     500 | Time:      54.74s | Loss: 2320847.90
Epoch:    1000 | Time:     104.39s | Loss: 2293011.60
Epoch:    1500 | Time:     158.45s | Loss: 2285203.98
Epoch:    2000 | Time:     210.90s | Loss: 2276261.77
Epoch:    2500 | Time:     266.92s | Loss: 2273116.68
Training finished after 2537 steps.
Saving results...
- Model saved to model_20240625_140642/model.pkl
- Parameters saved to model_20240625_140642/params.save


## Plot Results

After training, we can inspect all learned parameters, visualize the loss curve, etc.


### ELBO

In [4]:
# Plot training curve
plot_training_curve(model)

### Plot a heatmap of the weights

To perform downstream analysis, we need to extract the factor loadings and weights from the model. We can do this by calling the `get_factors()` and `get_weights()` method. This returns a dictionary with the weights for each modality and factors (for each group) respectively.


In [5]:
# Get learned model parameters
factors = model.get_factors()
weights = model.get_weights()

However, we also provide functionality to plot the weights directly.

In [12]:
# Plot learned weights
plot_all_weights(model, clip=(-2, 2))

### Factor Correlation Matrix

In [7]:
plot_factor_correlation(model)

### Variance Explained Plot

In [8]:
plot_variance_explained(model)

### Factor Values

In [9]:
plot_factor(model, factor=1)

### Top weights per factor

In [18]:
plot_top_weights(model, view="mutations", factor=[1, 4], orientation="horizontal")

In [14]:
plot_weights(model, view="mutations", factor=2, top_n_features=10)