In [14]:
import altair as alt
import pandas as pd
from data_loader import load_mefisto_visium
from mofapy2.run.entry_point import entry_point

In [2]:
adata = load_mefisto_visium()

In [3]:
ent = entry_point()
ent.set_data_options(use_float32=True)
ent.set_data_from_anndata(adata, features_subset="highly_variable")
ent.set_model_options(factors=4)
ent.set_train_options()
ent.set_train_options(seed=2021)
# We use 1000 inducing points to learn spatial covariance patterns
n_inducing = 1000

ent.set_covariates([adata.obsm["spatial"]], covariates_names=["imagerow", "imagecol"])
ent.set_smooth_options(sparseGP=True, frac_inducing=n_inducing/adata.n_obs,
                       start_opt=10, opt_freq=10)


        #########################################################
        ###           __  __  ____  ______                    ### 
        ###          |  \/  |/ __ \|  ____/\    _             ### 
        ###          | \  / | |  | | |__ /  \ _| |_           ### 
        ###          | |\/| | |  | |  __/ /\ \_   _|          ###
        ###          | |  | | |__| | | / ____ \|_|            ###
        ###          |_|  |_|\____/|_|/_/    \_\              ###
        ###                                                   ### 
        ######################################################### 
       
 
        
use_float32 set to True: replacing float64 arrays by float32 arrays to speed up computations...

Loaded view='rna' group='group1' with N=2487 samples and D=2000 features...


Model options:
- Automatic Relevance Determination prior on the factors: False
- Automatic Relevance Determination prior on the weights: True
- Spike-and-slab prior on the factors: False
- Spike-and-slab pr

In [5]:
ent.build()
ent.run()



######################################
## Training the model with seed 2021 ##
######################################


ELBO before training: -18666215.52 

Iteration 1: time=0.82, ELBO=-206209.44, deltaELBO=18460006.082 (98.89527986%), Factors=4
Iteration 2: time=0.80, ELBO=-89797.44, deltaELBO=116412.000 (0.62365079%), Factors=4
Iteration 3: time=0.84, ELBO=-75723.95, deltaELBO=14073.494 (0.07539554%), Factors=4
Iteration 4: time=0.81, ELBO=-72133.51, deltaELBO=3590.443 (0.01923498%), Factors=4
Iteration 5: time=0.81, ELBO=-69424.73, deltaELBO=2708.780 (0.01451167%), Factors=4
Iteration 6: time=0.80, ELBO=-67148.28, deltaELBO=2276.451 (0.01219557%), Factors=4
Iteration 7: time=0.79, ELBO=-65220.17, deltaELBO=1928.106 (0.01032939%), Factors=4
Iteration 8: time=0.78, ELBO=-63647.17, deltaELBO=1572.999 (0.00842698%), Factors=4
Iteration 9: time=0.79, ELBO=-62424.39, deltaELBO=1222.781 (0.00655077%), Factors=4
Optimising sigma node...
Iteration 10: time=43.25, ELBO=-15753.98, deltaELBO

In [23]:
ent.model.getParameters().keys()

dict_keys(['U', 'Z', 'W', 'Tau', 'Y', 'AlphaW', 'Sigma', 'ThetaW'])

In [26]:
ent.model.getParameters()["U"]["mean"]

array([[-1.30020598e+00, -2.89240197e+00,  7.54751502e+00,
         1.05253046e+01],
       [-1.97068658e+00, -2.47724139e+00,  6.86922113e+00,
         9.69276762e-01],
       [-1.81670868e+00, -3.01132671e+00,  7.48438878e+00,
         3.04116581e+00],
       ...,
       [-2.93948787e+00, -4.19927774e-01,  6.34242775e-01,
         8.80119092e-01],
       [-4.04150650e+00, -2.53989757e-02,  9.16735291e-03,
         6.18074512e-01],
       [-4.11376374e+00,  6.58559220e-01, -2.60403635e-01,
        -1.24541259e+00]])

In [24]:
ent.model.getExpectations().keys()

dict_keys(['U', 'Z', 'W', 'Tau', 'Y', 'AlphaW', 'Sigma', 'ThetaW'])

In [17]:
def plot_factors_covariate_2d(factors, covariates):
    factor_charts = []
    z = factors
    df = pd.DataFrame(z)
    for i in range(covariates.shape[-1]):
        df[f"covariate_{i}"] = covariates[:, i]
    df.columns = df.columns.astype(str)

    for factor in range(factors.shape[1]):
        scatter_plot = (
            alt.Chart(df)
            .mark_point(filled=True)
            .encode(
                x=alt.X("covariate_0:O", title="Covariate dim 1", axis=alt.Axis(labels=False)),
                y=alt.Y("covariate_1:O", title="Covariate dim 2", axis=alt.Axis(labels=False)),
                color=alt.Color(f"{factor}:Q", scale=alt.Scale(scheme="redblue", domainMid=0)),
            )
            .properties(width=300, height=300, title=f"Factor {factor+1} with covariates")
            .interactive()
        )

        factor_charts.append(scatter_plot)

    # Concatenate all the charts vertically
    final_chart = alt.hconcat(*factor_charts)

    # Display the chart
    final_chart.display()