# Inference with BulkRNABert - Jax version

[![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/instadeepai/nucleotide-transformer/blob/main/notebooks/bulk_rna_bert/inference_bulkrnabert_jax_example.ipynb)

## Installation and imports

In [1]:
!pip install pandas



In [None]:
import os

try:
    import nucleotide_transformer
except:
    !pip install git+https://github.com/instadeepai/nucleotide-transformer@main |tail -n 1
    import nucleotide_transformer

if "COLAB_TPU_ADDR" in os.environ:
    from jax.tools import colab_tpu

    colab_tpu.setup_tpu()

In [None]:
import haiku as hk
from huggingface_hub import hf_hub_download
import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd

from nucleotide_transformer.bulk_rna_bert.pretrained import get_pretrained_bulkrnabert_model

# Load model

In [None]:
parameters, forward_fn, tokenizer, config = get_pretrained_bulkrnabert_model(
    embeddings_layers_to_save=(4,),
)
forward_fn = hk.transform(forward_fn)

## Download the data

In [None]:
# Downloading the bulk RNA-seq file from HuggingFace
csv_path = hf_hub_download(
    repo_id="InstaDeepAI/BulkRNABert",
    filename="data/tcga_sample.csv",
    repo_type="model",
)

# Load dataset and preprocess

In [None]:
gene_expression_array = pd.read_csv(csv_path).drop(["identifier"], axis=1).to_numpy()[:1, :]
gene_expression_array = np.log10(1 + gene_expression_array)
assert gene_expression_array.shape[1] == config.n_genes

# Tokenize
gene_expression_ids = tokenizer.batch_tokenize(gene_expression_array)
gene_expression_ids = jnp.asarray(gene_expression_ids, dtype=jnp.int32)

# Inference

In [None]:
# Inference
random_key = jax.random.PRNGKey(0)
outs = forward_fn.apply(parameters, random_key, gene_expression_ids)

# Get mean embeddings from layer 4
gene_expression_mean_embeddings = outs["embeddings_4"].mean(axis=1)