In [None]:
# Colab Only
!git clone https://github.com/navidh86/perturbseq-10701.git
%cd ./perturbseq-10701
!pip install fastparquet tqdm


In [None]:
# Colab Only
!pip install --upgrade git+https://github.com/huggingface/transformers.git

In [None]:
import os
import pickle
from tqdm import tqdm
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForMaskedLM

device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
class AttentionPool(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.att = nn.Linear(dim, 1)

    def forward(self, token_embs, mask):
        # token_embs: (L, D)
        scores = self.att(token_embs).squeeze(-1)     # (L)
        scores = scores.masked_fill(mask == 0, -1e9)  # ignore PAD tokens
        weights = torch.softmax(scores, dim=0).unsqueeze(-1)
        return (weights * token_embs).sum(0)

class NTEncoderCLS(nn.Module):
    def __init__(self, model_name="InstaDeepAI/nucleotide-transformer-500m-human-ref", device="cuda"):
        super().__init__()
        self.device = device

        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForMaskedLM.from_pretrained(model_name).to(device)
        self.model.eval()

        self.max_len = self.tokenizer.model_max_length

    @torch.no_grad()
    def forward(self, seq: str):
        seq = seq.upper().replace("U", "T")
        embeds = []

        chunks = [seq[i:i+self.max_len] for i in range(0, len(seq), self.max_len)]

        for chunk in chunks:
            tokens = self.tokenizer(
                [chunk],
                return_tensors="pt",
                padding="max_length",
                max_length=self.max_len,
                truncation=True
            ).to(self.device)

            outputs = self.model(
                tokens["input_ids"],
                attention_mask=tokens["attention_mask"],
                output_hidden_states=True
            )

            hidden = outputs.hidden_states[-1].to(self.device)  # (1, L, D)

            # CLS TOKEN (position 0)
            cls_vec = hidden[:, 0, :].squeeze(0).to(self.device)

            embeds.append(cls_vec)

        return torch.stack(embeds).mean(0)

class NTEncoderAttention(nn.Module):
    def __init__(self, model_name="InstaDeepAI/nucleotide-transformer-500m-human-ref", device="cuda"):
        super().__init__()
        self.device = device

        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForMaskedLM.from_pretrained(model_name).to(device)
        self.model.eval()

        self.pool = AttentionPool(1280).to(device)
        self.max_len = self.tokenizer.model_max_length

    @torch.no_grad()
    def forward(self, seq: str):
        seq = seq.upper().replace("U", "T")
        embeds = []

        chunks = [seq[i:i+self.max_len] for i in range(0, len(seq), self.max_len)]

        for chunk in chunks:
            tokens = self.tokenizer(
                [chunk],
                return_tensors="pt",
                padding="max_length",
                max_length=self.max_len,
                truncation=True
            ).to(self.device)

            outputs = self.model(
                tokens["input_ids"],
                attention_mask=tokens["attention_mask"],
                output_hidden_states=True
            )

            hidden = outputs.hidden_states[-1].squeeze(0)  # (L, 1280)
            mask = tokens["attention_mask"].squeeze(0)     # (L)

            att_vec = self.pool(hidden, mask).to(self.device)

            embeds.append(att_vec)

        return torch.stack(embeds).mean(0)

In [None]:
def ensure_dir(path):
    if path != "" and not os.path.exists(path):
        os.makedirs(path, exist_ok=True)

def cache_tf_embeddings(encoder, tf_seq_dict, save_path="../embeds/tf_cls.pkl"):
    ensure_dir(os.path.dirname(save_path))
    cache = {}
    print("Caching TF embeddings...")
    for tf_name, seq in tqdm(tf_seq_dict.items()):
        emb = encoder(seq)
        if hasattr(emb, "cpu"):
            emb = emb.cpu()
        cache[tf_name] = emb

    with open(save_path, "wb") as f:
        pickle.dump(cache, f)

    print(f"Saved TF embedding cache to: {os.path.abspath(save_path)}")

def cache_gene_embeddings(encoder, gene_seq_dict, save_path="../embeds/gn_cls.pkl"):
    ensure_dir(os.path.dirname(save_path))
    cache = {}
    print("Caching gene embeddings...")
    for gene_name, seq in tqdm(gene_seq_dict.items()):
        emb = encoder(seq)
        if hasattr(emb, "cpu"):
            emb = emb.cpu()
        cache[gene_name] = emb

    with open(save_path, "wb") as f:
        pickle.dump(cache, f)

    print(f"Saved Gene embedding cache to: {os.path.abspath(save_path)}")

In [None]:
# Load sequence dictionaries (same format as your dataloader expects)
tf_seq_dict = pickle.load(open("data/tf_sequences.pkl", "rb"))
gene_seq_dict = pickle.load(open("data/gene_sequences_4000bp.pkl", "rb"))

# Initialize encoder (choose CLS or Attention)
encoder_cls = NTEncoderCLS(device=device)
# encoder_att = NTEncoderAttention(device=device)

# Cache embeddings
cache_tf_embeddings(encoder_cls, tf_seq_dict, save_path="../embeds/tf_cls.pkl")
cache_gene_embeddings(encoder_cls, gene_seq_dict, save_path="../embeds/gn_cls.pkl")