## Installation and imports

In [1]:
import os
import nucleotide_transformer
import random
random.seed(123)

In [2]:
import haiku as hk
import jax
import jax.numpy as jnp
from nucleotide_transformer.mypretrained import get_pretrained_model

In [3]:
NT_DIR = os.path.abspath('.')

In [4]:
#@title Select a model
#@markdown ---
model_name = '2B5_multi_species'#@param['500M_human_ref', '500M_1000G', '2B5_1000G', '2B5_multi_species']
#@markdown ---

In [5]:
# Get pretrained model
parameters, forward_fn, tokenizer, config = get_pretrained_model(
    model_name=model_name,
    mixed_precision=False,
    embeddings_layers_to_save=(20, 24, 28, 32),
    attention_maps_to_save=(),
    max_positions=1000,
    chkpt_dir = os.path.join(NT_DIR, 'checkpoints')
)
forward_fn = hk.transform(forward_fn)

## Define your input data and tokenize it
You can have a look at the tokens_str variable to see how your sequences have been split into tokens. The sequences will all be padded to the value you filled for max_positions.

In [9]:
# Get data and tokenize it
seq_len = 10
n_seqs = 20
sequences = [''.join(random.choice("ATCG") for i in range(seq_len)) for j in range(n_seqs)]
tokens_ids = [b[1] for b in tokenizer.batch_tokenize(sequences)]
tokens_str = [b[0] for b in tokenizer.batch_tokenize(sequences)]
tokens = jnp.asarray(tokens_ids, dtype=jnp.int32)

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


## Do the Inference
The first time you query this cell will be slower than usual inference because of the computation graph compilation.

In [11]:
jax.default_backend()

'cpu'

In [12]:
jax.devices()

[CpuDevice(id=0)]

In [13]:
random_key = jax.random.PRNGKey(0)

In [14]:
%%time
# Initialize random key
random_key = jax.random.PRNGKey(0)

# Infer
outs = forward_fn.apply(parameters, random_key, tokens)

CPU times: user 31min 55s, sys: 10min 51s, total: 42min 47s
Wall time: 3min 4s


In [15]:
print(outs.keys())

dict_keys(['embeddings_20', 'embeddings_24', 'embeddings_28', 'embeddings_32', 'logits'])


In [16]:
%%time
seq_len = 10
n_seqs = 20
sequences = [''.join(random.choice("ATCG") for i in range(seq_len)) for j in range(n_seqs)]
tokens_ids = [b[1] for b in tokenizer.batch_tokenize(sequences)]
tokens_str = [b[0] for b in tokenizer.batch_tokenize(sequences)]
tokens = jnp.asarray(tokens_ids, dtype=jnp.int32)
outs2 = forward_fn.apply(parameters, random_key, tokens)

CPU times: user 31min 59s, sys: 10min 39s, total: 42min 39s
Wall time: 3min 5s


In [17]:
%%time
seq_len = 1200
n_seqs = 20
sequences = [''.join(random.choice("ATCG") for i in range(seq_len)) for j in range(n_seqs)]
tokens_ids = [b[1] for b in tokenizer.batch_tokenize(sequences)]
tokens_str = [b[0] for b in tokenizer.batch_tokenize(sequences)]
tokens = jnp.asarray(tokens_ids, dtype=jnp.int32)
outs3 = forward_fn.apply(parameters, random_key, tokens)

CPU times: user 32min 12s, sys: 11min 35s, total: 43min 48s
Wall time: 3min 5s


In [18]:
%%time
seq_len = 5900
n_seqs = 20
sequences = [''.join(random.choice("ATCG") for i in range(seq_len)) for j in range(n_seqs)]
tokens_ids = [b[1] for b in tokenizer.batch_tokenize(sequences)]
tokens_str = [b[0] for b in tokenizer.batch_tokenize(sequences)]
tokens = jnp.asarray(tokens_ids, dtype=jnp.int32)
outs4 = forward_fn.apply(parameters, random_key, tokens)

CPU times: user 32min 33s, sys: 11min 28s, total: 44min 1s
Wall time: 3min 4s


In [28]:
%%time
seq_len = 5989
n_seqs = 20
sequences = [''.join(random.choice("ATCG") for i in range(seq_len)) for j in range(n_seqs)]
tokens_ids = [b[1] for b in tokenizer.batch_tokenize(sequences)]
tokens_str = [b[0] for b in tokenizer.batch_tokenize(sequences)]
tokens = jnp.asarray(tokens_ids, dtype=jnp.int32)
outs4 = forward_fn.apply(parameters, random_key, tokens)

CPU times: user 32min 37s, sys: 11min 43s, total: 44min 20s
Wall time: 3min 3s


In [29]:
%%time
seq_len = 5990
n_seqs = 20
sequences = [''.join(random.choice("ATCG") for i in range(seq_len)) for j in range(n_seqs)]
tokens_ids = [b[1] for b in tokenizer.batch_tokenize(sequences)]
tokens_str = [b[0] for b in tokenizer.batch_tokenize(sequences)]
tokens = jnp.asarray(tokens_ids, dtype=jnp.int32)
outs4 = forward_fn.apply(parameters, random_key, tokens)

ValueError: Found a sequence with length 1001 that exceeds the fixed length to tokenize (1000).

## Retrieve embeddings

In [None]:
print(outs["embeddings_20"].shape)

**Additional Tip**: Don't forget to remove the cls token and padded positions if you want for instance to compute mean embeddings!

In [None]:
embeddings = outs["embeddings_20"][:, 1:, :]  # removing CLS token
padding_mask = jnp.expand_dims(tokens[:, 1:] != tokenizer.pad_token_id, axis=-1)
masked_embeddings = embeddings * padding_mask  # multiply by 0 pad tokens embeddings
sequences_lengths = jnp.sum(padding_mask, axis=1)
mean_embeddings = jnp.sum(masked_embeddings, axis=1) / sequences_lengths

In [None]:
print(mean_embeddings.shape)