# Comparison of MEFISTO and PRISMO on Visium Data from MEFISTO paper

In [13]:
%load_ext autoreload
%autoreload 2

import os
import h5py
import pandas as pd
from plotnine import *

import numpy as np
from data_loader import load_mefisto_visium
from scipy.stats import pearsonr

from prismo import PRISMO, DataOptions, ModelOptions, TrainingOptions, SmoothOptions
from prismo.tl import match

os.makedirs('plots', exist_ok=True)
os.makedirs('results', exist_ok=True)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## MEFISTO

In [2]:
# from mofapy2.run.entry_point import entry_point
# from data_loader import load_mefisto_visium

# adata = load_mefisto_visium()

# ent = entry_point()
# ent.set_data_options(use_float32=True)
# ent.set_data_from_anndata(adata)
# ent.set_model_options(factors=4)
# ent.set_train_options()
# ent.set_train_options(seed=54321)
# n_inducing = 1000

# ent.set_covariates([adata.obsm["spatial"]], covariates_names=["imagerow", "imagecol"])
# ent.set_smooth_options(sparseGP=True, frac_inducing=n_inducing / adata.n_obs, start_opt=10, opt_freq=10)
# ent.build()
# ent.run()

# expectations = ent.model.getExpectations()
# ent.save("models/mefisto.hdf5")

In [3]:
mefisto_model = h5py.File("models/mefisto.hdf5", "r")

z_mefisto = mefisto_model["expectations"]["Z"]["group1"][:].T
w_mefisto = mefisto_model["expectations"]["W"]["rna"][:].T

## PRISMO

In [9]:
# adata = load_mefisto_visium()

# prismo_model = PRISMO(
#     adata,
#     DataOptions(covariates_obsm_key="spatial", plot_data_overview=False),
#     ModelOptions(n_factors=4, weight_prior="SnS", factor_prior="GP", likelihoods="Normal"),
#     TrainingOptions(device="cuda:0", early_stopper_patience=500, lr=5e-2, save_path="models/prismo", seed=5432),
#     SmoothOptions(n_inducing=1000, kernel="RBF")
# )

INFO:prismo._core.prismo:Epoch:       0 | Time:       1.21s | Loss:    2364.87
INFO:prismo._core.prismo:Epoch:     100 | Time:      13.01s | Loss:     993.09
INFO:prismo._core.prismo:Epoch:     200 | Time:      24.87s | Loss:     971.00
INFO:prismo._core.prismo:Epoch:     300 | Time:      36.72s | Loss:     964.64
INFO:prismo._core.prismo:Epoch:     400 | Time:      48.56s | Loss:     965.92
INFO:prismo._core.prismo:Epoch:     500 | Time:      60.40s | Loss:     966.52
INFO:prismo._core.prismo:Epoch:     600 | Time:      72.23s | Loss:     964.83
INFO:prismo._core.prismo:Epoch:     700 | Time:      84.06s | Loss:     960.79
INFO:prismo._core.prismo:Epoch:     800 | Time:      95.90s | Loss:     960.72
INFO:prismo._core.prismo:Epoch:     900 | Time:     107.73s | Loss:     961.88
INFO:prismo._core.prismo:Epoch:    1000 | Time:     119.57s | Loss:     959.79
INFO:prismo._core.prismo:Epoch:    1100 | Time:     131.40s | Loss:     961.44
INFO:prismo._core.prismo:Epoch:    1200 | Time:     

In [11]:
prismo_model = PRISMO.load("models/prismo")

z_prismo = prismo_model.get_factors()['group_1'].values
w_prismo = prismo_model.get_weights()['view_1'].T.values

## Factor Matching

In [12]:
z_mefisto_inds, z_prismo_inds, z_signs = match(z_mefisto, z_prismo, dim=-1)

z_prismo = z_prismo[:, z_prismo_inds]
z_prismo = z_prismo * z_signs.reshape(1, -1)

w_prismo = w_prismo[:, z_prismo_inds]
w_prismo = w_prismo * z_signs.reshape(1, -1)

z_prismo = z_prismo / np.ptp(z_prismo, axis=0, keepdims=True)
z_mefisto = z_mefisto / np.ptp(z_mefisto, axis=0, keepdims=True)

