In [27]:
from bend.models.hyena_dna import HyenaDNAPreTrainedModel, CharacterTokenizer
import os
from utils import generate_random_dna_sequence, get_device, chunkify_sequences, process_chunk_embeddings
import torch
import numpy as np
from bend_batch.datasets import DataSupervised
from hydra import compose, initialize
from torch.utils.data import DataLoader


WORK_PATH = '../../'
EMBEDDER_DIR = os.path.join(WORK_PATH, 'pretrained_models', 'hyenadna')
EMBEDDER_NAME = 'hyenadna-tiny-1k-seqlen'
EMBEDDER_PATH = os.path.join(EMBEDDER_DIR, EMBEDDER_NAME)

PADDING_VALUE = -100

MAX_LENGTHS = {
    "hyenadna-tiny-1k-seqlen": 1024,
    "hyenadna-small-32k-seqlen": 32768,
    "hyenadna-medium-160k-seqlen": 160000,
    "hyenadna-medium-450k-seqlen": 450000,
    "hyenadna-large-1m-seqlen": 1_000_000,
}

device = get_device()

Using device: mps


In [2]:
tokenizer = CharacterTokenizer(
    characters=["A", "C", "G", "T", "N"],  # add DNA characters, N is uncertain
    model_max_length=MAX_LENGTHS[EMBEDDER_NAME]
    + 2,  # to account for special tokens, like EOS
    add_special_tokens=False,  # we handle special tokens elsewhere
    padding_side="left",  # since HyenaDNA is causal, we pad on the left
)

model = HyenaDNAPreTrainedModel.from_pretrained(
    os.path.join(EMBEDDER_DIR, EMBEDDER_NAME),
    EMBEDDER_NAME,
    download=not os.path.exists(EMBEDDER_DIR),
    config=None,
    device=device,
    use_head=False,
    use_lm_head=False,  # we don't use the LM head for embeddings
    n_classes=2,
).eval().to(device)

Loaded pretrained weights ok!


In [3]:
sequences = [generate_random_dna_sequence(min_length=5, max_length=15) for _ in range(10)]
sequences

['AGGCACTAAATCTTA',
 'GAGCATT',
 'TGATCTACCCC',
 'GCGGTGTATATTC',
 'GTGTA',
 'CGCCCTCCTCAGGAA',
 'TTAGTACGAT',
 'GGTCGCTTC',
 'CAAGAAAT',
 'CTAGCA']

#### Divide into chunks

In [25]:
MAX_MODEL_LENGTH = 4
chunked_sequences = []
chunk_ids = []

for seq_idx, seq in enumerate(sequences):
    chunked_sequence = [
        seq[i : i + MAX_MODEL_LENGTH]
        for i in range(0, len(seq), MAX_MODEL_LENGTH)
    ]
    chunked_sequences.extend(chunked_sequence)
    chunk_ids.extend([seq_idx] * len(chunked_sequence))

for seq, chunk_id in zip(chunked_sequences, chunk_ids):
    print(f"Chunk ID: {chunk_id}, Sequence: {seq}")

Chunk ID: 0, Sequence: AGGC
Chunk ID: 0, Sequence: ACTA
Chunk ID: 0, Sequence: AATC
Chunk ID: 0, Sequence: TTA
Chunk ID: 1, Sequence: GAGC
Chunk ID: 1, Sequence: ATT
Chunk ID: 2, Sequence: TGAT
Chunk ID: 2, Sequence: CTAC
Chunk ID: 2, Sequence: CCC
Chunk ID: 3, Sequence: GCGG
Chunk ID: 3, Sequence: TGTA
Chunk ID: 3, Sequence: TATT
Chunk ID: 3, Sequence: C
Chunk ID: 4, Sequence: GTGT
Chunk ID: 4, Sequence: A
Chunk ID: 5, Sequence: CGCC
Chunk ID: 5, Sequence: CTCC
Chunk ID: 5, Sequence: TCAG
Chunk ID: 5, Sequence: GAA
Chunk ID: 6, Sequence: TTAG
Chunk ID: 6, Sequence: TACG
Chunk ID: 6, Sequence: AT
Chunk ID: 7, Sequence: GGTC
Chunk ID: 7, Sequence: GCTT
Chunk ID: 7, Sequence: C
Chunk ID: 8, Sequence: CAAG
Chunk ID: 8, Sequence: AAAT
Chunk ID: 9, Sequence: CTAG
Chunk ID: 9, Sequence: CA


