# <span style="color: steelblue;">Batch Correction using scaLR</span>

Keypoints

1. This notebook is designed as a tutorial for using batch correction using **`SimpleMetaDataLoader`** from **`scaLR`** library.
2. The dataloader is extensible to add any column from metadata as one hot-encoded vectors which can be useful for model training.

## <span style="color: steelblue;">What is the batch correction?</span>

- Single-cell genomic datasets related to a specific disease or trait are often compiled from multiple experiments, each sequenced from different single cells under varying conditions, such as capturing times, handling personnel, reagent lots, equipment, and even technology platforms.
- These differences lead to large variations also known as `batch effects` in the data. When combining these datasets for analysis and modeling, it's crucial to ensure that the model isn't biased towards data from certain batches due to their higher or lower value ranges. Therefore, it's necessary to eliminate `batch effects` from these datasets.


## <span style="color: steelblue;">How to perform batch correction?</span>

- Many statistical tools such as `Scanpy`, `Seurat`, `Harmony`, `Combat`, etc. performed `batch correction` by performing `normalization` and `dimensional reduction` and then removing the `batch effect` by fitting `linear or mixed models`, calculating the `k-nearest neighbor` or `mutual nearest neighbors distance`, `canonical correlation analysis`, etc.

- While traditional `batch corrections` are robust and widely used, they have limitations in handling non-linear relationships, scalability, flexibility, and the preservation of biological signals, AI/ML-based batch correction methods offer significant advantages in these areas, making them a powerful alternative in complex and large-scale single-cell genomic datasets.

- The `batch correction` approach in the **`scaLR`** platform is inspired by the [scGPT](https://www.nature.com/articles/s41592-024-02201-0)(Cui et al.) tool, where batch information is integrated into the feature data to inform the model about the origin of each sample. Since batch is a `categorical` variable, directly including it as a `label-encoded` feature is not appropriate, as no batch is inherently superior to another. Instead, the solution is to use a `one-hot encoded vector` to represent batch information.

- For example, if we have four batches in the dataset, the one-hot encoding would work as follows:
    - Batch 1 -> 0 0 0 1
    - Batch 2 -> 0 0 1 0
    - Batch 3 -> 0 1 0 0
    - Batch 4 -> 1 0 0 0

- These encoded vectors represent the batches and are added to the feature data. In this case, four additional columns will be included in the feature data, ensuring that the model is aware of the batch information while training on samples from different batches.


## <span style="color: steelblue;">How is it implemented in the scaLR platform?</span>

- In the **`scaLR`** platform, we've implemented **`SimpleMetaDataLoader`** data loader that handles this process automatically.
- You can specify the `metadata` column you want to one-hot encode and add it to the feature data, and the data loader will take care of the encoding.
- We've also extended the functionality to allow users to include multiple columns from the `metadata` as `one-hot encoded vectors` in the feature data. Bypassing a list of columns, you can easily incorporate additional information.
- This approach is particularly useful in scenarios like predicting `disease vs. non-disease` outcomes, where certain `metadata`, such as `cell type`, might enhance the model's predictive power.
- By adding the `cell type` information to the feature data using this method, you can improve the model's performance.
- Generally, this won't be used as a library utility - it will be mostly used as a part of pipeline. Please find below a explaination a code snippet to understand its basics.


## <span style="color: steelblue;">Cloning scaLR</span>

In [None]:
!git clone https://github.com/infocusp/scaLR.git

## <span style="color: steelblue;">Library Installation and Imports</span>



In [None]:
!pip install anndata==0.10.9 pydeseq2==0.4.11 scanpy==1.10.3

In [None]:
# Required imports.
import sys
sys.path.append('./scaLR')

import anndata
import numpy as np
import pandas as pd

from scalr.nn.dataloader import build_dataloader, simple_metadataloader

# Setting seed for reproducibility
np.random.seed(0)

## <span style="color: steelblue;">Downloading data</span>

For this tutorial, we will use two datasets from `cellxgene`([Jin et al. (2021) iScience](https://doi.org/10.1016/j.isci.2021.103115)). The first dataset will serve as batch 1, and the second as batch 2.

In [None]:
!wget -P ./data https://datasets.cellxgene.cziscience.com/16acb1d0-4108-4767-9615-0b42abe09992.h5ad
!wget -P ./data https://datasets.cellxgene.cziscience.com/8651e63c-0f98-4a87-bdbd-2da41bdf6de5.h5ad

## <span style="color: steelblue;">Loading datasets and merging</span>

In [None]:
data_1 = anndata.read_h5ad('/content/data/16acb1d0-4108-4767-9615-0b42abe09992.h5ad')
data_2 = anndata.read_h5ad('/content/data/8651e63c-0f98-4a87-bdbd-2da41bdf6de5.h5ad')
print(f'\nDataset-1 has {data_1.n_obs} cells and {data_1.n_vars} genes\nDataset-2 has {data_2.n_obs} cells and {data_2.n_vars} genes')

In [None]:
print(f'Shape of "obs" before adding batch\n\nDataset-1 : {data_1.obs.shape}\nDataset-2 : {data_2.obs.shape}')

In [None]:
# Adding batch information.
data_1.obs['batch'] = 'batch1'
data_2.obs['batch'] = 'batch2'
print(f'Shape of "obs" after adding batch\n\nDataset-1 : {data_1.obs.shape}\nDataset-2 : {data_2.obs.shape}')

In [None]:
# Combining two datasets.
adata = anndata.concat([data_1, data_2])
print(f'Combined datset has shape : {adata.shape}')

## <span style="color: steelblue;">Batch correction</span>

In [None]:
# Below are the required params for SimpleMetaDataLoader.
simple_metadataloader.SimpleMetaDataLoader.__init__.__annotations__

In [None]:
# Defining required parameters for SimpleMetaDataLoader.

# For batch correction you can pass the `batch` column inside `metadata_col`.
metadata_col = ['batch', ]

# Generating mappings for anndata obs columns.
mappings = {}
for column_name in adata.obs.columns:
    mappings[column_name] = {}

    id2label = []
    id2label += adata.obs[column_name].astype(
        'category').cat.categories.tolist()

    label2id = {id2label[i]: i for i in range(len(id2label))}
    mappings[column_name]['id2label'] = id2label
    mappings[column_name]['label2id'] = label2id

In [None]:
# Creating dataloader object.
dataloader = simple_metadataloader.SimpleMetaDataLoader(batch_size=3,
                                                        target='cell_type',
                                                        mappings=mappings,
                                                        metadata_col=metadata_col)

We can check if `batch` is added as a one hot-encoded vectors in features data`(i.e. 23586 genes + 2 batches)`.
Initially features shape is `(batch_size, 23586)` & there are 2 batches in data.
So number of features after adding this column to features data should be `23586+2=23588`.
Hence features shape has to be `(batch_size, 23588)` post doing batch correction.

In [None]:
# Verifying features in the dataloader
for feature, _ in dataloader(adata):
    print('Features shape :', feature.shape)
    print('Features :', feature)
    break

We observe that the feature tensor has a shape of `[3, 23588]`, representing `3` samples with `23,588` features each. The batch information is appended to the gene expression values using one-hot encoding: in this case, each feature vector ends with values `[1, 0]`, representing the samples from `batch 1`.

In [None]:
# Checking what is one hot encoding vector for batches.
onhotencode_batches = dataloader.metadata_onehotencoder['batch'].transform(np.array(['batch1', 'batch2']).reshape(-1, 1))
onhotencode_batches.A

It has been verified that the one-hot encoding vector for batch 1 is [1, 0] and for batch 2 is [0, 1]