# Batch Correction of the JUMP Target2 Data

In [1]:
import os
import scanpy as sc
from bcc.batch_correction import *
from bcc.scib_metrics import *

 captum (see https://github.com/pytorch/captum).
INFO:lightning_fabric.utilities.seed:Global seed set to 0


This notebook contains example function calls for the batch correction of the JUMP data with different existing methods.

In this project, batch correction was performed in 3 different ways:
1. Low level: for each source, the method is applied separately to correct for plate effects.
2. High level: using all batch-corrected embeddings from step 1 (low level), batch correction is performed to remove differences between sources.
3. Directly: we directly correct for source in the whole data set.

Five different existing integration methods were tested:
* [1 Harmony](#harmony)
* [2 Scanorama](#scanorama)
* [3 scGen](#scgen)
* [4 scanVI](#scanvi)
* [5 scVI](#scvi)

In [2]:
# path to where the data file lies and the integration results will be saved
data_path = "../../data/jump/"

In [3]:
data_file = f"{data_path}jump_target2_spherized.h5ad"
if not os.path.exists(data_file):
    raise FileNotFoundError("Please download the JUMP data and perform the preprocessing!")

In [4]:
adata = sc.read_h5ad(data_file)

## 1 Integration with Harmony <a class="anchor" id="harmony"></a>

In [None]:
# correct for nested batch effects on low level
adata_harmony_per_source = harmony_integration(adata, batch="Metadata_Plate", hierarchical="Metadata_Source")

In [16]:
adata_harmony_per_source.write_h5ad(f"{data_path}harmony_low.h5ad")

In [20]:
# correct for nested batch effects on high level
adata_harmony_overall = harmony_integration(reset_corrected_anndata(adata_harmony_per_source), batch="Metadata_Source")

	Initialization is completed.
	Completed 1 / 10 iteration(s).
	Completed 2 / 10 iteration(s).
	Completed 3 / 10 iteration(s).
	Completed 4 / 10 iteration(s).
	Completed 5 / 10 iteration(s).
	Completed 6 / 10 iteration(s).
	Completed 7 / 10 iteration(s).
	Completed 8 / 10 iteration(s).
Reach convergence after 8 iteration(s).


In [21]:
adata_harmony_overall.write_h5ad(f"{data_path}harmony_high.h5ad")

Correct for source directly on the high level

In [5]:
adata_harmony = harmony_integration(adata, batch="Metadata_Source")

	Initialization is completed.
	Completed 1 / 10 iteration(s).
	Completed 2 / 10 iteration(s).
Reach convergence after 2 iteration(s).


In [8]:
adata_harmony.write_h5ad(f"{data_path}harmony.h5ad")

## 2 Integration with Scanorama <a class="anchor" id="scanorama"></a>

In [None]:
# correct for nested batch effects on low level
adata_scanorama_per_source = scanorama_integration(adata, batch="Metadata_Plate", hierarchical="Metadata_Source")

In [14]:
adata_scanorama_per_source.write_h5ad(f"{data_path}scanorama_low.h5ad")

In [16]:
# correct for nested batch effects on high level
adata_scanorama_overall = scanorama_integration(reset_corrected_anndata(adata_scanorama_per_source), batch="Metadata_Source")

Found 100 genes among all datasets
[[0.   0.04 0.   0.   0.09 0.   0.07 0.03 0.01 0.05 0.  ]
 [0.   0.   0.03 0.01 0.   0.04 0.09 0.03 0.01 0.02 0.01]
 [0.   0.   0.   0.01 0.   0.01 0.01 0.   0.02 0.02 0.02]
 [0.   0.   0.   0.   0.   0.01 0.02 0.02 0.   0.   0.01]
 [0.   0.   0.   0.   0.   0.01 0.01 0.05 0.01 0.03 0.02]
 [0.   0.   0.   0.   0.   0.   0.04 0.06 0.01 0.05 0.02]
 [0.   0.   0.   0.   0.   0.   0.   0.06 0.   0.11 0.07]
 [0.   0.   0.   0.   0.   0.   0.   0.   0.07 0.1  0.03]
 [0.   0.   0.   0.   0.   0.   0.   0.   0.   0.04 0.01]
 [0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.05]
 [0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.  ]]
Processing datasets (6, 9)


In [20]:
adata_scanorama_overall.write_h5ad(f"{data_path}scanorama_high.h5ad")

Correct directly for source

In [10]:
adata_scanorama = scanorama_integration(adata, batch="Metadata_Source")

Found 558 genes among all datasets
[[0.   0.12 0.07 0.06 0.13 0.03 0.04 0.05 0.04 0.01 0.04]
 [0.   0.   0.1  0.11 0.04 0.13 0.12 0.08 0.03 0.03 0.03]
 [0.   0.   0.   0.08 0.06 0.04 0.08 0.1  0.04 0.04 0.06]
 [0.   0.   0.   0.   0.07 0.08 0.13 0.06 0.04 0.04 0.05]
 [0.   0.   0.   0.   0.   0.1  0.08 0.11 0.05 0.58 0.05]
 [0.   0.   0.   0.   0.   0.   0.16 0.32 0.07 0.02 0.02]
 [0.   0.   0.   0.   0.   0.   0.   0.41 0.07 0.03 0.02]
 [0.   0.   0.   0.   0.   0.   0.   0.   0.37 0.17 0.11]
 [0.   0.   0.   0.   0.   0.   0.   0.   0.   0.07 0.06]
 [0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.1 ]
 [0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.  ]]
Processing datasets (4, 9)
Processing datasets (6, 7)
Processing datasets (7, 8)
Processing datasets (5, 7)
Processing datasets (7, 9)
Processing datasets (5, 6)
Processing datasets (1, 5)
Processing datasets (0, 4)
Processing datasets (3, 6)
Processing datasets (0, 1)
Processing datasets (1, 6)
Processing datasets (1, 3)
Pr

In [12]:
adata_scanorama.write_h5ad(f"{data_path}scanorama.h5ad")

--------------------------------------

Since the other methods take longer, one might want to run them on a GPU.

## 3 Integration with scGen <a class="anchor" id="scgen"></a>

In [4]:
# correct for nested batch effects on low level
adata_scgen_per_source = scgen_integration(adata, batch="Metadata_Plate", labels="Metadata_JCP2022", hierarchical="Metadata_Source")

INFO:pytorch_lightning.utilities.rank_zero:GPU available: False, used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


Epoch 10/100:   9%|▉         | 9/100 [00:31<03:51,  2.55s/it, v_num=1, train_loss_step=6.52e+4, train_loss_epoch=1.36e+5]

In [None]:
adata_scgen_per_source.write_h5ad(f"{data_path}scgen_low.h5ad")

In [57]:
# correct for nested batch effects on high level
adata_scgen_overall = scgen_integration(reset_corrected_anndata(adata_scgen_per_source), batch="Metadata_Source", labels="Metadata_JCP2022")

INFO:pytorch_lightning.utilities.rank_zero:GPU available: False, used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


Epoch 26/100:  26%|██▌       | 26/100 [14:21<40:51, 33.12s/it, v_num=1, train_loss_step=0.0424, train_loss_epoch=0.421]
Monitored metric elbo_validation did not improve in the last 25 records. Best score: 226.901. Signaling Trainer to stop.
[34mINFO    [0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             


In [None]:
adata_scgen_overall.write_h5ad(f"{data_path}scgen_high.h5ad")

In [None]:
# correct for source directly
adata_scgen = scgen_integration(adata, batch="Metadata_Source", labels="Metadata_JCP2022")

In [None]:
adata_scgen.write_h5ad(f"{data_path}scgen.h5ad")

## 4 Integration with scanVI <a class="anchor" id="scanvi"></a>

In [None]:
# correct for nested batch effects on low level
adata_scanvi_per_source = scanvi_integration(adata, batch="Metadata_Plate", labels="Metadata_JCP2022", hierarchical="Metadata_Source")

In [None]:
adata_scanvi_per_source.write_h5ad(f"{data_path}scanvi_low.h5ad")

In [None]:
# correct for nested batch effects on high level
adata_scanvi_overall = scanvi_integration(reset_corrected_anndata(adata_scanvi_per_source), batch="Metadata_Source", labels="Metadata_JCP2022")

In [None]:
adata_scanvi_overall.write_h5ad(f"{data_path}scanvi_high.h5ad")

In [None]:
# correct for source directly
adata_scanvi = scanvi_integration(adata, batch="Metadata_Source", labels="Metadata_JCP2022")

In [None]:
adata_scanvi.write_h5ad(f"{data_path}scanvi.h5ad")

## 5 Integration with scVI <a class="anchor" id="scvi"></a>

In [None]:
# correct for nested batch effects on low level
adata_scvi_per_source = scvi_integration(adata, batch="Metadata_Plate", hierarchical="Metadata_Source")

In [None]:
adata_scvi_per_source.write_h5ad(f"{data_path}scvi_low.h5ad")

In [None]:
# correct for nested batch effects on high level
adata_scvi_overall = scvi_integration(reset_corrected_anndata(adata_scvi_per_source), batch="Metadata_Source")

In [None]:
adata_scvi_overall.write_h5ad(f"{data_path}scvi_high.h5ad")

In [None]:
# correct for source directly
adata_scvi = scvi_integration(adata, batch="Metadata_Source")

In [None]:
adata_scvi.write_h5ad(f"{data_path}scvi.h5ad")