In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import scanpy as sc
import numpy as np
import pandas as pd
import anndata as ad
from scipy import sparse

In [None]:
import cell2location

In [None]:
samples_dict = {"Immature": "A0021_043", "0hr": "A0008_041", "1hr": "A0021_044", "4hr": "A0008_045", "4hr_replicate": "A0021_042", "6hr": "A0021_045", "8hr": "A0008_044", "8hr_replicate": "A0021_038", "11hr": "A0008_046", "12hr": "A0021_046"}

In [None]:
sc.settings.verbosity = 3  
# verbosity: errors (0), warnings (1), info (2), hints (3)
sc.logging.print_versions()
sc.settings.set_figure_params(dpi=80, facecolor='white', frameon=True, figsize=(5, 5))

In [None]:
ovary_raw = pd.read_csv("./../scRNAseq_references/Ovary_subset_0_count_data.csv", index_col=0, header = 0, delimiter=",")
print(ovary_raw.shape)
ovary_meta_data = pd.read_csv("./../scRNAseq_references/Ovary_subset_0_meta_data.csv", index_col = 0)
print(ovary_meta_data.shape)


In [None]:
var_df = pd.DataFrame(ovary_raw.index.to_frame())
var_df.columns = ["gene"]

In [None]:
ovary_adata_sc = ad.AnnData(X=sparse.csr_matrix(ovary_raw.transpose().to_numpy()), obs=ovary_meta_data, var=var_df)
ovary_adata_sc

In [None]:
# ovary_adata_sc.write_h5ad("./../pyobjs/ovary_adata_sc.h5ad")


In [None]:
ovary_adata_sc = sc.read_h5ad("./../pyobjs/ovary_adata_sc.h5ad")
ovary_adata_sc

In [None]:
# sc.pp.normalize_total(ovary_adata_sc, inplace=True)
# sc.pp.log1p(ovary_adata_sc)
# sc.pp.highly_variable_genes(ovary_adata_sc, max_mean=3, min_disp=0.20)
# sc.pl.highly_variable_genes(ovary_adata_sc)
# ovary_adata_sc.raw = ovary_adata_sc
# ovary_adata_sc = ovary_adata_sc[:, ovary_adata_sc.var.highly_variable]
# print(ovary_adata_sc.shape)
# sc.pp.regress_out(ovary_adata_sc, ["total_counts"])
# sc.pp.scale(ovary_adata_sc, max_value=10)
# sc.pp.pca(ovary_adata_sc, random_state= 0)
# sc.pl.pca_variance_ratio(ovary_adata_sc, n_pcs = 50)
# sc.pp.neighbors(ovary_adata_sc, n_pcs=20)
# sc.tl.umap(ovary_adata_sc)
# sc.tl.leiden(ovary_adata_sc, key_added="leiden_1.0", resolution = 1.0)
# sc.tl.leiden(ovary_adata_sc, key_added="leiden_1.2", resolution = 1.2)

In [None]:
from cell2location.utils.filtering import filter_genes
selected = filter_genes(ovary_adata_sc, cell_count_cutoff=5, cell_percentage_cutoff2=0.03, nonz_mean_cutoff=1.12)
# %notebook filter the object
ovary_adata_sc = ovary_adata_sc[:, selected].copy()

In [None]:
ovary_adata_sc.obs.columns

In [None]:
ovary_adata_sc.obs["mouse"].value_counts()

In [None]:
# prepare anndata for the regression model
cell2location.models.RegressionModel.setup_anndata(adata=ovary_adata_sc,
                        # 10X reaction / sample / batch
                        batch_key='mouse',
                        # cell type, covariate used for constructing signatures
                        labels_key='Level1')

In [None]:
from cell2location.models import RegressionModel
mod = RegressionModel(ovary_adata_sc)

# view anndata_setup as a sanity check
mod.view_anndata_setup()

In [None]:
mod.train(max_epochs=250, use_gpu=0)

In [None]:
mod.plot_history(20)

In [None]:
# In this section, we export the estimated cell abundance (summary of the posterior distribution).
ovary_adata_sc = mod.export_posterior(
    ovary_adata_sc, sample_kwargs={'num_samples': 1000, 'batch_size': 2500, 'use_gpu': True}
)

# Save model
mod.save("./../pyobjs/ovary_cell2location_sc_ref_mod_level1", overwrite=True)

# Save anndata object with results
ovary_adata_sc.write("./../pyobjs/ovary_cell2location_sc_ref_level1.h5ad")

In [None]:
mod.plot_QC()

In [None]:
adata_ref = sc.read_h5ad("./../pyobjs/ovary_cell2location_sc_ref_level1.h5ad")
mod = cell2location.models.RegressionModel.load(f"./../pyobjs/ovary_cell2location_sc_ref_mod_level1", adata_ref)

In [None]:
# export estimated expression in each cluster
if 'means_per_cluster_mu_fg' in adata_ref.varm.keys():
    inf_aver = adata_ref.varm['means_per_cluster_mu_fg'][[f'means_per_cluster_mu_fg_{i}'
                                    for i in adata_ref.uns['mod']['factor_names']]].copy()
else:
    inf_aver = adata_ref.var[[f'means_per_cluster_mu_fg_{i}'
                                    for i in adata_ref.uns['mod']['factor_names']]].copy()
inf_aver.columns = adata_ref.uns['mod']['factor_names']
print(inf_aver.shape)
inf_aver.iloc[0:10, 0:10]

