# 🧬 Test new dna_embedding_model.py (LoRA-ready) locally

In [4]:
!pip install transformers torch

import torch
from transformers import AutoTokenizer, AutoModel
import numpy as np
import time

class DNAEmbedder:
    def __init__(self, model_id="armheb/DNA_bert_6", k=6, device=None):
        self.model_id = model_id
        self.k = k
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")

        print(f"🧠 Loading model {model_id} on {self.device}...")
        self.tokenizer = AutoTokenizer.from_pretrained(model_id)
        self.model = AutoModel.from_pretrained(model_id).to(self.device)
        self.model.eval()

    def tokenize(self, sequence):
        # Convert DNA to overlapping k-mers
        sequence = sequence.upper().replace(" ", "")
        tokens = [sequence[i:i+self.k] for i in range(len(sequence)-self.k+1)]
        return " ".join(tokens)

    def embed(self, sequence):
        input_text = self.tokenize(sequence)
        inputs = self.tokenizer(input_text, return_tensors="pt")
        inputs = {k: v.to(self.device) for k, v in inputs.items()}

        with torch.no_grad():
            start = time.time()
            output = self.model(**inputs)
            duration = time.time() - start
            print(f"⚡ Embedding computed in {duration:.2f}s")

        # Average token embeddings → 768-dim vector
        embedding = output.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()
        return embedding

    def embed_batch(self, sequences):
        vectors = [self.embed(seq) for seq in sequences]
        return np.vstack(vectors)



In [5]:
embedder = DNAEmbedder()  # defaults to armheb/DNA_bert_6

# Example DNA sequence
sequence = "ACGTAGCTAGCTTGACGTTGACGTGACGATCGTACG"

# Get embedding
embedding = embedder.embed(sequence)

print("✅ Embedding shape:", embedding.shape)
print("🧬 First 10 values:", embedding[:10])

🧠 Loading model armheb/DNA_bert_6 on cpu...
⚡ Embedding computed in 0.07s
✅ Embedding shape: (768,)
🧬 First 10 values: [-0.2573953   0.8604346  -0.33829284 -0.26517555  0.6124543  -0.14811641
  0.2850135  -0.5794153  -0.60324174  0.52297175]
