## Installation and imports

In [1]:
import os
import nucleotide_transformer
import random
random.seed(123)

In [2]:
import haiku as hk
import jax
import jax.numpy as jnp
from nucleotide_transformer.mypretrained import get_pretrained_model

In [3]:
devices = jax.local_devices()
print(devices)

[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=1, process_index=0, slice_index=0)]


In [4]:
NT_DIR = os.path.abspath('.')

In [5]:
#@title Select a model
#@markdown ---
model_name = '2B5_multi_species'#@param['500M_human_ref', '500M_1000G', '2B5_1000G', '2B5_multi_species']
#@markdown ---

In [6]:
# Get pretrained model
parameters, forward_fn, tokenizer, config = get_pretrained_model(
    model_name=model_name,
    mixed_precision=False,
    embeddings_layers_to_save=(20, 24, 28, 32),
    attention_maps_to_save=(),
    max_positions=32,
    chkpt_dir = os.path.join(NT_DIR, 'checkpoints')
)
forward_fn = hk.transform(forward_fn)

## Define your input data and tokenize it
You can have a look at the tokens_str variable to see how your sequences have been split into tokens. The sequences will all be padded to the value you filled for max_positions.

In [7]:
# Get data and tokenize it
seq_len = 10
sequences = [''.join(random.choice("ATCG") for i in range(seq_len))]
tokens_ids = [b[1] for b in tokenizer.batch_tokenize(sequences)]
tokens_str = [b[0] for b in tokenizer.batch_tokenize(sequences)]
tokens = jnp.asarray(tokens_ids, dtype=jnp.int32)

In [8]:
sequences[0][:10]

'ACAGCAAGCC'

## Do the Inference
The first time you query this cell will be slower than usual inference because of the computation graph compilation.

In [9]:
jax.default_backend()

'gpu'

In [10]:
jax.devices()

[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0),
 StreamExecutorGpuDevice(id=1, process_index=0, slice_index=0)]

In [11]:
%%time
# Initialize random key
random_key = jax.random.PRNGKey(0)


# Infer
outs = forward_fn.apply(parameters, random_key, tokens)

2023-04-03 15:25:57.201003: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:429] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR


XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

In [None]:
print(outs.keys())

## Retrieve embeddings

In [None]:
print(outs["embeddings_20"].shape)

**Additional Tip**: Don't forget to remove the cls token and padded positions if you want for instance to compute mean embeddings!

In [None]:
embeddings = outs["embeddings_20"][:, 1:, :]  # removing CLS token
padding_mask = jnp.expand_dims(tokens[:, 1:] != tokenizer.pad_token_id, axis=-1)
masked_embeddings = embeddings * padding_mask  # multiply by 0 pad tokens embeddings
sequences_lengths = jnp.sum(padding_mask, axis=1)
mean_embeddings = jnp.sum(masked_embeddings, axis=1) / sequences_lengths

In [None]:
print(mean_embeddings.shape)