# Inference with sCellTransformer

[![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/instadeepai/nucleotide-transformer/blob/main/examples/inference_sCellTransformer.ipynb)

## Installation and imports

In [1]:
import os

try:
    import nucleotide_transformer
except:
    !pip install git+https://github.com/instadeepai/nucleotide-transformer@main |tail -n 1
    import nucleotide_transformer

if "COLAB_TPU_ADDR" in os.environ:
    from jax.tools import colab_tpu

    colab_tpu.setup_tpu()

In [2]:
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from nucleotide_transformer.sCellTransformer.model import build_long_range_nt_fn
from nucleotide_transformer.sCellTransformer.params import download_ckpt

jax.config.update("jax_platform_name", "cpu")

backend = "cpu"
devices = jax.devices(backend)
num_devices = len(devices)
print(f"Devices found: {devices}")

Devices found: [CpuDevice(id=0)]


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
parameters, config = download_ckpt()
forward_fn = build_long_range_nt_fn(config)

Downloading model's weights...


In [4]:
forward_fn = hk.transform(forward_fn)
apply_fn = jax.pmap(forward_fn.apply, devices=devices, donate_argnums=(0,))

# Put required quantities for the inference on the devices. This step is not
# reproduced in the second inference since the quantities will already be loaded
# on the devices !
random_key = jax.random.PRNGKey(seed=0)
keys = jax.device_put_replicated(random_key, devices=devices)
parameters = jax.device_put_replicated(parameters, devices=devices)

# Data

In [5]:
num_cells = config.num_cells
dummy_gene_expressions = np.random.randint(0, 5, size=(num_devices, 1, 19968 * num_cells))

## Infer on batch

In [6]:
# Infer
outs = apply_fn(parameters, keys, dummy_gene_expressions) 

# Obtain the logits over the genomic features
logits = outs["logits"]
probabilities = np.asarray(jax.nn.softmax(logits[0, :, :, :5], axis=-1))[...,-1]

See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation.


In [7]:
probabilities.shape

(1, 998400)