In [None]:
from bend.models.hyena_dna import HyenaDNAPreTrainedModel, CharacterTokenizer
import os
from utils import generate_random_dna_sequence, get_device, remove_special_tokens_and_padding
import torch
import numpy as np

WORK_PATH = '../../'
EMBEDDER_DIR = os.path.join(WORK_PATH, 'pretrained_models', 'hyenadna')
EMBEDDER_NAME = 'hyenadna-tiny-1k-seqlen'
EMBEDDER_PATH = os.path.join(EMBEDDER_DIR, EMBEDDER_NAME)

PADDING_VALUE = -100

MAX_LENGTHS = {
    "hyenadna-tiny-1k-seqlen": 1024,
    "hyenadna-small-32k-seqlen": 32768,
    "hyenadna-medium-160k-seqlen": 160000,
    "hyenadna-medium-450k-seqlen": 450000,
    "hyenadna-large-1m-seqlen": 1_000_000,
}

device = get_device()

Using device: mps


In [6]:
tokenizer = CharacterTokenizer(
    characters=["A", "C", "G", "T", "N"],  # add DNA characters, N is uncertain
    model_max_length=MAX_LENGTHS[EMBEDDER_NAME]
    + 2,  # to account for special tokens, like EOS
    add_special_tokens=False,  # we handle special tokens elsewhere
    padding_side="left",  # since HyenaDNA is causal, we pad on the left
)

model = HyenaDNAPreTrainedModel.from_pretrained(
    os.path.join(EMBEDDER_DIR, EMBEDDER_NAME),
    EMBEDDER_NAME,
    download=not os.path.exists(EMBEDDER_DIR),
    config=None,
    device=device,
    use_head=False,
    use_lm_head=False,  # we don't use the LM head for embeddings
    n_classes=2,
).eval().to(device)

Loaded pretrained weights ok!


In [7]:
sequences = [generate_random_dna_sequence(min_length=5, max_length=15) for _ in range(10)]
sequences

['ACATGTTCG',
 'ACTTCGAGGATATG',
 'CATAACTTT',
 'TCTCT',
 'TCATATGTC',
 'TTGAGCTGAGGTCCG',
 'GCTCAGCAAATAA',
 'TCATCATCG',
 'ATCGTGTGTGGGG',
 'GCACAC']

#### Tokenise sequences

In [21]:
output = tokenizer(
    sequences,
    return_tensors="pt",
    return_token_type_ids=False,
    padding="longest",
)

input_ids = output["input_ids"]
attention_mask = output["attention_mask"]

input_ids

tensor([[ 4,  4,  4,  4,  4,  4,  0,  7,  8,  7, 10,  9, 10, 10,  8,  9,  1],
        [ 4,  0,  7,  8, 10, 10,  8,  9,  7,  9,  9,  7, 10,  7, 10,  9,  1],
        [ 4,  4,  4,  4,  4,  4,  0,  8,  7, 10,  7,  7,  8, 10, 10, 10,  1],
        [ 4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  0, 10,  8, 10,  8, 10,  1],
        [ 4,  4,  4,  4,  4,  4,  0, 10,  8,  7, 10,  7, 10,  9, 10,  8,  1],
        [ 0, 10, 10,  9,  7,  9,  8, 10,  9,  7,  9,  9, 10,  8,  8,  9,  1],
        [ 4,  4,  0,  9,  8, 10,  8,  7,  9,  8,  7,  7,  7, 10,  7,  7,  1],
        [ 4,  4,  4,  4,  4,  4,  0, 10,  8,  7, 10,  8,  7, 10,  8,  9,  1],
        [ 4,  4,  0,  7, 10,  8,  9, 10,  9, 10,  9, 10,  9,  9,  9,  9,  1],
        [ 4,  4,  4,  4,  4,  4,  4,  4,  4,  0,  9,  8,  7,  8,  7,  8,  1]])

In [22]:
for ids in input_ids:
    print(tokenizer.convert_ids_to_tokens(ids, skip_special_tokens=False))

