### AntiSplodge Reference Training
After generating all of the dataset variants, this notebook is used to train AntiSplodge for the reference datasets.

This is based on the [miminal AntiSplodge example](https://github.com/HealthML/AntiSplodge/blob/52e71a1e40b4161926ebbe45f8e18fc02de69bdc/AntiSplodge_minimal_example.ipynb) provided by the authors.

In [25]:
import anndata as ann
import scanpy as sc
import numpy as np
import pickle

import antisplodge as AS

from collections import Counter

After importing the required packages, load the dataset into memory, remember to have the dataset in the same folder as the notebook. 

In [2]:
REFERENCE_PATH = f"../../data/reference"
REFERENCE_DATASETS = {
    "Harvard_gender_female_full": f"{REFERENCE_PATH}/hca_harvard_gender_Female.sc.h5ad",
    "Harvard_gender_female_V1": f"{REFERENCE_PATH}/hca_harvard_gender_Female_-muscles.sc.h5ad",
    "Harvard_gender_female_V2": f"{REFERENCE_PATH}/hca_harvard_gender_Female_-endothelial.sc.h5ad",
    "Sanger_gender_female_full": f"{REFERENCE_PATH}/hca_sanger_gender_Female.sc.h5ad",
    "Sanger_gender_female_V1": f"{REFERENCE_PATH}/hca_sanger_gender_Female_-muscles.sc.h5ad",
    "Sanger_gender_female_V2": f"{REFERENCE_PATH}/hca_sanger_gender_Female_-endothelial.sc.h5ad",
}

datasets = {}

for name, path in REFERENCE_DATASETS.items():
    datasets[name] = ann.read(path)

In [3]:
for dat, dataset in datasets.items():
    print(dat, dataset.shape)

Harvard_gender_female_full (44898, 33538)
Harvard_gender_female_V1 (21169, 33538)
Harvard_gender_female_V2 (41957, 33538)
Sanger_gender_female_full (36966, 33538)
Sanger_gender_female_V1 (20159, 33538)
Sanger_gender_female_V2 (34174, 33538)


### Find marker genes
Because of the high number of genes (N=33538), we want to reduce the set of genes to a smaller one in order to speed up training and reduce memory footprint. But it would work with all the genes, but the time to get an as good model will be a lot higher. 

However, before we look for marker genes, we scale all profiles to 1 using scanpy's function normalize_total, as shown below:

In [4]:
for dataset in datasets.values():
    sc.pp.normalize_total(dataset, target_sum=1)

In [5]:
#
# Find the top 'key' marker genes that was found for each cell type during the univariate t-test analysis 
#
# usually you should use method='logreg', but t-test is faster for demonstration purpose
for dataset in datasets.values():
    sc.tl.rank_genes_groups(dataset, groupby='cell_type', method='t-test', use_raw=False, key_added='ranks')

In [6]:
#
# Get the corresponding gene sets
#
def getGenes(adata, key, ct, min_genes, score_threshold=0.01):
    genes = []
    
    # get the N most correlated genes for each cell type
    for i, cell_ in enumerate(ct):
        
        # find the number of genes to include
        index = min_genes 
        while adata.uns[key]['scores'][cell_][index] > score_threshold: # hardcoded 0.01 inclusion 
            index += 1
        
        genes_ = adata.uns[key]['names'][cell_][0:index]
        scores = adata.uns[key]['scores'][cell_][0:index]
        
        print(cell_, len(genes_))
        
        genes.extend(genes_)
        
    np_genes = np.unique(np.array(genes))
    print("Length of unique genes:",len(np_genes))
    
    return np_genes

# use the top 50 for cell_type and allow some threshold to select the rest, 
# only ventricular cardiomyocytes have genes with scores above this threshold
gene_sets = {}
for dat, dataset in datasets.items():
    gene_sets[dat] = getGenes(dataset, 'ranks', np.unique(dataset.obs['cell_type']), 50, 50)

Adipocytes 50
Atrial_Cardiomyocyte 50
Endothelial 50
Fibroblast 50
Lymphoid 50
Myeloid 50
Neuronal 50
Pericytes 76
Smooth_muscle_cells 50
Ventricular_Cardiomyocyte 344
Length of unique genes: 736
Adipocytes 50
Endothelial 50
Fibroblast 50
Lymphoid 50
Myeloid 50
Neuronal 50
Pericytes 50
Smooth_muscle_cells 50
Length of unique genes: 378
Adipocytes 50
Atrial_Cardiomyocyte 50
Fibroblast 50
Lymphoid 50
Myeloid 50
Neuronal 50
Pericytes 81
Smooth_muscle_cells 50
Ventricular_Cardiomyocyte 334
Length of unique genes: 698
Adipocytes 50
Atrial_Cardiomyocyte 50
Endothelial 50
Fibroblast 50
Lymphoid 50
Mesothelial 50
Myeloid 50
Neuronal 50
Pericytes 50
Smooth_muscle_cells 50
Ventricular_Cardiomyocyte 250
Length of unique genes: 670
Adipocytes 50
Endothelial 50
Fibroblast 50
Lymphoid 50
Mesothelial 50
Myeloid 50
Neuronal 50
Pericytes 50
Smooth_muscle_cells 50
Length of unique genes: 416
Adipocytes 50
Atrial_Cardiomyocyte 50
Fibroblast 50
Lymphoid 50
Mesothelial 50
Myeloid 50
Neuronal 50
Pericytes 5

In [7]:
for dat in list(datasets.keys()):
    datasets[dat] = datasets[dat][:,gene_sets[dat]] # filter SC

We need to scale again after removing unwanted genes, in order to again, have profiles of equal counts!

In [8]:
for dataset in datasets.values():
    sc.pp.normalize_total(dataset, target_sum=1)

  view_to_actual(adata)


In [9]:
# the profiles are now much smaller which will reduce memory and speed up traininig time
for dat, dataset in datasets.items():
    print(dat, dataset.shape)

Harvard_gender_female_full (44898, 736)
Harvard_gender_female_V1 (21169, 378)
Harvard_gender_female_V2 (41957, 698)
Sanger_gender_female_full (36966, 670)
Sanger_gender_female_V1 (20159, 416)
Sanger_gender_female_V2 (34174, 628)


### AntiSplodge experiment
We then setup the AntiSplodge experiment, in this step we do:

1) Create a new experiment by passing SC to the AntiSplodge 'DeconvolutionExperiment' function.
2) We set the cell type column to 'cell_type' in the dataset
3) We split the dataset into 80% train, 10% validation, and, 10% test. 