#### Tokenise sequences

In [11]:
input_ids = tokenizer(
    chunked_sequences,
    return_tensors="pt",
    return_token_type_ids=False,
    return_attention_mask=False,  # HyenaDNA does not use attention masks
    padding="longest",
)["input_ids"]

# input_ids = output["input_ids"]
# attention_mask = output["attention_mask"]

input_ids

tensor([[ 0,  7,  9,  9,  8,  1],
        [ 0,  7,  8, 10,  7,  1],
        [ 0,  7,  7, 10,  8,  1],
        [ 4,  0, 10, 10,  7,  1],
        [ 0,  9,  7,  9,  8,  1],
        [ 4,  0,  7, 10, 10,  1],
        [ 0, 10,  9,  7, 10,  1],
        [ 0,  8, 10,  7,  8,  1],
        [ 4,  0,  8,  8,  8,  1],
        [ 0,  9,  8,  9,  9,  1],
        [ 0, 10,  9, 10,  7,  1],
        [ 0, 10,  7, 10, 10,  1],
        [ 4,  4,  4,  0,  8,  1],
        [ 0,  9, 10,  9, 10,  1],
        [ 4,  4,  4,  0,  7,  1],
        [ 0,  8,  9,  8,  8,  1],
        [ 0,  8, 10,  8,  8,  1],
        [ 0, 10,  8,  7,  9,  1],
        [ 4,  0,  9,  7,  7,  1],
        [ 0, 10, 10,  7,  9,  1],
        [ 0, 10,  7,  8,  9,  1],
        [ 4,  4,  0,  7, 10,  1],
        [ 0,  9,  9, 10,  8,  1],
        [ 0,  9,  8, 10, 10,  1],
        [ 4,  4,  4,  0,  8,  1],
        [ 0,  8,  7,  7,  9,  1],
        [ 0,  7,  7,  7, 10,  1],
        [ 0,  8, 10,  7,  9,  1],
        [ 4,  4,  0,  8,  7,  1]])

In [12]:
for ids in input_ids:
    print(tokenizer.convert_ids_to_tokens(ids, skip_special_tokens=False))

