# Comparison of NSF and PRISMO on SlideSeq Data from NSF Paper

In [1]:
%load_ext autoreload
%autoreload 2

import numpy as np
import pandas as pd
import pickle as pkl
from scipy.stats import pearsonr
import matplotlib.pyplot as plt
from plotnine import *
from matplotlib.colors import ListedColormap
from matplotlib.colors import TwoSlopeNorm, ListedColormap
import matplotlib.ticker as ticker
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib.ticker as ticker
from mpl_toolkits.axes_grid1 import make_axes_locatable

from prismo import (
    PRISMO,
    DataOptions,
    ModelOptions,
    TrainingOptions,
    SmoothOptions,
)
from prismo.tl import match
from prismo.pl import plot_covariates_factor_scatter

from data_loader import load_nsf_slideseq



Importing the dtw module. When using in academic works please cite:
  T. Giorgino. Computing and Visualizing Dynamic Time Warping Alignments in R: The dtw Package.
  J. Stat. Soft., doi:10.18637/jss.v031.i07.





## PRISMO

In [3]:
adata = load_nsf_slideseq()

prismo_model = PRISMO(
    adata,
    DataOptions(
        covariates_obsm_key="spatial",
        plot_data_overview=False,
    ),
    ModelOptions(
        n_factors=10,
        weight_prior="Horseshoe",
        factor_prior="GP",
        likelihoods="GammaPoisson",
        nonnegative_weights=True,
        nonnegative_factors=True,

    ),
    TrainingOptions(
        device="cuda:1",
        batch_size=1000,
        max_epochs=500,
        lr=1e-2,
        early_stopper_patience=30,
        print_every=50,
        save_path="models/prismo",
        seed=5432,
    ),
    SmoothOptions(
        n_inducing=1000,
        kernel="Matern",
    )
)

prismo_model = PRISMO.load("models/prismo")

INFO:prismo._core.prismo:Setting up device...
INFO:prismo._core.prismo:- Using provided likelihood for all views.
INFO:prismo._core.prismo:  - view_1: GammaPoisson
INFO:prismo._core.prismo:Initializing factors using `random` method...
INFO:prismo._core.prismo:Decaying learning rate over 18000 iterations.
INFO:prismo._core.prismo:Setting training seed to `5432`.
INFO:prismo._core.prismo:Cleaning parameter store.
INFO:prismo._core.prismo:Epoch:       0 | Time:       4.61s | Loss:   88628.41
INFO:prismo._core.prismo:Epoch:      50 | Time:     232.04s | Loss:   35546.44
INFO:prismo._core.prismo:Epoch:     100 | Time:     459.87s | Loss:   34576.23
INFO:prismo._core.prismo:Epoch:     150 | Time:     687.68s | Loss:   34152.32
INFO:prismo._core.prismo:Epoch:     200 | Time:     915.48s | Loss:   33956.84
INFO:prismo._core.prismo:Epoch:     250 | Time:    1143.28s | Loss:   33851.67
INFO:prismo._core.prismo:Training finished after 283 steps.
INFO:prismo._core.prismo:Saving results...


In [4]:
z_prismo = prismo_model.get_gps()["group_1"].values
w_prismo = prismo_model.get_weights()["view_1"].T.values

pkl.dump({"z" : z_prismo, "w" : w_prismo}, open("lvs/prismo.pkl", "wb"))

## NSF

In [5]:
# # requires the nsf-paper conda environment to be activated

# import pickle as pkl

# import numpy as np
# import spatial_factorization as sf
# from data_loader import load_nsf_slideseq
# from tensorflow.data import Dataset

# adata = load_nsf_slideseq()

# # prepare data for SpatialFactorization model
# data = {
#     "X": adata.obsm["spatial"].copy().astype("float32"),
#     "Y": adata.layers["counts"].toarray().astype("float32"),
#     "sz" : np.ones((adata.n_obs, 1), dtype="float32"),
#     "idx" : np.arange(adata.n_obs)
#     }
# data_tf = Dataset.from_tensor_slices(data)
# data_tf = data_tf.batch(adata.n_obs)
# inducing_locations = sf.misc.kmeans_inducing_pts(data["X"], 1000)

# # setup and train SpatialFactorization model
# nsf_model = sf.SpatialFactorization(
#     J=adata.n_vars,
#     L=10,
#     Z=inducing_locations,
#     lik="poi",
#     nonneg=True,
# )
# nsf_model.init_loadings(data["Y"])
# trainer = sf.ModelTrainer(nsf_model)
# trainer.train_model(data_tf, adata.n_obs, None)

# # obtain inferred latent variables
# z_nsf = np.exp(nsf_model.sample_latent_GP_funcs(data["X"], S=100).numpy().mean(axis=0).T)
# w_nsf = nsf_model.get_loadings()

# pkl.dump({"z" : z_nsf, "w" : w_nsf}, open("lvs/nsf.pkl", "wb"))

## Factor Matching

