In [1]:
import copy
import json
import os
from pathlib import Path
import sys
import warnings

import torch
from anndata import AnnData
import scanpy as sc
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import networkx as nx
import pandas as pd
import tqdm

from torchtext.vocab import Vocab
from torchtext._torchtext import (
    Vocab as VocabPybind,
)

sys.path.insert(0, "../")
import scgpt as scg
from scgpt.tasks import GeneEmbedding
from scgpt.tokenizer.gene_tokenizer import GeneVocab
from scgpt.model import TransformerModel
from scgpt.preprocess import Preprocessor
from scgpt.utils import set_seed

os.environ["KMP_WARNINGS"] = "off"
warnings.filterwarnings('ignore')

  from pkg_resources import get_distribution, DistributionNotFound
  from .autonotebook import tqdm as notebook_tqdm


In [2]:
set_seed(42)
pad_token = "<pad>"
special_tokens = [pad_token, "<cls>", "<eoc>"]
n_hvg = 1200
n_bins = 51
mask_value = -1
pad_value = -2
n_input_bins = n_bins

In [3]:
# Specify model path; here we load the pre-trained scGPT blood model
model_dir = Path("/ix3/djishnu/alw399/SpaceOracle/notebooks/benchmark/scGPT-human") # downloaded pretrained model
model_config_file = model_dir / "args.json"
model_file = model_dir / "best_model.pt"
vocab_file = model_dir / "vocab.json"

vocab = GeneVocab.from_file(vocab_file)
for s in special_tokens:
    if s not in vocab:
        vocab.append_token(s)

# Retrieve model parameters from config files
with open(model_config_file, "r") as f:
    model_configs = json.load(f)
print(
    f"Resume model from {model_file}, the model args will override the "
    f"config {model_config_file}."
)
embsize = model_configs["embsize"]
nhead = model_configs["nheads"]
d_hid = model_configs["d_hid"]
nlayers = model_configs["nlayers"]
n_layers_cls = model_configs["n_layers_cls"]

gene2idx = vocab.get_stoi()

Resume model from /ix3/djishnu/alw399/SpaceOracle/notebooks/benchmark/scGPT-human/best_model.pt, the model args will override the config /ix3/djishnu/alw399/SpaceOracle/notebooks/benchmark/scGPT-human/args.json.


In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

ntokens = len(vocab)  # size of vocabulary
model = TransformerModel(
    ntokens,
    embsize,
    nhead,
    d_hid,
    nlayers,
    vocab=vocab,
    pad_value=pad_value,
    n_input_bins=n_input_bins,
)

try:
    model.load_state_dict(torch.load(model_file))
    print(f"Loading all model params from {model_file}")
except:
    # only load params that are in the model and match the size
    model_dict = model.state_dict()
    pretrained_dict = torch.load(model_file)
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items()
        if k in model_dict and v.shape == model_dict[k].shape
    }
    for k, v in pretrained_dict.items():
        print(f"Loading params {k} with shape {v.shape}")
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)

model.to(device)

Loading params encoder.embedding.weight with shape torch.Size([60697, 512])
Loading params encoder.enc_norm.weight with shape torch.Size([512])
Loading params encoder.enc_norm.bias with shape torch.Size([512])
Loading params value_encoder.linear1.weight with shape torch.Size([512, 1])
Loading params value_encoder.linear1.bias with shape torch.Size([512])
Loading params value_encoder.linear2.weight with shape torch.Size([512, 512])
Loading params value_encoder.linear2.bias with shape torch.Size([512])
Loading params value_encoder.norm.weight with shape torch.Size([512])
Loading params value_encoder.norm.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.0.self_attn.out_proj.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.0.self_attn.out_proj.bias with shape torch.Size([512])
Loading params transformer_encoder.layers.0.linear1.weight with shape torch.Size([512, 512])
Loading params transformer_encoder.layers.0.linear1.bias with 