w_prismo = w_prismo / np.ptp(w_prismo, axis=0, keepdims=True)
w_mefisto = w_mefisto / np.ptp(w_mefisto, axis=0, keepdims=True)

corr = []
for factor in range(z_prismo.shape[1]):
    corr.append(pearsonr(z_prismo[:, factor], z_mefisto[:, factor])[0])
corr = np.array(corr)
factor_order = corr.argsort()[::-1]
corr = corr[factor_order]

z_prismo = z_prismo[:, factor_order]
w_prismo = w_prismo[:, factor_order]

z_mefisto = z_mefisto[:, factor_order]
w_mefisto = w_mefisto[:, factor_order]

## Plots

In [15]:
covariates_df = pd.DataFrame(prismo_model.covariates["group_1"], columns=["x", "y"])
factors_df = pd.DataFrame(z_prismo, columns=[f"Factor {i+1}" for i in range(z_prismo.shape[1])])
df = pd.concat([covariates_df, factors_df], axis=1)
df["y"] = -df["y"]

for factor in range(1, 5):
    plot = (ggplot(df, aes(x="x", y="y", color=f"Factor {factor}"))
        + geom_point(size=0.1)
        + theme(figure_size=(3, 3), axis_text_x=element_blank(), axis_text_y=element_blank(), axis_ticks_major_x=element_blank(), axis_ticks_major_y=element_blank(), legend_key_width=15)
        + labs(x="", y="")
        + coord_fixed()
        + scale_color_gradient(low="white", high="#0571b0")
    )

    fig = plot.draw()
    points = fig.axes[0].collections[0]
    points.set_rasterized(True)
    fig.savefig(f"plots/prismo_factor_{factor}.pdf")

In [16]:
covariates_df = pd.DataFrame(prismo_model.covariates["group_1"], columns=["x", "y"])
factors_df = pd.DataFrame(z_mefisto, columns=[f"Factor {i+1}" for i in range(z_prismo.shape[1])])
df = pd.concat([covariates_df, factors_df], axis=1)
df["y"] = -df["y"]

for factor in range(1, 5):
    plot = (ggplot(df, aes(x="x", y="y", color=f"Factor {factor}"))
        + geom_point(size=0.1)
        + theme(figure_size=(3, 3), axis_text_x=element_blank(), axis_text_y=element_blank(), axis_ticks_major_x=element_blank(), axis_ticks_major_y=element_blank(), legend_key_width=15)
        + labs(x="", y="")
        + coord_fixed()
        + scale_color_gradient(low="white", high="#67001f")
    )

    fig = plot.draw()
    points = fig.axes[0].collections[0]
    points.set_rasterized(True)
    fig.savefig(f"plots/mefisto_factor_{factor}.pdf")

In [17]:
weights_prismo_df = pd.DataFrame(w_prismo, columns=[f"Factor {i+1}" for i in range(w_prismo.shape[1])])
weights_prismo_df_long = weights_prismo_df.melt(var_name="Factor", value_name="prismo_weight")

weights_mefisto_df = pd.DataFrame(w_mefisto, columns=[f"Factor {i+1}" for i in range(w_mefisto.shape[1])])
weights_mefisto_df_long = weights_mefisto_df.melt(var_name="Factor", value_name="mefisto_weight")["mefisto_weight"]

weights_df_long = pd.concat([weights_prismo_df_long, weights_mefisto_df_long], axis=1)
weights_df_long["Factor"] = pd.Categorical(weights_df_long["Factor"], categories=[f"Factor {i+1}" for i in range(10)], ordered=True)


plot = (
    ggplot(weights_df_long, aes(x="prismo_weight", y="mefisto_weight"))
    + geom_point(size=0.1, color="#0571b0")
    + theme(figure_size=(15, 7))
    + labs(x="PRISMO weight", y="MEFISTO weight")
    + geom_abline(intercept=0, slope=1, linetype="dashed", alpha=0.5)
    + facet_wrap("~Factor", ncol=5)
    + coord_equal()

)
plot.save("plots/weights_comparison.pdf")

