In [None]:
from bend.models.dnabert2 import BertForMaskedLM as DNABert2BertForMaskedLM
from transformers import AutoTokenizer
import os
from utils import generate_random_dna_sequence, get_device, chunkify_sequences, process_chunk_embeddings
import torch
import numpy as np

EMBEDDER_PATH = 'zhihan1996/DNABERT-2-117M'

PADDING_VALUE = -100

device = get_device()

Using device: mps


In [4]:
model = DNABert2BertForMaskedLM.from_pretrained(EMBEDDER_PATH).eval().to(device)
tokenizer = AutoTokenizer.from_pretrained(
    EMBEDDER_PATH, trust_remote_code=True
)



In [5]:
sequences = [generate_random_dna_sequence(min_length=5, max_length=15) for _ in range(10)]

for s in sequences:
    print(s, len(s))

CGCCAGAGCACTAT 14
GCCGACC 7
AGTAGATGA 9
ATCGAAGAAA 10
GATTTACACTTAGTT 15
TGAACTCC 8
TTGTT 5
GGCTCGG 7
GTAGATCTGGAC 12
TCATATCTAAAA 12


In [6]:
MAX_MODEL_LENGTH = 4
chunked_sequences, chunk_ids = chunkify_sequences(sequences, MAX_MODEL_LENGTH)

for seq, chunk_id in zip(chunked_sequences, chunk_ids):
    print(f"Chunk ID: {chunk_id}, Sequence: {seq}")

Chunk ID: 0, Sequence: CGCC
Chunk ID: 0, Sequence: AGAG
Chunk ID: 0, Sequence: CACT
Chunk ID: 0, Sequence: AT
Chunk ID: 1, Sequence: GCCG
Chunk ID: 1, Sequence: ACC
Chunk ID: 2, Sequence: AGTA
Chunk ID: 2, Sequence: GATG
Chunk ID: 2, Sequence: A
Chunk ID: 3, Sequence: ATCG
Chunk ID: 3, Sequence: AAGA
Chunk ID: 3, Sequence: AA
Chunk ID: 4, Sequence: GATT
Chunk ID: 4, Sequence: TACA
Chunk ID: 4, Sequence: CTTA
Chunk ID: 4, Sequence: GTT
Chunk ID: 5, Sequence: TGAA
Chunk ID: 5, Sequence: CTCC
Chunk ID: 6, Sequence: TTGT
Chunk ID: 6, Sequence: T
Chunk ID: 7, Sequence: GGCT
Chunk ID: 7, Sequence: CGG
Chunk ID: 8, Sequence: GTAG
Chunk ID: 8, Sequence: ATCT
Chunk ID: 8, Sequence: GGAC
Chunk ID: 9, Sequence: TCAT
Chunk ID: 9, Sequence: ATCT
Chunk ID: 9, Sequence: AAAA


#### Tokenise sequences

In [7]:
output = tokenizer(
    chunked_sequences,
    return_tensors="pt",
    return_token_type_ids=False,
    padding="longest",
)

input_ids = output["input_ids"]
attention_mask = output["attention_mask"]

input_ids

