In [None]:
from bend.models.dnabert2 import BertForMaskedLM as DNABert2BertForMaskedLM
from transformers import AutoTokenizer
import os
from utils import generate_random_dna_sequence, get_device, remove_special_tokens_and_padding, upsample
import torch

EMBEDDER_PATH = 'zhihan1996/DNABERT-2-117M'

PADDING_VALUE = -100

device = get_device()

  torch.utils._pytree._register_pytree_node(
  torch.utils._pytree._register_pytree_node(


Using device: mps


In [2]:
model = DNABert2BertForMaskedLM.from_pretrained(EMBEDDER_PATH).eval().to(device)
tokenizer = AutoTokenizer.from_pretrained(
    EMBEDDER_PATH, trust_remote_code=True
)



In [3]:
sequences = [generate_random_dna_sequence(min_length=5, max_length=15) for _ in range(10)]
sequences

['ACGAATGAGGCC',
 'TTGCGTGCGC',
 'GGACATC',
 'AAACATATGGCGGC',
 'TTTACCTAGAGA',
 'CGGGTT',
 'ATAGGCTACGGGTT',
 'GGTAG',
 'TCGCACTCAG',
 'CAGTG']

#### Tokenise sequences

In [4]:
output = tokenizer(
    sequences,
    return_tensors="pt",
    return_token_type_ids=False,
    padding="longest",
)

input_ids = output["input_ids"]
attention_mask = output["attention_mask"]

input_ids

tensor([[   1,    5,  166,   23,  136,    2],
        [   1,   10,  577,  118,    2,    3],
        [   1,   33,  278,    2,    3,    3],
        [   1,   18,  902,  247,    6,    2],
        [   1,   94,  114,   50,    2,    3],
        [   1,   72,   31,    2,    3,    3],
        [   1,    5,   99, 3547,   31,    2],
        [   1,  138,    7,    2,    3,    3],
        [   1,  704,   63,    7,    2,    3],
        [   1,  176,    2,    3,    3,    3]])

In [5]:
for ids in input_ids:
    print(tokenizer.convert_ids_to_tokens(ids, skip_special_tokens=False))

['[CLS]', 'A', 'CGAA', 'TGA', 'GGCC', '[SEP]']
['[CLS]', 'TT', 'GCGTG', 'CGC', '[SEP]', '[PAD]']
['[CLS]', 'GGA', 'CATC', '[SEP]', '[PAD]', '[PAD]']
['[CLS]', 'AAA', 'CATATG', 'GCGG', 'C', '[SEP]']
['[CLS]', 'TTTA', 'CCTA', 'GAGA', '[SEP]', '[PAD]']
['[CLS]', 'CGG', 'GTT', '[SEP]', '[PAD]', '[PAD]']
['[CLS]', 'A', 'TAGG', 'CTACGG', 'GTT', '[SEP]']
['[CLS]', 'GGTA', 'G', '[SEP]', '[PAD]', '[PAD]']
['[CLS]', 'TCGCA', 'CTCA', 'G', '[SEP]', '[PAD]']
['[CLS]', 'CAGTG', '[SEP]', '[PAD]', '[PAD]', '[PAD]']


In [6]:
attention_mask

tensor([[1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 0],
        [1, 1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 0],
        [1, 1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1, 0],
        [1, 1, 1, 0, 0, 0]])

#### Embed Sequences

In [14]:
embeddings = model(input_ids=input_ids.to(device), attention_mask=attention_mask.to(device))["hidden_states"].detach().cpu()
embeddings.size()

torch.Size([10, 6, 768])

In [18]:
embeddings[:, -1, -1]

tensor([-0.0203,  0.0000,  0.0000,  0.0340,  0.0000,  0.0000, -0.1651,  0.0000,
         0.0000,  0.0000])

#### Upsample

In [None]:
upsampled_embeddings = []

for ids, emb in zip(input_ids, embeddings):
    print('---')
    print('Embedding size: ', emb.size())
    masked_emb = remove_special_tokens_and_padding(tokenizer, ids, emb)
    print('Masked embedding size: ', masked_emb.size())
    masked_emb = upsample(tokenizer, ids, masked_emb)
    print('Upsampled embedding size: ', masked_emb.size())
    upsampled_embeddings.append(masked_emb)

upsampled_embeddings = torch.nn.utils.rnn.pad_sequence(
        upsampled_embeddings, batch_first=True, padding_value=PADDING_VALUE)
upsampled_embeddings.size()

---
Embedding size:  torch.Size([6, 768])
Masked embedding size:  torch.Size([4, 768])
Upsampled embedding size:  torch.Size([12, 768])
---
Embedding size:  torch.Size([6, 768])
Masked embedding size:  torch.Size([3, 768])
Upsampled embedding size:  torch.Size([10, 768])
---
Embedding size:  torch.Size([6, 768])
Masked embedding size:  torch.Size([2, 768])
Upsampled embedding size:  torch.Size([7, 768])
---
Embedding size:  torch.Size([6, 768])
Masked embedding size:  torch.Size([4, 768])
Upsampled embedding size:  torch.Size([14, 768])
---
Embedding size:  torch.Size([6, 768])
Masked embedding size:  torch.Size([3, 768])
Upsampled embedding size:  torch.Size([12, 768])
---
Embedding size:  torch.Size([6, 768])
Masked embedding size:  torch.Size([2, 768])
Upsampled embedding size:  torch.Size([6, 768])
---
Embedding size:  torch.Size([6, 768])
Masked embedding size:  torch.Size([4, 768])
Upsampled embedding size:  torch.Size([14, 768])
---
Embedding size:  torch.Size([6, 768])
Masked e

torch.Size([10, 14, 768])

In [17]:
upsampled_embeddings[:, -1, -1]

tensor([-1.0000e+02, -1.0000e+02, -1.0000e+02,  3.4659e-01, -1.0000e+02,
        -1.0000e+02,  7.2293e-03, -1.0000e+02, -1.0000e+02, -1.0000e+02])