In [11]:
experiments = {}

for dat, dataset in datasets.items():
    experiment = AS.DeconvolutionExperiment(dataset)
    experiment.setVerbosity(True)

    # CELLTYPE_COLUMN should be replaced with actual column
    experiment.setCellTypeColumn('cell_type')
    # Use 80% as train data and split the rest into a 50/50 split validation and testing
    experiment.splitTrainTestValidation(train=0.8, rest=0.5)
    experiments[dat] = experiment

Then we generate 500.000 training, 10.000 validation, and, 10.000 test samples. 
And load these into data loaders for use in training of the neural network. 

In [13]:
for experiment in experiments.values():
    experiment.generateTrainTestValidation(num_profiles=[500000,10000,10000], CD=[10,10])
    # Load the profiles into data loaders
    experiment.setupDataLoaders(batch_size=2500)

GENERATING PROFILES
GENERATING TRAIN DATASET (N=500000)
GENERATING VALIDATION DATASET (N=10000)
GENERATING TEST DATASET (N=10000)
GENERATING PROFILES
GENERATING TRAIN DATASET (N=500000)
GENERATING VALIDATION DATASET (N=10000)
GENERATING TEST DATASET (N=10000)
GENERATING PROFILES
GENERATING TRAIN DATASET (N=500000)
GENERATING VALIDATION DATASET (N=10000)
GENERATING TEST DATASET (N=10000)
GENERATING PROFILES
GENERATING TRAIN DATASET (N=500000)
GENERATING VALIDATION DATASET (N=10000)
GENERATING TEST DATASET (N=10000)
GENERATING PROFILES
GENERATING TRAIN DATASET (N=500000)
GENERATING VALIDATION DATASET (N=10000)
GENERATING TEST DATASET (N=10000)
GENERATING PROFILES
GENERATING TRAIN DATASET (N=500000)
GENERATING VALIDATION DATASET (N=10000)
GENERATING TEST DATASET (N=10000)


