# Normalization and batch correction

## Settings

In [None]:
# Choose normalization method
norm_methods = ['tfidf', 'total']  # can be 'tfidf' and/or 'total'

# remove_pc1: if True, the first PC is removed from TFIDF-LSI normalization before calculating neighbors, 
# since first component correlates with number of features
log_normalize = True

# Highly Variable Features options 
min_cells = 5 # This one is mandatory
max_cells = None

# UMAP related settings 
metacol = 'sample'

# batch correction: If True, several batch correction methods will be performed,
# you can choose the best one after
batch_column = "sample"
perform_batch_correction = True
batch_methods = ["bbknn", "harmony"] # "mnn", "scanorama"
threads = 8

# save figures
save_figs = False

--------------

## Loading packages and setup

In [None]:
import sctoolbox.tools as tools
import sctoolbox.plotting as pl
import sctoolbox.utils as utils
import scanpy as sc
import episcanpy as epi
import numpy as np

utils.settings_from_config("config.yaml", key="03")

## Load anndata 

In [None]:
# probably updated in future
adata = utils.load_h5ad("anndata_2.h5ad")
display(adata)

In [None]:
adata.obs["Sample"] = "1"

## Find highly variable features

In [None]:
# update number of cells per feature
adata = tools.calculate_qc_metrics(adata, var_type='features')

# get highly variable features
tools.get_variable_features(adata, max_cells, min_cells)

In [None]:
#Number of variable genes selected
adata.var["highly_variable"].sum()

In [None]:
pl.violin_HVF_distribution(adata)

-----------

## Normalization

In [None]:
normalizations = tools.atac_norm(adata, norm_methods)

### Remove PC1 for TF-IDF normalization

### log normalize

In [None]:
if log_normalize:
    epi.pp.log1p(adata)

### Calc standard neighbors and umap for each adata

In [None]:
for adata in normalizations.values():
    sc.pp.neighbors(adata, n_neighbors=15, n_pcs=50, method='umap', metric='euclidean')

In [None]:
tools.wrap_umap(normalizations.values())

### Compare between normalizations

In [None]:
#Plot the overview of batch correction methods
_ = pl.anndata_overview(normalizations, 
                        plots=["PCA", "PCA-var", "UMAP"],
                        color_by=["n_features"], 
                        output=None)

In [None]:
# Choose the final normalization method
norm_method = "tfidf"

adata = normalizations[norm_method]

-----------

## Plot final PCA

In [None]:
sc.pl.pca_overview(adata, color=['nb_features'], show=False)

-----------

## Batch Correction

In [None]:
adata.obs[batch_column] = adata.obs[batch_column].astype("category") #ensure that batch column is a category

In [None]:
if perform_batch_correction:
    batch_corrections = tools.wrap_corrections(adata, 
                                                  batch_key=batch_column,
                                                  methods=batch_methods)
else:
    batch_corrections = {"uncorrected": adata}

In [None]:
#Run standard umap for all adatas
tools.wrap_umap(batch_corrections.values(), threads=threads)

In [None]:
# Should preliminary clustering be performed?
do_clustering = True #True or False

In [None]:
#Perform additional clustering if it was chosen
color_by = []
if do_clustering:
    for adata in batch_corrections.values():
        sc.tl.leiden(adata, 0.1)
    color_by.append("leiden")

In [None]:
# Calculate LISI scores for batch
tools.wrap_batch_evaluation(batch_corrections, batch_key=batch_column, threads=threads, inplace=True)

In [None]:
#Plot the overview of batch correction methods
_ = pl.anndata_overview(batch_corrections, color_by=color_by + [batch_column], 
                        output=None)

In [None]:
#Choose an anndata object to proceed
batch_name = "uncorrected"

adata_corrected = batch_corrections[batch_name]

---------

## Save anndata

In [None]:
#Saving the data
adata_output = "anndata_3.h5ad"
utils.save_h5ad(adata, adata_output)