In [None]:
from bend.models.hyena_dna import HyenaDNAPreTrainedModel, CharacterTokenizer
import os
from utils import generate_random_dna_sequence, get_device, chunkify_sequences, process_chunk_embeddings
import torch
import numpy as np

WORK_PATH = '../../'
EMBEDDER_DIR = os.path.join(WORK_PATH, 'pretrained_models', 'hyenadna')
EMBEDDER_NAME = 'hyenadna-tiny-1k-seqlen'
EMBEDDER_PATH = os.path.join(EMBEDDER_DIR, EMBEDDER_NAME)

PADDING_VALUE = -100

MAX_LENGTHS = {
    "hyenadna-tiny-1k-seqlen": 1024,
    "hyenadna-small-32k-seqlen": 32768,
    "hyenadna-medium-160k-seqlen": 160000,
    "hyenadna-medium-450k-seqlen": 450000,
    "hyenadna-large-1m-seqlen": 1_000_000,
}

device = get_device()

  torch.utils._pytree._register_pytree_node(
  torch.utils._pytree._register_pytree_node(


Using device: mps


In [2]:
tokenizer = CharacterTokenizer(
    characters=["A", "C", "G", "T", "N"],  # add DNA characters, N is uncertain
    model_max_length=MAX_LENGTHS[EMBEDDER_NAME]
    + 2,  # to account for special tokens, like EOS
    add_special_tokens=False,  # we handle special tokens elsewhere
    padding_side="left",  # since HyenaDNA is causal, we pad on the left
)

model = HyenaDNAPreTrainedModel.from_pretrained(
    os.path.join(EMBEDDER_DIR, EMBEDDER_NAME),
    EMBEDDER_NAME,
    download=not os.path.exists(EMBEDDER_DIR),
    config=None,
    device=device,
    use_head=False,
    use_lm_head=False,  # we don't use the LM head for embeddings
    n_classes=2,
).eval().to(device)

Loaded pretrained weights ok!


In [3]:
sequences = [generate_random_dna_sequence(min_length=5, max_length=15) for _ in range(10)]
sequences

['CACCGAACAGCGGCG',
 'CCTGTTAACGT',
 'TATGTAGG',
 'ACACCTCACCCCG',
 'TCCATCAAAACGCAG',
 'AGCACT',
 'ACGGTCCAGAACCGC',
 'CGTAA',
 'AGCTTGGGGC',
 'GTTAGCCAT']

#### Divide into chunks

In [4]:
MAX_MODEL_LENGTH = 4
chunked_sequences, chunk_ids = chunkify_sequences(sequences, MAX_MODEL_LENGTH)

for seq, chunk_id in zip(chunked_sequences, chunk_ids):
    print(f"Chunk ID: {chunk_id}, Sequence: {seq}")

Chunk ID: 0, Sequence: CACC
Chunk ID: 0, Sequence: GAAC
Chunk ID: 0, Sequence: AGCG
Chunk ID: 0, Sequence: GCG
Chunk ID: 1, Sequence: CCTG
Chunk ID: 1, Sequence: TTAA
Chunk ID: 1, Sequence: CGT
Chunk ID: 2, Sequence: TATG
Chunk ID: 2, Sequence: TAGG
Chunk ID: 3, Sequence: ACAC
Chunk ID: 3, Sequence: CTCA
Chunk ID: 3, Sequence: CCCC
Chunk ID: 3, Sequence: G
Chunk ID: 4, Sequence: TCCA
Chunk ID: 4, Sequence: TCAA
Chunk ID: 4, Sequence: AACG
Chunk ID: 4, Sequence: CAG
Chunk ID: 5, Sequence: AGCA
Chunk ID: 5, Sequence: CT
Chunk ID: 6, Sequence: ACGG
Chunk ID: 6, Sequence: TCCA
Chunk ID: 6, Sequence: GAAC
Chunk ID: 6, Sequence: CGC
Chunk ID: 7, Sequence: CGTA
Chunk ID: 7, Sequence: A
Chunk ID: 8, Sequence: AGCT
Chunk ID: 8, Sequence: TGGG
Chunk ID: 8, Sequence: GC
Chunk ID: 9, Sequence: GTTA
Chunk ID: 9, Sequence: GCCA
Chunk ID: 9, Sequence: T


#### Tokenise sequences

In [5]:
output = tokenizer(
    chunked_sequences,
    return_tensors="pt",
    return_token_type_ids=False,
    padding="longest",
)

input_ids = output["input_ids"]
attention_mask = output["attention_mask"]

input_ids

tensor([[ 0,  8,  7,  8,  8,  1],
        [ 0,  9,  7,  7,  8,  1],
        [ 0,  7,  9,  8,  9,  1],
        [ 4,  0,  9,  8,  9,  1],
        [ 0,  8,  8, 10,  9,  1],
        [ 0, 10, 10,  7,  7,  1],
        [ 4,  0,  8,  9, 10,  1],
        [ 0, 10,  7, 10,  9,  1],
        [ 0, 10,  7,  9,  9,  1],
        [ 0,  7,  8,  7,  8,  1],
        [ 0,  8, 10,  8,  7,  1],
        [ 0,  8,  8,  8,  8,  1],
        [ 4,  4,  4,  0,  9,  1],
        [ 0, 10,  8,  8,  7,  1],
        [ 0, 10,  8,  7,  7,  1],
        [ 0,  7,  7,  8,  9,  1],
        [ 4,  0,  8,  7,  9,  1],
        [ 0,  7,  9,  8,  7,  1],
        [ 4,  4,  0,  8, 10,  1],
        [ 0,  7,  8,  9,  9,  1],
        [ 0, 10,  8,  8,  7,  1],
        [ 0,  9,  7,  7,  8,  1],
        [ 4,  0,  8,  9,  8,  1],
        [ 0,  8,  9, 10,  7,  1],
        [ 4,  4,  4,  0,  7,  1],
        [ 0,  7,  9,  8, 10,  1],
        [ 0, 10,  9,  9,  9,  1],
        [ 4,  4,  0,  9,  8,  1],
        [ 0,  9, 10, 10,  7,  1],
        [ 0,  

In [6]:
for ids in input_ids:
    print(tokenizer.convert_ids_to_tokens(ids, skip_special_tokens=False))

['[CLS]', 'C', 'A', 'C', 'C', '[SEP]']
['[CLS]', 'G', 'A', 'A', 'C', '[SEP]']
['[CLS]', 'A', 'G', 'C', 'G', '[SEP]']
['[PAD]', '[CLS]', 'G', 'C', 'G', '[SEP]']
['[CLS]', 'C', 'C', 'T', 'G', '[SEP]']
['[CLS]', 'T', 'T', 'A', 'A', '[SEP]']
['[PAD]', '[CLS]', 'C', 'G', 'T', '[SEP]']
['[CLS]', 'T', 'A', 'T', 'G', '[SEP]']
['[CLS]', 'T', 'A', 'G', 'G', '[SEP]']
['[CLS]', 'A', 'C', 'A', 'C', '[SEP]']
['[CLS]', 'C', 'T', 'C', 'A', '[SEP]']
['[CLS]', 'C', 'C', 'C', 'C', '[SEP]']
['[PAD]', '[PAD]', '[PAD]', '[CLS]', 'G', '[SEP]']
['[CLS]', 'T', 'C', 'C', 'A', '[SEP]']
['[CLS]', 'T', 'C', 'A', 'A', '[SEP]']
['[CLS]', 'A', 'A', 'C', 'G', '[SEP]']
['[PAD]', '[CLS]', 'C', 'A', 'G', '[SEP]']
['[CLS]', 'A', 'G', 'C', 'A', '[SEP]']
['[PAD]', '[PAD]', '[CLS]', 'C', 'T', '[SEP]']
['[CLS]', 'A', 'C', 'G', 'G', '[SEP]']
['[CLS]', 'T', 'C', 'C', 'A', '[SEP]']
['[CLS]', 'G', 'A', 'A', 'C', '[SEP]']
['[PAD]', '[CLS]', 'C', 'G', 'C', '[SEP]']
['[CLS]', 'C', 'G', 'T', 'A', '[SEP]']
['[PAD]', '[PAD]', '[PAD]', 

In [7]:
attention_mask

tensor([[1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1],
        [0, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1],
        [0, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1],
        [0, 0, 0, 1, 1, 1],
        [1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1],
        [0, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1],
        [0, 0, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1],
        [0, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1],
        [0, 0, 0, 1, 1, 1],
        [1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1],
        [0, 0, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1],
        [0, 0, 0, 1, 1, 1]])

#### Embed sequences

In [8]:
input_ids = torch.LongTensor(input_ids)

embeddings = model(input_ids=input_ids.to(device)).detach().cpu().numpy()
input_ids = input_ids.numpy()

embeddings.shape

(31, 6, 128)

In [9]:
embeddings[:, -1, -1]  # last token embeddings

array([-0.16824079,  0.04744038,  0.04514444,  0.19083378,  0.05306307,
       -0.19100481,  0.13585165,  0.49546304, -0.11520602,  0.35716903,
       -0.50841016, -0.03071012, -0.33436307, -0.3122338 , -0.5464077 ,
       -0.23307246, -0.33157188,  0.02428012,  0.2100084 , -0.37095368,
       -0.3122338 ,  0.04744038, -0.04458451,  0.42001677, -0.2036671 ,
        0.61932755, -0.2555642 , -0.07029043,  0.30389643, -0.2047453 ,
       -0.21036223], dtype=float32)

#### Post-process chunks into a list

In [18]:
sequence_embeddings, masked_tokens = process_chunk_embeddings(tokenizer, embeddings, input_ids, chunk_ids, upsample=False)

for i, seq_emb in enumerate(sequence_embeddings):
    assert sequences[i] == ''.join(masked_tokens[i]), f"Mismatch in sequence {i}: {sequences[i]} != {''.join(masked_tokens[i])}"
    print(f"Chunk ID: {np.unique(chunk_ids)[i]}")
    print(f"Sequence: {sequences[i]}  Length: {len(sequences[i])}")
    print(f"  Tokens: {''.join(masked_tokens[i])}")
    print(f"Embedding shape: {seq_emb.shape}")

Chunk ID: 0
Sequence: CACCGAACAGCGGCG  Length: 15
  Tokens: CACCGAACAGCGGCG
Embedding shape: (15, 128)
Chunk ID: 1
Sequence: CCTGTTAACGT  Length: 11
  Tokens: CCTGTTAACGT
Embedding shape: (11, 128)
Chunk ID: 2
Sequence: TATGTAGG  Length: 8
  Tokens: TATGTAGG
Embedding shape: (8, 128)
Chunk ID: 3
Sequence: ACACCTCACCCCG  Length: 13
  Tokens: ACACCTCACCCCG
Embedding shape: (13, 128)
Chunk ID: 4
Sequence: TCCATCAAAACGCAG  Length: 15
  Tokens: TCCATCAAAACGCAG
Embedding shape: (15, 128)
Chunk ID: 5
Sequence: AGCACT  Length: 6
  Tokens: AGCACT
Embedding shape: (6, 128)
Chunk ID: 6
Sequence: ACGGTCCAGAACCGC  Length: 15
  Tokens: ACGGTCCAGAACCGC
Embedding shape: (15, 128)
Chunk ID: 7
Sequence: CGTAA  Length: 5
  Tokens: CGTAA
Embedding shape: (5, 128)
Chunk ID: 8
Sequence: AGCTTGGGGC  Length: 10
  Tokens: AGCTTGGGGC
Embedding shape: (10, 128)
Chunk ID: 9
Sequence: GTTAGCCAT  Length: 9
  Tokens: GTTAGCCAT
Embedding shape: (9, 128)
