In [None]:
from sctoolbox.utils.jupyter import bgcolor, _compare_version
selected_pcs = None # set as default to prevent error in init cell

nb_name = "03_batch.ipynb"

_compare_version(nb_name)

# Batch effect correction and comparisons
<hr style="border:2px solid black"> </hr>

<h1><center>⬐ Fill in input data here ⬎</center></h1>

In [None]:
%bgcolor PowderBlue

# Set the species of the data
species = "zebrafish"

# Set number of cores to use for multiprocessing
threads = 4

# Options for highly variable genes
min_limit = 1000
max_limit = 5000

# Set number of PCs
subset_pcs = True  #True or False
n_pcs = None  #If None, PCs are chosen automatically. Set n_pcs to overwrite the automatic selection of PCs.

# Set number of neighbors
n_neighbors = 15  #Default=15

# Should preliminary clustering be performed?
do_clustering = True #True or False

# Options for batch correction
batch_column = "batch"  #a column in adata.obs containing batch information
perform_batch_correction = True
batch_methods = ["bbknn", "mnn", "harmony", "scanorama"] # , combat (excluded in this example due to slow runtime)

<hr style="border:2px solid black"> </hr>

## Set up

In [None]:
import pandas as pd
import scanpy as sc

import sctoolbox
import sctoolbox.utils as utils
import sctoolbox.tools as tools
import sctoolbox.plotting as pl

sctoolbox.settings.settings_from_config("config.yaml", key="03")

# Set additional options for figures
sc.set_figure_params(vector_friendly=True, dpi_save=600, scanpy=False)

-------

## Loading the anndata

In [None]:
adata = utils.adata.load_h5ad("anndata_2.h5ad")

with pd.option_context("display.max.rows", 5, "display.max.columns", None):
    display(adata)
    display(adata.obs)
    display(adata.var)

In [None]:
adata.obs[batch_column] = adata.obs[batch_column].astype(str).astype("category") #ensure that batch column is a category

---------

## Plot highly expressed genes

In [None]:
sc.pl.highest_expr_genes(adata, show=False)
pl._save_figure("highly_expressed.pdf")

In [None]:
exclude_highly_expressed = True

-------------

## Normalization

In [None]:
# Save raw layer before normalization
adata.layers["raw"] = adata.X.copy()

In [None]:
adata = tools.normalize_adata(adata, method="total", target_sum=None, exclude_highly_expressed=exclude_highly_expressed)["total"]

----------

## Predict Cell Cycle
Predict the division phase of each cell.

In [None]:
tools.predict_cell_cycle(adata, species=species, s_genes=None, g2m_genes=None, inplace=True)
utils.add_uns_info(adata, "obs_metrics", ["phase"], how="append")

-----------

## Find highly variable genes

In [None]:
tools.annot_HVG(adata, hvg_range=(min_limit, max_limit), save="highly_variable.pdf")

In [None]:
#Number of variable genes selected
adata.var["highly_variable"].sum()

---------

## PCA and neighbors for uncorrected data

In [None]:
sc.pp.pca(adata, svd_solver='arpack', use_highly_variable=True)

In [None]:
_ = pl.plot_pca_variance(adata, save="PCA_variance.pdf", n_pcs=50)

In [None]:
# Plot QC variables on the PCA embedding to show potential correlations
sc.pl.pca(adata, color=list(adata.uns["sctoolbox"]["obs_metrics"]) + [batch_column], ncols=3, show=False)
pl._save_figure("PCA_embedding.pdf")

In [None]:
# PCA correlations with obs variables 
_ = pl.plot_pca_correlation(adata, which="obs", title="Correlation of .obs columns with PCA loadings",
                            save="PCA_correlation_obs.pdf")

In [None]:
# PCA correlations with var variables
_ = pl.plot_pca_correlation(adata, which="var", title="Correlation of .var columns with PCA loadings",
                            save="PCA_correlation_var.pdf")

### Choose a subset of PCs (optional)

In [None]:
# Subset the number of pcs if chosen in the parameters
if subset_pcs:
    
    #Automatically identify n_pcs 
    if n_pcs is None:
        n_pcs = tools.define_PC(adata)
    
    # Plot and select number of PCs
    _ = pl.plot_pca_variance(adata, save="PCA_variance_selected.pdf", n_pcs=50, n_selected=n_pcs)
    tools.subset_PCA(adata, start=0, n_pcs=n_pcs)

In [None]:
# exclude first PC
adata.obsm["X_pca"] = adata.obsm["X_pca"][:, 1:]
adata.varm["PCs"] = adata.varm["PCs"][:, 1:]

### Calculate neighbors

In [None]:
sc.pp.neighbors(adata, n_neighbors=n_neighbors)

------------

## Batch correction (optional)

In [None]:
if perform_batch_correction:
    batch_corrections = tools.wrap_corrections(adata, 
                                              batch_key=batch_column,
                                              methods=batch_methods)
else:
    batch_corrections = {"uncorrected": adata}

### Plot overview of batch corrections

In [None]:
#Run standard umap for all adatas
tools.wrap_umap(batch_corrections.values(), threads=threads)

In [None]:
#Perform additional clustering if it was chosen
color_by = [batch_column]

if do_clustering:
    for adata in batch_corrections.values():
        sc.tl.leiden(adata)
    color_by.append("leiden")

##### LISI score:
LISI score (stored in adata.obs) indicates the effective number of different categories represented in the local neighborhood of each cell. If the cells are well-mixed, then we expect the LISI score to be closer to n for a data with n batches.

##### The higher the LISI score is, the better batch correction method worked to normalize the batch effect and mix the cells from different batches.



In [None]:
# Calculate LISI scores for batch
tools.wrap_batch_evaluation(batch_corrections, batch_key=batch_column, threads=threads, inplace=True)

In [None]:
#Plot the overview of batch correction methods
_ = pl.anndata_overview(batch_corrections, color_by=color_by, 
                        output="batch_correction_overview.pdf")

### Select the final object

In [None]:
%bgcolor PowderBlue

selected = "harmony"

In [None]:
if selected not in batch_corrections:
    raise KeyError(f"'{selected}' is not a key in batch_corrections")

In [None]:
adata = batch_corrections[selected]

-------------

## Saving adata for next notebook

In [None]:
adata

In [None]:
#Saving the data
adata_output = "anndata_3.h5ad"
utils.save_h5ad(adata, adata_output)

In [None]:
sctoolbox.settings.close_logfile()