### AntiSplodge minimal example
In this short and minimal tutorial we are going to use the "global" dataset located at https://www.heartcellatlas.org/.
We are going to deconvolute major heart-based cell types to check the JSD of the method.

You can download the dataset direcly from https://cellgeni.cog.sanger.ac.uk/heartcellatlas/data/global_raw.h5ad.

First: 
Import the packages we need for this, remember to install them with "pip install PACKAGE" in your terminal. 

In [1]:
import anndata as ann
import scanpy as sc
import numpy as np

import antisplodge as AS

from collections import Counter

After importing the required packages, load the dataset into memory, remember to have the dataset in the same folder as the notebook. 

In [2]:
SC = ann.read("../../data/reference/hca_harvard_gender_Female.sc.h5ad")

In [3]:
SC.obs

Unnamed: 0,NRP,age_group,cell_source,cell_type,donor,gender,n_counts,n_genes,percent_mito,percent_ribo,...,Used,n_genes_by_counts,log1p_n_genes_by_counts,total_counts,log1p_total_counts,pct_counts_in_top_50_genes,pct_counts_in_top_100_genes,pct_counts_in_top_200_genes,pct_counts_in_top_500_genes,.split
AAACCCAAGCTACTGT-1-H0015_apex,No,50-55,Harvard-Nuclei,Ventricular_Cardiomyocyte,H5,Female,3182.0,1521,0.000943,0.002200,...,Yes,1521,7.327781,3182.0,8.065579,35.575110,42.520427,51.571339,67.913262,train
AAACCCAGTACCGCGT-1-H0015_apex,No,50-55,Harvard-Nuclei,Pericytes,H5,Female,1202.0,726,0.000832,0.000000,...,Yes,726,6.588926,1202.0,7.092574,39.933444,47.920133,56.239601,81.198003,train
AAACCCATCAAACCCA-1-H0015_apex,No,50-55,Harvard-Nuclei,Ventricular_Cardiomyocyte,H5,Female,3804.0,1584,0.000263,0.001314,...,Yes,1584,7.368340,3804.0,8.244071,42.245005,48.475289,56.729758,71.503680,train
AAACCCATCGCAACAT-1-H0015_apex,No,50-55,Harvard-Nuclei,Pericytes,H5,Female,3529.0,1796,0.001700,0.001133,...,Yes,1796,7.493874,3529.0,8.169053,30.830264,37.432700,46.188722,63.247379,train
AAACCCATCTGGTCAA-1-H0015_apex,No,50-55,Harvard-Nuclei,Ventricular_Cardiomyocyte,H5,Female,3906.0,1677,0.000768,0.001024,...,Yes,1677,7.425358,3906.0,8.270525,40.117768,46.313364,54.557092,69.866871,train
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
TTTGATCTCCAATCCC-1-H0035_septum,No,45-50,Harvard-Nuclei,Ventricular_Cardiomyocyte,H7,Female,1615.0,941,0.000000,0.000619,...,Yes,941,6.848005,1615.0,7.387709,29.349845,38.823529,51.207430,72.693498,train
TTTGGAGAGCATCCTA-1-H0035_septum,No,45-50,Harvard-Nuclei,Smooth_muscle_cells,H7,Female,872.0,595,0.000000,0.001147,...,Yes,595,6.390241,872.0,6.771935,30.733945,42.201835,54.701835,89.105505,train
TTTGGAGGTCTAGGCC-1-H0035_septum,No,45-50,Harvard-Nuclei,Pericytes,H7,Female,806.0,608,0.002481,0.002481,...,Yes,608,6.411818,806.0,6.693324,27.047146,36.972705,49.379653,86.600496,train
TTTGGTTTCAAGGCTT-1-H0035_septum,No,45-50,Harvard-Nuclei,Ventricular_Cardiomyocyte,H7,Female,2470.0,1286,0.000405,0.000810,...,Yes,1286,7.160069,2470.0,7.812378,31.578947,39.230769,49.473684,68.178138,train


Let us look at the number of cells remaining (N=163959)

In [6]:
print(SC)

