In [None]:
import sctoolbox
from sctoolbox.utilities import bgcolor

# Batch effect correction and comparisons
<hr style="border:2px solid black"> </hr>

<h1><center>⬐ Fill in input data here ⬎</center></h1>

In [None]:
%bgcolor PowderBlue

# Choose normalization method
norm_methods = ['tfidf', 'total']  # can be 'tfidf' and/or 'total'

# Set number of PCs
subset_pcs = True  #True or False
n_pcs = 10  #If None, PCs are chosen automatically. Set n_pcs to overwrite the automatic selection of PCs.

# Set number of neighbors
n_neighbors = 15  #Default=15

# UMAP related settings 
metacol = 'sample'

# batch correction: If True, several batch correction methods will be performed,
# you can choose the best one after
batch_column = "sample"
perform_batch_correction = True
batch_methods = ["bbknn", "harmony"] # "mnn", "scanorama"
threads = 8

<hr style="border:2px solid black"> </hr>

## Set up

In [None]:
import sctoolbox.tools as tools
import sctoolbox.plotting as pl
import sctoolbox.utils as utils
import scanpy as sc
import episcanpy as epi
import numpy as np
import matplotlib.pyplot as plt
utils.settings_from_config("config.yaml", key="03")

sc.set_figure_params(vector_friendly=True, dpi_save=600, scanpy=False)

----------------

## Load anndata 

In [None]:
# probably updated in future
adata = utils.load_h5ad("anndata_2.h5ad")
display(adata)

----------

## Normalization

In [None]:
normalizations = tools.normalize_adata(adata, norm_methods)

--------

## Show PC embedding per method

In [None]:
_ = pl.compare_embeddings(list(normalizations.values()), adata_names=list(normalizations.keys()),
                          var_list=[batch_column], embedding="pca")

In [None]:
# PCA correlations with obs and var variables
for method, adata_norm in normalizations.items():
    
    _ = pl.plot_pca_correlation(adata_norm, which="obs", 
                                title=f"Normalization method = {method}\nCorrelation of .obs columns with PCA loadings",
                                save=f"PCA_{method}_correlation_obs.pdf")
    
    _ = pl.plot_pca_correlation(adata_norm, which="var", 
                                title=f"Normalization method = {method}\nCorrelation of .var columns with PCA loadings",
                                save=f"PCA_{method}_correlation_var.pdf")

---------------------

## Remove PC1 for TF-IDF normalization

In [None]:
if 'tfidf' in norm_methods:
    normalizations["tfidf-noPC1"] = normalizations["tfidf"].copy()
    tools.subset_PCA(normalizations["tfidf-noPC1"], 50, start=1)
if 'total' in norm_methods:
    normalizations["total-noPC1"] = normalizations["total"].copy()
    tools.subset_PCA(normalizations["total-noPC1"], 50, start=1)

In [None]:
if subset_pcs:
    for key, adata_normed in normalizations.items():
    
            #Automatically identify n_pcs 
        if n_pcs is None:
            n_pcs = tools.define_PC(adata_normed)

        # Plot and select number of PCs
        _, ax = plt.subplots()
        ax.set_title(key)
        _ = pl.plot_pca_variance(adata_normed, save=key + "_PCA_variance_selected.pdf", n_pcs=50, n_selected=n_pcs, ax=ax)
        tools.subset_PCA(adata_normed, start=0, n_pcs=n_pcs)
        
        normalizations[key] = adata_normed

--------

## Calc standard neighbors and umap for each adata

In [None]:
for adata in normalizations.values():
    sc.pp.neighbors(adata, n_neighbors=n_neighbors, method='umap', metric='euclidean')

In [None]:
tools.wrap_umap(normalizations.values())

------------------------

## Compare between normalizations

In [None]:
#Plot the overview of batch correction methods
_ = pl.anndata_overview(normalizations, 
                        plots=["PCA", "PCA-var", "UMAP"],
                        color_by=["n_features"], 
                        output=None)

In [None]:
# Choose the final normalization method
norm_method = "tfidf-noPC1"

adata = normalizations[norm_method]

-----------

## Plot final PCA

In [None]:
sc.pl.pca_overview(adata, color=['n_features'], show=False)

_ = pl.plot_pca_correlation(adata, which="var", 
                            title=f"Normalization method = {method}\nCorrelation of .var columns with PCA loadings",
                            save=f"PCA_{method}_correlation_var.pdf")

-----------

## Batch correction

In [None]:
if perform_batch_correction:
    batch_corrections = tools.wrap_corrections(adata, 
                                                  batch_key=batch_column,
                                                  methods=batch_methods)
else:
    batch_corrections = {"uncorrected": adata}

In [None]:
#Run standard umap for all adatas
tools.wrap_umap(batch_corrections.values(), threads=threads)

In [None]:
# Should preliminary clustering be performed?
do_clustering = True #True or False

In [None]:
#Perform additional clustering if it was chosen
color_by = []
if do_clustering:
    for adata in batch_corrections.values():
        sc.tl.leiden(adata, 0.1)
    color_by.append("leiden")

In [None]:
# Calculate LISI scores for batch
tools.wrap_batch_evaluation(batch_corrections, batch_key=batch_column, threads=threads, inplace=True)

In [None]:
#Plot the overview of batch correction methods
adata.obs[batch_column] = adata.obs[batch_column].astype("category") #ensure that batch column is a category

_ = pl.anndata_overview(batch_corrections, color_by=color_by + [batch_column], 
                        output=None)

In [None]:
#Choose an anndata object to proceed
batch_name = "bbknn"

adata_corrected = batch_corrections[batch_name]

---------

## Save anndata

In [None]:
#Saving the data
adata_output = "anndata_3.h5ad"
utils.save_h5ad(adata, adata_output)