# Unsupervised Finetune scGPT-spatial

Object: GEPS intra, GEPS inter

Requirements: python <= 3.10, torch==2.3.0+cu121, torchtext==0.18.0, numpy<1.24

## Colab Pre-Requisites

In [2]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

# mount to google drive
from google.colab import drive
# drive.flush_and_unmount()
drive.mount('/content/drive')
%cd /content/drive/MyDrive/ST_FM_Benchmark/scGPT_spatial

Mounted at /content/drive
/content/drive/MyDrive/ST_FM_Benchmark/scGPT_spatial


In [3]:
# # torchtext only support torch 2.3; torch text 0.18.0 is latest version.
# # https://pytorch.org/get-started/locally
# ! pip install torch==2.3.0+cu121 torchvision==0.18.0 --index-url https://download.pytorch.org/whl/cu121
# ! pip install torchtext==0.18.0
# ! pip install scgpt
# ! pip install datasets
# ! pip install scanpy
# ! pip install tdigest
# ! pip install anndata
# ! pip install numpy==1.23.5

In [4]:
# verify torch version
# expect:
# Python: 3.11.13 (main, Jun  4 2025, 08:57:29) [GCC 11.4.0]
# Torch: 2.3.0+cu121 CUDA: 12.1
# numpy:  1.23.5
# CXX11_ABI: False

# import sys, torch, numpy as np
# print("Python:", sys.version)
# print("Torch:", torch.__version__, "CUDA:", torch.version.cuda)
# print("numpy: ", np.__version__)
# try:
#     print("CXX11_ABI:", torch._C._GLIBCXX_USE_CXX11_ABI)  # 0=FALSE, 1=TRUE
# except Exception as e:
#     print(e)


## Helper Functions

In [5]:
import warnings
from dataclasses import dataclass
from typing import List, Tuple  # type: ignore

import torch

warnings.filterwarnings('ignore')


# configs
@dataclass
class Config:
    h5ad: str = ""
    ckpt_dir: str = ""
    finetuned_ckpt_dir: str = "results"

    seed: int = 42
    device: str = "cuda" if torch.cuda.is_available() else "cpu"

    # hyperparams
    n_hvg: int = 600
    n_bins: int = 51
    max_seq_len: int = 601  # n_hvg + 1 for <cls>
    mask_ratio: float = 0.4
    pad_token: str = '<pad>'
    pad_value: int = -2
    mask_value: int = -1
    lr: float = 5e-05
    schedule_ratio = 0.9  # Default rate for learning rate decay
    epochs: int = 15
    batch_size: int = 16
    n_neighbors: int = 7  # Neighbors per patch. Each patch contains (1+n_neighbors) cells.

    # data
    batch_key_col: str = "protocol"  # batch key col name in adata.obs
    gene_names_col: str = "gene_names"  # gene name col in adata.var. Use `index` if gene name is the index of adata.var.
    coordinates_x_col: str = "array_row"  # x coordinates in adata.obs.
    coordinates_y_col: str = "array_col"  # x coordinates in adata.obs.

    # domain prediction
    dom_key: str = "domain_scgpt"  # key for the domain prediction results
    rep_key: str = "X_scGPT"  # key for the embeddings
    method: str = "leiden"  # "leiden" or "louvain"
    target_clusters: int = 6
    leiden_resolutions: Tuple[float, ...] = (0.3, 0.5, 0.8, 1.0, 1.2)

In [6]:
from scgpt_spatial.tasks.cell_emb import load_pretrained
from pathlib import Path
from scgpt.tokenizer import GeneVocab
import scgpt_spatial
import json


def load_backbone_and_vocab(ckpt_dir: Path, device: str) -> tuple[
    scgpt_spatial.model.TransformerModel, GeneVocab, dict]:
    model_config_file = ckpt_dir / "args.json"
    model_file = ckpt_dir / "best_model.pt"
    vocab_file = ckpt_dir / "vocab.json"
    with open(model_config_file, "r") as f:
        model_args = json.load(f)
    vocab = GeneVocab.from_file(vocab_file)
    for tok in ("<pad>", "<cls>", "<eoc>"):
        if tok not in vocab:
            vocab.append_token(tok)
    vocab.set_default_index(vocab["<pad>"])

    model = scgpt_spatial.model.TransformerModel(
        ntoken=len(vocab),
        d_model=model_args["embsize"],
        nhead=model_args["nheads"],
        d_hid=model_args["d_hid"],
        nlayers=model_args["nlayers"],
        nlayers_cls=model_args["n_layers_cls"],
        n_cls=1,
        vocab=vocab,
        dropout=model_args["dropout"],
        pad_token=model_args["pad_token"],
        pad_value=model_args["pad_value"],
        do_mvc=True,
        do_dab=False,
        use_batch_labels=True,
        num_batch_labels=4,  # trained by 4 protocols
        input_emb_style=model_args["input_emb_style"],
        n_input_bins=model_args["n_bins"],
        cell_emb_style='cls',
        mvc_decoder_style="inner product",
        ecs_threshold=0.8,
        explicit_zero_prob=False,
        use_generative_training=True,
        use_fast_transformer=False,
        fast_transformer_backend="flash",
        pre_norm=False,
        use_MVC_impute=True,
        impute_MVC_knn_k=model_args["impute_k"],
        use_moe_dec=True,
    )
    load_pretrained(model, torch.load(model_file, map_location=device),
                    verbose=True)
    model.to(device)
    model.eval()

    # sync config params
    config.lr = model_args["lr"]
    return model, vocab, model_args



In [7]:
import pandas as pd
from sklearn.model_selection import train_test_split
from scipy.sparse import issparse
import scanpy as sc


