# Inference with pretrained Nucleotide Transformer models

## Installation and imports

In [1]:
import os
import nucleotide_transformer
import random
random.seed(123)

In [2]:
import haiku as hk
import jax
import jax.numpy as jnp
from nucleotide_transformer.pretrained import get_pretrained_model

In [3]:
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
devices = jax.local_devices()
print(devices)

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


[CpuDevice(id=0)]


In [4]:
NT_DIR = os.path.dirname(os.path.abspath('.'))

In [5]:
#@title Select a model
#@markdown ---
model_name = '2B5_multi_species'#@param['500M_human_ref', '500M_1000G', '2B5_1000G', '2B5_multi_species']
#@markdown ---

In [6]:
# Get pretrained model
parameters, forward_fn, tokenizer, config = get_pretrained_model(
    model_name=model_name,
    mixed_precision=False,
    embeddings_layers_to_save=(20,),
    attention_maps_to_save=((1, 4), (7, 18)),
    max_positions=32,
    chkpt_dir = os.path.join(NT_DIR, 'checkpoints'),
)
forward_fn = hk.transform(forward_fn)

In [7]:
#@title Select a model
#@markdown ---
model_name = '2B5_multi_species'#@param['500M_human_ref', '500M_1000G', '2B5_1000G', '2B5_multi_species']
#@markdown ---

# Get pretrained model
parameters, forward_fn, tokenizer, config = get_pretrained_model(
    model_name=model_name,
    mixed_precision=False,
    embeddings_layers_to_save=(20, 24, 28, 32),
    attention_maps_to_save=(),
    max_positions=1024,
    chkpt_dir = os.path.join(NT_DIR, 'checkpoints')
)
forward_fn = hk.transform(forward_fn)

## Define your input data and tokenize it
You can have a look at the tokens_str variable to see how your sequences have been split into tokens. The sequences will all be padded to the value you filled for max_positions.

In [8]:
config

NucleotideTransformerConfig(alphabet_size=4105, pad_token_id=1, mask_token_id=2, max_positions=1000, embed_scale=1.0, emb_layer_norm_before=False, attention_heads=20, key_size=128, embed_dim=2560, ffn_embed_dim=10240, num_layers=32, token_dropout=True, masking_ratio=0.15, masking_prob=0.8, use_gradient_checkpointing=False, embeddings_layers_to_save=(20, 24, 38, 32), attention_maps_to_save=())

In [9]:
# Get data and tokenize it
seq_len = 10
sequences = [''.join(random.choice("ATCG") for i in range(seq_len))]
tokens_ids = [b[1] for b in tokenizer.batch_tokenize(sequences)]
tokens_str = [b[0] for b in tokenizer.batch_tokenize(sequences)]
tokens = jnp.asarray(tokens_ids, dtype=jnp.int32)

In [10]:
sequences[0][:10]

'ACAGCAAGCC'

## Do the Inference
The first time you query this cell will be slower than usual inference because of the computation graph compilation.

In [11]:
%%time
# Initialize random key
random_key = jax.random.PRNGKey(0)

# Infer
outs = forward_fn.apply(parameters, random_key, tokens)

CPU times: user 1min 59s, sys: 37.6 s, total: 2min 36s
Wall time: 12.8 s


In [12]:
print(outs.keys())

dict_keys(['embeddings_20', 'embeddings_24', 'embeddings_32', 'logits'])


## Retrieve embeddings
And use them as you please! Enjoy!

In [28]:
print(outs["embeddings_20"].shape)

(1, 1024, 2560)


**Additional Tip**: Don't forget to remove the cls token and padded positions if you want for instance to compute mean embeddings!

In [16]:
embeddings = outs["embeddings_20"][:, 1:, :]  # removing CLS token
padding_mask = jnp.expand_dims(tokens[:, 1:] != tokenizer.pad_token_id, axis=-1)
masked_embeddings = embeddings * padding_mask  # multiply by 0 pad tokens embeddings
sequences_lengths = jnp.sum(padding_mask, axis=1)
mean_embeddings = jnp.sum(masked_embeddings, axis=1) / sequences_lengths

In [17]:
print(mean_embeddings.shape)

(1, 2560)
