In [None]:
import sys
import scanpy as sc
import anndata as ad
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from scipy.io import mmread
from scipy.io import mmwrite
from scipy.sparse import csr_matrix


In [None]:
# solve the jax and jaxlib issue from https://github.com/google/jax/issues/5501
import cell2location
import scvi

from matplotlib import rcParams
rcParams['pdf.fonttype'] = 42 # enables correct plotting of text
import seaborn as sns


In [None]:
results_folder = '../results/visium_axolotl_R12830_resequenced_20220308/cell2location_coarse_out'

# create paths and names to results folders for reference regression and cell2location models
ref_run_name = f'{results_folder}/reference_signatures'
run_name = f'{results_folder}/cell2location_map'


In [None]:
original_counts = np.round(csr_matrix(mmread("../data/snRNAseq_countMatrix.mtx")))


In [None]:
original_counts

In [None]:
original_meta = pd.read_csv("../data/snRNAseq_countMatrix_metadata.csv", index_col = 0)
original_meta

In [None]:
original_genes = pd.read_csv("../data/snRNAseq_countMatrix_gene.csv", index_col = 0)
original_genes

In [None]:
adata_ref = ad.AnnData(original_counts, obs=original_meta, var = original_genes)

In [None]:
adata_ref
adata_ref.to_df() # double check if the cell barcodes and gene names are correct 

In [None]:
# filtering 
# pal_data = pal_data[pal_data.obs.cellclusters!="glut_SUBSET_23",]

In [None]:
from cell2location.utils.filtering import filter_genes
selected = filter_genes(adata_ref, cell_count_cutoff=5, cell_percentage_cutoff2=0.03, nonz_mean_cutoff=1.12)

# filter the object
adata_ref = adata_ref[:, selected].copy()


In [None]:
adata_ref

In [None]:
# prepare anndata for the regression model
cell2location.models.RegressionModel.setup_anndata(adata=adata_ref,
                        # 10X reaction / sample / batch
                        batch_key=None,
                        # cell type, covariate used for constructing signatures
                        labels_key='celltypes',
                        # multiplicative technical effects (platform, 3' vs 5', donor effect)
                        categorical_covariate_keys=['condition']
                       )


In [None]:
# create the regression model
from cell2location.models import RegressionModel
mod = RegressionModel(adata_ref)

# view anndata_setup as a sanity check
mod.view_anndata_setup()

In [None]:
# Use all data for training (validation not implemented yet, train_size=1)
# takes ~ 3.5h for 175 epochs, but should increase to 200/250 epochs
mod.train(max_epochs=100, batch_size=2500, train_size=1, lr=0.002, use_gpu=False) 

# plot ELBO loss history during training, removing first 20 epochs from the plot
#mod.plot_history(20)

#mod.train(max_epochs=50, use_gpu=False)


In [None]:
# plot ELBO loss history during training, removing first 20 epochs from the plot
mod.plot_history(20)


In [None]:
# In this section, we export the estimated cell abundance (summary of the posterior distribution).
adata_ref = mod.export_posterior(
    adata_ref, sample_kwargs={'num_samples': 1000, 'batch_size': 2500, 'use_gpu': False}
)

# Save model
mod.save(f"{ref_run_name}", overwrite=True)

# Save anndata object with results
adata_file = f"{ref_run_name}/sc_corase_celltypes_v3.h5ad"
adata_ref.write(adata_file)
adata_file

In [None]:
mod.plot_QC()

In [None]:
# Reload the gene signatures and model

In [None]:
# The model and output h5ad can be loaded later like this if needed
#ref_run_name = "/Volumes/groups/tanaka/People/current/jiwang/projects/heart_regeneration/results/visium_axolotl_R12830_resequenced_20220308/cell2location_coarse_out/reference_signatures/"

adata_file = f"{ref_run_name}/sc_corase_celltypes_v3.h5ad"
adata_ref = sc.read_h5ad(adata_file)

In [None]:
adata_ref

In [None]:
mod = cell2location.models.RegressionModel.load(f"{ref_run_name}", adata_ref)

In [None]:
# export estimated expression in each cluster
if 'means_per_cluster_mu_fg' in adata_ref.varm.keys():
    inf_aver = adata_ref.varm['means_per_cluster_mu_fg'][[f'means_per_cluster_mu_fg_{i}'
                                    for i in adata_ref.uns['mod']['factor_names']]].copy()
else:
    inf_aver = adata_ref.var[[f'means_per_cluster_mu_fg_{i}'
                                    for i in adata_ref.uns['mod']['factor_names']]].copy()
inf_aver.columns = adata_ref.uns['mod']['factor_names']
inf_aver.iloc[0:5, 0:5]

In [None]:
# load visium data 
vis = sc.read_loom("../data/visium_Amex_all.loom") # all 4 slices merged


In [None]:
vis.X = np.round(vis.X)

In [None]:
# find shared genes and subset both anndata and reference signatures
intersect = np.intersect1d(vis.var_names, inf_aver.index)
vis = vis[:, intersect].copy()
inf_aver = inf_aver.loc[intersect, :].copy()



In [None]:
# prepare anndata for cell2location model
#scvi.data.setup_anndata(adata=vis, continuous_covariate_keys = ["nCount_Spatial"])
#scvi.data.view_anndata_setup(vis)

# prepare anndata for cell2location model
cell2location.models.Cell2location.setup_anndata(adata=vis, 
                                                 categorical_covariate_keys=['condition'],
                                                 continuous_covariate_keys = ["nCount_Spatial"])


In [None]:
# create and train the model
mod = cell2location.models.Cell2location(
    vis, cell_state_df=inf_aver,
    # the expected average cell abundance: tissue-dependent
    # hyper-prior which can be estimated from paired histology:
    N_cells_per_location=10,
    # hyperparameter controlling normalisation of
    # within-experiment variation in RNA detection:
    detection_alpha=20
)
mod.view_anndata_setup()

In [None]:
mod.train(max_epochs=200,
          # train using full data (batch_size=None)
          batch_size=None,
          # use all data points in training because
          # we need to estimate cell abundance at all locations
          train_size=1,
          use_gpu=False)

In [None]:
# plot ELBO loss history during training, removing first 100 epochs from the plot
mod.plot_history(200)
plt.legend(labels=['full data training']);


In [None]:
run_name # folder to save the result

In [None]:
# In this section, we export the estimated cell abundance (summary of the posterior distribution).
vis = mod.export_posterior(
    vis, sample_kwargs={'num_samples': 1000, 'batch_size': mod.adata.n_obs, 'use_gpu': False}
)



In [None]:
# Save model
mod.save(f"{run_name}", overwrite=True)

# mod = cell2location.models.Cell2location.load(f"{run_name}", adata_vis)

# Save anndata object with results
adata_file = f"{run_name}/sp.h5ad"
vis.write(adata_file)
adata_file

In [None]:
# add 5% quantile, representing confident cell abundance, 'at least this amount is present',
# to adata.obs with nice names for plotting
vis.obs[vis.uns['mod']['factor_names']] = vis.obsm['q05_cell_abundance_w_sf']
vis.obs.to_csv(results_folder+"/predictions_cell2loc_v1.csv")

In [None]:
mod.plot_QC()