In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pyro
import pyro.distributions as dist
import torch
from pyro.nn import PyroModule
from cellij.core.models import MOFA


In [None]:
# Create data
import os
import anndata
import muon as mu
import numpy as np
import pandas as pd

from importlib import resources
from collections import UserDict


obs = pd.read_csv(
    filepath_or_buffer=os.fspath("../data/cll_metadata.csv"),
    sep=",",
    index_col="Sample",
    encoding="utf-8",
)


modalities = {}

for ome in ["mrna"]:  # "drugs", "methylation", "mrna", "mutations"

    modalities[ome] = anndata.AnnData(
        pd.read_csv(
            filepath_or_buffer=os.fspath(f"../data/cll_{ome}.csv"),
            sep=",",
            index_col=0,
            encoding="utf-8",
        ).T
    )

mdata = mu.MuData(modalities)
mdata.obs = mdata.obs.join(obs)

In [None]:
mdata["mrna"]

In [None]:
model = MOFA(n_factors=50)

In [None]:
model.add_data(data=mdata)

In [None]:
model.fit(likelihood="Normal",  epochs=1000, verbose_epochs=50)

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
fig, ax = plt.subplots(1, 1, figsize=(30, 10))
sns.heatmap(model._guide.locs.w.detach().numpy().squeeze(), center=0, cmap='RdBu_r')

In [None]:
w = model._guide.locs.w.detach().numpy().squeeze()
z = model._guide.locs.z.detach().numpy().squeeze()

print(w.shape, z.shape)

xhat = np.matmul(z, w)

fig, ax = plt.subplots(1, 2, figsize=(30, 10))
sns.heatmap(xhat, center=0, cmap='RdBu_r', ax=ax[0])
sns.heatmap(mdata["mrna"].X, center=0, cmap='RdBu_r', ax=ax[1])

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(30, 10))
sns.heatmap(mdata.X.nu, center=0, cmap='RdBu_r')