def split_dataset(h5ad_path: Path, config: Config, vocab: GeneVocab):
    adata = sc.read_h5ad(str(h5ad_path))

    # tokenize genes
    adata.var["id_in_vocab"] = [
        vocab[gene] if gene in vocab else -1 for gene in (adata.var.index if
                                                          config.gene_names_col == 'index' else \
                                                              adata.var[
                                                                  config.gene_names_col])
    ]
    adata = adata[:, adata.var["id_in_vocab"] >= 0]
    genes = adata.var.index.tolist() if config.gene_names_col == 'index' else \
        adata.var[config.gene_names_col].tolist()
    gene_ids = np.array(vocab(genes), dtype=int)

    all_counts = (
        adata.X.toarray() if issparse(adata.X) else adata.X
    )

    batch_ids = adata.obs[config.batch_key_col].tolist()
    batch_ids = np.array(batch_ids)

    coordinates = pd.concat(
        [adata.obs[config.coordinates_x_col],
         adata.obs[config.coordinates_y_col]],
        axis=1).to_numpy()

    (
        train_data,
        valid_data,
        train_batch_labels,
        valid_batch_labels,
        train_xy,
        valid_xy,
    ) = train_test_split(
        all_counts, batch_ids, coordinates, test_size=0.1, shuffle=True
    )

    scgpt_spatial.logger.info(
        f"Train set number of samples: {train_data.shape[0]}  Valid set number of samples: {valid_data.shape[0]}"
    )

    return gene_ids, train_data, valid_data, train_batch_labels, valid_batch_labels, train_xy, valid_xy

In [8]:
from torch.utils.data import DataLoader


# dataset. Adopted from scgpt-spatial.
class Dataset(torch.utils.data.Dataset):
    def __init__(self, count_matrix, gene_ids, batch_labels, coordinates,
                 config: Config):
        self.slide_mean = np.mean(
            count_matrix[count_matrix.nonzero()[0], count_matrix.nonzero()[1]])
        self.count_matrix = count_matrix / self.slide_mean
        self.gene_ids = gene_ids
        self.batch_ids = batch_labels
        self.coordinates = coordinates
        self.gene_stats_dict = pd.read_csv(
            Path(config.ckpt_dir, "all_dict_mean_std.csv"), index_col=0)
        self.config = config
        # For any new genes not already in gene_stats_dict
        # Calculate mean from current dataset
        new_genes = set(self.gene_ids).difference(
            set(self.gene_stats_dict.index.values))
        for i in new_genes:
            idx = np.where(self.gene_ids == i)[0]
            col = self.count_matrix[:, idx].flatten()
            nonzero_idx = np.nonzero(col)[0]
            values = col[nonzero_idx]
            self.gene_stats_dict.loc[i] = [float(values.mean())]

    def __len__(self):
        return len(self.count_matrix)

    def __getitem__(self, idx):
        row = self.count_matrix[idx]
        nonzero_idx = np.nonzero(row)[0]
        # values = row[nonzero_idx]
        # genes = self.gene_ids[nonzero_idx]
        mean_divide_by = self.gene_stats_dict.loc[self.gene_ids, 'mean'].values
        values = np.divide(row, mean_divide_by)
        # append <cls> token at the beginning
        genes = np.insert(self.gene_ids, 0, vocab["<cls>"])
        values = np.insert(values, 0, config.pad_value, axis=1)
        genes = torch.from_numpy(genes).long()
        values = torch.from_numpy(values).float()
        output = {
            "id": idx,
            "genes": genes,
            "expressions": torch.nan_to_num(values, nan=0.0),
            "batch_labels": torch.from_numpy(self.batch_ids[idx]).long(),
            "coordinates": torch.from_numpy(self.coordinates[idx]),
        }
        return output

In [9]:
# Data collator. Adapted from scGPT-spatial.
from scgpt_spatial import binning
from dataclasses import dataclass, field
from typing import Callable, Dict, Mapping, Optional, Tuple

import numpy as np


