In [None]:
# Load packages

import sys
import os

import scanpy as sc
import anndata
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

import cell2location
import scvi

from matplotlib import rcParams
rcParams['pdf.fonttype'] = 42 # enables correct plotting of text
import seaborn as sns

sid = int(os.getenv('SLURM_ARRAY_TASK_ID')) - 1
#sid = 1

In [None]:
#read metadata file
ID_map = pd.read_csv("/edgehpc/dept/compbio/projects/TST11523/mouse_cpz/ID.map.csv")
ID_map.head(5)

In [None]:
sample_id = ID_map["SampleID"][sid]
sample_id

In [None]:
ref_id = ID_map["Number"][sid]
ref_id

In [None]:
#define input folder and result folder
sp_data_folder = "/edgehpc/dept/compbio/projects/TST11523/mouse_cpz/data-raw/"
sc_data_folder = "/edgehpc/dept/compbio/projects/TST11523/mouse_cpz/data-raw/snRNAseq/cuprizonemouse_brain_alllevels_comb_clean_splitsamples/"
results_folder = "/edgehpc/dept/compbio/projects/TST11523/mouse_cpz/out/cell2location/" + sample_id
run_name = "/edgehpc/dept/compbio/projects/TST11523/mouse_cpz/out/cell2location/" + sample_id + '/cell2location_map'
ref_run_name = "/edgehpc/dept/compbio/projects/TST11523/mouse_cpz/out/cell2location/" + sample_id + '/reference_signatures'
if os.path.exists(results_folder) is not True:
    #os.mkdir('./results')
    os.mkdir(results_folder)
    os.mkdir(run_name)
    os.mkdir(ref_run_name)

In [9]:
#read snRNA-Seq data
snrna=anndata.read_h5ad("/edgehpc/dept/compbio/users/msheehan/msheehan/sc_sn_RNAseq/20220713_TST11621_cuprizone_finalLT/20220720_cuprizonemousebrain_cell2loc_and_OL_labels_added_clean.h5ad")

In [10]:
#subset a sample from snRNA-Seq 
adata_snrna = snrna[snrna.obs['Number']==ref_id]

In [14]:
#generate new anndata
X = adata_snrna.raw.X
adata_ref = anndata.AnnData(X, obs=adata_snrna.obs, var=adata_snrna.var, dtype='int32')

In [None]:
#rename gene_ids-2G-2' to'gene_ids'
adata_ref.var.rename(columns = {'gene_ids-2G-2':'gene_ids'}, inplace = True)
adata_ref

In [None]:
# Use ENSEMBL as gene IDs to make sure IDs are unique and correctly matched
adata_ref.raw = adata_ref
adata_ref.var['SYMBOL'] = adata_ref.var.index
adata_ref.var.index = adata_ref.var['gene_ids'].copy()
adata_ref.var_names = adata_ref.var['gene_ids'].copy()
adata_ref.var.index.name = None
adata_ref.raw.var['SYMBOL'] = adata_ref.raw.var.index
adata_ref.raw.var.index = adata_ref.raw.var['gene_ids'].copy()
adata_ref.raw.var.index.name = None

In [None]:
# remove cells and genes with 0 counts everywhere
sc.pp.filter_cells(adata_ref, min_genes=1)
sc.pp.filter_genes(adata_ref, min_cells=1)

In [None]:
## before we estimate the reference cell type signature we recommend to perform very permissive genes selection
## in this 2D histogram orange rectangle lays over excluded genes.
from cell2location.utils.filtering import filter_genes
selected = filter_genes(adata_ref, cell_count_cutoff=5, cell_percentage_cutoff2=0.03, nonz_mean_cutoff=1.12)

## filter the object
adata_ref = adata_ref[:, selected].copy()

In [None]:
## prepare anndata for the regression model
scvi.data.setup_anndata(adata=adata_ref,
                        # 10X reaction / sample / batch
                        batch_key='Number',
                        # cell type, covariate used for constructing signatures
                        labels_key='final_label',
                        # multiplicative technical effects (platform, 3' vs 5', donor effect)
                        #categorical_covariate_keys=['BiogenBatch']
                       )
