# Inference with sCellTransformer

[![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/instadeepai/nucleotide-transformer/blob/main/examples/inference_sCellTransformer.ipynb)

## Installation and imports

In [None]:
!pip install scikit-learn
!pip install torch

In [2]:
import os

try:
    import nucleotide_transformer
except:
    !pip install git+https://github.com/instadeepai/nucleotide-transformer@main |tail -n 1
    import nucleotide_transformer

if "COLAB_TPU_ADDR" in os.environ:
    from jax.tools import colab_tpu

    colab_tpu.setup_tpu()

In [3]:
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
from sklearn.metrics import matthews_corrcoef
from nucleotide_transformer.sCellTransformer.model import build_long_range_nt_fn
from nucleotide_transformer.sCellTransformer.params import download_ckpt
from nucleotide_transformer.sCellTransformer.get_paper_dataset import get_dataset_dataloader

jax.config.update("jax_platform_name", "cpu")

backend = "cpu"
devices = jax.devices(backend)
num_devices = len(devices)
print(f"Devices found: {devices}")

  from .autonotebook import tqdm as notebook_tqdm


Devices found: [CpuDevice(id=0)]


In [4]:
parameters, config = download_ckpt()
forward_fn = build_long_range_nt_fn(config)

Downloading model's weights...


In [5]:
forward_fn = hk.transform(forward_fn)
apply_fn = jax.pmap(forward_fn.apply, devices=devices, donate_argnums=(0,))

# Put required quantities for the inference on the devices. This step is not
# reproduced in the second inference since the quantities will already be loaded
# on the devices !
random_key = jax.random.PRNGKey(seed=0)
keys = jax.device_put_replicated(random_key, devices=devices)
parameters = jax.device_put_replicated(parameters, devices=devices)

# Data

In [6]:
num_cells = config.num_cells
dummy_gene_expressions = np.random.randint(0, 5, size=(num_devices, 1, 19968 * num_cells))

## Infer on batch

In [None]:
# Infer
outs = apply_fn(parameters, keys, dummy_gene_expressions) 

# Obtain the logits over the genomic features
logits = outs["logits"]
probabilities = np.asarray(jax.nn.softmax(logits[0, :, :, :5], axis=-1))[...,-1]

In [8]:
probabilities.shape

(1, 998400)

# Replicate example from the paper

In [9]:
# Dataloader and parameters
dataloader = get_dataset_dataloader(config, batch_size=1)
mask_ratio = 0.15
num_batches = 1

In [10]:
# Lists to store true and predicted values for masked tokens
all_true_values = []
all_pred_values = []

# Iterate over specified number of batches
iterator = iter(dataloader)

for batch_idx in range(num_batches):
    try:
        # Get next batch
        batch = next(iterator)

        # Move batch to the same device as model
        gene_expressions = jnp.array(batch["gene_expressions"])

        # Create random mask
        mask = jax.random.uniform(jax.random.PRNGKey(0), shape=gene_expressions.shape) < mask_ratio

        # Clone and mask gene expressions
        masked_gene_expressions = gene_expressions.copy()
        masked_gene_expressions = jnp.where(mask, config.mask_token_id, gene_expressions)

        # Keep original values before masking for evaluation
        true_values = np.asarray(gene_expressions[mask])

        # Convert to jax and replicate over devices
        masked_gene_expressions = jnp.array(masked_gene_expressions)
        masked_gene_expressions = jnp.expand_dims(masked_gene_expressions, axis=0)
        random_key = jax.random.PRNGKey(seed=0)
        keys = jax.device_put_replicated(random_key, devices=devices)

        # Forward pass without gradient computation
        outs = apply_fn(parameters, keys, masked_gene_expressions) 
        logits = np.asarray(outs["logits"][0])

        # Get predictions (assuming classification - adjust if regression)
        predictions = np.argmax(logits, axis=-1)
        pred_values = predictions[mask]

        # Store true and predicted values
        all_true_values.append(true_values)
        all_pred_values.append(pred_values)

        print(f"Processed batch {batch_idx + 1}/{num_batches}")

    except StopIteration:
        print(f"DataLoader exhausted after {batch_idx} batches")
        break

# Concatenate all batches
all_true_values = np.concatenate(all_true_values)
all_pred_values = np.concatenate(all_pred_values)

# Compute Matthews Correlation Coefficient
mcc = matthews_corrcoef(all_true_values, all_pred_values)
print(f"Overall MCC: {mcc:.4f}")

Processed batch 1/1
Overall MCC: 0.2307
