In [None]:
import sys

#if installed somewhere else
sys.path.insert(1, '/nfs/team205/vk7/sanger_projects/BayraktarLab/cell2location/')
sys.path.insert(1, '/lustre/scratch119/casm/team299ly/al15/projects/scHierarchy/')

import scanpy as sc
import anndata
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt 
import matplotlib as mpl

import cell2location
import scvi
import schierarchy

from matplotlib import rcParams
rcParams['pdf.fonttype'] = 42 # enables correct plotting of text
import seaborn as sns

In [None]:
#location of scRNA data
sc_data_folder = '/nfs/casm/team299ly/al15/projects/sc-breast/data_atlas/'
#location where the result is stored 
results_folder = '/nfs/casm/team299ly/al15/projects/sc-breast/data_atlas/results/'

#prefix for experiment
ref_run_name = f'{results_folder}hierarchical_logist/'

## Load Breast cancer scRNA dataset

In [None]:
## read data
adata_ref = anndata.read_h5ad(sc_data_folder + "atals_processed.h5ad")
adata_ref.layers['processed'] = adata_ref.X
#revert log transformation (if data is originally transformed)
adata_ref.X = ((adata_ref.layers['processed'].expm1() / 10000).multiply(adata_ref.obs[['nCount_RNA']].values)).tocsr() #real data 

# mitochondrial genes
adata_ref.var['mt'] = adata_ref.var_names.str.startswith('MT-') 
# ribosomal genes
adata_ref.var['ribo'] = adata_ref.var_names.str.startswith(("RPS","RPL"))
# hemoglobin genes.
adata_ref.var['hb'] = adata_ref.var_names.str.contains(("^HB[^(P)]"))

#delete ribo mt and hb genes 
adata_ref = adata_ref[:, np.logical_and(np.logical_and(~adata_ref.var['mt'], ~adata_ref.var['ribo']), ~adata_ref.var['hb'])]

### Process single cell data

In [None]:
# before we estimate the reference cell type signature we recommend to perform very permissive genes selection
# in this 2D histogram orange rectangle lays over excluded genes.
# In this case, the downloaded dataset was already filtered using this method,
# hence no density under the orange rectangle
from cell2location.utils.filtering import filter_genes
selected = filter_genes(adata_ref, cell_count_cutoff=5, cell_percentage_cutoff2=0.03, nonz_mean_cutoff=1.12)
adata_ref = adata_ref[:, selected].copy()

#remove genes which are omnispread
max_cutoff = (adata_ref.var['n_cells'] / adata_ref.n_obs) > 0.8
print(f'% of genes expressed in more than 80% of cells {max_cutoff.mean()}')

# filter the object
adata_ref = adata_ref[:, ~max_cutoff].copy()

In [None]:
%%time
#qunatile normalise log_transformed data, could be replaced with a transformation of your choice
from schierarchy.utils.data_transformation import data_to_zero_truncated_cdf
adata_ref.layers["cdf"] = np.apply_along_axis(
        data_to_zero_truncated_cdf, 0, adata_ref.layers["processed"].toarray()
)

## Initialise and run the model

In [None]:
from schierarchy import LogisticModel

#names of label columns from the most coarse to the most fine
level_keys = ['celltype_major', 'celltype_minor', 'celltype_subset']

LogisticModel.setup_anndata(adata_ref, layer="cdf")

In [None]:
# train regression model to get signatures of cell types
from schierarchy import LogisticModel
learning_mode = 'fixed-sigma'
mod = LogisticModel(adata_ref, level_keys=level_keys, laplace_learning_mode=learning_mode)

# Use all data for training (validation not implemented yet, train_size=1)
mod.train(max_epochs=600, batch_size=2500, train_size=1, lr=0.01, use_gpu=True)

# plot ELBO loss history during training, removing first 20 epochs from the plot
mod.plot_history(50)

In [None]:
%%time

# In this section, we export the estimated gene weights and per-cell probabilities 
# (summary of the posterior distribution).
adata_ref = mod.export_posterior(
    adata_ref, sample_kwargs={'num_samples': 50, 'batch_size': 2500, 'use_gpu': True}
)

# Save model
mod.save(f"{ref_run_name}", overwrite=True)

# Save anndata object with results
adata_file = f"{ref_run_name}/sc.h5ad"
adata_ref.write(adata_file)
adata_file

## Load model

In [None]:
#if you're not making predictions - just work with adata_file, it already has stored results 
model = LogisticModel.load(ref_run_name, adata_ref)
adata_ref = model.export_posterior(
    adata_ref, sample_kwargs={'num_samples': 50, 'batch_size': 2500, 'use_gpu': True}
)

## Visualise hierarchy of marker genes

In [None]:
adata_file = f"{ref_run_name}/sc.h5ad"
adata_ref = sc.read(adata_file)

In [None]:
#complete slected gene plots 
gene_names = adata_ref.var['gene_ids'].values
observed_labels = []

selected_dcit = {}
for level in level_keys:
    for i, name in enumerate(adata_ref.obs[level].cat.categories):
        weights = adata_ref.varm[f'means_weight_{level}'][f'means_weight_{level}_{name}'].values
        top_n = np.argpartition(weights, -3)[-3:]
        if name not in observed_labels:
            selected_dcit[name] = gene_names[top_n]
    fig = sc.pl.dotplot(adata_ref, selected_dcit, level, log=True, gene_symbols='gene_ids')

In [None]:
ind = adata_ref.obs[level_keys[0]].isin(['T-cells'])
adata_ref_subset = adata_ref[ind, :]

gene_names = adata_ref.var['gene_ids'].values
observed_labels = []


for level in level_keys:
    selected_dcit = {}
    for i, name in enumerate(adata_ref_subset.obs[level].cat.categories):
        weights = adata_ref_subset.varm[f'means_weight_{level}'][f'means_weight_{level}_{name}'].values
        top_n = np.argpartition(weights, -5)[-5:]
        if name not in observed_labels:
            selected_dcit[name] = gene_names[top_n]
        observed_labels.append(name)
    ind = adata_ref_subset.obs[level].isin(list(selected_dcit.keys()))
    adata_ref_subset_v2 = adata_ref_subset[ind, :]
    if adata_ref_subset_v2.n_obs > 0:
        fig = sc.pl.dotplot(adata_ref_subset_v2, selected_dcit, level, log=True, gene_symbols='gene_ids')