# Inference with pretrained Nucleotide Transformer models

## Installation and imports

In [1]:
import os
import nucleotide_transformer
import random
random.seed(123)

In [2]:
import haiku as hk
import jax
import jax.numpy as jnp
from nucleotide_transformer.pretrained import get_pretrained_model

In [3]:
NT_DIR = os.path.dirname(os.path.abspath('.'))

In [4]:
#@title Select a model
#@markdown ---
model_name = '2B5_multi_species'#@param['500M_human_ref', '500M_1000G', '2B5_1000G', '2B5_multi_species']
#@markdown ---

In [5]:
# Get pretrained model
parameters, forward_fn, tokenizer, config = get_pretrained_model(
    model_name=model_name,
    mixed_precision=False,
    embeddings_layers_to_save=(20,),
    attention_maps_to_save=((1, 4), (7, 18)),
    max_positions=32,
    chkpt_dir = os.path.join(NT_DIR, 'checkpoints'),
)
forward_fn = hk.transform(forward_fn)

In [5]:
#@title Select a model
#@markdown ---
model_name = '2B5_multi_species'#@param['500M_human_ref', '500M_1000G', '2B5_1000G', '2B5_multi_species']
#@markdown ---

# Get pretrained model
parameters, forward_fn, tokenizer, config = get_pretrained_model(
    model_name=model_name,
    mixed_precision=False,
    embeddings_layers_to_save=(20,),
    attention_maps_to_save=((1, 4), (7, 18)),
    max_positions=7000,
    chkpt_dir = os.path.join(NT_DIR, 'checkpoints')
)
forward_fn = hk.transform(forward_fn)

## Define your input data and tokenize it
You can have a look at the tokens_str variable to see how your sequences have been split into tokens. The sequences will all be padded to the value you filled for max_positions.

In [6]:
# Get data and tokenize it
seq_len = 7000
sequences = [''.join(random.choice("ATCG") for i in range(seq_len))]
tokens_ids = [b[1] for b in tokenizer.batch_tokenize(sequences)]
tokens_str = [b[0] for b in tokenizer.batch_tokenize(sequences)]
tokens = jnp.asarray(tokens_ids, dtype=jnp.int32)

ValueError: Found a sequence with length 1171 that exceeds the fixed length to tokenize (32).

In [None]:
sequences[0][:10]

## Do the Inference
The first time you query this cell will be slower than usual inference because of the computation graph compilation.

In [7]:
%%time
# Initialize random key
random_key = jax.random.PRNGKey(0)

# Infer
outs = forward_fn.apply(parameters, random_key, tokens)

CPU times: user 15.7 s, sys: 420 ms, total: 16.2 s
Wall time: 2.18 s


In [8]:
print(outs.keys())

dict_keys(['attention_map_layer_1_number_4', 'attention_map_layer_7_number_18', 'embeddings_20', 'logits'])


## Retrieve embeddings
And use them as you please! Enjoy!

In [9]:
print(outs["embeddings_20"].shape)

(2, 32, 2560)


**Additional Tip**: Don't forget to remove the cls token and padded positions if you want for instance to compute mean embeddings!

In [12]:
embeddings = outs["embeddings_20"][:, 1:, :]  # removing CLS token
padding_mask = jnp.expand_dims(tokens[:, 1:] != tokenizer.pad_token_id, axis=-1)
masked_embeddings = embeddings * padding_mask  # multiply by 0 pad tokens embeddings
sequences_lengths = jnp.sum(padding_mask, axis=1)
mean_embeddings = jnp.sum(masked_embeddings, axis=1) / sequences_lengths

In [13]:
print(mean_embeddings.shape)

(2, 2560)
