#  Single Slide Domain Detection w/ Cell Embeddings

This notebook will fine tune the scGPT-spatial model for the single slide domain detection task and generate output in a csv file.

Input: adata with spatial coordinates and ground truth domain labels

Output: adata with spatial coordinates and predicted domain labels / cell name, spatial coordinates and predicted domain labels

## Colab Pre-requisites

In [1]:
import os

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

# mount to google drive
from google.colab import drive

# drive.flush_and_unmount()
drive.mount('/content/drive')
%cd /content/drive/MyDrive/ST_FM_Benchmark/scGPT_spatial

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/MyDrive/ST_FM_Benchmark/scGPT_spatial


In [2]:
# torchtext only support torch 2.3; torch text 0.18.0 is latest version.
# https://pytorch.org/get-started/locally
! pip install torch==2.3.0+cu121 torchvision==0.18.0 --index-url https://download.pytorch.org/whl/cu121

Looking in indexes: https://download.pytorch.org/whl/cu121
Collecting torch==2.3.0+cu121
  Downloading https://download.pytorch.org/whl/cu121/torch-2.3.0%2Bcu121-cp311-cp311-linux_x86_64.whl (781.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m781.0/781.0 MB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torchvision==0.18.0
  Downloading https://download.pytorch.org/whl/cu121/torchvision-0.18.0%2Bcu121-cp311-cp311-linux_x86_64.whl (7.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.0/7.0 MB[0m [31m105.5 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch==2.3.0+cu121)
  Downloading https://download.pytorch.org/whl/cu121/nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.7/23.7 MB[0m [31m123.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting nvidia-cuda-runtime-cu12==12.1.105 (from torch==2.3.0+cu121)
  Downl

In [3]:
! pip install -r requirements.txt

Collecting scgpt (from -r requirements.txt (line 1))
  Downloading scgpt-0.2.4-py3-none-any.whl.metadata (10.0 kB)
Collecting scanpy (from -r requirements.txt (line 3))
  Downloading scanpy-1.11.4-py3-none-any.whl.metadata (9.2 kB)
Collecting tdigest (from -r requirements.txt (line 4))
  Downloading tdigest-0.5.2.2-py3-none-any.whl.metadata (4.9 kB)
Collecting torchtext==0.18.0 (from -r requirements.txt (line 5))
  Downloading torchtext-0.18.0-cp311-cp311-manylinux1_x86_64.whl.metadata (7.9 kB)
Collecting anndata (from -r requirements.txt (line 6))
  Downloading anndata-0.12.2-py3-none-any.whl.metadata (9.6 kB)
Collecting numpy<1.24 (from -r requirements.txt (line 8))
  Downloading numpy-1.23.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.3 kB)