scvi.data.view_anndata_setup(adata_ref)

In [None]:
## create and train the regression model
from cell2location.models import RegressionModel
mod = RegressionModel(adata_ref)

## Use all data for training (validation not implemented yet, train_size=1)
## Increase max_epochs if plot is still decreasing (default 250) 
mod.train(max_epochs=300, batch_size=2500, train_size=1, lr=0.002, use_gpu=True)

# plot ELBO loss history during training, removing first 20 epochs from the plot
mod.plot_history(20)

In [None]:
# In this section, we export the estimated cell abundance (summary of the posterior distribution).
adata_ref = mod.export_posterior(
    adata_ref, sample_kwargs={'num_samples': 1000, 'batch_size': 2500, 'use_gpu': True}
)

# Save model
mod.save(f"{ref_run_name}", overwrite=True)

# Save anndata object with results
adata_file = f"{ref_run_name}/sc.h5ad"
adata_ref.write(adata_file)
adata_file

In [None]:
# Plot QC 
mod.plot_QC()

In [None]:
# export estimated expression in each cluster

if 'means_per_cluster_mu_fg' in adata_ref.varm.keys():
    inf_aver = adata_ref.varm['means_per_cluster_mu_fg'][[f'means_per_cluster_mu_fg_{i}'
                                    for i in adata_ref.uns['mod']['factor_names']]].copy()
else:
    inf_aver = adata_ref.var[[f'means_per_cluster_mu_fg_{i}'
                                    for i in adata_ref.uns['mod']['factor_names']]].copy()
inf_aver.columns = adata_ref.uns['mod']['factor_names']
inf_aver.iloc[0:5, 0:5]

In [None]:
# Define functions for loading in samples

def read_and_qc(sample_name, path=sp_data_folder): #+ 'rawdata/'):
    r""" This function reads the data for one 10X spatial experiment into the anndata object.
    It also calculates QC metrics. Modify this function if required by your workflow.

    :param sample_name: Name of the sample
    :param path: path to data
    """

    adata = sc.read_visium(path + str(sample_name),
                           count_file='filtered_feature_bc_matrix.h5', load_images=True)
    adata.obs['sample'] = sample_name
    adata.var['SYMBOL'] = adata.var_names
    adata.var.rename(columns={'gene_ids': 'ENSEMBL'}, inplace=True)
    adata.var_names = adata.var['ENSEMBL']
    adata.var.drop(columns='ENSEMBL', inplace=True)

    # Calculate QC metrics
    sc.pp.calculate_qc_metrics(adata, inplace=True)
    adata.var['mt'] = [gene.startswith('mt-') for gene in adata.var['SYMBOL']]
    adata.obs['mt_frac'] = adata[:, adata.var['mt'].tolist()].X.sum(1).A.squeeze()/adata.obs['total_counts']

    # add sample name to obs names
    adata.obs["sample"] = [str(i) for i in adata.obs['sample']]
    adata.obs_names = adata.obs["sample"] \
                          + '_' + adata.obs_names
    adata.obs.index.name = 'spot_id'
    return adata

def select_slide(adata, s, s_col='sample'):
    r""" This function selects the data for one slide from the spatial anndata object.

    :param adata: Anndata object with multiple spatial experiments
    :param s: name of selected experiment
    :param s_col: column in adata.obs listing experiment name for each location
    """

    slide = adata[adata.obs[s_col].isin([s]), :]
    s_keys = list(slide.uns['spatial'].keys())
    s_spatial = np.array(s_keys)[[s in k for k in s_keys]][0]

    slide.uns['spatial'] = {s_spatial: slide.uns['spatial'][s_spatial]}

    return slide

In [None]:
# Load Visium data