In [18]:
z_nsf, w_nsf = pkl.load(open("lvs/nsf.pkl", "rb")).values()

# postprocess scales
z_nsf_sum = np.sum(z_nsf, axis=0)
z_nsf = z_nsf / z_nsf_sum.reshape(1, -1)
w_nsf = w_nsf * z_nsf_sum
w_nsf_sum = np.sum(w_nsf, axis=1, keepdims=True)
w_nsf = w_nsf / w_nsf_sum
z_nsf *= 1e4

z_prismo, w_prismo = pkl.load(open("lvs/prismo.pkl", "rb")).values()

# postprocess scales
z_prismo_sum = np.sum(z_prismo, axis=0)
z_prismo = z_prismo / z_prismo_sum.reshape(1, -1)
w_prismo = w_prismo * z_prismo_sum
w_prismo_sum = np.sum(w_prismo, axis=1, keepdims=True)
w_prismo = w_prismo / w_prismo_sum
z_prismo *= 1e4

In [19]:
z_nsf_inds, z_prismo_inds, _ = match(z_nsf, z_prismo, dim=-1)

z_prismo = z_prismo[:, z_prismo_inds]
w_prismo = w_prismo[:, z_prismo_inds]

corr = []
for factor in range(10):
    corr.append(pearsonr(z_prismo[:, factor], z_nsf[:, factor])[0])
corr = np.array(corr)
factor_order = corr.argsort()[::-1]
corr = corr[factor_order]

z_prismo = z_prismo[:, factor_order]
w_prismo = w_prismo[:, factor_order]

z_nsf = z_nsf[:, factor_order]
w_nsf = w_nsf[:, factor_order]

## Plots

In [15]:
covariates_df = pd.DataFrame(prismo_model.covariates["group_1"], columns=["x", "y"])
factors_df = pd.DataFrame(z_prismo, columns=[f"Factor {i+1}" for i in range(z_prismo.shape[1])])
df = pd.concat([covariates_df, factors_df], axis=1)
df["y"] = -df["y"]

for factor in range(1, 11):
    plot = (ggplot(df, aes(x="x", y="y", color=f"Factor {factor}"))
        + geom_point(size=0.1)
        + theme(figure_size=(3, 3), axis_text_x=element_blank(), axis_text_y=element_blank(), axis_ticks_major_x=element_blank(), axis_ticks_major_y=element_blank(), legend_key_width=15)
        + labs(x="", y="")
        + coord_fixed()
        + scale_color_gradient(low="white", high="#0571b0")
    )

    fig = plot.draw()
    points = fig.axes[0].collections[0]
    points.set_rasterized(True)
    fig.savefig(f"plots/prismo_factor_{factor}.pdf")

In [16]:
covariates_df = pd.DataFrame(prismo_model.covariates["group_1"], columns=["x", "y"])
factors_df = pd.DataFrame(z_nsf, columns=[f"Factor {i+1}" for i in range(z_prismo.shape[1])])
df = pd.concat([covariates_df, factors_df], axis=1)
df["y"] = -df["y"]

for factor in range(1, 11):
    plot = (ggplot(df, aes(x="x", y="y", color=f"Factor {factor}"))
        + geom_point(size=0.1)
        + theme(figure_size=(3, 3), axis_text_x=element_blank(), axis_text_y=element_blank(), axis_ticks_major_x=element_blank(), axis_ticks_major_y=element_blank(), legend_key_width=15)
        + labs(x="", y="")
        + coord_fixed()
        + scale_color_gradient(low="white", high="#67001f")
    )

    fig = plot.draw()
    points = fig.axes[0].collections[0]
    points.set_rasterized(True)
    fig.savefig(f"plots/nsf_factor_{factor}.pdf")

In [20]:
weights_prismo_df = pd.DataFrame(w_prismo, columns=[f"Factor {i+1}" for i in range(w_prismo.shape[1])])
weights_prismo_df_long = weights_prismo_df.melt(var_name="Factor", value_name="prismo_weight")

weights_nsf_df = pd.DataFrame(w_nsf, columns=[f"Factor {i+1}" for i in range(w_nsf.shape[1])])
weights_nsf_df_long = weights_nsf_df.melt(var_name="Factor", value_name="nsf_weight")["nsf_weight"]

weights_df_long = pd.concat([weights_prismo_df_long, weights_nsf_df_long], axis=1)
weights_df_long["Factor"] = pd.Categorical(weights_df_long["Factor"], categories=[f"Factor {i+1}" for i in range(10)], ordered=True)


plot = (
    ggplot(weights_df_long, aes(x="prismo_weight", y="nsf_weight"))
    + geom_point(size=0.1, color="#0571b0")
    + theme(figure_size=(15, 7))
    + labs(x="PRISMO weight", y="NSF weight")
    + geom_abline(intercept=0, slope=1, linetype="dashed", alpha=0.5)
    + facet_wrap("~Factor", ncol=5)
    + coord_equal()

)
plot.save("plots/weights_comparison.pdf")