tensor([[   1,  169,    2,    3,    3],
        [   1,    5,   17,    7,    2],
        [   1,   12, 1038,    2,    3],
        [   1,    5,    8,    2,    3],
        [   1,   36,    7,    2,    3],
        [   1,    5,   13,    2,    3],
        [   1,    5,   35,    2,    3],
        [   1,   83,    2,    3,    3],
        [   1,    5,    2,    3,    3],
        [   1,    5,   16,    7,    2],
        [   1,    9,   17,    2,    3],
        [   1,    9,    2,    3,    3],
        [   1,   73,    2,    3,    3],
        [   1,   80,    2,    3,    3],
        [   1,   81,    2,    3,    3],
        [   1,   31,    2,    3,    3],
        [   1,   52,    2,    3,    3],
        [   1,   78,    2,    3,    3],
        [   1,   10, 1049,    2,    3],
        [   1,    8,    2,    3,    3],
        [   1,   15, 1038,    2,    3],
        [   1,   72,    2,    3,    3],
        [   1,   35,    7,    2,    3],
        [   1,    5, 2877,    2,    3],
        [   1,   33,    6,    2,    3],


In [8]:
for ids in input_ids:
    print(tokenizer.convert_ids_to_tokens(ids, skip_special_tokens=False))

['[CLS]', 'CGCC', '[SEP]', '[PAD]', '[PAD]']
['[CLS]', 'A', 'GA', 'G', '[SEP]']
['[CLS]', 'CA', 'CT', '[SEP]', '[PAD]']
['[CLS]', 'A', 'T', '[SEP]', '[PAD]']
['[CLS]', 'GCC', 'G', '[SEP]', '[PAD]']
['[CLS]', 'A', 'CC', '[SEP]', '[PAD]']
['[CLS]', 'A', 'GTA', '[SEP]', '[PAD]']
['[CLS]', 'GATG', '[SEP]', '[PAD]', '[PAD]']
['[CLS]', 'A', '[SEP]', '[PAD]', '[PAD]']
['[CLS]', 'A', 'TC', 'G', '[SEP]']
['[CLS]', 'AA', 'GA', '[SEP]', '[PAD]']
['[CLS]', 'AA', '[SEP]', '[PAD]', '[PAD]']
['[CLS]', 'GATT', '[SEP]', '[PAD]', '[PAD]']
['[CLS]', 'TACA', '[SEP]', '[PAD]', '[PAD]']
['[CLS]', 'CTTA', '[SEP]', '[PAD]', '[PAD]']
['[CLS]', 'GTT', '[SEP]', '[PAD]', '[PAD]']
['[CLS]', 'TGAA', '[SEP]', '[PAD]', '[PAD]']
['[CLS]', 'CTCC', '[SEP]', '[PAD]', '[PAD]']
['[CLS]', 'TT', 'GT', '[SEP]', '[PAD]']
['[CLS]', 'T', '[SEP]', '[PAD]', '[PAD]']
['[CLS]', 'GG', 'CT', '[SEP]', '[PAD]']
['[CLS]', 'CGG', '[SEP]', '[PAD]', '[PAD]']
['[CLS]', 'GTA', 'G', '[SEP]', '[PAD]']
['[CLS]', 'A', 'TCT', '[SEP]', '[PAD]']
['[

In [9]:
attention_mask

tensor([[1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 0],
        [1, 1, 1, 1, 0],
        [1, 1, 1, 1, 0],
        [1, 1, 1, 1, 0],
        [1, 1, 1, 1, 0],
        [1, 1, 1, 0, 0],
        [1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 0],
        [1, 1, 1, 0, 0],
        [1, 1, 1, 0, 0],
        [1, 1, 1, 0, 0],
        [1, 1, 1, 0, 0],
        [1, 1, 1, 0, 0],
        [1, 1, 1, 0, 0],
        [1, 1, 1, 0, 0],
        [1, 1, 1, 1, 0],
        [1, 1, 1, 0, 0],
        [1, 1, 1, 1, 0],
        [1, 1, 1, 0, 0],
        [1, 1, 1, 1, 0],
        [1, 1, 1, 1, 0],
        [1, 1, 1, 1, 0],
        [1, 1, 1, 1, 0],
        [1, 1, 1, 1, 0],
        [1, 1, 1, 0, 0]])

#### Embed Sequences

In [10]:
embeddings = model(input_ids=input_ids.to(device), attention_mask=attention_mask.to(device))["hidden_states"].detach().cpu().numpy()
input_ids = input_ids.numpy()
embeddings.shape


(28, 5, 768)

In [11]:
embeddings[:, -1, -1]

array([ 0.        , -0.06105152,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.00236766,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ], dtype=float32)

#### Upsample

In [14]:
sequence_embeddings, masked_tokens = process_chunk_embeddings(tokenizer, embeddings, input_ids, chunk_ids, upsample=True)

for i, seq_emb in enumerate(sequence_embeddings):
    assert sequences[i] == ''.join(masked_tokens[i]), f"Mismatch in sequence {i}: {sequences[i]} != {''.join(masked_tokens[i])}"
    print(f"Chunk ID: {np.unique(chunk_ids)[i]}")
    print(f"Sequence: {sequences[i]}  Length: {len(sequences[i])}")
    print(f"  Tokens: {''.join(masked_tokens[i])}")
    print(f"Embedding shape: {seq_emb.shape}")

Chunk ID: 0
Sequence: CGCCAGAGCACTAT  Length: 14
  Tokens: CGCCAGAGCACTAT
Embedding shape: (14, 768)
Chunk ID: 1
Sequence: GCCGACC  Length: 7
  Tokens: GCCGACC
Embedding shape: (7, 768)
Chunk ID: 2
Sequence: AGTAGATGA  Length: 9
  Tokens: AGTAGATGA
Embedding shape: (9, 768)
Chunk ID: 3
Sequence: ATCGAAGAAA  Length: 10
  Tokens: ATCGAAGAAA
Embedding shape: (10, 768)
Chunk ID: 4
Sequence: GATTTACACTTAGTT  Length: 15
  Tokens: GATTTACACTTAGTT
Embedding shape: (15, 768)
Chunk ID: 5
Sequence: TGAACTCC  Length: 8
  Tokens: TGAACTCC
Embedding shape: (8, 768)
Chunk ID: 6
Sequence: TTGTT  Length: 5
  Tokens: TTGTT
Embedding shape: (5, 768)
Chunk ID: 7
Sequence: GGCTCGG  Length: 7
  Tokens: GGCTCGG
Embedding shape: (7, 768)
Chunk ID: 8
Sequence: GTAGATCTGGAC  Length: 12
  Tokens: GTAGATCTGGAC
Embedding shape: (12, 768)
Chunk ID: 9
Sequence: TCATATCTAAAA  Length: 12
  Tokens: TCATATCTAAAA
Embedding shape: (12, 768)