AnnData object with n_obs × n_vars = 44898 × 33538
    obs: 'NRP', 'age_group', 'cell_source', 'cell_type', 'donor', 'gender', 'n_counts', 'n_genes', 'percent_mito', 'percent_ribo', 'region', 'sample', 'scrublet_score', 'source', 'type', 'version', 'cell_states', 'Used', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_50_genes', 'pct_counts_in_top_100_genes', 'pct_counts_in_top_200_genes', 'pct_counts_in_top_500_genes', '.split'
    var: 'gene_ids-Harvard-Nuclei', 'feature_types-Harvard-Nuclei', 'gene_ids-Sanger-Nuclei', 'feature_types-Sanger-Nuclei', 'gene_ids-Sanger-Cells', 'feature_types-Sanger-Cells', 'gene_ids-Sanger-CD45', 'feature_types-Sanger-CD45', 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts'
    uns: 'cell_type_colors'
    obsm: 'X_pca', 'X_umap'


### Find marker genes
Because of the high number of genes (N=33538), we want to reduce the set of genes to a smaller one in order to speed up training and reduce memory footprint. But it would work with all the genes, but the time to get an as good model will be a lot higher. 

However, before we look for marker genes, we scale all profiles to 1 using scanpy's function normalize_total, as shown below:

In [7]:
SC.layers["norm"] = sc.pp.normalize_total(SC, target_sum=1, inplace=False)["X"]

In [9]:
#
# Find the top 'key' marker genes that was found for each cell type during the univariate t-test analysis 
#
# usually you should use method='logreg', but t-test is faster for demonstration purpose
sc.tl.rank_genes_groups(SC, groupby='cell_type', method='t-test', layer="norm", use_raw=False, key_added='ranks')

In [12]:
#
# Get the corresponding gene sets
#
def getGenes(adata, key, ct, min_genes, score_threshold=0.01):
    genes = []
    
    # get the N most correlated genes for each cell type
    for i, cell_ in enumerate(ct):
        
        # find the number of genes to include
        index = min_genes 
        while adata.uns[key]['scores'][cell_][index] > score_threshold: # hardcoded 0.01 inclusion 
            index += 1
        
        genes_ = adata.uns[key]['names'][cell_][0:index]
        scores = adata.uns[key]['scores'][cell_][0:index]
        
        print(cell_, len(genes_))
        
        genes.extend(genes_)
        
    np_genes = np.unique(np.array(genes))
    print("Length of unique genes:",len(np_genes))
    
    return np_genes

# use the top 50 for cell_type and allow some threshold to select the rest, 
# only ventricular cardiomyocytes have genes with scores above this threshold
gene_set = getGenes(SC, 'ranks', np.unique(SC.obs['cell_type']), 50, 50)

Adipocytes 50
Atrial_Cardiomyocyte 50
Endothelial 50
Fibroblast 50
Lymphoid 50
Myeloid 50
Neuronal 50
Pericytes 76
Smooth_muscle_cells 50
Ventricular_Cardiomyocyte 344
Length of unique genes: 736


In [13]:
SC = SC[:,gene_set] # filter SC

We need to scale again after removing unwanted genes, in order to again, have profiles of equal counts!

In [14]:
sc.pp.normalize_total(SC, target_sum=1)

  view_to_actual(adata)


In [15]:
# the profiles are now much smaller which will reduce memory and speed up traininig time
print(SC)

AnnData object with n_obs × n_vars = 44898 × 736
    obs: 'NRP', 'age_group', 'cell_source', 'cell_type', 'donor', 'gender', 'n_counts', 'n_genes', 'percent_mito', 'percent_ribo', 'region', 'sample', 'scrublet_score', 'source', 'type', 'version', 'cell_states', 'Used', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_50_genes', 'pct_counts_in_top_100_genes', 'pct_counts_in_top_200_genes', 'pct_counts_in_top_500_genes', '.split'
    var: 'gene_ids-Harvard-Nuclei', 'feature_types-Harvard-Nuclei', 'gene_ids-Sanger-Nuclei', 'feature_types-Sanger-Nuclei', 'gene_ids-Sanger-Cells', 'feature_types-Sanger-Cells', 'gene_ids-Sanger-CD45', 'feature_types-Sanger-CD45', 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts'
    uns: 'cell_type_colors', 'ranks'
    obsm: 'X_pca', 'X_umap'
    layers: 'norm'


### AntiSplodge experiment
We then setup the AntiSplodge experiment, in this step we do:

1) Create a new experiment by passing SC to the AntiSplodge 'DeconvolutionExperiment' function.
2) We set the cell type column to 'cell_type' in the dataset
3) We split the dataset into 80% train, 10% validation, and, 10% test. 

In [16]:
experiment = AS.DeconvolutionExperiment(SC)

