In [None]:
from bend.models.awd_lstm import AWDLSTMModelForInference
from transformers import AutoTokenizer
import os
from utils import generate_random_dna_sequence, get_device, pad_embeddings
import torch

WORK_PATH = '../../'
EMBEDDER_DIR = os.path.join(WORK_PATH, 'pretrained_models')
EMBEDDER_NAME = 'awd_lstm'
EMBEDDER_PATH = os.path.join(EMBEDDER_DIR, EMBEDDER_NAME)

PADDING_VALUE = -100

device = get_device()

Using device: mps


In [14]:
tokenizer = AutoTokenizer.from_pretrained(EMBEDDER_PATH)
model = AWDLSTMModelForInference.from_pretrained(EMBEDDER_PATH).to(device).eval().eval().to(device)

In [15]:
sequences = [generate_random_dna_sequence(min_length=5, max_length=15) for _ in range(10)]
sequences

['CCGGAT',
 'GTGCACGTGA',
 'GCAAGCA',
 'GCCGTAACCCTC',
 'GCACC',
 'ATTGC',
 'GTGCAG',
 'TATCCGAGT',
 'TTGGCTGTTTGGA',
 'TCAACGGAGTGCA']

#### Tokenise sequences

In [16]:
output = tokenizer(
    sequences,
    return_tensors="pt",
    return_token_type_ids=False,
    padding="longest",
)

input_ids = output["input_ids"]
attention_mask = output["attention_mask"]

input_ids

tensor([[ 9,  9, 10, 10,  8, 11,  0,  0,  0,  0,  0,  0,  0],
        [10, 11, 10,  9,  8,  9, 10, 11, 10,  8,  0,  0,  0],
        [10,  9,  8,  8, 10,  9,  8,  0,  0,  0,  0,  0,  0],
        [10,  9,  9, 10, 11,  8,  8,  9,  9,  9, 11,  9,  0],
        [10,  9,  8,  9,  9,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 8, 11, 11, 10,  9,  0,  0,  0,  0,  0,  0,  0,  0],
        [10, 11, 10,  9,  8, 10,  0,  0,  0,  0,  0,  0,  0],
        [11,  8, 11,  9,  9, 10,  8, 10, 11,  0,  0,  0,  0],
        [11, 11, 10, 10,  9, 11, 10, 11, 11, 11, 10, 10,  8],
        [11,  9,  8,  8,  9, 10, 10,  8, 10, 11, 10,  9,  8]])

In [17]:
for ids in input_ids:
    print(tokenizer.convert_ids_to_tokens(ids, skip_special_tokens=False))

['c', 'c', 'g', 'g', 'a', 't', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]']
['g', 't', 'g', 'c', 'a', 'c', 'g', 't', 'g', 'a', '[PAD]', '[PAD]', '[PAD]']
['g', 'c', 'a', 'a', 'g', 'c', 'a', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]']
['g', 'c', 'c', 'g', 't', 'a', 'a', 'c', 'c', 'c', 't', 'c', '[PAD]']
['g', 'c', 'a', 'c', 'c', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]']
['a', 't', 't', 'g', 'c', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]']
['g', 't', 'g', 'c', 'a', 'g', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]']
['t', 'a', 't', 'c', 'c', 'g', 'a', 'g', 't', '[PAD]', '[PAD]', '[PAD]', '[PAD]']
['t', 't', 'g', 'g', 'c', 't', 'g', 't', 't', 't', 'g', 'g', 'a']
['t', 'c', 'a', 'a', 'c', 'g', 'g', 'a', 'g', 't', 'g', 'c', 'a']


In [18]:
attention_mask

tensor([[1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])

#### Embed sequences

In [19]:
embeddings = model(input_ids=input_ids.to(device)).last_hidden_state.detach()
embeddings.size()

torch.Size([10, 13, 64])

In [20]:
embeddings[:, -1, -1]

tensor([ 5.7450e-04, -3.8959e-03,  1.6868e-03,  2.5601e-03, -1.4707e-05,
        -7.9252e-04,  2.0713e-03,  3.0819e-03,  5.4204e-03,  3.2390e-03],
       device='mps:0')

#### Remove padding and stack into batch

In [None]:
# There are no [CLS] and [SEP] tokens in the AWDLSTM model
# and the embeddings are padded to the right.
# We can use the attention mask directly

embeddings = pad_embeddings(embeddings, attention_mask)
embeddings.size()

torch.Size([10, 13, 64])

In [27]:
embeddings[:, -1, -1]

tensor([ 5.7450e-04, -3.8959e-03,  1.6868e-03,  2.5601e-03, -1.4707e-05,
        -7.9252e-04,  2.0713e-03,  3.0819e-03, -1.0000e+02, -1.0000e+02],
       device='mps:0')