['[CLS]', 'A', 'G', 'G', 'C', '[SEP]']
['[CLS]', 'A', 'C', 'T', 'A', '[SEP]']
['[CLS]', 'A', 'A', 'T', 'C', '[SEP]']
['[PAD]', '[CLS]', 'T', 'T', 'A', '[SEP]']
['[CLS]', 'G', 'A', 'G', 'C', '[SEP]']
['[PAD]', '[CLS]', 'A', 'T', 'T', '[SEP]']
['[CLS]', 'T', 'G', 'A', 'T', '[SEP]']
['[CLS]', 'C', 'T', 'A', 'C', '[SEP]']
['[PAD]', '[CLS]', 'C', 'C', 'C', '[SEP]']
['[CLS]', 'G', 'C', 'G', 'G', '[SEP]']
['[CLS]', 'T', 'G', 'T', 'A', '[SEP]']
['[CLS]', 'T', 'A', 'T', 'T', '[SEP]']
['[PAD]', '[PAD]', '[PAD]', '[CLS]', 'C', '[SEP]']
['[CLS]', 'G', 'T', 'G', 'T', '[SEP]']
['[PAD]', '[PAD]', '[PAD]', '[CLS]', 'A', '[SEP]']
['[CLS]', 'C', 'G', 'C', 'C', '[SEP]']
['[CLS]', 'C', 'T', 'C', 'C', '[SEP]']
['[CLS]', 'T', 'C', 'A', 'G', '[SEP]']
['[PAD]', '[CLS]', 'G', 'A', 'A', '[SEP]']
['[CLS]', 'T', 'T', 'A', 'G', '[SEP]']
['[CLS]', 'T', 'A', 'C', 'G', '[SEP]']
['[PAD]', '[PAD]', '[CLS]', 'A', 'T', '[SEP]']
['[CLS]', 'G', 'G', 'T', 'C', '[SEP]']
['[CLS]', 'G', 'C', 'T', 'T', '[SEP]']
['[PAD]', '[PAD]

In [8]:
attention_mask

tensor([[1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1],
        [0, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1],
        [0, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1],
        [0, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1],
        [0, 0, 0, 1, 1, 1],
        [1, 1, 1, 1, 1, 1],
        [0, 0, 0, 1, 1, 1],
        [1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1],
        [0, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1],
        [0, 0, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1],
        [0, 0, 0, 1, 1, 1],
        [1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1],
        [0, 0, 1, 1, 1, 1]])

#### Embed sequences

In [13]:
input_ids = torch.LongTensor(input_ids)
embeddings = (
    model(input_ids=input_ids.to(device)).detach().cpu().numpy()
)
input_ids = input_ids.numpy()

embeddings.shape

(29, 6, 128)

In [9]:
embeddings[:, -1, -1]  # last token embeddings

array([-0.16824079,  0.04744038,  0.04514444,  0.19083378,  0.05306307,
       -0.19100481,  0.13585165,  0.49546304, -0.11520602,  0.35716903,
       -0.50841016, -0.03071012, -0.33436307, -0.3122338 , -0.5464077 ,
       -0.23307246, -0.33157188,  0.02428012,  0.2100084 , -0.37095368,
       -0.3122338 ,  0.04744038, -0.04458451,  0.42001677, -0.2036671 ,
        0.61932755, -0.2555642 , -0.07029043,  0.30389643, -0.2047453 ,
       -0.21036223], dtype=float32)

#### Post-process chunks into a list

In [14]:
masked_embeddings = []

for sequence_idx in np.unique(chunk_ids):

    mask_sequence = sequence_idx == chunk_ids
    concat_embeddings = np.concatenate(embeddings[mask_sequence], axis=0)
    concat_input_ids = np.concatenate(input_ids[mask_sequence], axis=0)

    # concat_embeddings = self._remove_special_tokens(
    #     concat_input_ids, concat_embeddings
    # )

    masked_embeddings.append(concat_embeddings)

In [16]:
for emb in masked_embeddings:
    print(emb.shape)

(24, 128)
(12, 128)
(18, 128)
(24, 128)
(12, 128)
(24, 128)
(18, 128)
(18, 128)
(12, 128)
(12, 128)


In [18]:
sequence_embeddings, masked_tokens = process_chunk_embeddings(tokenizer, embeddings, input_ids, chunk_ids, upsample=False)

for i, seq_emb in enumerate(sequence_embeddings):
    assert sequences[i] == ''.join(masked_tokens[i]), f"Mismatch in sequence {i}: {sequences[i]} != {''.join(masked_tokens[i])}"
    print(f"Chunk ID: {np.unique(chunk_ids)[i]}")
    print(f"Sequence: {sequences[i]}  Length: {len(sequences[i])}")
    print(f"  Tokens: {''.join(masked_tokens[i])}")
    print(f"Embedding shape: {seq_emb.shape}")

Chunk ID: 0
Sequence: CACCGAACAGCGGCG  Length: 15
  Tokens: CACCGAACAGCGGCG
Embedding shape: (15, 128)
Chunk ID: 1
Sequence: CCTGTTAACGT  Length: 11
  Tokens: CCTGTTAACGT
Embedding shape: (11, 128)
Chunk ID: 2
Sequence: TATGTAGG  Length: 8
  Tokens: TATGTAGG
Embedding shape: (8, 128)
Chunk ID: 3
Sequence: ACACCTCACCCCG  Length: 13
  Tokens: ACACCTCACCCCG
Embedding shape: (13, 128)
Chunk ID: 4
Sequence: TCCATCAAAACGCAG  Length: 15
  Tokens: TCCATCAAAACGCAG
Embedding shape: (15, 128)
Chunk ID: 5
Sequence: AGCACT  Length: 6
  Tokens: AGCACT
Embedding shape: (6, 128)
Chunk ID: 6
Sequence: ACGGTCCAGAACCGC  Length: 15
  Tokens: ACGGTCCAGAACCGC
Embedding shape: (15, 128)
Chunk ID: 7
Sequence: CGTAA  Length: 5
  Tokens: CGTAA
Embedding shape: (5, 128)
Chunk ID: 8
Sequence: AGCTTGGGGC  Length: 10
  Tokens: AGCTTGGGGC
Embedding shape: (10, 128)
Chunk ID: 9
Sequence: GTTAGCCAT  Length: 9
  Tokens: GTTAGCCAT
Embedding shape: (9, 128)


# Real dataset

In [None]:
task = "gene_finding"  # or any other task you want to test

with initialize(version_base=None, config_path="../../config/"):
    cfg_batch = compose(config_name="config", overrides=[f"tasks@task={task}"])

dataset = DataSupervised(
    '../../data/gene_finding/gene_finding.bed',
    '../../data/genomes/GRCh38.primary_assembly.genome.fa',
    hdf5_path='../../data/gene_finding/gene_finding.hdf5',
    label_depth=cfg_batch["task"]["dataset"].get("label_depth", None),
    sequence_length=cfg_batch["task"]["dataset"].get("sequence_length", None),
    split='test',  # or 'train', 'valid' depending on your needs
)

dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)

