# Inference with sCellTransformer - Jax version

[![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/instadeepai/nucleotide-transformer/blob/main/notebooks/sct/inference_sCT_jax_example.ipynb)

## Installation and imports

In [None]:
import os

try:
    import nucleotide_transformer
except:
    !pip install git+https://github.com/instadeepai/nucleotide-transformer@main |tail -n 1
    !pip install anndata
    !pip install cellxgene_census
    !pip install scanpy
    !pip install jax
    import nucleotide_transformer

if "COLAB_TPU_ADDR" in os.environ:
    from jax.tools import colab_tpu

    colab_tpu.setup_tpu()

In [None]:
import haiku as hk
import jax
import jax.numpy as jnp
from tqdm import tqdm
from torch.utils.data import DataLoader
from sklearn.metrics import matthews_corrcoef
import os
import json
import anndata as ad
import scanpy as sc
import numpy as np
import math
import itertools
from scipy.sparse import issparse
from typing import Any
import cellxgene_census

from nucleotide_transformer.sCellTransformer.model import build_sct_fn
from nucleotide_transformer.sCellTransformer.params import download_ckpt

# Specify "cpu" as default (but you can decide to use GPU or TPU in the next cell)
jax.config.update("jax_platform_name", "cpu")

Devices found: [CpuDevice(id=0)]


# Specify your backend device

In [None]:
# Use either "cpu", "gpu" or "tpu"
backend = "cpu"

In [None]:
devices = jax.devices(backend)
num_devices = len(devices)
print(f"Devices found: {devices}")

# Load model and infer

In [3]:
parameters, config = download_ckpt()
forward_fn = build_sct_fn(config, name="long_range_nt")
forward_fn = hk.transform(forward_fn)

Downloading model's weights...


In [4]:
# 2. Create simple input (no splitting)
dummy_batch_size = 1
dummy_sequence_length = 19968 * 50 
dummy_tokens = np.zeros((dummy_batch_size, dummy_sequence_length), dtype=np.int32)
dummy_tokens = jax.device_put_replicated(dummy_tokens, devices=devices)

In [5]:
# 3. Setup devices and keys
devices = jax.local_devices()
num_devices = len(devices)
master_key = jax.random.PRNGKey(seed=0)
keys_batch = jax.random.split(master_key, num_devices)

In [6]:
# 4. Create pmap function (input not split, just replicated)
apply_fn = jax.pmap(
    forward_fn.apply,
    in_axes=(None, 0, 0),  # params: replicated, keys: split, data: replicated
    devices=devices
)

In [7]:
# 5. Run the forward pass
print("Running forward pass...")
try:
    outs = apply_fn(parameters, keys_batch, dummy_tokens)
    print("✓ SUCCESS!")
    
    if isinstance(outs, dict):
        print("Output keys:")
        for key, value in outs.items():
            if hasattr(value, 'shape'):
                print(f"  {key}: {value.shape}")
        
        if "logits" in outs:
            logits = outs["logits"]
            print(f"Logits shape: {logits.shape}")
            predictions = np.asarray(np.argmax(logits[0,:,:,:5], axis=-1))
            print("Done!")
    
except Exception as e:
    print(f"✗ Error: {e}")

Running forward pass...
✓ SUCCESS!
Output keys:
  conv_out: (1, 1, 3900, 1024)
  deconv_out: (1, 1, 64, 998400)
  embedding: (1, 1, 50, 19968, 64)
  logits: (1, 1, 998400, 7)
  transformer_out: (1, 1, 3900, 1024)
Logits shape: (1, 1, 998400, 7)
Done!


# Replicate example from the paper

## Load the h5ad file

In [8]:
# Downloading the file from the public API of cellxgene
# This file is a h5ad file containing single-cell RNA-seq data
# It corresponds to Sst Chodl - MTG: Seattle Alzheimer's Disease Atlas (SEA-AD)
# - Single-cell RNA-seq data: cells x genes expression matrix
# - Sparse data (~90% zeros), typically 16k cells, ~30k genes
# - Contains cell type annotations (neurons, astrocytes, etc.) and metadata
# - Real biological data vs synthetic test data - will need preprocessing for model
# Open the Census (using the same version as your S3 path: 2023-12-15)
cellxgene_census.download_source_h5ad(
    "81e91ff8-f619-4ad1-a0c3-b45e1dc63f68",
    to_path="brain.h5ad",
    census_version="2023-12-15"
)

Downloading: 100%|██████████| 42.8M/42.8M [00:53<00:00, 842kB/s] 


## Load dataset

In [9]:
current_dir = os.getcwd()

# Loading mapping from ENSEMBL name to index in the dataset.
with open(os.path.join(current_dir, "data/ensembl_id_vocab.json"), "r") as f:
    ENSEMBL_ID_VOCAB = json.load(f)

# Loading mapping, for the considered coding genes,
# between global index in the dataset and their index among coding genes only.
# Restricting from 60k genes to 20k genes only. 
with open(os.path.join(current_dir, "data/protein_gene_map.json"), "r") as f:
    PROTEIN_GENE_MAP = json.load(f)

## Define dataloader functions

In [10]:
def define_mapping_between_adata_and_model(adata: ad.AnnData,
                                           ENSEMBL_ID_VOCAB: dict,
                                           PROTEIN_GENE_MAP: dict) -> np.ndarray:
    # Define mapping
    names = list(adata.var.feature_name.keys())
    MAP_TO_PROTEIN_GENE_INDEX = {}
    indexes_present_in_data = {}
    for i, name in enumerate(names):
        if name in ENSEMBL_ID_VOCAB:
            index = str(ENSEMBL_ID_VOCAB[name])
            if index in PROTEIN_GENE_MAP:
                indexes_present_in_data[index] = 1
                MAP_TO_PROTEIN_GENE_INDEX[str(i)] = PROTEIN_GENE_MAP[index]

    # Create gene mapping arrays
    gene_map = {int(k): MAP_TO_PROTEIN_GENE_INDEX[k] for k in MAP_TO_PROTEIN_GENE_INDEX}
    new_gene_map_array = np.full(70000, -1, dtype=np.int32)
    for k, v in gene_map.items():
        new_gene_map_array[k] = v
    return new_gene_map_array


# Note that this data download already includes a log normalization 
# on the gene expression levels.
adata = sc.read_h5ad('brain.h5ad')
# Creating the mapping between indexes in the downloaded dataset and 
# the index in the model for the considered coding genes. 
new_gene_map_array = define_mapping_between_adata_and_model(
    adata=adata,
    ENSEMBL_ID_VOCAB=ENSEMBL_ID_VOCAB,
    PROTEIN_GENE_MAP=PROTEIN_GENE_MAP
)

In [11]:
def get_h5ad_scrna_dataset(
        adata: Any,
        new_gene_map_array: np.ndarray,
        num_downsamples: int,
        cell_len: int,
        num_cells: int,
        pad_token_id: int,
        gene_expression_num_bins: int,
        batch_size: int,
) -> Any:
    """
    Creates an iterable dataset from h5ad file for single-cell RNA-seq data.
    
    Args:
        h5ad_path: Path to the h5ad file
        new_gene_map_array: Array mapping new gene indices to your previous index system
        num_downsamples: Number of downsampling steps
        cell_len: Length of each cell in the dataset
        num_cells: Number of cells per sample
        pad_token_id: Token ID for padding
        gene_expression_num_bins: Number of bins for gene expression
        batch_size: Batch size
    
    Returns:
        An iterable dataset that yields batches of samples
    """

    # Extract expression matrix (usually X is sparse)
    expr_matrix = adata.X
    if issparse(expr_matrix):
        # Convert to CSR for efficient row access
        expr_matrix = expr_matrix.tocsr()

    # Calculate sequence length with downsampling
    downsample_factor = 2 ** num_downsamples
    seq_length = math.ceil(cell_len / downsample_factor) * downsample_factor

    class H5adIterableDataset:
        def __init__(self):
            self.length = cell_len
            self.batch_size = batch_size
            self.total_cells = expr_matrix.shape[0]

        def __len__(self):
            return self.length

        def __iter__(self):
            # Create infinite iterator over cell indices
            cell_indices = itertools.cycle(range(self.total_cells))

            while True:
                batch = []

                # Generate batch_size samples
                for _ in range(self.batch_size):
                    cells = []

                    # Collect num_cells cells for one sample
                    while len(cells) < num_cells:
                        cell_idx = next(cell_indices)

                        # Get expression data for this cell
                        if issparse(expr_matrix):
                            # For sparse matrix, get the row as a dense array
                            cell_expr = expr_matrix[cell_idx, :].toarray().flatten()
                        else:
                            cell_expr = expr_matrix[cell_idx, :]

                        # Find non-zero expressions
                        non_zero_mask = cell_expr > 0
                        gene_idxs = np.where(non_zero_mask)[0].astype(np.int32)
                        expressions = cell_expr[non_zero_mask].astype(np.float32)

                        if len(gene_idxs) == 0:
                            # Skip cells with no expression
                            continue

                        # Map genes using new_gene_map_array
                        mapped_idxs = new_gene_map_array[gene_idxs]
                        valid_mask = mapped_idxs != -1
                        positions = mapped_idxs[valid_mask]
                        valid_expr = expressions[valid_mask]

                        if len(valid_expr) == 0:
                            continue

                        # Bin expressions
                        if min(valid_expr) == max(valid_expr):
                            bin_edges = np.array(
                                [min(valid_expr) - 0.1, max(valid_expr) + 0.1])
                            binned = np.ones_like(valid_expr, dtype=np.int32)
                        else:
                            bin_edges = np.linspace(
                                min(valid_expr),
                                max(valid_expr),
                                gene_expression_num_bins,
                            )
                            bin_edges[-1] += 0.01
                            binned = np.digitize(valid_expr, bin_edges)

                        # Create full arrays (using the original gene_ids from your code)
                        full_expr = np.zeros(self.length, dtype=np.int32)
                        full_expr[positions] = binned

                        raw_expr = np.zeros(self.length, dtype=np.float32)
                        raw_expr[positions] = valid_expr

                        # Create cell dictionary
                        cell = {
                            # "gene_ids": gene_ids.copy(),
                            "gene_expressions": full_expr,
                            "raw_gene_expressions": raw_expr,
                            "bins": bin_edges,
                            "source": ["h5ad"],
                        }

                        # Pad if needed
                        if seq_length > self.length:
                            for key in ["gene_expressions",
                                        "raw_gene_expressions"]:
                                pad_value = pad_token_id
                                padded = np.full(seq_length, pad_value,
                                                 dtype=cell[key].dtype)
                                padded[:len(cell[key])] = cell[key]
                                cell[key] = padded

                        cells.append(cell)

                    # Merge cells for this sample
                    if len(cells) == num_cells:
                        sample = {}

                        # Merge arrays by concatenation
                        for key in cells[0]:
                            if isinstance(cells[0][key], np.ndarray):
                                sample[key] = np.concatenate([c[key] for c in cells])
                            else:
                                sample[key] = list(itertools.chain.from_iterable(
                                    c[key] for c in cells
                                ))

                        batch.append(sample)

                # Reshape and yield batch
                if len(batch) == self.batch_size:
                    batch_result = {}

                    for key in batch[0]:
                        if isinstance(batch[0][key], np.ndarray):
                            batch_arrays = [b[key] for b in batch]
                            batch_result[key] = np.stack(batch_arrays, axis=0)
                        else:
                            batch_result[key] = [b[key] for b in batch]

                    yield batch_result

    return H5adIterableDataset()


# Create dataset
dataloader = get_h5ad_scrna_dataset(
    adata=adata,
    new_gene_map_array=new_gene_map_array,
    num_downsamples=config.num_downsamples,
    cell_len=len(PROTEIN_GENE_MAP),
    num_cells=config.num_cells,
    pad_token_id=config.pad_token_id,
    gene_expression_num_bins=5,
    batch_size=1,
)

## Define evaluation function

In [12]:
def evaluate_gene_expression_imputation(
        train_dataloader: DataLoader,
        num_batches: int|None = 10,
        mask_ratio: float = 0.15,
) -> None:
    """
    Evaluate a gene expression imputation model using Matthews Correlation Coefficient.

    Args:
        train_dataloader: DataLoader containing gene expression data
        num_batches: Number of batches to evaluate. If None, all batches are evaluated
        mask_ratio: Ratio of tokens to mask for imputation

    Returns:
        float: Average Matthews Correlation Coefficient across all batches
    """
    # Lists to store true and predicted values for masked tokens
    all_true_values = []
    all_pred_values = []

    # Iterate over specified number of batches
    iterator = iter(train_dataloader)
    
    if num_batches is None:
        num_batches = len(train_dataloader)

    for batch_idx in tqdm(range(num_batches)):
        try:
            # Get next batch
            batch = next(iterator)

            # Move batch to the same device as model
            gene_expressions = jnp.array(batch["gene_expressions"])

            # Create random mask
            mask = jax.random.uniform(jax.random.PRNGKey(0), shape=gene_expressions.shape) < mask_ratio

            # Clone and mask gene expressions
            masked_gene_expressions = jnp.where(mask, config.mask_token_id, gene_expressions)
    
            # Keep original values before masking for evaluation
            true_values = np.asarray(gene_expressions[mask])
    
            # Convert to jax and replicate over devices
            masked_gene_expressions = jnp.array(masked_gene_expressions)
            masked_gene_expressions = jnp.expand_dims(masked_gene_expressions, axis=0)
            random_key = jax.random.PRNGKey(seed=0)
            keys = jax.device_put_replicated(random_key, devices=devices)
    
            # Forward pass without gradient computation
            outs = apply_fn(parameters, keys, masked_gene_expressions) 
            logits = outs["logits"]
    
            # Get predictions (assuming classification - adjust if regression)
            predictions = np.asarray(np.argmax(logits[0,:,:,:5], axis=-1))
            pred_values = predictions[mask]
            
            # Store true and predicted values
            all_true_values.append(true_values)
            all_pred_values.append(pred_values)

        except StopIteration:
            print(f"DataLoader exhausted after {batch_idx} batches")
            break

    # Concatenate all batches
    all_true_values = np.concatenate(all_true_values)
    all_pred_values = np.concatenate(all_pred_values)

    # Compute Matthews Correlation Coefficient
    # - Binary classification metric ranging from -1 to +1
    # - +1: perfect prediction, 0: random prediction, -1: total disagreement
    # - Balanced metric that works well with imbalanced datasets
    # - Considers all confusion matrix elements (TP, TN, FP, FN)

    mcc = matthews_corrcoef(all_true_values, all_pred_values)
    print(f"Overall MCC: {mcc:.4f}")

In [13]:
# Evaluate the model
# Please define the number of batches and the mask ratio for evaluation
# The reported metric is the average MCC across all batches computed over the bins.
evaluate_gene_expression_imputation(
    dataloader,
    num_batches=1,
    mask_ratio=0.15,
)

100%|██████████| 1/1 [00:55<00:00, 55.30s/it]

Overall MCC: 0.3703



