In [1]:
import json
import matplotlib.pyplot as plt
import numpy as np
import os
import scanpy as sc
import scgpt.tasks.cell_emb as cell_emb
import torch

from scgpt.model import TransformerModel
from scgpt.tasks.cell_emb import get_batch_cell_embeddings
from scgpt.tokenizer import GeneVocab
from scgpt.utils import load_pretrained
from sklearn.manifold import TSNE
from torch.utils.data import DataLoader

import warnings
warnings.filterwarnings("ignore")
sc.settings.verbosity = 0

  from pkg_resources import get_distribution, DistributionNotFound
  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Disable multiprocessing
if not hasattr(os, "sched_getaffinity"):
    os.sched_getaffinity = lambda x: [0]
    
_Orig = DataLoader
def NoMP(*args, **kwargs):
    kwargs["num_workers"] = 0
    return _Orig(*args, **kwargs)

torch.utils.data.DataLoader = NoMP
cell_emb.DataLoader = NoMP

In [3]:
# Paths and parameters
data_path = "../../data/merged_data/"
model_dir = "../../data/_utils/scGPT/"

In [4]:
# Read and subset data
adata = sc.read_h5ad(data_path + "adata_all_raw.h5ad")
np.random.seed(42)
idx = np.random.choice(adata.n_obs, 51200, replace=False)
adata = adata[idx].copy()
adata

AnnData object with n_obs × n_vars = 51200 × 5001
    obs: 'cell_id', 'global_x', 'global_y', 'transcript_counts', 'control_probe_counts', 'genomic_control_counts', 'control_codeword_counts', 'unassigned_codeword_counts', 'deprecated_codeword_counts', 'total_counts', 'cell_area', 'nucleus_area', 'nucleus_count', 'segmentation_method', 'cell_type_merged', 'batch'
    var: 'gene_ids', 'feature_types', 'genome', 'gene'
    uns: 'batch_colors', 'cell_type_merged_colors'

In [5]:
# Load vocabulary
vocab_file = os.path.join(model_dir, "vocab.json")
vocab = GeneVocab.from_file(vocab_file)

In [6]:
# Map genes to vocabulary indices
adata.var["id_in_vocab"] = [vocab[g] if g in vocab else -1 for g in adata.var["gene"]]
adata = adata[:, adata.var["id_in_vocab"] >= 0].copy()
gene_ids = np.array(adata.var["id_in_vocab"])

In [7]:
# Load model configurations
with open(os.path.join(model_dir, "args.json"), "r") as f:
    model_configs = json.load(f)

In [8]:
# Initialize model
model = TransformerModel(
    ntoken=len(vocab),
    d_model=model_configs["embsize"],
    nhead=model_configs["nheads"],
    d_hid=model_configs["d_hid"],
    nlayers=model_configs["nlayers"],
    nlayers_cls=model_configs["n_layers_cls"],
    vocab=vocab,
    dropout=model_configs["dropout"],
    pad_token=model_configs["pad_token"],
    pad_value=model_configs["pad_value"],
    do_mvc=True,
    do_dab=False,
    use_batch_labels=False,
    domain_spec_batchnorm=False,
    explicit_zero_prob=True,
    n_input_bins=model_configs["n_bins"],
    use_fast_transformer=False,
)

In [9]:
# Load pretrained model weights
state_dict = torch.load(os.path.join(model_dir, "best_model.pt"), map_location="cpu")
load_pretrained(model, state_dict, verbose=True)

scGPT - INFO - Loading parameter encoder.embedding.weight with shape torch.Size([60697, 512])
scGPT - INFO - Loading parameter encoder.enc_norm.weight with shape torch.Size([512])
scGPT - INFO - Loading parameter encoder.enc_norm.bias with shape torch.Size([512])
scGPT - INFO - Loading parameter value_encoder.linear1.weight with shape torch.Size([512, 1])
scGPT - INFO - Loading parameter value_encoder.linear1.bias with shape torch.Size([512])
scGPT - INFO - Loading parameter value_encoder.linear2.weight with shape torch.Size([512, 512])
scGPT - INFO - Loading parameter value_encoder.linear2.bias with shape torch.Size([512])
scGPT - INFO - Loading parameter value_encoder.norm.weight with shape torch.Size([512])
scGPT - INFO - Loading parameter value_encoder.norm.bias with shape torch.Size([512])
scGPT - INFO - Loading parameter transformer_encoder.layers.0.self_attn.in_proj_weight with shape torch.Size([1536, 512])
scGPT - INFO - Loading parameter transformer_encoder.layers.0.self_attn.