@dataclass
class DataCollator:
    """
    Data collator for the mask value learning task. It pads the sequences to
    the maximum length in the batch and masks the gene expression values.

    Args:
        do_padding (:obj:`bool`): whether to pad the sequences to the max length.
        pad_token_id (:obj:`int`, optional): the token id to use for padding.
            This is required if do_padding is True.
        pad_value (:obj:`int`): the value to use for padding the expression
            values to the max length.
        do_mlm (:obj:`bool`): whether to do masking with MLM.
        do_binning (:obj:`bool`): whether to bin the expression values.
        n_bins (:obj:`int`): the number of bins to use for binning.
        mlm_probability (:obj:`float`): the probability of masking with MLM.
        mask_value (:obj:`int`): the value to fill at the expression postions
            that are masked.
        max_length (:obj:`int`, optional): the maximum length of the sequences.
            This is required if do_padding is True.
        sampling (:obj:`bool`): whether to do sampling instead of truncation if
            length > max_length.
        reserve_keys (:obj:`List[str]`, optional): a list of keys in the examples
            to reserve in the output dictionary. Default to []. These fields
            will be kept unchanged in the output.
        append_tokens (:obj:`List[Callable]`, optional): a list of functions to
            append tokens to the beginning of the sequence. Each function takes
            an example as input and append tokens to the beginning of the sequence
            after `keep_first_n_tokens` tokens. This is useful when special tokens
            have been added to the beginning of the sequence. Default to [].
        keep_first_n_tokens (:obj:`int`): the number of tokens in the beginning
            of the sequence to keep unchanged from sampling. This is useful when
            special tokens have been added to the beginning of the sequence. **Note**
            that the `append_tokens` will be handled automatically, so only include
            other tokens in this number. Default to 1.
        data_style (:obj:`str`): the style of the data. If "pcpt", the data is
            masked and padded for perception training. If "gen", only the gene
            tokens are provided, but not the expression values, for pure generative
            training setting. If "both", the output will contain both fields above.
            Choices: "pcpt", "gen", "both". Default to "pcpt".
    """

    do_padding: bool = True
    pad_token_id: Optional[int] = None
    pad_value: int = 0
    do_mlm: bool = True
    do_binning: bool = True
    n_bins: int = 51
    mlm_probability: float = 0.15
    mask_value: int = -1
    max_length: Optional[int] = None
    sampling: bool = True
    reserve_keys: List[str] = field(default_factory=lambda: [])
    append_tokens: List[Callable] = field(default_factory=lambda: [])
    keep_first_n_tokens: int = 1  # for <cls>
    data_style: str = "both"
    device: str = "cpu"

    def __post_init__(self):
        if self.do_padding:
            if self.pad_token_id is None:
                raise ValueError("`pad_token_id` is required if `do_padding`.")
            if self.max_length is None:
                raise ValueError("`max_length` is required if `do_padding`.")

        if self.do_binning:
            if self.n_bins < 2:
                raise ValueError("`n_bins` must be greater than 1.")

        if isinstance(self.mlm_probability, float):
            if self.mlm_probability <= 0 or self.mlm_probability >= 1:
                raise ValueError("`mlm_probability` must be between 0 and 1.")
        elif isinstance(self.mlm_probability, (list, tuple)):
            if min(self.mlm_probability) <= 0 or max(self.mlm_probability) >= 1:
                raise ValueError("`mlm_probability` must be between 0 and 1.")
        else:
            raise ValueError(
                "`mlm_probability` must be a float or iterable of floats.")

        if isinstance(self.reserve_keys, str):
            self.reserve_keys = [self.reserve_keys]

        if len(self.append_tokens) > 0:
            self.original_prefix_n_tokens = self.keep_first_n_tokens
            self.keep_first_n_tokens = self.keep_first_n_tokens + len(
                self.append_tokens
            )

        if self.keep_first_n_tokens < 0 or self.keep_first_n_tokens > self.max_length:
            raise ValueError(
                "`keep_first_n_tokens` must be between 0 and `max_length` "
                f"({self.max_length})."
            )

        if self.data_style not in ["pcpt", "gen", "both"]:
            raise ValueError(
                "`data_style` must be one of 'pcpt', 'gen', 'both'.")

    def __call__(
            self, examples: List[Dict[str, torch.Tensor]]
    ) -> Dict[str, torch.Tensor]:
        """
        Args:
            examples (:obj:`List[Dict[str, torch.Tensor]]`): a list of data dicts.
                Each dict is for one cell. It contains multiple 1 dimensional tensors
                like the following exmaple:
                    {'id': tensor(184117),
                    'genes': tensor([36572, 17868, ..., 17072]),
                    'expressions': tensor([ 0.,  2., ..., 18.])}

        Returns:
            :obj:`Dict[str, torch.Tensor]`: a dict of tensors.
        """
        if len(self.reserve_keys) > 0:
            assert all(key in examples[0] for key in self.reserve_keys), (
                f"reserve_keys must be a subset of the keys in the examples. "
                f"Got {self.reserve_keys} but expected keys in {list(examples[0].keys())}."
            )

        if len(self.append_tokens) > 0:
            for i, func in enumerate(self.append_tokens):
                examples = self._append_token(
                    examples, func,
                    prefix_n_tokens=self.original_prefix_n_tokens + i
                )

        if self.data_style == "pcpt":
            data_dict = self._call_pcpt(examples)
        elif self.data_style == "gen":
            data_dict = self._call_gen(examples)
        elif self.data_style == "both":
            data_dict = self._call_both(examples)

        # add reserved keys
        for key in self.reserve_keys:
            data_ = [example[key] for example in examples]
            data_dict[key] = torch.stack(data_, dim=0).to(self.device)

        return data_dict

    def _append_token(
            self,
            examples: List[Dict[str, torch.Tensor]],
            func: Callable,
            prefix_n_tokens: int,
    ) -> Tuple[List[Dict[str, torch.Tensor]], int]:
        """
        Append tokens to the beginning of the sequence. This is useful when special
        tokens have been added to the beginning of the sequence.

        Args:
            examples (:obj:`List[Dict[str, torch.Tensor]]`): a list of data dicts.
                Each dict is for one cell. It contains multiple 1 dimensional tensors
                like the following exmaple:
                    {'id': tensor(184117),
                    'genes': tensor([36572, 17868, ..., 17072]),
                    'expressions': tensor([ 0.,  2., ..., 18.])}
            func (:obj:`Callable`): a function that takes an example as input and
                append tokens to the beginning of the sequence after `keep_first_n_tokens`
                tokens.
            prefix_n_tokens (:obj:`int`): the number of tokens in the beginning
                of the sequence to keep unchanged from sampling.

        Returns:
            Tuple[List[Dict[str, torch.Tensor]], int]: the updated examples and
                keep_first_n_tokens.
        """
        for i in range(len(examples)):
            examples[i] = func(examples[i], prefix_n_tokens)
        return examples

    def _call_pcpt(
            self, examples: List[Dict[str, torch.Tensor]]
    ) -> Dict[str, torch.Tensor]:
        """
        Each example is like:
            {'id': tensor(184117),
            'genes': tensor([36572, 17868, ..., 17072]),
            'expressions': tensor([ 0.,  2., ..., 18.])}

        Args:
            examples (:obj:`List[Dict[str, torch.Tensor]]`): a list of examples.
                Each example is a dictionary of tensors.
        Returns:
            :obj:`Dict[str, torch.Tensor]`: a dictionary of tensors.
        """
        if not isinstance(examples[0], Mapping):
            return NotImplementedError

        max_ori_len = max(len(example["genes"]) for example in examples)
        _max_length = self.max_length if max_ori_len >= self.max_length else max_ori_len

        # pad and truncate
        padded_genes = []
        padded_expressions = []
        for i in range(len(examples)):
            genes = examples[i]["genes"]
            expressions = examples[i]["expressions"]
            if self.do_binning:
                expressions[self.keep_first_n_tokens:] = binning(
                    row=expressions[self.keep_first_n_tokens:],
                    n_bins=self.n_bins,
                )

            genes, expressions = self._sample_or_truncate_plus_pad(
                genes, expressions, _max_length
            )  # torch tensors of length _max_length

            padded_genes.append(genes)
            padded_expressions.append(expressions)

        padded_genes = torch.stack(padded_genes, dim=0).to(self.device)
        padded_expressions = torch.stack(padded_expressions, dim=0).to(
            self.device)

        data_dict = {
            "gene": padded_genes,
            "expr": padded_expressions,
        }

        # mask
        if self.do_mlm:
            masked_expressions = self._mask(
                padded_expressions, self.keep_first_n_tokens
            )
        else:
            masked_expressions = padded_expressions
        data_dict["masked_expr"] = masked_expressions

        return data_dict

    def _call_gen(
            self,
            examples: List[Dict[str, torch.Tensor]]
    ) -> Dict[str, torch.Tensor]:
        """
        This method will simply return the gene ids, with needed padding. There is
        no masking for pure generative training, and no input of expr values.

        Each example is like:
            {'id': tensor(184117),
            'genes': tensor([36572, 17868, ..., 17072])}

        Returns:
            Dict[str, torch.Tensor]: a dict of tensors.
            Example:
                {'pcpt_gene': tensor([[36572, 17868, ..., 17072],
                                        [36572, 17868, ..., 17072],
                                        ...,
                                        [36572, 17868, ..., 17072]]),
                'pcpt_expr': tensor([[ 0.,  2., ..., 18.],
                                        [ 0.,  2., ..., 18.],
                                        ...,
                                        [ 0.,  2., ..., 18.]])}
        """
        raise NotImplementedError("Code coming soon with pre-training branch.")

    def _call_both(
            self,
            examples: List[Dict[str, torch.Tensor]],
    ) -> Dict[str, torch.Tensor]:
        """
        This method will split the input into the peception part and the generation
        part. The perception part will be processed into gene ids and expr values,
        and the generation part will be processed into gene ids only.

        By default, the mlm_probability will be used to select the genese assigned to
        the generation part.

        Each example is like:
            {'id': tensor(184117),
            'genes': tensor([36572, 17868, ..., 17072]),
            'expressions': tensor([ 0.,  2., ..., 18.])}

        Args:
            gen_prob (float, optional): the probability of a gene being assigned to
                the generation part. If not provided, the mlm_probability will be used.

        Returns:
            Dict[str, torch.Tensor]: a dict of tensors.
            Example:
                {'pcpt_gene': tensor([[36572, 17868, ..., 17072],
                                        [36572, 17868, ..., 17072],
                                        ...,
                                        [36572, 17868, ..., 17072]]),
                'pcpt_expr': tensor([[ 0.,  2., ..., 18.],
                                        [ 0.,  2., ..., 18.],
                                        ...,
                                        [ 0.,  2., ..., 18.]]),
                'gen_gene': tensor([[36573, 17869, ..., 17073],
                                        [36573, 17869, ..., 17073],
                                        ...,
                                        [36573, 17869, ..., 17073]]),
                'gen_expr_target': tensor([[ 1.,  3., ..., 19.],
                                        [ 1.,  3., ..., 19.],
                                        ...,
                                        [ 1.,  3., ..., 19.]])}
        """
        if not isinstance(examples[0], Mapping):
            return NotImplementedError

        max_ori_len = max(len(example["genes"]) for example in examples)
        _max_length = self.max_length if max_ori_len >= self.max_length else max_ori_len

        # pad and truncate
        padded_genes = []
        padded_expressions = []
        gen_genes, gen_expr_targets = [], []
        for i in range(len(examples)):
            for expression in examples[i]["expressions"]:
                genes = examples[i]["genes"]
                # randomly choose query genes with around 1:1 zero and non-zero expressions
                nz_idx = expression.nonzero().squeeze()
                z_idx = (expression == 0).nonzero().squeeze()

                total_gen_gene_num = int(_max_length * self.mlm_probability)
                if nz_idx.shape[0] <= z_idx.shape[0]:
                    nz_gen_gene_num = int(
                        min(nz_idx.shape[0] * 0.9, total_gen_gene_num // 2))
                    z_gen_gene_num = total_gen_gene_num - nz_gen_gene_num
                else:
                    z_gen_gene_num = int(
                        min(z_idx.shape[0] * 0.9, total_gen_gene_num // 2))
                    nz_gen_gene_num = total_gen_gene_num - z_gen_gene_num
                z_gen_idx = z_idx[
                    torch.randperm(z_idx.shape[0])[:z_gen_gene_num]]
                nz_gen_idx = nz_idx[
                    torch.randperm(nz_idx.shape[0])[:nz_gen_gene_num]]
                gen_gene_idx = torch.cat((nz_gen_idx, z_gen_idx), dim=0)
                gen_gene, gen_expr_target = self._sample_or_truncate_plus_pad(
                    genes[gen_gene_idx], expression[gen_gene_idx], _max_length
                )
                gen_genes.append(gen_gene)
                gen_expr_targets.append(gen_expr_target)

                # choose pcpt genes
                mask = torch.ones(expression.shape[0], dtype=torch.bool)
                mask[gen_gene_idx] = False
                expression = expression[mask]
                genes = genes[mask]
                if self.do_binning:
                    expression[self.keep_first_n_tokens:] = binning(
                        row=expression[self.keep_first_n_tokens:],
                        n_bins=self.n_bins,
                    )

                genes, expression = self._sample_or_truncate_plus_pad(
                    genes, expression, _max_length
                )  # torch tensors of length _max_length

                padded_genes.append(genes)
                padded_expressions.append(expression)

        padded_gene_genes = torch.stack(gen_genes, dim=0).to(self.device)
        padded_gen_expr_target = torch.stack(gen_expr_targets, dim=0).to(
            self.device)
        padded_genes = torch.stack(padded_genes, dim=0).to(self.device)
        padded_expressions = torch.stack(padded_expressions, dim=0).to(
            self.device)

        data_dict = {
            "gen_gene": padded_gene_genes,
            "gen_expr_target": padded_gen_expr_target,
            "pcpt_gene": padded_genes,
            "pcpt_expr": padded_expressions,
        }

        return data_dict

    def get_mlm_probability(self) -> float:
        """
        Get the mlm probability for the current step.
        """
        if isinstance(self.mlm_probability, float):
            return self.mlm_probability
        elif isinstance(self.mlm_probability, list):
            # random choose a probability
            return np.random.choice(self.mlm_probability)
        else:
            raise ValueError(
                "mlm_probability must be a float or a list of floats, "
                f"but got {self.mlm_probability}."
            )

    def _mask(
            self, expressions: torch.Tensor, keep_first_n_tokens: int = 0
    ) -> torch.Tensor:
        """
        Mask the expression values with MLM.
        """
        if keep_first_n_tokens > 0:
            result_ = self._mask(
                expressions[:, keep_first_n_tokens:],
                keep_first_n_tokens=0,
            )
            return torch.cat([expressions[:, :keep_first_n_tokens], result_],
                             dim=1)

        shape = expressions.shape

        probability_matrix = torch.full(shape, self.get_mlm_probability())
        # set padded postion probability to 0
        probability_matrix[expressions.eq(self.pad_value)] = 0
        if self.keep_first_n_tokens > 0:
            probability_matrix[:, : self.keep_first_n_tokens] = 0

        mask = torch.bernoulli(probability_matrix).bool()
        mask = mask.to(self.device)

        masked_expressions = expressions.masked_fill(mask, self.mask_value)
        return masked_expressions

    def _sample_or_truncate_plus_pad(
            self,
            genes: torch.LongTensor,
            expressions: torch.Tensor,
            max_length: int,
    ) -> Tuple[torch.LongTensor, torch.Tensor]:
        assert len(genes) == len(expressions)
        if len(genes) == max_length:
            return genes, expressions
        if len(genes) > max_length:  # sample or truncate
            if self.sampling:
                return self._sample(genes, expressions, max_length)
            else:
                return genes[:max_length], expressions[:max_length]
        else:  # pad
            return self._pad(genes, expressions, max_length)

    def _sample(
            self,
            genes: torch.LongTensor,
            expressions: torch.Tensor,
            max_length: int,
    ) -> Tuple[torch.LongTensor, torch.Tensor]:
        # NOTE: the fastest way to sample in torch has been benchmarked here
        # https://discuss.pytorch.org/t/torch-equivalent-of-numpy-random-choice/16146/19
        # it shows the randperm on gpu is the fastest.
        # NOTE: also, the current implementation permute the orders of the genes
        # and expressions, although it is probably a nice argmentation.
        if self.keep_first_n_tokens == 0:
            indices = torch.randperm(len(genes), device=self.device)[
                      :max_length]
            return genes[indices], expressions[indices]

        # keep the first n tokens unchanged
        _n = self.keep_first_n_tokens
        indices = torch.randperm(len(genes) - _n, device=self.device)[
                  : max_length - _n]
        indices = torch.cat([torch.arange(_n), indices + _n], dim=0)
        return genes[indices], expressions[indices]

    def _pad(
            self,
            genes: torch.LongTensor,
            expressions: torch.Tensor,
            max_length: int,
    ):
        genes = torch.cat(
            [
                genes,
                torch.full(
                    (max_length - len(genes),),
                    self.pad_token_id,
                    dtype=genes.dtype,
                    device=self.device,
                ),
            ]
        )
        expressions = torch.cat(
            [
                expressions,
                torch.full(
                    (max_length - len(expressions),),
                    self.pad_value,
                    dtype=expressions.dtype,
                    device=self.device,
                ),
            ]
        )
        return genes, expressions

In [10]:
# spatial aware sampler
from __future__ import annotations
import random
from typing import List
from typing import Iterable

import torch
from torch.utils.data import Sampler


# per slide (default single slide)
class KNNPatchBatchSampler(Sampler[List[int]]):
    def __init__(
            self,
            coordinates,
            patches_per_batch: int = 8,
            # batch_size = (k_neighbors + 1) * patches_per_batch
            k_neighbors: int = 7,  # neighbors per patch
            seed: int = 42,
    ):
        self.rng = random.Random(seed)
        self.patch_num = len(coordinates) // (k_neighbors + 1)
        self.patches_per_batch = patches_per_batch
        self.k_neighbors = k_neighbors
        # one anchor per patch
        anchor_indices = random.sample(range(len(coordinates)), self.patch_num)

        # pre-compute knn
        coordinates = torch.tensor(coordinates, dtype=torch.float32)
        dist = torch.cdist(coordinates.float(), coordinates.float(),
                           p=2)  # Shape [n_cells, n_cells]
        self.patches = []
        for anchor_index in anchor_indices:
            topk_indices = \
                torch.topk(dist, k=k_neighbors + 1, dim=-1, largest=False,
                           sorted=True)[
                    1]  # Shape [n_cells, (k_neighbors + 1)]
            if torch.isinf(topk_indices[anchor_index]).any().item():
                # DropLast = true
                break
            self.patches.append(topk_indices[anchor_index])
            dist[anchor_index, :] = torch.inf
            dist[:, anchor_index] = torch.inf

    def __iter__(self) -> Iterable[List[int]]:
        random_patch_idx = torch.randperm(self.patch_num)
        for i in range(0, self.patch_num, self.patches_per_batch):
            batch = [x for j in random_patch_idx[i:i + self.patches_per_batch]
                     for x in self.patches[j]]
            self.rng.shuffle(batch)
            yield batch

    def __len__(self) -> int:
        return self.patch_num * (self.k_neighbors + 1)

In [11]:
# prepare data loader
def prepare_dataloader(count_matrix, gene_ids, batch_labels, coordinates, vocab,
                       config: Config) -> DataLoader:
    dataset = Dataset(count_matrix, gene_ids, batch_labels, coordinates, config)

    collator = DataCollator(
        do_padding=True,
        pad_token_id=vocab[config.pad_token],
        pad_value=config.pad_value,
        do_mlm=True,
        do_binning=True,
        n_bins=51,
        mlm_probability=config.mask_ratio,
        mask_value=config.mask_value,
        max_length=config.max_seq_len,
        sampling=True,
        reserve_keys=['coordinates', 'batch_labels'],
        keep_first_n_tokens=1,
        data_style='both',
    )
    return DataLoader(
        dataset,
        sampler=KNNPatchBatchSampler(coordinates,
                                     patches_per_batch=config.batch_size // (
                                             config.n_neighbors + 1),
                                     k_neighbors=config.n_neighbors,
                                     seed=config.seed),
        collate_fn=collator,
        num_workers=0,
        pin_memory=True,
    )

In [12]:
from torch import nn


def train_epoch(epoch, model: nn.Module, loader, scheduler, optimizer,
                criterion, scaler, device, vocab, config: config):
    """
    Train the model for one epoch.
    """

    total_loss, total_mse, total_geps_intra, total_geps_inter = 0.0, 0.0, 0.0, 0.0

    num_batches = len(loader) // config.batch_size
    start_time = time.time()
    for batch, batch_data in enumerate(loader):
        pcpt_gene = batch_data["pcpt_gene"].to(device)
        pcpt_expr = batch_data["pcpt_expr"].to(device)
        pcpt_key_padding_mask = pcpt_gene.eq(vocab[config.pad_token])
        gen_gene = batch_data["gen_gene"].to(device)
        gen_expr_target = batch_data["gen_expr_target"].to(device)
        gen_key_padding_mask = gen_gene.eq(vocab[config.pad_token])
        coordinates = batch_data["coordinates"].to(device)
        batch_labels = batch_data["batch_labels"].to(device)

        with (torch.cuda.amp.autocast(enabled=False)):
            output_dict = model(
                pcpt_genes=pcpt_gene,
                pcpt_values=pcpt_expr,
                pcpt_key_padding_mask=pcpt_key_padding_mask,
                gen_genes=gen_gene,
                gen_key_padding_mask=gen_key_padding_mask,
                batch_labels=batch_labels,
                coordinates=coordinates,
                CLS=False,
                MVC=True,
                ECS=False,
                MVC_impute=True,
                do_sample=False,
                input_cell_emb=None,
                generative_training=True
            )  # dict: key: (batch, embsize)

            # gepc intra
            masked_positions = gen_gene.ne(vocab[config.pad_token])
            loss = loss_mse = criterion(
                output_dict["gen_preds"], gen_expr_target, masked_positions
            )
            metrics_to_log = {"train/mse": loss_mse.item()}

            geps_target = torch.cat([pcpt_expr, gen_expr_target], dim=1)
            pcpt_mask = torch.zeros(pcpt_expr.shape, dtype=torch.bool).to(
                device)
            geps_masked_position = torch.cat(
                [pcpt_mask, masked_positions], dim=1)
            loss_geps_intra = criterion(
                output_dict["mvc_output"], geps_target, geps_masked_position
            )
            loss += loss_geps_intra
            metrics_to_log.update(
                {"train/loss_geps_intra": loss_geps_intra.item()})

            # gepc inter
            loss_geps_inter = criterion(
                output_dict["impute_pred"], geps_target, geps_masked_position
            )
            loss = loss + loss_geps_inter
            metrics_to_log.update(
                {"train/loss_geps_inter": loss_geps_inter.item()})

        model.zero_grad()
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        with warnings.catch_warnings(record=True) as w:
            warnings.filterwarnings("always")
            total_norm = torch.nn.utils.clip_grad_norm_(
                model.parameters(),
                1.0,
                error_if_nonfinite=False
            )
            if not torch.isfinite(total_norm):
              print(f"[Warn] grad total_norm is {total_norm} at step {batch} (clipped).")

        scaler.step(optimizer)
        scaler.update()

        # logging
        total_loss += loss.item()
        total_mse += loss_mse.item()
        total_geps_intra += loss_geps_intra.item()
        total_geps_inter += loss_geps_inter.item()
        logging_interval = 100
        if batch % logging_interval == 0 and batch > 0:
            lr = scheduler.get_last_lr()[0]
            ms_per_batch = (time.time() - start_time) * logging_interval
            cur_loss = total_loss / logging_interval
            cur_mse = total_mse / logging_interval
            cur_gepc_intra = total_geps_intra / logging_interval
            scgpt_spatial.logger.info(
                f"| epoch {epoch:3d} | {batch:3d}/{num_batches:3d} batches | "
                f"lr {lr:05.4f} | ms/batch {ms_per_batch:5.2f} | "
                f"loss {cur_loss:5.2f} | mse {cur_mse:5.2f} |"
                + (f"gepc {cur_gepc_intra:5.2f} ")
            )
            total_loss = 0
            total_mse = 0
            total_geps_intra = 0
            total_geps_inter = 0
            start_time = time.time()

In [13]:
from scgpt.loss import masked_relative_error

def evaluate(model: nn.Module, loader: DataLoader, criterion, device, vocab,
             config: config):
    """
    Evaluate the model on the evaluation data.
    """
    model.eval()
    total_loss = 0.0
    total_error = 0.0
    total_num = 0
    with torch.no_grad():
        for batch_data in loader:
            pcpt_gene = batch_data["pcpt_gene"].to(device)
            pcpt_expr = batch_data["pcpt_expr"].to(device)
            pcpt_key_padding_mask = pcpt_gene.eq(vocab[config.pad_token])
            gen_gene = batch_data["gen_gene"].to(device)
            gen_expr_target = batch_data["gen_expr_target"].to(device)
            gen_key_padding_mask = gen_gene.eq(vocab[config.pad_token])
            coordinates = batch_data["coordinates"].to(device)
            batch_labels = batch_data["batch_labels"].to(device)

            with torch.cuda.amp.autocast(enabled=False):
                output_dict = model(
                    pcpt_genes=pcpt_gene,
                    pcpt_values=pcpt_expr,
                    pcpt_key_padding_mask=pcpt_key_padding_mask,
                    gen_genes=gen_gene,
                    gen_key_padding_mask=gen_key_padding_mask,
                    batch_labels=batch_labels,
                    coordinates=coordinates,
                    CLS=False,
                    MVC=True,
                    ECS=False,
                    MVC_impute=True,
                    do_sample=False,
                    input_cell_emb=None,
                    generative_training=True
                )  # dict: key: (batch, embsize)

            masked_positions = gen_gene.ne(vocab[config.pad_token])
            loss = criterion(
                output_dict["gen_preds"], gen_expr_target, masked_positions
            )
            mre = masked_relative_error(
                output_dict["gen_preds"], gen_expr_target, masked_positions
            )
            total_error += mre.item()

            # gepc intra
            geps_target = torch.cat([pcpt_expr, gen_expr_target], dim=1)
            pcpt_mask = torch.zeros(pcpt_expr.shape, dtype=torch.bool).to(
                device)
            geps_masked_position = torch.cat([pcpt_mask, masked_positions], dim=1)
            loss_geps_intra = criterion(
                output_dict["mvc_output"], geps_target, geps_masked_position
            )
            loss += loss_geps_intra

            # gepc inter
            loss_geps_inter = criterion(
                output_dict["impute_pred"], geps_target, geps_masked_position
            )
            loss = loss + loss_geps_inter

            total_loss += loss.item()
            total_num += len(pcpt_gene)

    return total_loss / total_num, total_error / total_num

## Unsupervised Finetune

In [14]:
# set up configs
config = Config(h5ad='data/1_visium.h5ad',
                ckpt_dir='checkpoints/scGPT_spatial_v1',
                finetuned_ckpt_dir='finetuned_checkpoints/0919_epoches_5',
                batch_key_col='protocol',
                gene_names_col='gene_name',
                coordinates_x_col='array_row',
                coordinates_y_col='array_col',
                rep_key='X_scGPT_finetuned',
                epochs=5,
                batch_size=16,
                n_neighbors=7)

In [15]:
import warnings
warnings.filterwarnings('ignore')
from scgpt.loss import masked_mse_loss
import copy
import time
import logging

"""
Fine tune scGPT-spatial FM and save checkpoints.
"""
torch.manual_seed(config.seed)
np.random.seed(config.seed)
scgpt_spatial.logger.setLevel(logging.INFO)

finetuned_ckpt_dir = Path(config.finetuned_ckpt_dir)
finetuned_ckpt_dir.mkdir(parents=True, exist_ok=True)

# Step 1: load pre-trained model and vocab
model, vocab, _ = load_backbone_and_vocab(Path(config.ckpt_dir), config.device)

# Step 2: load and split data. Will do binning in Dataset.
gene_ids, train_data, valid_data, train_batch_labels, valid_batch_labels, train_xy, valid_xy = split_dataset(
    Path(config.h5ad), config, vocab)

# Step 3: finetune scGPT-spatial with MVC and impute MVC
criterion = masked_mse_loss
optimizer = torch.optim.Adam(
    model.parameters(), lr=config.lr, eps=1e-8
)
gamma = config.schedule_ratio
if isinstance(gamma, (tuple, list)):
    gamma = float(gamma[0])
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=gamma)
scaler = torch.cuda.amp.GradScaler(enabled=False)

best_val_loss = float("inf")
best_model = None
for epoch in range(1, config.epochs + 1):
    epoch_start_time = time.time()
    train_loader = prepare_dataloader(
        count_matrix=train_data,
        gene_ids=gene_ids,
        batch_labels=train_batch_labels,
        coordinates=train_xy,
        vocab=vocab,
        config=config,
    )
    valid_loader = prepare_dataloader(
        count_matrix=valid_data,
        gene_ids=gene_ids,
        batch_labels=valid_batch_labels,
        coordinates=valid_xy,
        vocab=vocab,
        config=config,
    )

    train_epoch(
        epoch,
        model,
        loader=train_loader,
        scheduler=scheduler,
        optimizer=optimizer,
        scaler=scaler,
        criterion=criterion,
        device=config.device,
        vocab=vocab,
        config=config
    )
    val_loss, val_mre = evaluate(
        model,
        loader=valid_loader,
        criterion=criterion,
        device=config.device,
        vocab=vocab,
        config=config
    )
    elapsed = time.time() - epoch_start_time
    scgpt_spatial.logger.info("-" * 89)
    scgpt_spatial.logger.info(
        f"| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | "
        f"valid loss/mse {val_loss:5.4f} | mre {val_mre:5.4f}"
    )
    scgpt_spatial.logger.info("-" * 89)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
    best_model = copy.deepcopy(model)
    best_model_epoch = epoch
    scgpt_spatial.logger.info(
        f"Best model with score {best_val_loss:5.4f}")

    scheduler.step()

scgpt_spatial.logger.info(f"Saving model to {config.finetuned_ckpt_dir}")
torch.save(best_model.state_dict(),
            Path(config.finetuned_ckpt_dir, "best_model.pt"))


INFO:scGPT-spatial:Loading parameter encoder.embedding.weight with shape torch.Size([60697, 512])
INFO:scGPT-spatial:Loading parameter encoder.enc_norm.weight with shape torch.Size([512])
INFO:scGPT-spatial:Loading parameter encoder.enc_norm.bias with shape torch.Size([512])
INFO:scGPT-spatial:Loading parameter flag_encoder.weight with shape torch.Size([2, 512])
INFO:scGPT-spatial:Loading parameter value_encoder.linear1.weight with shape torch.Size([512, 1])
INFO:scGPT-spatial:Loading parameter value_encoder.linear1.bias with shape torch.Size([512])
INFO:scGPT-spatial:Loading parameter value_encoder.linear2.weight with shape torch.Size([512, 512])
INFO:scGPT-spatial:Loading parameter value_encoder.linear2.bias with shape torch.Size([512])
INFO:scGPT-spatial:Loading parameter value_encoder.norm.weight with shape torch.Size([512])
INFO:scGPT-spatial:Loading parameter value_encoder.norm.bias with shape torch.Size([512])
INFO:scGPT-spatial:Loading parameter batch_encoder.embedding.weight w

## Example: Zero-shot Domain Detection w/ Finefuned Model

In [16]:
# load data & simple pre process
import numpy as np
import pandas as pd
import scanpy as sc


adata = sc.read_h5ad('data/1_visium.h5ad')
adata = adata[np.logical_not(adata.obs['ground_truth'].isna())]  #remove NAN
adata.var['gene_name'] = adata.var.index

adata.obs['protocol'] = 'visium'
PROTO2ID = {
  "merfish": 0,
  "visium": 1,
  "visium_hd": 2,
  "xenium": 3
}   # keep aligned with scGPT-spatial v1 checkpoints
adata.obs['protocol'] = adata.obs['protocol'].map(PROTO2ID).astype(int).to_numpy()

print(adata)
print(adata.obs.ground_truth.unique())
adata.write('data/1_visium.h5ad')

AnnData object with n_obs × n_vars = 4221 × 33538
    obs: 'in_tissue', 'array_row', 'array_col', 'Region', 'ground_truth', 'protocol'
    var: 'gene_ids', 'feature_types', 'genome', 'gene_name'
    uns: 'spatial'
    obsm: 'spatial'
['Layer1', 'Layer3', 'WM', 'Layer6', 'Layer5', 'Layer2', 'Layer4']
Categories (7, object): ['Layer1', 'Layer2', 'Layer3', 'Layer4', 'Layer5', 'Layer6', 'WM']


In [17]:
import warnings
warnings.filterwarnings('ignore')
import scgpt_spatial

# run scGPT-spatial inference
model_dir = 'checkpoints/scGPT_spatial_v1'
finetuned_model_dir = 'finetuned_checkpoints/0919_epoches_5'
adata = scgpt_spatial.tasks.embed_data(
    adata,
    finetuned_model_dir,
    gene_col='gene_name',
    obs_to_save=['array_row', 'array_col', 'ground_truth'],
    batch_size=64,
    return_new_adata=False,
    use_fast_transformer=False
)
print(adata)

INFO:scGPT-spatial:match 23325/33538 genes in vocabulary of size 60697.
INFO:scGPT-spatial:Loading parameter encoder.embedding.weight with shape torch.Size([60697, 512])
INFO:scGPT-spatial:Loading parameter encoder.enc_norm.weight with shape torch.Size([512])
INFO:scGPT-spatial:Loading parameter encoder.enc_norm.bias with shape torch.Size([512])
INFO:scGPT-spatial:Loading parameter flag_encoder.weight with shape torch.Size([2, 512])
INFO:scGPT-spatial:Loading parameter value_encoder.linear1.weight with shape torch.Size([512, 1])
INFO:scGPT-spatial:Loading parameter value_encoder.linear1.bias with shape torch.Size([512])
INFO:scGPT-spatial:Loading parameter value_encoder.linear2.weight with shape torch.Size([512, 512])
INFO:scGPT-spatial:Loading parameter value_encoder.linear2.bias with shape torch.Size([512])
INFO:scGPT-spatial:Loading parameter value_encoder.norm.weight with shape torch.Size([512])
INFO:scGPT-spatial:Loading parameter value_encoder.norm.bias with shape torch.Size([512

AnnData object with n_obs × n_vars = 4221 × 23325
    obs: 'in_tissue', 'array_row', 'array_col', 'Region', 'ground_truth', 'protocol'
    var: 'gene_ids', 'feature_types', 'genome', 'gene_name', 'id_in_vocab'
    uns: 'spatial'
    obsm: 'spatial', 'X_scGPT'





In [18]:
import scanpy as sc
from sklearn.metrics import silhouette_score

def predict_domain_using_embedding(h5ad: str,
                                   dom_key: str = "domain",  # output label for clustering result
                                   method: str = "leiden", # "leiden" or "louvain"
                                   rep_key: str = "X_scGPT",  # adata.obsm[rep_key]: embedding
                                   target_clusters: int = 6,
                                   n_neighbors: int = 15,
                                   resolution_grid: Iterable[float] = (0.3, 0.5,
                                                                       0.8, 1.0,
                                                                       1.2),
                                   return_silhouette: bool = True,
                                   ):
    """
    predict domain using generated embedding.
    Returns:
      labels: np.ndarray[int] — Domain labels for each spot (encoded integers)
      best_res: Optional[float] — Resolution to use (if automatically selected)
      best_score: Optional[float] — Silhouette score (if calculated)
    Side Effects:
      - Writes adata.obsm[rep_key] = (n_spot, D)
      - Writes adata.obs[dom_key] = pandas.Categorical
    Depends:
      - adata.obsm[rep_key]
    """
    sc.pp.neighbors(adata, use_rep=rep_key, n_neighbors=n_neighbors)

    def _cluster_at(res):
        if method == "leiden":
            sc.tl.leiden(adata, resolution=res, key_added=f"{dom_key}_tmp")
        elif method == "louvain":
            sc.tl.louvain(adata, resolution=res, key_added=f"{dom_key}_tmp")
        else:
            raise ValueError("method must be 'leiden' or 'louvain'")
        return adata.obs[f"{dom_key}_tmp"].astype(
            "category").cat.codes.to_numpy()

    best_res, best_score, best_labels, best_clusters = None, -1.0, None, None
    if target_clusters is None:
        for res in resolution_grid:
            labels = _cluster_at(res)
            if len(np.unique(labels)) < 2:
                score = -1.0
            else:
                try:
                    score = silhouette_score(adata.obsm[rep_key], labels)
                except Exception:
                    score = -1.0
            if score > best_score:
                best_res, best_score, best_labels = res, score, labels
    else:
        for res in resolution_grid:
            labels = _cluster_at(res)
            try:
                score = silhouette_score(adata.obsm[rep_key], labels)
            except Exception:
                score = -1.0
            if best_clusters is None or best_clusters >= abs(
                    len(np.unique(labels)) - target_clusters):
                best_clusters = abs(len(np.unique(labels)) - target_clusters)
                best_res, best_score, best_labels = res, score, labels

    labels = best_labels if best_labels is not None else _cluster_at(
        resolution_grid[0])
    adata.obs[dom_key] = labels
    adata.obs[dom_key] = adata.obs[dom_key].astype("category")
    return labels, best_res, (best_score if return_silhouette else None)

In [19]:
config = Config(
    h5ad='data/1_visium.h5ad',
    dom_key = "domain_scgpt",
    rep_key = "X_scGPT",
    method = "leiden",
    target_clusters = 6,
    n_neighbors = 7,
)

# run leiden for domain detection w. scGPT-spatial embeddings
labels, _, _ = predict_domain_using_embedding(adata, dom_key=config.dom_key,
                                              method=config.method,
                                              rep_key=config.rep_key,
                                              target_clusters=config.target_clusters,
                                              n_neighbors=config.n_neighbors)

# save domain detection results to adata / csv.
adata.write("data/1_visium_domain_detection.h5ad")

print(adata)
print(adata.obs.domain_scgpt)

AnnData object with n_obs × n_vars = 4221 × 23325
    obs: 'in_tissue', 'array_row', 'array_col', 'Region', 'ground_truth', 'protocol', 'domain_scgpt_tmp', 'domain_scgpt'
    var: 'gene_ids', 'feature_types', 'genome', 'gene_name', 'id_in_vocab'
    uns: 'spatial', 'neighbors', 'domain_scgpt_tmp'
    obsm: 'spatial', 'X_scGPT'
    obsp: 'distances', 'connectivities'
AAACAACGAATAGTTC-1    0
AAACAAGTATCTCCCA-1    2
AAACAATCTACTAGCA-1    1
AAACACCAATAACTGC-1    3
AAACAGCTTTCAGAAG-1    4
                     ..
TTGTTGTGTGTCAAGA-1    1
TTGTTTCACATCCAGG-1    4
TTGTTTCATTAGTCTA-1    0
TTGTTTCCATACAACT-1    3
TTGTTTGTGTAAATTC-1    4
Name: domain_scgpt, Length: 4221, dtype: category
Categories (6, int8): [0, 1, 2, 3, 4, 5]