In [None]:
adata_vis = sc.read_h5ad("./../pyobjs/adata_ovary_combined_raw_counts_filtered.h5ad")
adata_vis = adata_vis[sc.read_h5ad("./../pyobjs/adata_ovary_combined_processed0.h5ad").obs_names]
adata_vis

In [None]:
# find shared genes and subset both anndata and reference signatures
intersect = np.intersect1d(adata_vis.var_names, inf_aver.index)
print(len(intersect))
adata_vis = adata_vis[:, intersect].copy()
inf_aver = inf_aver.loc[intersect, :].copy()

# prepare anndata for cell2location model
cell2location.models.Cell2location.setup_anndata(adata=adata_vis, batch_key="Sample")

In [None]:
# create and train the model
mod = cell2location.models.Cell2location(
    adata_vis, cell_state_df=inf_aver,
    # the expected average cell abundance: tissue-dependent
    # hyper-prior which can be estimated from paired histology:
    N_cells_per_location=1,
    # hyperparameter controlling normalisation of
    # within-experiment variation in RNA detection:
    detection_alpha=20
)
mod.view_anndata_setup()

In [None]:
mod.train(max_epochs=1000,
          # train using full data (batch_size=None)
          batch_size=5000,
          # use all data points in training because
          # we need to estimate cell abundance at all locations
          train_size= 1.0,
          use_gpu=1)

In [None]:
# plot ELBO loss history during training, removing first 100 epochs from the plot
mod.plot_history(20)
plt.legend(labels=['full data training'])

In [None]:
# In this section, we export the estimated cell abundance (summary of the posterior distribution).
adata_vis = mod.export_posterior(
    adata_vis, sample_kwargs={'num_samples': 500, 'batch_size': 5000, 'use_gpu': True}
)

# Save model
mod.save("./../pyobjs/ovary_cell2location_ss_combined_mod_level1", overwrite=True)

# Save anndata object with results
adata_vis.write("./../pyobjs/ovary_cell2location_ss_combined_level1.h5ad")

In [None]:
mod.plot_QC()

In [None]:
adata_vis = sc.read_h5ad("./../pyobjs/ovary_cell2location_ss_combined_level1.h5ad")
# mod = cell2location.models.Cell2location.load(f"./../pyobjs/ovary_cell2location_ss_combined_mod_level1", adata_vis)
adata_vis

In [None]:
print(adata_vis.obsm)
adata_vis.obsm['q05_cell_proportions'] = adata_vis.obsm['q05_cell_abundance_w_sf'].div(adata_vis.obsm['q05_cell_abundance_w_sf'].sum(axis=1), axis=0)
# add 5% quantile, representing confident cell abundance, 'at least this amount is present',
# to adata.obs with nice names for plotting
adata_vis.obs["total_abundance"] = adata_vis.obsm['q05_cell_abundance_w_sf'].sum(axis = 1)
adata_vis.obs[adata_vis.uns['mod']['factor_names']] = adata_vis.obsm['q05_cell_proportions']
ct_list = list(adata_vis.uns['mod']['factor_names'])
ct_list = ["Level1_" + x for x in ct_list]
# for ct in ct_list:
#     data = adata_vis.obs[ct].values
#     adata_vis.obs[ct] = np.clip(data,0, np.quantile(data, 0.90))
adata_vis.obs["Level1_" + "max_pred"] = adata_vis.obs[adata_vis.uns['mod']['factor_names']].max(axis=1)
adata_vis.obs["Level1_" + "max_pred_celltype"] = adata_vis.obs[adata_vis.uns['mod']['factor_names']].idxmax(axis=1)

In [None]:
# plot in spatial coordinates
sc.settings.set_figure_params(dpi_save= 400, fontsize=6, figsize=(3.0, 3.0), facecolor='white', frameon=False, transparent=True, vector_friendly = True, format="pdf")
for sample in samples_dict.keys():    
    sc.pl.spatial(adata_vis[adata_vis.obs["Sample"] == sample], cmap="Blues",
                      # show first 8 cell types
                      color=list(adata_vis.obs["Level1_max_pred_celltype"].value_counts().index), spot_size= 30, 
                  ncols=6, wspace = 0.2, hspace=0.2,
                      # limit color scale at 99.2% quantile of cell abundance
                     vmin=0, vmax='p99.2', legend_fontsize=5)

In [None]:
# Compute expected expression per cell type
expected_dict = mod.module.model.compute_expected_per_cell_type(
    mod.samples["post_sample_q05"], mod.adata_manager  
)

# Add to anndata layers
for i, n in enumerate(mod.factor_names_):
    adata_vis.layers["Level1_" + n] = expected_dict['mu'][i]

In [None]:
adata_vis

In [None]:
# Save anndata object with results
adata_vis.write_h5ad("./../pyobjs/slideseq_cell2loc_RCTD_level1.h5ad")

In [None]:
adata_vis = sc.read_h5ad("./../pyobjs/slideseq_cell2loc_RCTD_level1.h5ad")
adata_vis

In [None]:
fig, axs = plt.subplots(2,5, figsize = (30,8))
for i, sample in enumerate(samples_dict.keys()):
    sc.settings.set_figure_params(dpi=200, dpi_save= 300, fontsize=10, facecolor='white', frameon=False, figsize=(2.0, 2.0), vector_friendly = False, transparent=True, format="pdf")
    sc.pl.spatial(adata_vis[adata_vis.obs["Sample"] == sample], color = ["Level1_max_pred_celltype"], wspace= 1.0, spot_size = 30, frameon=False, title=sample, show=False, ax=axs[int(i/5), int(i%5)])  