In [2]:
%load_ext autoreload
%autoreload 2

import warnings

from data_loader import load_mouse_development

from famo.core import CORE
from famo.plotting import (
    plot_all_weights,
    plot_factor_correlation,
    plot_factor_covariate,
    plot_top_weights,
    plot_training_curve,
    plot_variance_explained,
    plot_weights,
)

warnings.simplefilter(action="ignore", category=FutureWarning)


In [3]:
# Load mouse development data
mdata = load_mouse_development()

mdata["RNA"].obsm['umap'] = mdata["RNA"].obs[['UMAP1', 'UMAP2']].values
mdata["motif_met"].obsm['umap'] = mdata["motif_met"].obs[['UMAP1', 'UMAP2']].values
mdata["motif_acc"].obsm['umap'] = mdata["motif_acc"].obs[['UMAP1', 'UMAP2']].values

In [4]:
# Run model with MuData object
model = CORE(device="cuda")
model.fit(
    n_factors=7,
    data=mdata,
    likelihoods={
        "RNA": "Normal",
        "motif_met": "Normal",
        "motif_acc": "Normal",
    },
    factor_prior="GP",
    weight_prior="Horseshoe",
    covariates_key="umap",
    lr=0.005,
    early_stopper_patience=500,
    print_every=10,
    center_groups=True,
    scale_views=False,
    scale_groups=True,
    max_epochs=10000,
    save=False,
    save_path=None,
    init_factors="random",
    init_scale=0.1,
)

Setting up device...
- No device id given. Using default device: 0
- Running all computations on `cuda:0`
Fitting model...
- Checking compatibility of provided likelihoods with data.
  - RNA: Normal
  - motif_met: Normal
  - motif_acc: Normal
- Centering group_1/RNA...
- Centering group_1/motif_met...
- Centering group_1/motif_acc...


Initializing factors using `random` method...
{'RNA': 0.794912559618442, 'motif_met': 1.102543720190779, 'motif_acc': 1.102543720190779}
Decaying learning rate over 10000 iterations.
Setting training seed to `2408131328`.
Cleaning parameter store.


KeyError: 'z_group_1_0'

In [None]:
plot_training_curve(model)

In [None]:
plot_variance_explained(model)

In [None]:
plot_all_weights(model)

In [None]:
plot_weights(model, "RNA")

In [None]:
plot_weights(model, "motif_met")

In [None]:
plot_weights(model, "motif_acc")

In [None]:
plot_factor_correlation(model)

In [None]:
plot_top_weights(model, view="RNA")

In [None]:
for factor in range(model.n_factors):
    plot_factor_covariate(model, factor + 1)

In [None]:
model.gps["group_1"].covar_module.outputscale

tensor([[0.0961],
        [0.0848],
        [0.1090],
        [0.0971],
        [0.1097],
        [0.1149],
        [0.0948]], device='cuda:0', grad_fn=<AddBackward0>)