# Multi-batch integration - SKNSH & NBLW


In this notebook, an [scGPT](https://www.nature.com/articles/s41592-024-02201-0) model is used to predict a cell type annotation with a given gene expression profile.

This follows the tutorial from scGPT [here](https://github.com/bowang-lab/scGPT/blob/main/tutorials/Tutorial_Annotation.ipynb), but instead of fine-tuning the entire model, a smaller neural network is trained, using the embeddings of the gene expressions as inputs, to make a prediction.

The same approach is made with the [Geneformer](https://www.nature.com/articles/s41586-023-06139-9.epdf?sharing_token=u_5LUGVkd3A8zR-f73lU59RgN0jAjWel9jnR3ZoTv0N2UB4yyXENUK50s6uqjXH69sDxh4Z3J4plYCKlVME-W2WSuRiS96vx6t5ex2-krVDS46JkoVvAvJyWtYXIyj74pDWn_DutZq1oAlDaxfvBpUfSKDdBPJ8SKlTId8uT47M%3D) model and the results are compared against each other.

This approach greatly reduces time and complexity.

In [1]:
#!pip3 install helical
#!conda install -c conda-forge louvain
#!pip3 install datasets --upgrade

In [1]:
from sklearn.metrics import accuracy_score, precision_score, f1_score, recall_score
from sklearn.preprocessing import LabelEncoder
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import anndata as ad
import torch
from torch.utils.data import DataLoader, TensorDataset
from torch import nn
from scipy.sparse import lil_matrix
import torch.optim as optim
from helical.models.scgpt import scGPT, scGPTConfig
from helical.models.geneformer import Geneformer, GeneformerConfig
from copy import deepcopy
from torch.nn.functional import one_hot
import scanpy as sc

  from .autonotebook import tqdm as notebook_tqdm

INFO:datasets:PyTorch version 2.6.0 available.


Load target datasets

In [2]:
# Load SKNSH data
adata_sknsh = sc.read_h5ad("../integration/210824_aumc00012sc_10k_scRNAseq.h5ad")
# Load NBLW data
adata_nblw = sc.read_h5ad("../integration/210824_aumc00013sc_10k_scRNAseq.h5ad")

# Load combined data
adata = sc.read_h5ad("combined_aumc00012sc_aumc00013sc_scRNAseq.h5ad")


In [3]:
# Cell labels - add cellline ID to cell labels from individual Seurat objects. 
# This is normally added by merge function in Seurat, but since we are loading preprocessed h5ad files, we need to do it manually here.

nblw_ids = ["nblw_" + s for s in adata_nblw.obs.index.tolist()]
print(nblw_ids)

sknsh_ids = ["sknsh_" + s for s in adata_sknsh.obs.index.tolist()]
print(sknsh_ids)

nblwsknsh_ids = sknsh_ids + nblw_ids
print(nblwsknsh_ids)

['nblw_-13sc_AAACCCAAGGTGATAT-1', 'nblw_-13sc_AAACCCATCCATTTCA-1', 'nblw_-13sc_AAACGAAGTAGGCTGA-1', 'nblw_-13sc_AAACGCTCACCAAAGG-1', 'nblw_-13sc_AAACGCTGTACCGTGC-1', 'nblw_-13sc_AAACGCTTCAGCACCG-1', 'nblw_-13sc_AAAGAACCAGCACGAA-1', 'nblw_-13sc_AAAGAACCATTGTCGA-1', 'nblw_-13sc_AAAGAACTCCTATGGA-1', 'nblw_-13sc_AAAGAACTCTCGCTCA-1', 'nblw_-13sc_AAAGGGCCAACCAGAG-1', 'nblw_-13sc_AAAGGGCCAGGTACGA-1', 'nblw_-13sc_AAAGGGCTCCAACCGG-1', 'nblw_-13sc_AAAGGGCTCCGGTTCT-1', 'nblw_-13sc_AAAGGTAAGAAGATCT-1', 'nblw_-13sc_AAAGGTAGTAGGTCAG-1', 'nblw_-13sc_AAAGTCCAGTATAACG-1', 'nblw_-13sc_AAAGTCCGTGGTTTGT-1', 'nblw_-13sc_AAAGTGAAGACAACAT-1', 'nblw_-13sc_AAAGTGACAATTGCAC-1', 'nblw_-13sc_AAAGTGACATCTAGAC-1', 'nblw_-13sc_AAAGTGAGTTGTAGCT-1', 'nblw_-13sc_AAAGTGAGTTTGGGAG-1', 'nblw_-13sc_AAATGGAAGCCTTTGA-1', 'nblw_-13sc_AAATGGACACCCTAAA-1', 'nblw_-13sc_AAATGGAGTCTCGCGA-1', 'nblw_-13sc_AAATGGAGTTCCTAGA-1', 'nblw_-13sc_AACAAAGGTCAATGGG-1', 'nblw_-13sc_AACAAAGGTCACCTTC-1', 'nblw_-13sc_AACAACCAGCATGTTC-1', 'nblw_-13

In [4]:
# Check if cell Ids match
nblwsknsh_ids == adata.obs.index.tolist()

True

In [5]:
print(adata.X)
print(adata.raw.X)

  (0, 21)	1.8602688956440707
  (0, 27)	0.9083654360057144
  (0, 39)	0.9083654360057144
  (0, 40)	0.9083654360057144
  (0, 43)	2.1723822318241837
  (0, 51)	0.9083654360057144
  (0, 53)	2.1723822318241837
  (0, 63)	0.9083654360057144
  (0, 66)	2.1723822318241837
  (0, 68)	0.9083654360057144
  (0, 78)	2.1723822318241837
  (0, 86)	0.9083654360057144
  (0, 89)	2.1723822318241837
  (0, 93)	0.9083654360057144
  (0, 98)	0.9083654360057144
  (0, 122)	1.4614423910429508
  (0, 132)	1.8602688956440707
  (0, 154)	4.142530858226281
  (0, 157)	0.9083654360057144
  (0, 161)	1.4614423910429508
  (0, 172)	0.9083654360057144
  (0, 178)	1.4614423910429508
  (0, 185)	0.9083654360057144
  (0, 189)	0.9083654360057144
  (0, 190)	2.1723822318241837
  :	:
  (5078, 33157)	1.5965745196438144
  (5078, 33158)	1.5965745196438144
  (5078, 33207)	1.5965745196438144
  (5078, 33209)	1.5965745196438144
  (5078, 33249)	2.822267566031215
  (5078, 33297)	3.7164931290906282
  (5078, 33321)	2.335851343983151
  (5078, 33326)	1

In [36]:
print(adata.raw.X)

  (0, 21)	3.0
  (0, 27)	1.0
  (0, 39)	1.0
  (0, 40)	1.0
  (0, 43)	4.0
  (0, 51)	1.0
  (0, 53)	4.0
  (0, 63)	1.0
  (0, 66)	4.0
  (0, 68)	1.0
  (0, 78)	4.0
  (0, 86)	1.0
  (0, 89)	4.0
  (0, 93)	1.0
  (0, 98)	1.0
  (0, 122)	2.0
  (0, 132)	3.0
  (0, 154)	19.0
  (0, 157)	1.0
  (0, 161)	2.0
  (0, 172)	1.0
  (0, 178)	2.0
  (0, 185)	1.0
  (0, 189)	1.0
  (0, 190)	4.0
  :	:
  (5078, 33157)	1.0
  (5078, 33158)	1.0
  (5078, 33207)	1.0
  (5078, 33209)	1.0
  (5078, 33249)	3.0
  (5078, 33297)	6.0
  (5078, 33321)	2.0
  (5078, 33326)	1.0
  (5078, 33376)	1.0
  (5078, 33394)	2.0
  (5078, 33396)	1.0
  (5078, 33443)	2.0
  (5078, 33456)	1.0
  (5078, 33470)	1.0
  (5078, 33496)	42.0
  (5078, 33497)	25.0
  (5078, 33498)	62.0
  (5078, 33499)	75.0
  (5078, 33501)	129.0
  (5078, 33502)	96.0
  (5078, 33503)	34.0
  (5078, 33504)	1.0
  (5078, 33505)	51.0
  (5078, 33506)	10.0
  (5078, 33508)	69.0


In [7]:
print(adata_sknsh.raw.X)

  (0, 21)	3.0
  (0, 27)	1.0
  (0, 39)	1.0
  (0, 40)	1.0
  (0, 43)	4.0
  (0, 51)	1.0
  (0, 53)	4.0
  (0, 63)	1.0
  (0, 66)	4.0
  (0, 68)	1.0
  (0, 78)	4.0
  (0, 86)	1.0
  (0, 89)	4.0
  (0, 93)	1.0
  (0, 98)	1.0
  (0, 122)	2.0
  (0, 132)	3.0
  (0, 154)	19.0
  (0, 157)	1.0
  (0, 161)	2.0
  (0, 172)	1.0
  (0, 178)	2.0
  (0, 185)	1.0
  (0, 189)	1.0
  (0, 190)	4.0
  :	:
  (2773, 33269)	1.0
  (2773, 33279)	1.0
  (2773, 33297)	2.0
  (2773, 33304)	2.0
  (2773, 33321)	1.0
  (2773, 33380)	1.0
  (2773, 33396)	1.0
  (2773, 33443)	2.0
  (2773, 33446)	1.0
  (2773, 33470)	1.0
  (2773, 33479)	1.0
  (2773, 33490)	1.0
  (2773, 33493)	1.0
  (2773, 33496)	138.0
  (2773, 33497)	65.0
  (2773, 33498)	104.0
  (2773, 33499)	98.0
  (2773, 33501)	140.0
  (2773, 33502)	172.0
  (2773, 33503)	91.0
  (2773, 33504)	6.0
  (2773, 33505)	108.0
  (2773, 33506)	23.0
  (2773, 33507)	2.0
  (2773, 33508)	70.0


In [8]:
print(adata_nblw.raw.X)

  (0, 26)	1.0
  (0, 40)	1.0
  (0, 53)	1.0
  (0, 68)	1.0
  (0, 93)	1.0
  (0, 154)	11.0
  (0, 157)	1.0
  (0, 178)	2.0
  (0, 190)	1.0
  (0, 201)	3.0
  (0, 214)	1.0
  (0, 229)	1.0
  (0, 244)	3.0
  (0, 259)	1.0
  (0, 269)	1.0
  (0, 412)	3.0
  (0, 445)	1.0
  (0, 465)	2.0
  (0, 477)	1.0
  (0, 483)	1.0
  (0, 493)	20.0
  (0, 503)	1.0
  (0, 526)	1.0
  (0, 546)	3.0
  (0, 555)	1.0
  :	:
  (2304, 33157)	1.0
  (2304, 33158)	1.0
  (2304, 33207)	1.0
  (2304, 33209)	1.0
  (2304, 33249)	3.0
  (2304, 33297)	6.0
  (2304, 33321)	2.0
  (2304, 33326)	1.0
  (2304, 33376)	1.0
  (2304, 33394)	2.0
  (2304, 33396)	1.0
  (2304, 33443)	2.0
  (2304, 33456)	1.0
  (2304, 33470)	1.0
  (2304, 33496)	42.0
  (2304, 33497)	25.0
  (2304, 33498)	62.0
  (2304, 33499)	75.0
  (2304, 33501)	129.0
  (2304, 33502)	96.0
  (2304, 33503)	34.0
  (2304, 33504)	1.0
  (2304, 33505)	51.0
  (2304, 33506)	10.0
  (2304, 33508)	69.0


In [9]:
# Function for getting 'data_processed' for scGPT model to check for equality
def get_data_processed(adata): 
    adata.X = adata.raw.X.copy()
    adata.var["gene_name"] = adata.var_names # "Data must have the provided key 'gene_name' in its 'var' section to be processed by the Helical RNA model."

    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Configure and initialize the scGPT model
    scgpt_config = scGPTConfig(batch_size=50, device=device, binning_seed=123)
    scgpt = scGPT(configurer = scgpt_config)

    # Process the data for the scGPT model
    Normalize_SubsetHighlyVariable = False
    if Normalize_SubsetHighlyVariable:
        data_processed = scgpt.process_data(adata, gene_names = "gene_name", fine_tuning=True)
    else:
        data_processed = scgpt.process_data(adata, gene_names = "gene_name")

    return data_processed

In [10]:
# Generate data_processed for combined and individual datasets
combined_data_processed = get_data_processed(adata)
sknsh_data_processed = get_data_processed(adata_sknsh)
nblw_data_processed = get_data_processed(adata_nblw)



INFO:helical.models.scgpt.model:Model finished initializing.
INFO:helical.models.scgpt.model:'scGPT' model is in 'eval' mode, on device 'cuda' with embedding mode 'cell'.
INFO:helical.models.scgpt.model:Processing data for scGPT.
INFO:helical.models.scgpt.model:Filtering out 10213 genes to a total of 23325 genes with an ID in the scGPT vocabulary.
INFO:helical.models.scgpt.model:Successfully processed the data for scGPT.

INFO:helical.models.scgpt.model:Model finished initializing.
INFO:helical.models.scgpt.model:'scGPT' model is in 'eval' mode, on device 'cuda' with embedding mode 'cell'.
INFO:helical.models.scgpt.model:Processing data for scGPT.
INFO:helical.models.scgpt.model:Filtering out 10213 genes to a total of 23325 genes with an ID in the scGPT vocabulary.
INFO:helical.models.scgpt.model:Successfully processed the data for scGPT.

INFO:helical.models.scgpt.model:Model finished initializing.
INFO:helical.models.scgpt.model:'scGPT' model is in 'eval' mode, on device 'cuda' with

In [11]:
# Shapes of the count matrices
print(combined_data_processed.count_matrix.shape)
print(sknsh_data_processed.count_matrix.shape)
print(nblw_data_processed.count_matrix.shape)

(5079, 23325)
(2774, 23325)
(2305, 23325)


In [12]:
print(combined_data_processed.count_matrix)



[[  0.   0.   0. ...  14.  11. 100.]
 [  0.   0.   0. ...  17.  20. 124.]
 [  0.   0.   0. ...  19.   7.  55.]
 ...
 [  0.   0.   0. ...  41.   3. 124.]
 [  0.   0.   0. ...  30.   5. 152.]
 [  0.   0.   0. ...  10.   0.  69.]]


In [13]:
print(sknsh_data_processed.count_matrix)

[[  0.   0.   0. ...  14.  11. 100.]
 [  0.   0.   0. ...  17.  20. 124.]
 [  0.   0.   0. ...  19.   7.  55.]
 ...
 [  0.   0.   0. ...  20.   2.  68.]
 [  0.   0.   0. ...  24.   0.  79.]
 [  0.   0.   0. ...  23.   2.  70.]]


In [14]:
print(nblw_data_processed.count_matrix)

[[  0.   0.   0. ...  10.   2.  28.]
 [  0.   0.   0. ...  74.   3. 188.]
 [  0.   0.   0. ...  20.  10.  90.]
 ...
 [  0.   0.   0. ...  41.   3. 124.]
 [  0.   0.   0. ...  30.   5. 152.]
 [  0.   0.   0. ...  10.   0.  69.]]


In [15]:
# Append the individual count matrices to compare with the combined count matrix
stacked = np.vstack((sknsh_data_processed.count_matrix, nblw_data_processed.count_matrix)) 
stacked

array([[  0.,   0.,   0., ...,  14.,  11., 100.],
       [  0.,   0.,   0., ...,  17.,  20., 124.],
       [  0.,   0.,   0., ...,  19.,   7.,  55.],
       ...,
       [  0.,   0.,   0., ...,  41.,   3., 124.],
       [  0.,   0.,   0., ...,  30.,   5., 152.],
       [  0.,   0.,   0., ...,  10.,   0.,  69.]])

In [16]:
print(stacked.shape)
print(combined_data_processed.count_matrix.shape)

(5079, 23325)
(5079, 23325)


In [17]:
log_df = combined_data_processed.count_matrix == stacked
log_df

array([[ True,  True,  True, ...,  True,  True,  True],
       [ True,  True,  True, ...,  True,  True,  True],
       [ True,  True,  True, ...,  True,  True,  True],
       ...,
       [ True,  True,  True, ...,  True,  True,  True],
       [ True,  True,  True, ...,  True,  True,  True],
       [ True,  True,  True, ...,  True,  True,  True]])

In [18]:
# check if data_processed count matrices are equal
np.array_equal(combined_data_processed.count_matrix, stacked)

True

In [19]:
# check if data_processed count matrices are equal with tolerance for floating point errors
np.allclose(combined_data_processed.count_matrix, stacked)

True

We are interested in the names of the cells we want to predict. They are saved in `adata.obs["active_ident_celltypes"]`.

Additionally, we need to know how many distinct cell types/classes we have.

Use the Helical package to get the embeddings of the gene expression profile.

The only thing we need to specify is the column containing the names of the genes. (`gene_name` in this case)

The resulting embeddings are the input features `x`.

# scGPT

In [32]:
# TODO: Print progress of the function

def scgpt_get_embeddings(adata, Normalize_SubsetHighlyVariable=False):
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Configure and initialize the scGPT model
    scgpt_config = scGPTConfig(batch_size=50, device=device, binning_seed=123)
    scgpt = scGPT(configurer = scgpt_config)

    # Process the data for the scGPT model
    if Normalize_SubsetHighlyVariable:
        data_processed = scgpt.process_data(adata, gene_names = "gene_name", fine_tuning=True)
    else:
        data_processed = scgpt.process_data(adata, gene_names = "gene_name")

    # Get embeddings
    x_scgpt = scgpt.get_embeddings(data_processed)
    
    return x_scgpt

In [21]:
def set_my_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

In [33]:
def print_my_seed():
    print("Numpy seed:", np.random.get_state()[1][0])
    print("Torch seed:", torch.initial_seed())
    if torch.cuda.is_available():
        print("Torch CUDA seed:", torch.cuda.initial_seed())

In [34]:
print(np.__version__)

1.26.4


In [31]:
np.random.default_rng()

Generator(PCG64) at 0x77E944E12DC0

In [24]:
print_my_seed()

Numpy seed: 2147483648
Torch seed: 1213970427842518326
Torch CUDA seed: 6725295639226610


In [25]:
# Force deterministic algorithms (may slow down GPU)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [35]:
# Get embeddings for combined and individual datasets
set_my_seed(123)
print_my_seed()
combined_scgpt_embeddings = scgpt_get_embeddings(adata, Normalize_SubsetHighlyVariable=False)
print_my_seed()
set_my_seed(123)
print_my_seed()
sknsh_scgpt_embeddings = scgpt_get_embeddings(adata_sknsh, Normalize_SubsetHighlyVariable=False)
print_my_seed()
set_my_seed(123)
print_my_seed()
nblw_scgpt_embeddings = scgpt_get_embeddings(adata_nblw, Normalize_SubsetHighlyVariable=False)
print_my_seed()
set_my_seed(123)
print_my_seed()
print(combined_scgpt_embeddings.shape)
print(sknsh_scgpt_embeddings.shape)
print(nblw_scgpt_embeddings.shape)


Numpy seed: 123
Torch seed: 123
Torch CUDA seed: 123



INFO:helical.models.scgpt.model:Model finished initializing.
INFO:helical.models.scgpt.model:'scGPT' model is in 'eval' mode, on device 'cuda' with embedding mode 'cell'.
INFO:helical.models.scgpt.model:Processing data for scGPT.
INFO:helical.models.scgpt.model:Filtering out 10213 genes to a total of 23325 genes with an ID in the scGPT vocabulary.
INFO:helical.models.scgpt.model:Successfully processed the data for scGPT.
INFO:helical.models.scgpt.model:Started getting embeddings:
  torch.cuda.amp.autocast(enabled=True),

Embedding cells: 100%|██████████| 102/102 [00:44<00:00,  2.28it/s]
INFO:helical.models.scgpt.model:Finished getting embeddings.


Numpy seed: 1988737281
Torch seed: 123
Torch CUDA seed: 123
Numpy seed: 123
Torch seed: 123
Torch CUDA seed: 123


INFO:helical.models.scgpt.model:Model finished initializing.
INFO:helical.models.scgpt.model:'scGPT' model is in 'eval' mode, on device 'cuda' with embedding mode 'cell'.
INFO:helical.models.scgpt.model:Processing data for scGPT.
INFO:helical.models.scgpt.model:Filtering out 10213 genes to a total of 23325 genes with an ID in the scGPT vocabulary.
INFO:helical.models.scgpt.model:Successfully processed the data for scGPT.
INFO:helical.models.scgpt.model:Started getting embeddings:
Embedding cells: 100%|██████████| 56/56 [00:24<00:00,  2.30it/s]
INFO:helical.models.scgpt.model:Finished getting embeddings.


Numpy seed: 2100082102
Torch seed: 123
Torch CUDA seed: 123
Numpy seed: 123
Torch seed: 123
Torch CUDA seed: 123


INFO:helical.models.scgpt.model:Model finished initializing.
INFO:helical.models.scgpt.model:'scGPT' model is in 'eval' mode, on device 'cuda' with embedding mode 'cell'.
INFO:helical.models.scgpt.model:Processing data for scGPT.
INFO:helical.models.scgpt.model:Filtering out 10213 genes to a total of 23325 genes with an ID in the scGPT vocabulary.
INFO:helical.models.scgpt.model:Successfully processed the data for scGPT.
INFO:helical.models.scgpt.model:Started getting embeddings:
Embedding cells: 100%|██████████| 47/47 [00:20<00:00,  2.34it/s]
INFO:helical.models.scgpt.model:Finished getting embeddings.


Numpy seed: 2239475340
Torch seed: 123
Torch CUDA seed: 123
Numpy seed: 123
Torch seed: 123
Torch CUDA seed: 123
(5079, 512)
(2774, 512)
(2305, 512)


In [37]:
# Append the individual embeddings to compare with the combined embeddings
stacked_scgpt = np.vstack((sknsh_scgpt_embeddings, nblw_scgpt_embeddings))

In [38]:
# check if scgpt embeddings are equal
np.array_equal(combined_scgpt_embeddings, stacked_scgpt)

False

In [39]:
# logical array for element-wise comparison
logical_scgpt = combined_scgpt_embeddings == stacked_scgpt
logical_scgpt

array([[ True,  True,  True, ...,  True,  True,  True],
       [ True,  True,  True, ...,  True,  True,  True],
       [ True,  True,  True, ...,  True,  True,  True],
       ...,
       [False, False, False, ..., False, False, False],
       [False, False, False, ..., False, False, False],
       [False, False, False, ..., False, False, False]])

In [40]:
# function to truncate numpy array values to a certain number of decimals
def truncate(arr, decimals=5):
    factor = 10.0 ** decimals
    return np.trunc(arr * factor) / factor

In [41]:
# truncate values to 5 decimals and compare again
stacked_scgpt_trunc = truncate(stacked_scgpt, decimals=5)
combined_scgpt_embeddings_trunc = truncate(combined_scgpt_embeddings, decimals=5)

In [42]:
np.array_equal(stacked_scgpt_trunc, combined_scgpt_embeddings_trunc)    

False

In [43]:
# logical array for element-wise comparison
logical_scgpt_trunc = combined_scgpt_embeddings_trunc == stacked_scgpt_trunc
logical_scgpt_trunc

array([[ True,  True,  True, ...,  True,  True,  True],
       [ True,  True,  True, ...,  True,  True,  True],
       [ True,  True,  True, ...,  True,  True,  True],
       ...,
       [False, False, False, ..., False, False, False],
       [False, False, False, ..., False, False, False],
       [False, False, False, ..., False, False, False]])

# Save AnnData

In [None]:
#adata.obsm["scgpt"] = combined_scgpt_embeddings

In [None]:
# adata.raw column is called _index which raises an error when saving
#adata.raw = None

In [None]:
#adata.write_h5ad("combined_aumc00012sc_aumc00013sc_scRNAseq_scGPT.h5ad")

... storing 'Phase' as categorical