experiment.setVerbosity(True)

# CELLTYPE_COLUMN should be replaced with actual column
experiment.setCellTypeColumn('cell_type')
# Use 80% as train data and split the rest into a 50/50 split validation and testing
experiment.splitTrainTestValidation(train=0.8, rest=0.5)

Then we generate 500.000 training, 10.000 validation, and, 10.000 test samples. 
And load these into data loaders for use in training of the neural network. 

In [17]:
experiment.generateTrainTestValidation(num_profiles=[500000,10000,10000], CD=[10,10])
# Load the profiles into data loaders
experiment.setupDataLoaders(batch_size=2500)

GENERATING PROFILES
GENERATING TRAIN DATASET (N=500000)
GENERATING VALIDATION DATASET (N=10000)
GENERATING TEST DATASET (N=10000)


We then define the model (with default values) and use the first cuda device (cuda_id=1).

In [18]:
# Initialize Neural network-model and allocate it to the cuda_id specified
# Use 'cuda_id="cpu"' if you want to allocate it to a cpu
experiment.setupModel()

(CUDA) device is: cpu


We then train the network with 25 warm restarts, this means whenever we finish a training session we load the best model settings we have found so far back onto the model and continue from there in the next setting. 

In [21]:
# do 25 warm restarts with decreasing learning rate
print("Training experiment for dataset")
lr = 0.001
best_error=None # no target error to beat in the beginning
for k in range(25):

    # Consider changing learning rate (lr) during run more dynamically
    if k >= 5:
        lr = 0.0005
    if k >= 10:
        lr = 0.0001
    if k >= 15:
        lr = 0.00005
    experiment.setupOptimizerAndCriterion(learning_rate = lr)
    # Train the experiment constructed by passing the experiment to the AntiSplodge training function
    AS.train(experiment, save_file="ModelCheckpoint.pt", patience=5, best_loss=best_error) # For longer training, increase patience threshold
    best_error = AS.getMeanJSD(experiment, "validation") # set best error as the target error to beat

    print("Restart [{}] - JSDs".format(k),AS.getMeanJSD(experiment, "train"), AS.getMeanJSD(experiment, "validation"))

Training experiment for dataset
Epoch: 001 | Epochs since last increase: 000
Loss: (Train) 0.01422 | (Valid): 0.01678 - (mean)JSD: (Train) 0.08914 | (Valid) 0.10575 

Epoch: 002 | Epochs since last increase: 001
Loss: (Train) 0.01326 | (Valid): 0.01733 - (mean)JSD: (Train) 0.09444 | (Valid) 0.11322 

Epoch: 003 | Epochs since last increase: 000 | Better solution found
Loss: (Train) 0.01290 | (Valid): 0.01578 - (mean)JSD: (Train) 0.08748 | (Valid) 0.10299 

Epoch: 004 | Epochs since last increase: 000 | Better solution found
Loss: (Train) 0.01271 | (Valid): 0.01517 - (mean)JSD: (Train) 0.08233 | (Valid) 0.10082 

Epoch: 005 | Epochs since last increase: 000 | Better solution found


KeyboardInterrupt: 

In [35]:
for dat in dataset_strings:
    print(dat, "Test mean JSDs:", "{:2f}%".format(AS.getMeanJSD(experiments[dat], "test")*100))

Harvard-Nuclei Test mean JSDs: 8.354745%
Sanger-Nuclei Test mean JSDs: 9.354040%
Sanger-Cells Test mean JSDs: 9.200285%
Sanger-CD45 Test mean JSDs: 9.146629%


In [51]:
for dat in dataset_strings:
    experiments[dat].SC_train.write("{}_SC_train.h5ad".format(dat))
    experiments[dat].SC_val.write("{}_SC_val.h5ad".format(dat))
    experiments[dat].SC_test.write("{}_SC_test.h5ad".format(dat))
    
    np.savez_compressed(
        "{}_HCA_Dataset_for_Fabian.npz".format(dat),
        #X_train=np.array(experiments[dat].X_train),
        #X_val=np.array(experiments[dat].X_val),
        ST_X_test=np.array(experiments[dat].X_test),
        
        #Y_train=np.array(experiments[dat].Y_train_prop),
        #Y_val=np.array(experiments[dat].Y_val_prop),
        ST_Y_test=np.array(experiments[dat].Y_test_prop),
        
        genes=gene_sets[dat],
    )