Collecting cell-gears<0.0.3 (from scgpt->-r requirements.txt (line 1))
  Downloading cell-gears-0.0.2.tar.gz (25 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting leidenalg>=0.8.10 (from scgpt->-r requ

In [None]:
# verify torch version
# expect:
# Python: 3.11.13 (main, Jun  4 2025, 08:57:29) [GCC 11.4.0]
# Torch: 2.3.0+cu121 CUDA: 12.1
# numpy:  1.23.5
# CXX11_ABI: False
import sys, torch, numpy as np

print("Python:", sys.version)
print("Torch:", torch.__version__, "CUDA:", torch.version.cuda)
print("numpy: ", np.__version__)
try:
    print("CXX11_ABI:", torch._C._GLIBCXX_USE_CXX11_ABI)  # 0=FALSE, 1=TRUE
except Exception as e:
    print(e)


Python: 3.11.13 (main, Jun  4 2025, 08:57:29) [GCC 11.4.0]
Torch: 2.3.0+cu121 CUDA: 12.1
numpy:  2.0.2
CXX11_ABI: False


In [None]:
# # install flash-attn without build
# !pip -q install -U pip setuptools wheel packaging ninja
# %env MAX_JOBS=4
# !pip install flash-attn --no-build-isolation

In [None]:
# # verify flash attn
# import torch
# import flash_attn
# print("Torch:", torch.__version__)
# print("FlashAttention:", getattr(flash_attn, "__version__", "unknown"))

## Zero-shot Domain Detection

In [5]:
from typing import Iterable
import scanpy as sc
from sklearn.metrics import silhouette_score

def predict_domain_using_embedding(adata,
                                   dom_key: str = "domain",  # output label for clustering result
                                   method: str = "leiden", # "leiden" or "louvain"
                                   rep_key: str = "X_scGPT",  # adata.obsm[rep_key]: embedding
                                   target_clusters: int = 6,
                                   n_neighbors: int = 15,
                                   resolution_grid: Iterable[float] = (0.3, 0.5,
                                                                       0.8, 1.0,
                                                                       1.2),
                                   return_silhouette: bool = True,
                                   ):
    """
    predict domain using generated embedding.
    Returns:
      labels: np.ndarray[int] — Domain labels for each spot (encoded integers)
      best_res: Optional[float] — Resolution to use (if automatically selected)
      best_score: Optional[float] — Silhouette score (if calculated)
    Side Effects:
      - Writes adata.obsm[rep_key] = (n_spot, D)
      - Writes adata.obs[dom_key] = pandas.Categorical
    Depends:
      - adata.obsm[rep_key]
    """
    sc.pp.neighbors(adata, use_rep=rep_key, n_neighbors=n_neighbors)

    def _cluster_at(res):
        if method == "leiden":
            sc.tl.leiden(adata, resolution=res, key_added=f"{dom_key}_tmp")
        elif method == "louvain":
            sc.tl.louvain(adata, resolution=res, key_added=f"{dom_key}_tmp")
        else:
            raise ValueError("method must be 'leiden' or 'louvain'")
        return adata.obs[f"{dom_key}_tmp"].astype(
            "category").cat.codes.to_numpy()

    best_res, best_score, best_labels, best_clusters = None, -1.0, None, None
    if target_clusters is None:
        for res in resolution_grid:
            labels = _cluster_at(res)
            if len(np.unique(labels)) < 2:
                score = -1.0
            else:
                try:
                    score = silhouette_score(adata.obsm[rep_key], labels)
                except Exception:
                    score = -1.0
            if score > best_score:
                best_res, best_score, best_labels = res, score, labels
    else:
        for res in resolution_grid:
            labels = _cluster_at(res)
            try:
                score = silhouette_score(adata.obsm[rep_key], labels)
            except Exception:
                score = -1.0
            if best_clusters is None or best_clusters >= abs(
                    len(np.unique(labels)) - target_clusters):
                best_clusters = abs(len(np.unique(labels)) - target_clusters)
                best_res, best_score, best_labels = res, score, labels

    labels = best_labels if best_labels is not None else _cluster_at(
        resolution_grid[0])
    adata.obs[dom_key] = labels
    adata.obs[dom_key] = adata.obs[dom_key].astype("category")
    return labels, best_res, (best_score if return_silhouette else None)

In [2]:
# load data
import numpy as np
import scanpy as sc

data_folder = '../../data/1_visium/'
adata = sc.read_h5ad(data_folder + '1_visium_scgpt_zero_shot.h5ad')
adata = adata[np.logical_not(adata.obs['ground_truth'].isna())]  #remove NAN
print(adata)
print(adata.obs.ground_truth.unique())

View of AnnData object with n_obs × n_vars = 4221 × 33538
    obs: 'in_tissue', 'array_row', 'array_col', 'Region', 'ground_truth'
    var: 'gene_ids', 'feature_types', 'genome'
    uns: 'spatial'
    obsm: 'X_scGPT', 'spatial'
['Layer1', 'Layer3', 'WM', 'Layer6', 'Layer5', 'Layer2', 'Layer4']
Categories (7, object): ['Layer1', 'Layer2', 'Layer3', 'Layer4', 'Layer5', 'Layer6', 'WM']


In [7]:
# run leiden for domain detection w. scGPT-spatial embeddings
labels, _, _ = predict_domain_using_embedding(adata, dom_key="domain_scgpt",
                                              method="leiden",
                                              rep_key="X_scGPT",
                                              target_clusters=6,
                                              n_neighbors=7)

# save domain detection results to adata / csv.
adata.write(data_folder + "1_visium_scgpt_zero_shot_domain_detection.h5ad")

print(adata)
print(adata.obs.domain_scgpt)

AnnData object with n_obs × n_vars = 4221 × 33538
    obs: 'in_tissue', 'array_row', 'array_col', 'Region', 'ground_truth', 'domain_scgpt_tmp', 'domain_scgpt'
    var: 'gene_ids', 'feature_types', 'genome'
    uns: 'spatial', 'neighbors', 'leiden'
    obsm: 'X_scGPT', 'spatial'
    obsp: 'distances', 'connectivities'
AAACAACGAATAGTTC-1    0
AAACAAGTATCTCCCA-1    1
AAACAATCTACTAGCA-1    0
AAACACCAATAACTGC-1    6
AAACAGCTTTCAGAAG-1    3
                     ..
TTGTTGTGTGTCAAGA-1    5
TTGTTTCACATCCAGG-1    4
TTGTTTCATTAGTCTA-1    6
TTGTTTCCATACAACT-1    4
TTGTTTGTGTAAATTC-1    0
Name: domain_scgpt, Length: 4221, dtype: category
Categories (7, int8): [0, 1, 2, 3, 4, 5, 6]