['[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[CLS]', 'A', 'C', 'A', 'T', 'G', 'T', 'T', 'C', 'G', '[SEP]']
['[PAD]', '[CLS]', 'A', 'C', 'T', 'T', 'C', 'G', 'A', 'G', 'G', 'A', 'T', 'A', 'T', 'G', '[SEP]']
['[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[CLS]', 'C', 'A', 'T', 'A', 'A', 'C', 'T', 'T', 'T', '[SEP]']
['[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[CLS]', 'T', 'C', 'T', 'C', 'T', '[SEP]']
['[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[CLS]', 'T', 'C', 'A', 'T', 'A', 'T', 'G', 'T', 'C', '[SEP]']
['[CLS]', 'T', 'T', 'G', 'A', 'G', 'C', 'T', 'G', 'A', 'G', 'G', 'T', 'C', 'C', 'G', '[SEP]']
['[PAD]', '[PAD]', '[CLS]', 'G', 'C', 'T', 'C', 'A', 'G', 'C', 'A', 'A', 'A', 'T', 'A', 'A', '[SEP]']
['[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[CLS]', 'T', 'C', 'A', 'T', 'C', 'A', 'T', 'C', 'G', '[SEP]']
['[PAD]', '[PAD]', '[CLS]', 'A', 'T', 'C', 'G', 'T', 'G', 'T', 'G', 'T', 'G', 'G', 'G', 'G', '[SEP]'

In [23]:
attention_mask

tensor([[0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]])

#### Embed sequences

In [24]:
input_ids = torch.LongTensor(input_ids)

embeddings = model(input_ids=input_ids.to(device)).detach().cpu()
embeddings.size()

torch.Size([10, 17, 128])

In [28]:
embeddings[:, -1, -1]  # last token embeddings

tensor([-0.1129,  1.1244,  0.4640,  0.2453,  0.1486,  0.7546,  0.2239,  0.0606,
         0.1209,  0.8705])

#### Remove special tokens and stack into batch

In [None]:
# Padding is on the left side, hence the need to remove it, then apply padding on the right side.
# Also, removes special tokens like [CLS] and [SEP].

masked_embeddings = []

for ids, emb in zip(input_ids, embeddings):
    print('Embedding size: ', emb.size())
    masked_emb = remove_special_tokens_and_padding(tokenizer, ids, emb)
    print('Masked emb size after postprocessing: ', masked_emb.size())
    masked_embeddings.append(masked_emb)

masked_embeddings = torch.nn.utils.rnn.pad_sequence(
        masked_embeddings, batch_first=True, padding_value=PADDING_VALUE)
masked_embeddings.size()

Embedding size:  torch.Size([17, 128])
Masked emb size after postprocessing:  torch.Size([9, 128])
Embedding size:  torch.Size([17, 128])
Masked emb size after postprocessing:  torch.Size([14, 128])
Embedding size:  torch.Size([17, 128])
Masked emb size after postprocessing:  torch.Size([9, 128])
Embedding size:  torch.Size([17, 128])
Masked emb size after postprocessing:  torch.Size([5, 128])
Embedding size:  torch.Size([17, 128])
Masked emb size after postprocessing:  torch.Size([9, 128])
Embedding size:  torch.Size([17, 128])
Masked emb size after postprocessing:  torch.Size([15, 128])
Embedding size:  torch.Size([17, 128])
Masked emb size after postprocessing:  torch.Size([13, 128])
Embedding size:  torch.Size([17, 128])
Masked emb size after postprocessing:  torch.Size([9, 128])
Embedding size:  torch.Size([17, 128])
Masked emb size after postprocessing:  torch.Size([13, 128])
Embedding size:  torch.Size([17, 128])
Masked emb size after postprocessing:  torch.Size([6, 128])


torch.Size([10, 15, 128])

In [27]:
masked_embeddings[:, -1, -1]

tensor([-1.0000e+02, -1.0000e+02, -1.0000e+02, -1.0000e+02, -1.0000e+02,
        -2.0537e-02, -1.0000e+02, -1.0000e+02, -1.0000e+02, -1.0000e+02])