TransformerModel(
  (encoder): GeneEncoder(
    (embedding): Embedding(60697, 512, padding_idx=60694)
    (enc_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (value_encoder): ContinuousValueEncoder(
    (dropout): Dropout(p=0.5, inplace=False)
    (linear1): Linear(in_features=1, out_features=512, bias=True)
    (activation): ReLU()
    (linear2): Linear(in_features=512, out_features=512, bias=True)
    (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-11): 12 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=512, bias=True)
        (dropout): Dropout(p=0.5, inplace=False)
        (linear2): Linear(in_features=512, out_features=512, bias=True)
        (norm1): LayerNorm((512,), eps=1e-05, el

In [5]:
adata = sc.read_h5ad('/ix/djishnu/shared/djishnu_kor11/training_data_2025/snrna_human_tonsil.h5ad')

adata.obs['batch'] = 'tonsil'

ori_batch_col = "batch"
adata.obs["celltype"] = adata.obs["cell_type"].astype(str)
data_is_raw = False


In [6]:
# Preprocess the data following the scGPT data pre-processing pipeline
preprocessor = Preprocessor(
    use_key="X",  # the key in adata.layers to use as raw data
    filter_gene_by_counts=3,  # step 1
    filter_cell_by_counts=False,  # step 2
    normalize_total=1e4,  # 3. whether to normalize the raw data and to what sum
    result_normed_key="X_normed",  # the key in adata.layers to store the normalized data
    log1p=data_is_raw,  # 4. whether to log1p the normalized data
    result_log1p_key="X_log1p",
    subset_hvg=n_hvg,  # 5. whether to subset the raw data to highly variable genes
    hvg_flavor="seurat_v3" if data_is_raw else "cell_ranger",
    binning=n_bins,  # 6. whether to bin the raw data and to what number of bins
    result_binned_key="X_binned",  # the key in adata.layers to store the binned data
)
preprocessor(adata, batch_key="batch")

scGPT - INFO - Filtering genes by counts ...
scGPT - INFO - Normalizing total counts ...
scGPT - INFO - Subsetting highly variable genes ...
scGPT - INFO - Binning data ...


In [7]:
# Retrieve the data-independent gene embeddings from scGPT
gene_ids = np.array([id for id in gene2idx.values()])
gene_embeddings = model.encoder(torch.tensor(gene_ids, dtype=torch.long).to(device))
gene_embeddings = gene_embeddings.detach().cpu().numpy()

In [8]:
# Filter on the intersection between the Immune Human HVGs found in step 1.2 and scGPT's 30+K foundation model vocab
gene_embeddings = {gene: gene_embeddings[i] for i, gene in enumerate(gene2idx.keys()) if gene in adata.var.index.tolist()}
print('Retrieved gene embeddings for {} genes.'.format(len(gene_embeddings)))

Retrieved gene embeddings for 1192 genes.


In [9]:
# Construct gene embedding network
embed = GeneEmbedding(gene_embeddings)

100%|██████████| 1192/1192 [00:00<00:00, 2110430.72it/s]


In [10]:
# Perform Louvain clustering with desired resolution; here we specify resolution=40
gdata = embed.get_adata(resolution=40)
# Retrieve the gene clusters
metagenes = embed.get_metagenes(gdata)

2025-06-30 14:09:16.029308: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-06-30 14:09:16.043516: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-06-30 14:09:16.047566: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-06-30 14:09:16.058336: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 AVX512_FP16 AVX_VNNI AMX_TILE AMX_INT8 AMX_BF16 FMA, in other operations, rebuild TensorFlow with the a

In [11]:
# Obtain the set of gene programs from clusters with #genes >= 5
mgs = dict()
for mg, genes in metagenes.items():
    if len(genes) > 4:
        mgs[mg] = genes

In [12]:
import json
with open('/ix/djishnu/shared/djishnu_kor11/scGPT_outputs/tonsil_mgs_pretrained.json', 'w') as f:
    json.dump(mgs, f)

In [13]:
ref_embed_adata = scg.tasks.embed_data(
    adata,
    model_dir,
    gene_col='index',
    batch_size=64,
    return_new_adata=False,
)

ref_embed_adata

scGPT - INFO - match 1192/1200 genes in vocabulary of size 60697.


Embedding cells: 100%|██████████| 91/91 [00:03<00:00, 25.34it/s]


AnnData object with n_obs × n_vars = 5778 × 1192
    obs: 'cell_type', 'author_cell_type', 'cell_type_int', 'banksy_celltypes', 'cell_type_2', 'batch', 'celltype'
    var: 'n_counts', 'highly_variable', 'means', 'dispersions', 'dispersions_norm', 'highly_variable_nbatches', 'highly_variable_intersection', 'index', 'id_in_vocab'
    uns: 'cell_thresholds', 'cell_type_int_colors', 'received_ligands', 'received_ligands_tfl', 'hvg'
    obsm: 'spatial', 'spatial_unscaled', 'bin_edges', 'X_scGPT'
    layers: 'imputed_count', 'normalized_count', 'X_normed', 'X_binned'

In [14]:
embed_df = pd.DataFrame(ref_embed_adata.obsm['X_scGPT'], index=ref_embed_adata.obs.index)
embed_df

Unnamed: 0_level_0,0,1,2,3,4,5,6,7,8,9,...,502,503,504,505,506,507,508,509,510,511
NAME,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
AAACCCAAGCGCCTTG-1,0.012679,0.033990,-0.022575,-0.027113,0.002759,-0.003039,0.016741,-0.010990,-0.005743,-0.010845,...,-0.011440,-0.001572,0.000572,0.003772,-0.012034,0.001211,0.038917,-0.012217,0.034038,-0.033374
AAACCCAAGTGGACGT-1,0.055772,0.058613,-0.045575,-0.028732,0.023830,-0.038945,-0.011158,-0.005452,-0.005985,0.002775,...,-0.020498,-0.018862,0.015582,0.025737,-0.004703,-0.014754,0.002766,-0.041661,0.027332,-0.016341
AAACCCACAGAAGTGC-1,0.042358,0.061969,-0.018097,-0.032435,0.015256,-0.023132,0.008818,-0.002340,0.000971,-0.003514,...,-0.023124,-0.001365,0.032850,0.013485,-0.003391,-0.029026,0.008960,-0.010630,0.025934,-0.036061
AAACCCAGTCATTGCA-1,0.026779,0.044952,-0.017846,-0.029361,0.026572,-0.001715,0.022163,-0.001436,0.007882,0.002111,...,-0.006393,-0.011647,0.007375,0.000434,-0.013978,0.014779,0.023302,-0.016247,0.032590,-0.022023
AAACCCATCATCGCAA-1,0.021346,0.043378,-0.035684,-0.012263,0.026250,-0.008220,0.011903,-0.010840,-0.011226,-0.002133,...,-0.006123,-0.026347,0.015204,0.021433,-0.020125,-0.009858,0.014942,-0.003959,0.033371,-0.034511
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
TTTGTTGCAGGGACTA-1,0.044359,0.042555,-0.004852,-0.015482,0.012198,-0.010583,0.005981,-0.009249,-0.020597,-0.016300,...,-0.016108,-0.018804,0.023457,0.024487,-0.003530,-0.012398,0.018969,-0.005771,0.036932,-0.020645
TTTGTTGCATTGTAGC-1,0.056101,0.047190,-0.047522,-0.031512,0.029193,-0.038083,-0.018171,0.002924,-0.006285,0.011071,...,-0.014117,-0.023777,0.004803,0.027973,0.010112,-0.005548,0.017982,-0.041940,0.016465,-0.026971
TTTGTTGGTACCACGC-1,0.040479,0.030753,-0.002240,-0.014950,0.003016,-0.023222,-0.002127,-0.004745,-0.030631,-0.010744,...,-0.027701,-0.024816,0.011036,0.021804,-0.010900,-0.003421,0.009828,-0.026497,0.018989,-0.029901
TTTGTTGGTCTGTCCT-1,0.032625,0.033501,-0.036847,-0.035308,0.017868,-0.004681,0.010571,0.001721,-0.006476,-0.014111,...,0.004110,-0.001332,0.010089,0.001788,-0.014333,-0.012018,0.046623,-0.009620,0.032957,-0.034998


In [15]:
embed_df.to_parquet('/ix/djishnu/shared/djishnu_kor11/scGPT_outputs/tonsil_embeddings_pretrained.parquet')