We then define the model (with default values) and use the first cuda device (cuda_id=1).

In [14]:
# Initialize Neural network-model and allocate it to the cuda_id specified
# Use 'cuda_id="cpu"' if you want to allocate it to a cpu
for experiment in experiments.values():
    experiment.setupModel(cuda_id=0)

(CUDA) device is: cuda:0
(CUDA) device is: cuda:0
(CUDA) device is: cuda:0
(CUDA) device is: cuda:0
(CUDA) device is: cuda:0
(CUDA) device is: cuda:0


We then train the network with 25 warm restarts, this means whenever we finish a training session we load the best model settings we have found so far back onto the model and continue from there in the next setting. 

In [33]:
for dat, experiment in experiments.items():
    # do 25 warm restarts with decreasing learning rate
    print("Training experiment for dataset")
    experiment.setVerbosity(False)
    lr = 0.001
    best_error=None # no target error to beat in the beginning
    for k in range(25):

        # Consider changing learning rate (lr) during run more dynamically
        if k >= 5:
            lr = 0.0005
        if k >= 10:
            lr = 0.0001
        if k >= 15:
            lr = 0.00005
        experiment.setupOptimizerAndCriterion(learning_rate = lr)
        # Train the experiment constructed by passing the experiment to the AntiSplodge training function
        AS.train(experiment, save_file=f"{dat}.model.pt", patience=5, best_loss=best_error) # For longer training, increase patience threshold
        best_error = AS.getMeanJSD(experiment, "validation") # set best error as the target error to beat

        print("Restart [{}] - JSDs".format(k), AS.getMeanJSD(experiment, "validation"))
    experiment.setVerbosity(True)

Training experiment for dataset
Finished training (checkpoint saved in: Harvard_gender_female_full.model.pt)
Time elapsed: 62.10 (1.03 Minutes)
Autoloading best parameters onto model (auto_load_model_on_finish==True)
Restoring checkpoint: Harvard_gender_female_full.model.pt
Restart [0] - JSDs 0.11716707077861588
Finished training (checkpoint saved in: Harvard_gender_female_full.model.pt)
Time elapsed: 55.64 (0.93 Minutes)
Autoloading best parameters onto model (auto_load_model_on_finish==True)
Restoring checkpoint: Harvard_gender_female_full.model.pt
Restart [1] - JSDs 0.11581717265608009
Finished training (checkpoint saved in: Harvard_gender_female_full.model.pt)
Time elapsed: 37.35 (0.62 Minutes)
Autoloading best parameters onto model (auto_load_model_on_finish==True)
Restoring checkpoint: Harvard_gender_female_full.model.pt
Restart [2] - JSDs 0.11581717265608009
Finished training (checkpoint saved in: Harvard_gender_female_full.model.pt)
Time elapsed: 37.14 (0.62 Minutes)
Autoloadin

In [34]:
for dat, experiment in experiments.items():
    print(dat, "Test mean JSDs:", "{:2f}%".format(AS.getMeanJSD(experiment, "test")*100))

Harvard_gender_female_full Test mean JSDs: 10.627662%
Harvard_gender_female_V1 Test mean JSDs: 9.920403%
Harvard_gender_female_V2 Test mean JSDs: 9.255984%
Sanger_gender_female_full Test mean JSDs: 12.768898%
Sanger_gender_female_V1 Test mean JSDs: 16.391650%
Sanger_gender_female_V2 Test mean JSDs: 14.407963%


In [35]:
# Export gene and cell type names as signature for reuse
for dat in datasets.keys():
    signature = (list(datasets[dat].var_names), list(experiments[dat].celltypes))
    with open(f"{dat}.signature.pickle", "wb") as f:
        pickle.dump(signature, f)