Filtered annotations to 597
Filtered annotations from 5977 to 597




  0%|          | 0/597 [00:00<?, ?it/s]

In [None]:
MAX_MODEL_LENGTH = 1024  # Adjust this based on your model's max length

for _, (sequences, _) in enumerate(dataloader):
        print(f"Len sequence: {len(sequences[0])} bases")

        chunked_sequences = []
        chunk_ids = []

        for seq_idx, seq in enumerate(sequences):
            chunked_sequence = [
                seq[i : i + MAX_MODEL_LENGTH]
                for i in range(0, len(seq), MAX_MODEL_LENGTH)
            ]
            chunked_sequences.extend(chunked_sequence)
            chunk_ids.extend([seq_idx] * len(chunked_sequence))

        for seq, chunk_id in zip(chunked_sequences, chunk_ids):
            print(f"Chunk ID: {chunk_id}, Sequence: {seq} bases")

        input_ids = tokenizer(
            chunked_sequences,
            return_tensors="pt",
            return_token_type_ids=False,
            return_attention_mask=False,  # HyenaDNA does not use attention masks
            padding="longest",
        )["input_ids"]

        for ids in input_ids:
            print(tokenizer.convert_ids_to_tokens(ids, skip_special_tokens=False))

        break

Len sequence: 6567 bases
Chunk ID: 0, Sequence: CAAGGGGTCCCTACCCACCCCGCCCCTGGCTGGGGGCCGTAGGCTGGCCCTGCCCCGCTGCGCCGGAGGTGGACCTCTACCAGGGCAGTTTCTCTCTGAGGCTGCGCGCTAAGGCGGTGGGCGGTCCCAGGCAGGCCCAGAAGCTGGGCAGCCTCTGCCGGGTTCCGGGAAAAGGAGCTCCTGCTGCCACTGCTCTTCCGGAGCCTGCAGCATGGGGCCCCTGCCGCGCACCGTGGAGCTCTTCTATGACGTGCTGTCCCCCTACTCCTGGCTGGGCTTCGAGGTGACGCTGGGAGGGGTCGCCTCGGCAGTGTCTGGGGAGTGAGGGCGGAGGGAAGAGTGAGCGCAGGGCTCCGGGACAGAGGTCTCGTGTAACTCCTGGCGCCGCCCAAGGGGTTAAGGCAAGCAGGGAGAGCTCCGGGGCTGAAGGTCACTTTGTGCTTTTAAACGGAATAGAGTCGCTGGCTCCAACCCGAGCCTTTATATCCCGACTGCAGTTTCCCACGGTGGTGGAAAGAGGGGCGGCTCCAAACATTCAGGGAGAAATGCAGAACACGGCCACCTCTTAGCCCAACGATGGCAGTTTTGGGGAAAACTGGCCACAGGAGCGAAGATCCTGGAATAGATTTCTAGTAGTGAGGAGAGTTTTGCCAATTTCAAACCAAACAATGTAGGGTGTCCGCATAACCCCTGCTCTGCACTTAGCGCCTCCCTTCCCTTTCCCTCGGCTCTTTCACTTTTCCATCTTGCACACTGGCAATCGTTCTTGACCTCTGCCCGCAGATCCTGTGCCGGTATCAGAATATCTGGAACATCAACCTGCAGTTGCGGCCCAGCCTCATAACAGGGATCATGAAAGACAGTGGTAGGAAGGGAGGGTCGGGGCAGGGGTGATCTCAGTGGCCCAAGAGAGCCGACCCCAGCGGGTCCTTATATTGGCACTGGCAGCCAG