# <span style="color: steelblue;">Batch Correction using scaLR</span>

Keypoints

1. This notebook is designed as a tutorial for using batch correction using metadatalaoder from a scaLR library.
2. The dataloader is extensible to add any column from metadata as one hot-encoded vectors which can be useful for model training.

## <span style="color: steelblue;">What is the batch correction?</span>

- Single-cell genomic datasets related to a specific disease or trait are often compiled from multiple experiments, each sequenced from different single cells under varying conditions, such as capturing times, handling personnel, reagent lots, equipment, and even technology platforms.
- These differences lead to large variations also known as batch effects in the data. When combining these datasets for analysis and modeling, it's crucial to ensure that the model isn't biased towards data from certain batches due to their higher or lower value ranges. Therefore, it's necessary to eliminate batch effects from these datasets.


## <span style="color: steelblue;">How to perform batch correction?</span>

- Many statistical tools such as Scanpy, Seurat, Harmony, Combat, etc. performed batch correction by performing normalization and dimensional reduction and then removing the batch effect by fitting linear or mixed models, calculating the k-nearest neighbor or mutual nearest neighbors distance, canonical correlation analysis, etc. 

- While traditional batch corrections are robust and widely used, they have limitations in handling non-linear relationships, scalability, flexibility, and the preservation of biological signals, AI/ML-based batch correction methods offer significant advantages in these areas, making them a powerful alternative in complex and large-scale single-cell genomic datasets. 

- The batch correction approach in the scaLR platform is inspired by the scGPT tool, where batch information is integrated into the feature data to inform the model about the origin of each sample. Since batch is a categorical variable, directly including it as a label-encoded feature is not appropriate, as no batch is inherently superior to another. Instead, the solution is to use a one-hot encoded vector to represent batch information.

- For example, if we have four batches in the dataset, the one-hot encoding would work as follows:
    - Batch 1 -> 0 0 0 1
    - Batch 2 -> 0 0 1 0
    - Batch 3 -> 0 1 0 0
    - Batch 4 -> 1 0 0 0

- These encoded vectors represent the batches and are added to the feature data. In this case, four additional columns will be included in the feature data, ensuring that the model is aware of the batch information while training on samples from different batches.


## <span style="color: steelblue;">How is it implemented in the scaLR platform?</span>

- In the scaLR platform, we've implemented SimpleMetaDataLoader data loader that handles this process automatically.
- You can specify the metadata column you want to one-hot encode and add it to the feature data, and the data loader will take care of the encoding. 
- We've also extended the functionality to allow users to include multiple columns from the metadata as one-hot encoded vectors in the feature data. Bypassing a list of columns, you can easily incorporate additional information.
- This approach is particularly useful in scenarios like predicting disease vs. non-disease outcomes, where certain metadata, such as cell type, might enhance the model's predictive power.
- By adding the cell type information to the feature data using this method, you can improve the model's performance.
- Generally, this won't be used as a library utility - it will be mostly used as a part of pipeline. Please find below a explaination a code snippet to understand its basics.


In [None]:
import sys
sys.path.append('/path/to/scaLR')

In [None]:
# Required imports
import anndata
import numpy as np
import pandas as pd

from scalr.nn.dataloader import build_dataloader, simple_metadataloader

# Setting seed for reproducibility
np.random.seed(0)

In [None]:
# Creating anndata object.
adata = anndata.AnnData(X=np.random.rand(15, 7))
adata.obs = pd.DataFrame.from_dict({
    'celltype': np.random.choice(['B', 'C', 'DC', 'T'], size=15),
    'batch': np.random.choice(['batch1', 'batch2'], size=15),
    'env': np.random.choice(['env1', 'env2', 'env3'], size=15)
})
adata.obs.index = adata.obs.index.astype('O')

In [None]:
# Below are the required params for metadataloader. 
simple_metadataloader.SimpleDataLoader.__init__.__annotations__

{'batch_size': int, 'target': str, 'mappings': dict, 'padding': int}

In [None]:
# Defining required parameters for metadataloader.

# For batch correction you can pass the `batch` column inside `metadata_col`. 
metadata_col = ['batch', ]

# We need to use `SimpleMetaDataLoader` for doing batch correction.
dataloader_config = {
    'name': 'SimpleMetaDataLoader',
    'params': {
        'batch_size': 3,
        'metadata_col': metadata_col
    }
}

# Generating mappings for anndata obs columns.
mappings = {}
for column_name in adata.obs.columns:
    mappings[column_name] = {}

    id2label = []
    id2label += adata.obs[column_name].astype(
        'category').cat.categories.tolist()

    label2id = {id2label[i]: i for i in range(len(id2label))}
    mappings[column_name]['id2label'] = id2label
    mappings[column_name]['label2id'] = label2id

In [None]:
# Creating dataloader object.
dataloader, _ = build_dataloader(dataloader_config=dataloader_config,
                                    adata=adata,
                                    target='celltype',
                                    mappings=mappings)

In [None]:
# We can check if `batch` is added as a one hot-encoded vectors in features data.
# Initially features shape is (batch_size, 7) & there are 2 batches in data.
# So number of features after adding this column to features data should be 7+2=9.
# Hence features shape has to be (batch_size, 9) post doing batch correction.

for feature, _ in dataloader:
    print('Features shape :', feature.shape)
    break

Features shape : torch.Size([3, 9])
