In [78]:
from bend.models.awd_lstm import AWDLSTMModelForInference
from transformers import AutoTokenizer
import os
from utils import generate_random_dna_sequence, get_device, pad_embeddings, remove_special_tokens_and_padding
import torch
import numpy as np

WORK_PATH = '../../'
EMBEDDER_DIR = os.path.join(WORK_PATH, 'pretrained_models')
EMBEDDER_NAME = 'awd_lstm'
EMBEDDER_PATH = os.path.join(EMBEDDER_DIR, EMBEDDER_NAME)

MAX_SEQ_LENGTH = 4

PADDING_VALUE = -100

device = get_device()

Using device: mps


In [79]:
tokenizer = AutoTokenizer.from_pretrained(EMBEDDER_PATH)
model = AWDLSTMModelForInference.from_pretrained(EMBEDDER_PATH).to(device).eval().eval().to(device)

In [80]:
sequences = [generate_random_dna_sequence(min_length=5, max_length=15) for _ in range(10)]
sequences

['ATCAGGCCCAC',
 'GACATGCCGGACGTT',
 'TTCAGT',
 'CACTAAGTGTAGTC',
 'GGAGCATT',
 'GTTCAGGTAA',
 'ATTTTTT',
 'ATCAGGCTCGGGT',
 'CTACTGC',
 'CCCACGTT']

#### Tokenise sequences

In [97]:
output = tokenizer(
    sequences,
    return_tensors="pt",
    return_token_type_ids=False,
    padding="longest",
)

input_ids = output["input_ids"]
attention_mask = output["attention_mask"]

input_ids

tensor([[ 8, 11,  9,  8, 10, 10,  9,  9,  9,  8,  9,  0,  0,  0,  0],
        [10,  8,  9,  8, 11, 10,  9,  9, 10, 10,  8,  9, 10, 11, 11],
        [11, 11,  9,  8, 10, 11,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 9,  8,  9, 11,  8,  8, 10, 11, 10, 11,  8, 10, 11,  9,  0],
        [10, 10,  8, 10,  9,  8, 11, 11,  0,  0,  0,  0,  0,  0,  0],
        [10, 11, 11,  9,  8, 10, 10, 11,  8,  8,  0,  0,  0,  0,  0],
        [ 8, 11, 11, 11, 11, 11, 11,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 8, 11,  9,  8, 10, 10,  9, 11,  9, 10, 10, 10, 11,  0,  0],
        [ 9, 11,  8,  9, 11, 10,  9,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 9,  9,  9,  8,  9, 10, 11, 11,  0,  0,  0,  0,  0,  0,  0]])

In [98]:
for ids in input_ids:
    print(tokenizer.convert_ids_to_tokens(ids, skip_special_tokens=False))

['a', 't', 'c', 'a', 'g', 'g', 'c', 'c', 'c', 'a', 'c', '[PAD]', '[PAD]', '[PAD]', '[PAD]']
['g', 'a', 'c', 'a', 't', 'g', 'c', 'c', 'g', 'g', 'a', 'c', 'g', 't', 't']
['t', 't', 'c', 'a', 'g', 't', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]']
['c', 'a', 'c', 't', 'a', 'a', 'g', 't', 'g', 't', 'a', 'g', 't', 'c', '[PAD]']
['g', 'g', 'a', 'g', 'c', 'a', 't', 't', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]']
['g', 't', 't', 'c', 'a', 'g', 'g', 't', 'a', 'a', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]']
['a', 't', 't', 't', 't', 't', 't', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]']
['a', 't', 'c', 'a', 'g', 'g', 'c', 't', 'c', 'g', 'g', 'g', 't', '[PAD]', '[PAD]']
['c', 't', 'a', 'c', 't', 'g', 'c', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]']
['c', 'c', 'c', 'a', 'c', 'g', 't', 't', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]']


In [99]:
attention_mask

tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0]])

#### Embed sequences

In [100]:
embeddings = model(input_ids=input_ids.to(device)).last_hidden_state.detach()
embeddings.size()

torch.Size([10, 15, 64])

In [101]:
embeddings[:, -1, -1]

tensor([-0.0009,  0.0065, -0.0006,  0.0011,  0.0009, -0.0003, -0.0029,  0.0031,
        -0.0004, -0.0005], device='mps:0')

#### Remove padding and stack into batch

In [102]:
# There are no [CLS] and [SEP] tokens in the AWDLSTM model
# and the embeddings are padded to the right.
# We can use the attention mask directly

embeddings = pad_embeddings(embeddings, attention_mask)
embeddings.size()

torch.Size([10, 15, 64])

In [103]:
embeddings[:, -1, -1]

tensor([-9.0983e-04,  6.5451e-03, -6.4345e-04,  1.0745e-03,  8.6566e-04,
        -3.3110e-04, -2.8806e-03,  3.1128e-03, -1.0000e+02, -1.0000e+02],
       device='mps:0')

### Remove padding and convert to list

In [104]:
attention_mask = attention_mask.numpy().astype(bool)

masked_embeddings = []
for idx in range(len(embeddings)):
    # Remove padding from embeddings
    masked_embeddings.append(
        embeddings[idx][attention_mask[idx]]
    )

for emb in masked_embeddings:
    print(emb.shape)
     

torch.Size([11, 64])
torch.Size([15, 64])
torch.Size([6, 64])
torch.Size([14, 64])
torch.Size([8, 64])
torch.Size([10, 64])
torch.Size([7, 64])
torch.Size([13, 64])
torch.Size([7, 64])
torch.Size([8, 64])
