# Normalization and batch correction

## Settings

In [None]:
# Path related settings (these should be the same as for the previous notebook)
output_dir = '/mnt/workspace/jdetlef/ext_ana/processed'
test = 'all'

In [None]:
# choose normalization method. If None, two normalization methods will be 
# performed and visualized with pca plot
norm_method='total'  # can be 'tfidf' or 'total'
# remove_pc1: if True, the first PC is removed from TFIDF-LSI normalization before calculating neighbors, 
# since first component correlates with number of features
log_normalize=True

# Highly Variable Features options 
min_cells = 5 # This one is mandatory
max_cells = None

# UMAP related settings 
metacol = 'Sample'

# batch correction: If True, several batch correction methods will be performed,
# you can choose the best one after
batch_column = "Sample"
perform_batch_correction = True
batch_methods = ["bbknn", "harmony"] # "mnn", "scanorama"
threads = 8

# save figures
save_figs = False

## Loading packages and setup

In [None]:
# sctoolbox modules
import sctoolbox.atac_tree as sub_tree
import sctoolbox.creators as cr
import sctoolbox.annotation as an
from sctoolbox.qc_filter import *
import sctoolbox.plotting as pl
from sctoolbox.atac_utils import *
from sctoolbox.analyser import *
import sctoolbox.atac as atac
# import episcanpy
import scanpy as sc
import episcanpy as epi
#from episcanpy.preprocessing import _decomposition
import numpy as np

## Setup path handling object 

In [None]:
# make an instance of the class
tree = sub_tree.ATAC_tree()
# set processing/output directory
tree.processing_dir = output_dir
# set sample/experiment.. 
tree.run = test

## Load anndata 

In [None]:
# probably updated in future
qc_output = tree.qc_anndata  # path to qc_adata should be here
adata = epi.read_h5ad(qc_output)
adata

## Find highly variable features

In [None]:
adata.var 

In [None]:
# update number of cells per feature
adata = analyser.calculate_qc_metrics(adata, var_type='features')
# get highly variable features
atac.get_variable_features(adata, max_cells, min_cells)

In [None]:
#Number of variable genes selected
adata.var["highly_variable"].sum()

In [None]:
adata.var

In [None]:
reset_cutoffs_dec = input('Do you want to change the cutoffs again? answer with yes or no: ')

In [None]:
if reset_cutoffs_dec.lower() == 'yes':
    min_cells = input('Enter the minimal number of cells per feature?: ')
    max_cells = input('Enter the maximum number of cells per feature?: ')
    min_cells = int(min_cells)
    max_cells = int(max_cells)
    adata.var["highly_variable"] = (adata.var['n_cells_by_counts'] <= max_cells) & (adata.var['n_cells_by_counts'] >= min_cells)
    print('Number of highly variable features: ' + str(adata.var["highly_variable"].sum()))

In [None]:
violin_HVF_distribution(adata)

## Normalization

In [None]:
# Consider using parts of sctoolbox.analyser

In [None]:
if norm_method == 'tfidf':
    print('Performing TFIDF and LSI...')
    atac.tfidf(adata)
    atac.lsi(adata)
    print('Done')
if norm_method == 'total':
    print('Performing total and log1p normalization...')
    sc.pp.normalize_total(adata)
    adata.layers['normalised'] = adata.X.copy()
    if log_normalize:
        epi.pp.log1p(adata)
    print('Done')
if not norm_method:
    adata_tfidf, adata_total = atac.atac_norm(adata)

In [None]:
if not norm_method:
    user_norm = input('Choose a normalization method (total or tfidf): ')
    if user_norm == 'total':
        adata = adata_total
    elif user_norm == 'tfidf':
        adata = adata_tfidf
else:
    user_norm = None
    
display(adata)

## PCA

In [None]:
if norm_method == 'total':
    print('Performing PCA')
    sc.pp.pca(adata, svd_solver='arpack', n_comps=50, use_highly_variable=True)
    print('Done')

### Plot PCA

In [None]:
if norm_method == 'tfidf' or user_norm == 'tfidf':
    # Change to module
    if save_figs:
        epi.pl.pca(adata, color=['nb_features'], show=False)
        #plt.savefig(f'{OUTPUT_FIGS}/pca_nb_features.png')
        plt.show()
    else:
        epi.pl.pca(adata, color=['nb_features'])
        
elif norm_method == 'total':
    if save_figs:
        epi.pl.pca_overview(adata, color=['nb_features'], show=False)
        #plt.savefig(f'{OUTPUT_FIGS}/pca_nb_features.png')
        plt.show()
    else:
        epi.pl.pca_overview(adata, color=['nb_features'])

## Calc Neighbours

In [None]:
if norm_method:
    if norm_method == 'tfidf' and remove_pc1:
        print('Calculating neighbors')
        sc.pp.neighbors(adata, n_neighbors=10, n_pcs=30, method='umap', metric='euclidean')
    else:
        print('Calculating neighbors')
        sc.pp.neighbors(adata, n_neighbors=15, n_pcs=50, method='umap', metric='euclidean')
    print('Done')

## UMAP

In [None]:
pl.search_umap_parameters(adata, 
                       dist_range=(0.1, 0.4, 0.1), 
                       spread_range=(2.0, 3.0, 0.5), 
                       metacol=metacol, 
                       n_components=2, 
                       verbose=True, 
                       threads=4, 
                       save=None)

## Batch Correction

In [None]:
adata.obs[batch_column] = adata.obs[batch_column].astype("category") #ensure that batch column is a category

In [None]:
if perform_batch_correction:
    batch_corrections = analyser.wrap_corrections(adata, 
                                              batch_key=batch_column,
                                              methods=batch_methods)
else:
    batch_corrections = {"uncorrected": adata}

In [None]:
#Run standard umap for all adatas
analyser.wrap_umap(batch_corrections.values(), threads=threads)

In [None]:
# Should preliminary clustering be performed?
do_clustering = True #True or False

In [None]:
#Perform additional clustering if it was chosen
color_by = []
if do_clustering:
    for adata in batch_corrections.values():
        sc.tl.leiden(adata, 0.1)
    color_by.append("leiden")

In [None]:
# Calculate LISI scores for batch
analyser.wrap_batch_evaluation(batch_corrections, batch_key=batch_column, threads=threads, inplace=True)

In [None]:
#Plot the overview of batch correction methods
_ = pl.anndata_overview(batch_corrections, color_by=color_by, 
                       output=tree.norm_correction_plots + "batch_correction_overview.pdf")

In [None]:
#Choose an anndata object to proceed
batch_name = input('Choose an anndata object to proceed. Type the name of the batch correction or uncorrected: ')
try:
    adata_corrected = batch_corrections[batch_name]
except:
    adata_corrected = batch_corrections['uncorrected']

## save anndata

In [None]:
adata_output = tree.norm_correction_anndata
adata_corrected.write(filename=adata_output)

In [None]:
import os
import shutil
repo_path = os.getcwd()
notebook_name = '3_normalization_batch_correction.ipynb'
notebook_path = os.path.join(repo_path, notebook_name)
notebook_copy = os.path.join(tree.norm_correction_dir , notebook_name)
shutil.copyfile(notebook_path, notebook_copy)