TransformerModel(
  (encoder): GeneEncoder(
    (embedding): Embedding(60697, 512, padding_idx=60694)
    (enc_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (value_encoder): ContinuousValueEncoder(
    (dropout): Dropout(p=0.2, inplace=False)
    (linear1): Linear(in_features=1, out_features=512, bias=True)
    (activation): ReLU()
    (linear2): Linear(in_features=512, out_features=512, bias=True)
    (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-11): 12 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=512, bias=True)
        (dropout): Dropout(p=0.2, inplace=False)
        (linear2): Linear(in_features=512, out_features=512, bias=True)
        (norm1): LayerNorm((512,), eps=1e-05, el

In [11]:
import scipy.sparse as sp
def filter_empty_cells(adata):
    X = adata.X
    if sp.issparse(X):
        nnz = np.asarray((X != 0).sum(axis=1)).ravel()
    else:
        nnz = (X != 0).sum(axis=1)
    keep = nnz > 0
    print(f"Keeping {keep.sum()}/{adata.n_obs} cells with >0 nonzero genes in adata.X")
    return adata[keep].copy()

adata = filter_empty_cells(adata)
adata

Keeping 51176/51200 cells with >0 nonzero genes in adata.X


AnnData object with n_obs × n_vars = 51176 × 4993
    obs: 'cell_id', 'global_x', 'global_y', 'transcript_counts', 'control_probe_counts', 'genomic_control_counts', 'control_codeword_counts', 'unassigned_codeword_counts', 'deprecated_codeword_counts', 'total_counts', 'cell_area', 'nucleus_area', 'nucleus_count', 'segmentation_method', 'cell_type_merged', 'batch'
    var: 'gene_ids', 'feature_types', 'genome', 'gene', 'id_in_vocab'
    uns: 'batch_colors', 'cell_type_merged_colors'

In [12]:
# Get cell embeddings
model.eval()
embeddings = get_batch_cell_embeddings(
    adata,
    model=model,
    vocab=vocab,
    model_configs=model_configs,
    gene_ids=gene_ids,
    batch_size=64,
    max_length=1200,
)

Embedding cells: 100%|██████████| 800/800 [3:23:18<00:00, 15.25s/it]  


In [13]:
# Compute t-SNE on the embeddings
adata.obsm["X_scGPT"] = embeddings
X = adata.obsm["X_scGPT"]
X_tsne = TSNE(n_components=2, perplexity=30, learning_rate="auto", init="pca").fit_transform(X)
adata.obsm["X_tsne_scGPT"] = X_tsne

In [14]:
embeddings.shape

(51176, 512)

In [17]:
# Plot t-SNE colored by cell_type_merged
cell_types = adata.obs["cell_type_merged"].astype("category")
colors = cell_types.cat.codes

plt.figure(figsize=(6,6))
scatter = plt.scatter(X_tsne[:,0], X_tsne[:,1], s=0.5, alpha=0.7, c=colors, cmap="tab20")
plt.title(" ")
plt.xlabel("t-SNE 1")
plt.ylabel("t-SNE 2")
plt.xticks([])
plt.yticks([])
handles = [
    plt.Line2D([], [], marker="o", linestyle="", color=scatter.cmap(scatter.norm(i)), label=cat)
    for i, cat in enumerate(cell_types.cat.categories)
]
plt.legend(handles=handles, bbox_to_anchor=(1.05, 1), loc='upper left', title="Cell Types")
plt.savefig("scGPT_tsne_cell_type_merged.png", bbox_inches="tight", dpi=300)
plt.close()

In [26]:
# Plot t-SNE colored by batch
batches = adata.obs["batch"].astype("category")
colors = batches.cat.codes

plt.figure(figsize=(6,6))
scatter = plt.scatter(X_tsne[:,0], X_tsne[:,1], s=0.5, alpha=0.7, c=colors, cmap="tab10")
plt.title(" ")
plt.xlabel("t-SNE 1")
plt.ylabel("t-SNE 2")
plt.xticks([])
plt.yticks([])
handles = [
    plt.Line2D([], [], marker="o", linestyle="", color=scatter.cmap(scatter.norm(i)), label=cat)
    for i, cat in enumerate(batches.cat.categories)
]
plt.legend(handles=handles, bbox_to_anchor=(1.05, 1), loc='upper left', title="Batches")
plt.savefig("scGPT_tsne_batch.png", bbox_inches="tight", dpi=300)
plt.close()