# create paths and names to results folders for reference regression and cell2location models
ref_run_name = f'{results_folder}/reference_signatures'
run_name = f'{results_folder}/cell2location_map'

adata_vis=read_and_qc(sample_id, path=sp_data_folder)

In [None]:
# find mitochondria-encoded (MT) genes
adata_vis.var['mt_gene'] = [gene.startswith('mt-') for gene in adata_vis.var['SYMBOL']]

# remove MT genes for spatial mapping (keeping their counts in the object)
adata_vis.obsm['mt'] = adata_vis[:, adata_vis.var['mt_gene'].values].X.toarray()
adata_vis = adata_vis[:, ~adata_vis.var['mt_gene'].values]

In [None]:
adata_vis

In [None]:
adata_vis.obs["sample"].value_counts()

In [None]:
# Spatial mapping
## find shared genes and subset both anndata and reference signatures
intersect = np.intersect1d(adata_vis.var_names, inf_aver.index)
adata_vis = adata_vis[:, intersect].copy()
inf_aver = inf_aver.loc[intersect, :].copy()

## prepare anndata for cell2location model
scvi.data.setup_anndata(adata=adata_vis, batch_key="sample")
scvi.data.view_anndata_setup(adata_vis)

In [None]:
# create model
mod = cell2location.models.Cell2location(
    adata_vis, cell_state_df=inf_aver,
    # the expected average cell abundance: tissue-dependent
    # hyper-prior which can be estimated from paired histology:
    N_cells_per_location=5,
    # hyperparameter controlling normalisation of
    # within-experiment variation in RNA detection (using default here):
    detection_alpha=200
)

In [None]:
# train model
mod.train(max_epochs=30000,
          # train using full data (batch_size=None)
          batch_size=None,
          # use all data points in training because
          # we need to estimate cell abundance at all locations
          train_size=1,
          use_gpu=True)

# plot ELBO loss history during training, removing first 100 epochs from the plot
mod.plot_history(1000)
plt.legend(labels=['full data training']);

In [None]:
# In this section, we export the estimated cell abundance (summary of the posterior distribution).
adata_vis = mod.export_posterior(
    adata_vis, sample_kwargs={'num_samples': 1000, 'batch_size': mod.adata.n_obs, 'use_gpu': True}
)

# Save model
mod.save(f"{run_name}", overwrite=True)

# mod = cell2location.models.Cell2location.load(f"{run_name}", adata_vis)

# Save anndata object with results
adata_file = f"{run_name}/sp.h5ad"
adata_vis.write(adata_file)
adata_file

In [None]:
# add 5% quantile, representing confident cell abundance, 'at least this amount is present',
# to adata.obs with nice names for plotting
adata_vis.obs[adata_vis.uns['mod']['factor_names']] = adata_vis.obsm['q05_cell_abundance_w_sf']
celltypes = adata_vis.uns['mod']['factor_names']
#reorder = [0,7,12,13,14,15,6,1,2,3,4,5,8,9,10,11]

# select one slide
#from cell2location.utils import select_slide
#slide = select_slide(adata_vis, sample_list[0])

# plot in spatial coordinates
with mpl.rc_context({'axes.facecolor':  'black',
                     'figure.figsize': [4.5, 5]}):

    sc.pl.spatial(adata_vis, cmap='magma',
                  #color = [celltypes[i] for i in reorder],
                  color = celltypes,
                  ncols=4, size=1.3,
                  img_key='lowres',
                  # limit color scale at 99.2% quantile of cell abundance
                  vmin=0, vmax='p99.2'
                 )

In [None]:
# Compute expected expression per cell type
expected_dict = mod.module.model.compute_expected_per_cell_type(
    mod.samples["post_sample_q05"], mod.adata
)

In [None]:
# Add to anndata layers
for i, n in enumerate(mod.factor_names_):
    adata_vis.layers[n] = expected_dict['mu'][i]

In [None]:
# Save anndata object with results
adata_file = f"{run_name}/sp.h5ad"
adata_vis.write(adata_file)
adata_file