# Inference with MOJO - Jax version

[![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/instadeepai/nucleotide-transformer/blob/main/notebooks/mojo/inference_mojo_jax_example.ipynb)

## Installation and imports

In [None]:
!pip install pandas

In [None]:
import os

try:
    import nucleotide_transformer
except:
    !pip install git+https://github.com/instadeepai/nucleotide-transformer@main |tail -n 1
    import nucleotide_transformer

if "COLAB_TPU_ADDR" in os.environ:
    from jax.tools import colab_tpu

    colab_tpu.setup_tpu()

In [None]:
import haiku as hk
from huggingface_hub import hf_hub_download
import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd

from nucleotide_transformer.mojo.pretrained import get_mojo_pretrained_model

# Load model


In [None]:
# Get pretrained MOJO model
parameters, forward_fn, tokenizers, config = get_mojo_pretrained_model()
forward_fn = hk.transform(forward_fn)

## Download, load and preprocess the data

In [None]:
n_examples = 4
omic_dict = {}

for omic in ["rnaseq", "methylation"]:
    csv_path = hf_hub_download(
        repo_id="InstaDeepAI/MOJO",
        filename=f"data/tcga_{omic}_sample.csv",
        repo_type="model",
    )
    omic_array = pd.read_csv(csv_path).drop(["identifier", "cohort"], axis=1).to_numpy()[:n_examples, :]
    if omic == "rnaseq":
        omic_array = np.log10(1 + omic_array)
    assert omic_array.shape[1] == config.sequence_length
    omic_dict[omic] = omic_array

In [None]:
tokens_ids = {
    omic: jnp.asarray(tokenizers[omic].batch_tokenize(omic_array, pad_to_fixed_length=True), dtype=jnp.int32)
    for omic, omic_array in omic_dict.items()
}

# Inference

In [None]:
# Inference
random_key = jax.random.PRNGKey(0)
outs = forward_fn.apply(parameters, random_key, tokens_ids)

# Get embedding from last transformer layer
mean_embedding = outs["after_transformer_embedding"].mean(axis=1)