# Inference with pretrained Nucleotide Transformer models

## Installation and imports

In [None]:
!pip install ../.

In [None]:
import haiku as hk
import jax
import jax.numpy as jnp
from nucleotide_transformer.pretrained import get_pretrained_model

## Download the weights
The following cell allows you to download the weights of any of the four nucleotide transformer model. It returns the weights dictionary, the haiku forward function, the tokenizer and the config dictionary. Supported model names are **500M_human_ref**, **500M_1000G**, **2B5_1000G** and **2B5_multi_species**.

Please also specify the layers at which you'd like to collect embeddings (e.g. (5, 10, 20) to get embeddings at layers 5, 10 and 20) and the maximum number of tokens in the sequences you'll compute the inference on. You can put values up to 1024 (counting the class token that will be added automatically at the beginning of the sequence), however we recommend keeping this number as small as possible for optimized memory and inference time.

In [None]:
# Get pretrained model
parameters, forward_fn, tokenizer, config = get_pretrained_model(
    model_name="500M_1000G",
    mixed_precision=False,
    embeddings_layers_to_save=(20,),
    max_positions=32,
)
forward_fn = hk.transform(forward_fn)

## Define your input data and tokenize it
You can have a look at the tokens_str variable to see how your sequences have been split into tokens. The sequences will all be padded to the value you filled for max_positions.

In [None]:
# Get data and tokenize it
sequences = ["ATTCCGATTCCGATTCCG", "ATTTCTCTCTCTCTCTGAGATCGATCGATCGAT"]
tokens_ids = [b[1] for b in tokenizer.batch_tokenize(sequences)]
tokens_str = [b[0] for b in tokenizer.batch_tokenize(sequences)]
tokens = jnp.asarray(tokens_ids, dtype=jnp.int32)

## Do the Inference
The first time you query this cell will be slower than usual inference because of the computation graph compilation.

In [None]:
%%time
# Initialize random key
random_key = jax.random.PRNGKey(0)

# Infer
outs = forward_fn.apply(parameters, random_key, tokens)

## Retrieve embeddings
And use them as you please! Enjoy!

In [None]:
outs["embeddings_20"].shape

**Additional Tip**: Don't forget to remove the cls token and padded positions if you want for instance to compute mean embeddings!

In [None]:
embeddings = outs["embeddings_20"][:, 1:, :]  # removing CLS token
padding_mask = jnp.expand_dims(tokens[:, 1:] != tokenizer.pad_token_id, axis=-1)
masked_embeddings = embeddings * padding_mask  # multiply by 0 pad tokens embeddings
mean_embeddings = jnp.sum(masked_embeddings, axis=1) / jnp.sum(padding_mask, axis=1)

In [None]:
mean_embeddings.shape