<a href="https://colab.research.google.com/github/navidh86/perturbseq-10701/blob/master/data/gen_embed.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Colab Only
!git clone https://github.com/navidh86/perturbseq-10701.git
%cd ./perturbseq-10701
!pip install fastparquet tqdm


Cloning into 'perturbseq-10701'...
remote: Enumerating objects: 160, done.[K
remote: Counting objects: 100% (39/39), done.[K
remote: Compressing objects: 100% (31/31), done.[K
remote: Total 160 (delta 17), reused 23 (delta 7), pack-reused 121 (from 2)[K
Receiving objects: 100% (160/160), 252.11 MiB | 15.09 MiB/s, done.
Resolving deltas: 100% (63/63), done.
Updating files: 100% (46/46), done.
/content/perturbseq-10701
Collecting fastparquet
  Downloading fastparquet-2024.11.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.2 kB)
Downloading fastparquet-2024.11.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m72.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: fastparquet
Successfully installed fastparquet-2024.11.0


In [2]:
# Colab Only
!pip install --upgrade git+https://github.com/huggingface/transformers.git

Collecting git+https://github.com/huggingface/transformers.git
  Cloning https://github.com/huggingface/transformers.git to /tmp/pip-req-build-p0he3ex0
  Running command git clone --filter=blob:none --quiet https://github.com/huggingface/transformers.git /tmp/pip-req-build-p0he3ex0
  Resolved https://github.com/huggingface/transformers.git to commit a5c061d24e31706c18670c0ae6579b5c72d545f3
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting huggingface-hub<2.0,>=1.0.0 (from transformers==5.0.0.dev0)
  Downloading huggingface_hub-1.1.7-py3-none-any.whl.metadata (13 kB)
Downloading huggingface_hub-1.1.7-py3-none-any.whl (516 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m516.2/516.2 kB[0m [31m36.5 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: transformers
  Building wheel for transformers (pyproject.toml) ... 

In [3]:
import os
import pickle
from tqdm import tqdm
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForMaskedLM

device = "cuda" if torch.cuda.is_available() else "cpu"

In [4]:
class AttentionPool(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.att = nn.Linear(dim, 1)

    def forward(self, token_embs, mask):
        # token_embs: (L, D)
        scores = self.att(token_embs).squeeze(-1)     # (L)
        scores = scores.masked_fill(mask == 0, -1e9)  # ignore PAD tokens
        weights = torch.softmax(scores, dim=0).unsqueeze(-1)
        return (weights * token_embs).sum(0)

class NTEncoder(nn.Module):
    def __init__(self, model_name="InstaDeepAI/nucleotide-transformer-500m-human-ref", device="cuda"):
        super().__init__()
        self.device = device

        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForMaskedLM.from_pretrained(model_name).to(device)
        self.model.eval()

        self.max_len = self.tokenizer.model_max_length

    @torch.no_grad()
    def forward(self, seq: str):
        seq = seq.upper().replace("U", "T")

        # chunk the sequence
        chunks = [seq[i:i+self.max_len] for i in range(0, len(seq), self.max_len)]
        chunk_embs = []

        for chunk in chunks:
            # Works with all HF models including ESM
            tokens = self.tokenizer(
                [chunk],
                return_tensors="pt",
                padding="max_length",
                max_length=self.max_len,
                truncation=True
            ).to(self.device)

            input_ids = tokens["input_ids"]
            attention_mask = tokens["attention_mask"]

            out = self.model(
                input_ids,
                attention_mask=attention_mask,
                encoder_attention_mask=attention_mask,
                output_hidden_states=True
            )

            # ⭐ FIX: use hidden_states correctly
            hidden = out.hidden_states[-1].squeeze(0)   # (L, D)
            attn = attention_mask.squeeze(0).unsqueeze(-1)  # (L, 1)

            # mean over non-pad tokens
            embed = (hidden * attn).sum(0) / attn.sum()
            chunk_embs.append(embed)

        return torch.stack(chunk_embs).mean(0)

class NTEncoderCLS(nn.Module):
    def __init__(self, model_name="InstaDeepAI/nucleotide-transformer-500m-human-ref", device="cuda"):
        super().__init__()
        self.device = device

        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForMaskedLM.from_pretrained(model_name).to(device)
        self.model.eval()

        self.max_len = self.tokenizer.model_max_length

    @torch.no_grad()
    def forward(self, seq: str):
        seq = seq.upper().replace("U", "T")
        embeds = []

        chunks = [seq[i:i+self.max_len] for i in range(0, len(seq), self.max_len)]

        for chunk in chunks:
            tokens = self.tokenizer(
                [chunk],
                return_tensors="pt",
                padding="max_length",
                max_length=self.max_len,
                truncation=True
            ).to(self.device)

            outputs = self.model(
                tokens["input_ids"],
                attention_mask=tokens["attention_mask"],
                output_hidden_states=True
            )

            hidden = outputs.hidden_states[-1].to(self.device)  # (1, L, D)

            # CLS TOKEN (position 0)
            cls_vec = hidden[:, 0, :].squeeze(0).to(self.device)

            embeds.append(cls_vec)

        return torch.stack(embeds).mean(0)

class NTEncoderAttention(nn.Module):
    def __init__(self, model_name="InstaDeepAI/nucleotide-transformer-500m-human-ref", device="cuda"):
        super().__init__()
        self.device = device

        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForMaskedLM.from_pretrained(model_name).to(device)
        self.model.eval()

        self.pool = AttentionPool(1280).to(device)
        self.max_len = self.tokenizer.model_max_length

    @torch.no_grad()
    def forward(self, seq: str):
        seq = seq.upper().replace("U", "T")
        embeds = []

        chunks = [seq[i:i+self.max_len] for i in range(0, len(seq), self.max_len)]

        for chunk in chunks:
            tokens = self.tokenizer(
                [chunk],
                return_tensors="pt",
                padding="max_length",
                max_length=self.max_len,
                truncation=True
            ).to(self.device)

            outputs = self.model(
                tokens["input_ids"],
                attention_mask=tokens["attention_mask"],
                output_hidden_states=True
            )

            hidden = outputs.hidden_states[-1].squeeze(0)  # (L, 1280)
            mask = tokens["attention_mask"].squeeze(0)     # (L)

            att_vec = self.pool(hidden, mask).to(self.device)

            embeds.append(att_vec)

        return torch.stack(embeds).mean(0)

In [5]:
def ensure_dir(path):
    if path != "" and not os.path.exists(path):
        os.makedirs(path, exist_ok=True)

def cache_tf_embeddings(encoder, tf_seq_dict, save_path="../embeds/tf_cls.pkl"):
    ensure_dir(os.path.dirname(save_path))
    cache = {}
    print("Caching TF embeddings...")
    for tf_name, seq in tqdm(tf_seq_dict.items()):
        emb = encoder(seq)
        if hasattr(emb, "cpu"):
            emb = emb.cpu()
        cache[tf_name] = emb

    with open(save_path, "wb") as f:
        pickle.dump(cache, f)

    print(f"Saved TF embedding cache to: {os.path.abspath(save_path)}")

def cache_gene_embeddings(encoder, gene_seq_dict, save_path="../embeds/gn_cls.pkl"):
    ensure_dir(os.path.dirname(save_path))
    cache = {}
    print("Caching gene embeddings...")
    for gene_name, seq in tqdm(gene_seq_dict.items()):
        emb = encoder(seq)
        if hasattr(emb, "cpu"):
            emb = emb.cpu()
        cache[gene_name] = emb

    with open(save_path, "wb") as f:
        pickle.dump(cache, f)

    print(f"Saved Gene embedding cache to: {os.path.abspath(save_path)}")

In [None]:
# Load sequence dictionaries (same format as your dataloader expects)
tf_seq_dict = pickle.load(open("data/tf_sequences.pkl", "rb"))
gene_seq_dict = pickle.load(open("data/gene_sequences_4000bp.pkl", "rb"))

# Initialize encoder (choose CLS or Attention)
encoder_cls = NTEncoderCLS(device=device)
# encoder_att = NTEncoderAttention(device=device)

# Cache embeddings
cache_tf_embeddings(encoder_cls, tf_seq_dict, save_path="../embeds/tf_cls.pkl")
cache_gene_embeddings(encoder_cls, gene_seq_dict, save_path="../embeds/gn_cls.pkl")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/129 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/101 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/706 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.94G [00:00<?, ?B/s]

Loading weights:   0%|          | 0/396 [00:00<?, ?it/s]

EsmForMaskedLM LOAD REPORT from: InstaDeepAI/nucleotide-transformer-500m-human-ref
Key                         | Status     |  | 
----------------------------+------------+--+-
esm.embeddings.position_ids | UNEXPECTED |  | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.


Caching TF embeddings...


100%|██████████| 223/223 [06:33<00:00,  1.77s/it]


Saved TF embedding cache to: /content/embeds/tf_cls_2.pkl
Caching gene embeddings...


  3%|▎         | 161/5307 [03:10<1:41:28,  1.18s/it]