## Installation and imports

In [1]:
import os
import nucleotide_transformer
import random
random.seed(123)

In [2]:
import haiku as hk
import jax
import jax.numpy as jnp
from nucleotide_transformer.mypretrained import get_pretrained_model

In [3]:
# Initialize random key
random_key = jax.random.PRNGKey(0)

In [4]:
NT_DIR = os.path.abspath('.')

In [5]:
#@title Select a model
#@markdown ---
model_name = '2B5_multi_species'#@param['500M_human_ref', '500M_1000G', '2B5_1000G', '2B5_multi_species']
#@markdown ---

In [None]:
# Get pretrained model
parameters, forward_fn, tokenizer, config = get_pretrained_model(
    model_name=model_name,
    mixed_precision=False,
    embeddings_layers_to_save=(20),
    attention_maps_to_save=(),
    max_positions=1000,
    chkpt_dir = os.path.join(NT_DIR, 'checkpoints')
)
forward_fn = hk.transform(forward_fn)

## Do the Inference
The first time you query this cell will be slower than usual inference because of the computation graph compilation.

In [17]:
tokens.shape

(20, 1000)

In [13]:
%%time
seq_len = 5979
n_seqs = 20
sequences = [''.join(random.choice("ATCG") for i in range(seq_len-1)) + 'N' for j in range(n_seqs)]
tokens_ids = [b[1] for b in tokenizer.batch_tokenize(sequences)]
tokens_str = [b[0] for b in tokenizer.batch_tokenize(sequences)]
tokens = jnp.asarray(tokens_ids, dtype=jnp.int32)
outs4 = forward_fn.apply(parameters, random_key, tokens)
del outs4

TypeError: argument of type 'int' is not iterable

In [None]:
%%time
seq_len = 5985
n_seqs = 20
sequences = [''.join(random.choice("ATNCG") for i in range(seq_len)) for j in range(n_seqs)]
tokens_ids = [b[1] for b in tokenizer.batch_tokenize(sequences)]
tokens_str = [b[0] for b in tokenizer.batch_tokenize(sequences)]
tokens = jnp.asarray(tokens_ids, dtype=jnp.int32)
outs4 = forward_fn.apply(parameters, random_key, tokens)
del outs4

In [28]:
%%time
seq_len = 5989
n_seqs = 20
sequences = [''.join(random.choice("ATNCG") for i in range(seq_len)) for j in range(n_seqs)]
tokens_ids = [b[1] for b in tokenizer.batch_tokenize(sequences)]
tokens_str = [b[0] for b in tokenizer.batch_tokenize(sequences)]
tokens = jnp.asarray(tokens_ids, dtype=jnp.int32)
outs4 = forward_fn.apply(parameters, random_key, tokens)
del outs4

CPU times: user 32min 37s, sys: 11min 43s, total: 44min 20s
Wall time: 3min 3s


## Retrieve embeddings

In [None]:
print(outs["embeddings_20"].shape)

**Additional Tip**: Don't forget to remove the cls token and padded positions if you want for instance to compute mean embeddings!

In [None]:
embeddings = outs["embeddings_20"][:, 1:, :]  # removing CLS token
padding_mask = jnp.expand_dims(tokens[:, 1:] != tokenizer.pad_token_id, axis=-1)
masked_embeddings = embeddings * padding_mask  # multiply by 0 pad tokens embeddings
sequences_lengths = jnp.sum(padding_mask, axis=1)
mean_embeddings = jnp.sum(masked_embeddings, axis=1) / sequences_lengths

In [None]:
print(mean_embeddings.shape)