In [None]:
from bend.models.dilated_cnn import ConvNetModel
from transformers import AutoTokenizer
import os
from utils import generate_random_dna_sequence, get_device, pad_embeddings
import torch

WORK_PATH = '../../'
EMBEDDER_DIR = os.path.join(WORK_PATH, 'pretrained_models')
EMBEDDER_NAME = 'convnet'
EMBEDDER_PATH = os.path.join(EMBEDDER_DIR, EMBEDDER_NAME)

PADDING_VALUE = -100

device = get_device()

  torch.utils._pytree._register_pytree_node(
  torch.utils._pytree._register_pytree_node(


Using device: mps


In [2]:
tokenizer = AutoTokenizer.from_pretrained(EMBEDDER_PATH)
model = ConvNetModel.from_pretrained(EMBEDDER_PATH).to(device).eval().eval().to(device)

In [3]:
sequences = [generate_random_dna_sequence(min_length=5, max_length=15) for _ in range(10)]
sequences

['TGCTGCG',
 'ACTGTAACGATCT',
 'TCCGC',
 'ACCCTGGGGAGAGC',
 'ATTATA',
 'ATTGTATCACCGA',
 'TAGTTGCCTCCCTGC',
 'TTACGGGA',
 'GACACGCTA',
 'AATCGCTG']

#### Tokenise sequences

In [4]:
output = tokenizer(
    sequences,
    return_tensors="pt",
    return_token_type_ids=False,
    padding="longest",
)

input_ids = output["input_ids"]
attention_mask = output["attention_mask"]

input_ids

tensor([[6, 5, 4, 6, 5, 4, 5, 0, 0, 0, 0, 0, 0, 0, 0],
        [3, 4, 6, 5, 6, 3, 3, 4, 5, 3, 6, 4, 6, 0, 0],
        [6, 4, 4, 5, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [3, 4, 4, 4, 6, 5, 5, 5, 5, 3, 5, 3, 5, 4, 0],
        [3, 6, 6, 3, 6, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [3, 6, 6, 5, 6, 3, 6, 4, 3, 4, 4, 5, 3, 0, 0],
        [6, 3, 5, 6, 6, 5, 4, 4, 6, 4, 4, 4, 6, 5, 4],
        [6, 6, 3, 4, 5, 5, 5, 3, 0, 0, 0, 0, 0, 0, 0],
        [5, 3, 4, 3, 4, 5, 4, 6, 3, 0, 0, 0, 0, 0, 0],
        [3, 3, 6, 4, 5, 4, 6, 5, 0, 0, 0, 0, 0, 0, 0]])

In [5]:
for ids in input_ids:
    print(tokenizer.convert_ids_to_tokens(ids, skip_special_tokens=False))

['t', 'g', 'c', 't', 'g', 'c', 'g', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]']
['a', 'c', 't', 'g', 't', 'a', 'a', 'c', 'g', 'a', 't', 'c', 't', '[PAD]', '[PAD]']
['t', 'c', 'c', 'g', 'c', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]']
['a', 'c', 'c', 'c', 't', 'g', 'g', 'g', 'g', 'a', 'g', 'a', 'g', 'c', '[PAD]']
['a', 't', 't', 'a', 't', 'a', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]']
['a', 't', 't', 'g', 't', 'a', 't', 'c', 'a', 'c', 'c', 'g', 'a', '[PAD]', '[PAD]']
['t', 'a', 'g', 't', 't', 'g', 'c', 'c', 't', 'c', 'c', 'c', 't', 'g', 'c']
['t', 't', 'a', 'c', 'g', 'g', 'g', 'a', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]']
['g', 'a', 'c', 'a', 'c', 'g', 'c', 't', 'a', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]']
['a', 'a', 't', 'c', 'g', 'c', 't', 'g', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]']


In [6]:
attention_mask

tensor([[1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0]])

#### Embed sequences

In [26]:
embeddings = model(input_ids=input_ids.to(device)).last_hidden_state.detach()
embeddings.size()

torch.Size([10, 15, 256])

In [27]:
embeddings

tensor([[[ 2.3741e+00,  2.5866e+00,  3.5126e-01,  ..., -3.7956e+00,
           7.1755e-01,  5.9118e+00],
         [ 8.9125e+00,  5.6957e+00, -4.0249e+00,  ..., -2.4655e-02,
           3.1164e-01, -3.6897e+00],
         [ 7.6138e+00,  5.4017e+00, -2.0462e+00,  ..., -1.9721e-02,
           1.3770e+00, -1.8145e+00],
         ...,
         [ 7.3280e+00,  1.5545e+00, -8.3563e-01,  ..., -4.6375e+00,
           3.7045e-01,  7.1262e+00],
         [ 8.2953e+00,  3.1282e-01, -2.2682e+00,  ..., -1.4121e+00,
           1.6485e+00, -1.5751e+00],
         [ 7.5224e+00, -2.9341e+00, -2.3099e+00,  ..., -1.8222e+00,
          -1.0883e+00,  2.9411e+00]],

        [[ 4.2401e+00,  1.1803e+00, -2.7010e+00,  ..., -1.5993e+00,
           2.0799e+00,  5.3699e+00],
         [ 1.5626e+00,  2.7013e+00, -8.0157e-01,  ...,  7.3217e-01,
           1.8076e+00,  2.2686e+00],
         [ 5.7458e+00, -4.9225e+00,  3.4357e+00,  ..., -1.9603e-01,
          -2.7376e-01,  8.5376e-01],
         ...,
         [ 6.1897e+00, -1

#### Remove padding and stack into batch

In [None]:
# There are no [CLS] and [SEP] tokens in the ConvNet model 
# and the embeddings are padded to the right.
# We can use the attention mask directly

embeddings = pad_embeddings(embeddings, attention_mask)
embeddings.size()

tensor([[[ 2.3741e+00,  2.5866e+00,  3.5126e-01,  ..., -3.7956e+00,
           7.1755e-01,  5.9118e+00],
         [ 8.9125e+00,  5.6957e+00, -4.0249e+00,  ..., -2.4655e-02,
           3.1164e-01, -3.6897e+00],
         [ 7.6138e+00,  5.4017e+00, -2.0462e+00,  ..., -1.9721e-02,
           1.3770e+00, -1.8145e+00],
         ...,
         [ 7.3280e+00,  1.5545e+00, -8.3563e-01,  ..., -4.6375e+00,
           3.7045e-01,  7.1262e+00],
         [ 8.2953e+00,  3.1282e-01, -2.2682e+00,  ..., -1.4121e+00,
           1.6485e+00, -1.5751e+00],
         [ 7.5224e+00, -2.9341e+00, -2.3099e+00,  ..., -1.8222e+00,
          -1.0883e+00,  2.9411e+00]],

        [[ 4.2401e+00,  1.1803e+00, -2.7010e+00,  ..., -1.5993e+00,
           2.0799e+00,  5.3699e+00],
         [ 1.5626e+00,  2.7013e+00, -8.0157e-01,  ...,  7.3217e-01,
           1.8076e+00,  2.2686e+00],
         [ 5.7458e+00, -4.9225e+00,  3.4357e+00,  ..., -1.9603e-01,
          -2.7376e-01,  8.5376e-01],
         ...,
         [ 6.1897e+00, -1

In [29]:
masked_embeddings[:, -1, -1]

tensor([-100.0000, -100.0000, -100.0000, -100.0000, -100.0000, -100.0000,
           1.4555, -100.0000, -100.0000, -100.0000], device='mps:0')