<a href="https://colab.research.google.com/github/navidh86/perturbseq-10701/blob/master/nt_main.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [19]:
import pandas as pd
import numpy as np

In [20]:
# ONLY FOR COLAB
!git clone https://github.com/navidh86/perturbseq-10701.git
%cd ./perturbseq-10701
!pip install fastparquet tqdm


Cloning into 'perturbseq-10701'...
remote: Enumerating objects: 79, done.[K
remote: Counting objects: 100% (70/70), done.[K
remote: Compressing objects: 100% (61/61), done.[K
remote: Total 79 (delta 26), reused 35 (delta 7), pack-reused 9 (from 1)[K
Receiving objects: 100% (79/79), 91.05 MiB | 14.79 MiB/s, done.
Resolving deltas: 100% (27/27), done.
/content/perturbseq-10701/perturbseq-10701


In [21]:
!pip install --upgrade git+https://github.com/huggingface/transformers.git

Collecting git+https://github.com/huggingface/transformers.git
  Cloning https://github.com/huggingface/transformers.git to /tmp/pip-req-build-rcrxdg76
  Running command git clone --filter=blob:none --quiet https://github.com/huggingface/transformers.git /tmp/pip-req-build-rcrxdg76
  Resolved https://github.com/huggingface/transformers.git to commit d08b98b965176ea9cf8c8e8b24995c955b7e2ec9
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone


In [22]:
from reference_data_alternate import (
    PairPerturbSeqDataset,
    perturbseq_collate,
    get_dataloader
)

train_loader = get_dataloader(type="train", batch_size=4)   # small batch for NT
test_loader  = get_dataloader(type="test", batch_size=4)

print("Train size:", len(train_loader))
print("Test size: ", len(test_loader))


Train size: 236692
Test size:  59174


In [23]:
import numpy as np
from torch.utils.data import DataLoader

train_dataset = PairPerturbSeqDataset(type="train")
loader = DataLoader(train_dataset, batch_size=512, collate_fn=perturbseq_collate)

all_y = []
for _, y in loader:
    all_y.extend(y.numpy())

all_y = np.array(all_y)
mu = all_y.mean()
sigma = all_y.std()

print("mu =", mu)
print("sigma =", sigma)


mu = -0.022736955
sigma = 0.15207928


In [24]:
def weighted_mse_loss(pred, target, mu, sigma, alpha=3.0, threshold=1.0):
    z = (target - mu) / sigma
    weights = torch.where(
        torch.abs(z) > threshold,
        torch.tensor(alpha, device=target.device),
        torch.tensor(1.0, device=target.device)
    )
    mse = (pred - target)**2
    return (weights * mse).sum() / weights.sum()


In [39]:
!rm -rf /root/.cache/huggingface/transformers
!rm -rf /root/.cache/huggingface/hub

!pip install --upgrade git+https://github.com/huggingface/transformers.git


Collecting git+https://github.com/huggingface/transformers.git
  Cloning https://github.com/huggingface/transformers.git to /tmp/pip-req-build-o7u0femr
  Running command git clone --filter=blob:none --quiet https://github.com/huggingface/transformers.git /tmp/pip-req-build-o7u0femr
  Resolved https://github.com/huggingface/transformers.git to commit d08b98b965176ea9cf8c8e8b24995c955b7e2ec9
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone


In [40]:
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained("InstaDeepAI/nucleotide-transformer-500m-human-ref")
print(type(tok))
print(tok.__class__.__name__)
print("batch_encode_plus" in dir(tok))


tokenizer_config.json:   0%|          | 0.00/129 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/101 [00:00<?, ?B/s]

<class 'transformers.models.esm.tokenization_esm.EsmTokenizer'>
EsmTokenizer
False


In [45]:
# NT ENCODER
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForMaskedLM

class NTEncoder(nn.Module):
    def __init__(self, model_name="InstaDeepAI/nucleotide-transformer-500m-human-ref", device="cuda"):
        super().__init__()
        self.device = device

        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForMaskedLM.from_pretrained(model_name).to(device)
        self.model.eval()

        self.max_len = self.tokenizer.model_max_length

    @torch.no_grad()
    def forward(self, seq: str):
        seq = seq.upper().replace("U", "T")

        # chunk the sequence
        chunks = [seq[i:i+self.max_len] for i in range(0, len(seq), self.max_len)]
        chunk_embs = []

        for chunk in chunks:
            # Works with all HF models including ESM
            tokens = self.tokenizer(
                [chunk],
                return_tensors="pt",
                padding="max_length",
                max_length=self.max_len,
                truncation=True
            ).to(self.device)

            input_ids = tokens["input_ids"]
            attention_mask = tokens["attention_mask"]

            out = self.model(
                input_ids,
                attention_mask=attention_mask,
                encoder_attention_mask=attention_mask,
                output_hidden_states=True
            )

            # ⭐ FIX: use hidden_states correctly
            hidden = out.hidden_states[-1].squeeze(0)   # (L, D)
            attn = attention_mask.squeeze(0).unsqueeze(-1)  # (L, 1)

            # mean over non-pad tokens
            embed = (hidden * attn).sum(0) / attn.sum()
            chunk_embs.append(embed)

        return torch.stack(chunk_embs).mean(0)


In [49]:
# Biencoder model
class NTBiEncoder(nn.Module):
    def __init__(self, encoder, emb_dim=1280):
        super().__init__()
        self.encoder = encoder
        self.mlp = nn.Sequential(
            nn.Linear(emb_dim * 2, 512),   # 1280*2 = 2560
            nn.ReLU(),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )

    @torch.no_grad()
    def encode(self, seq):
        return self.encoder(seq)

    def forward(self, tf_seqs, gene_seqs):
        tf_embs = []
        gene_embs = []
        for tf, gene in zip(tf_seqs, gene_seqs):
            tf_embs.append(self.encoder(tf))
            gene_embs.append(self.encoder(gene))

        tf_embs = torch.stack(tf_embs)
        gene_embs = torch.stack(gene_embs)
        h = torch.cat([tf_embs, gene_embs], dim=-1)
        return self.mlp(h).squeeze(-1)



In [54]:
# import pickle
# from tqdm import tqdm

# def cache_tf_embeddings(encoder, tf_seq_dict, save_path="tf_embed_cache.pkl"):
#     cache = {}
#     print("Caching TF embeddings...")
#     for tf_name, seq in tqdm(tf_seq_dict.items()):
#         cache[tf_name] = encoder(seq).cpu()
#     pickle.dump(cache, open(save_path, "wb"))
#     print("Saved TF embedding cache to", save_path)


# def cache_gene_embeddings(encoder, gene_seq_dict, save_path="gene_embed_cache.pkl"):
#     cache = {}
#     print("Caching Gene embeddings...")
#     for gene_name, seq in tqdm(gene_seq_dict.items()):
#         cache[gene_name] = encoder(seq).cpu()
#     pickle.dump(cache, open(save_path, "wb"))
#     print("Saved Gene embedding cache to", save_path)


import os
import pickle
from tqdm import tqdm

def ensure_dir(path):
    if path != "" and not os.path.exists(path):
        os.makedirs(path, exist_ok=True)

def cache_tf_embeddings(encoder, tf_seq_dict, save_path="./embeds/tf_embed_cache.pkl"):
    # FIX: ensure the directory of save_path exists
    ensure_dir(os.path.dirname(save_path))

    cache = {}
    print("Caching TF embeddings...")

    for tf_name, seq in tqdm(tf_seq_dict.items()):
        emb = encoder(seq)
        if hasattr(emb, "cpu"):
            emb = emb.cpu()
        cache[tf_name] = emb

    with open(save_path, "wb") as f:
        pickle.dump(cache, f)

    print(f"Saved TF embedding cache to: {os.path.abspath(save_path)}")


def cache_gene_embeddings(encoder, gene_seq_dict, save_path="./embeds/gene_embed_cache.pkl"):
    # FIX: ensure the directory exists
    ensure_dir(os.path.dirname(save_path))

    cache = {}
    print("Caching gene embeddings...")

    for gene_name, seq in tqdm(gene_seq_dict.items()):
        emb = encoder(seq)
        if hasattr(emb, "cpu"):
            emb = emb.cpu()
        cache[gene_name] = emb

    with open(save_path, "wb") as f:
        pickle.dump(cache, f)

    print(f"Saved Gene embedding cache to: {os.path.abspath(save_path)}")


In [55]:
# Load your sequence dictionaries (the same ones the dataloader uses)
tf_seq_dict = pickle.load(open("tf_sequences.pkl", "rb"))
gene_seq_dict = pickle.load(open("gene_sequences_4000bp.pkl", "rb"))

# Initialize encoder
encoder = NTEncoder(device=device)

# Cache embeddings (takes 3–20 minutes total depending on sizes)
cache_tf_embeddings(encoder, tf_seq_dict)
cache_gene_embeddings(encoder, gene_seq_dict)


Loading weights:   0%|          | 0/396 [00:00<?, ?it/s]

EsmForMaskedLM LOAD REPORT from: InstaDeepAI/nucleotide-transformer-500m-human-ref
Key                         | Status     |  | 
----------------------------+------------+--+-
esm.embeddings.position_ids | UNEXPECTED |  | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.


Caching TF embeddings...


100%|██████████| 223/223 [06:46<00:00,  1.83s/it]


Saved TF embedding cache to tf_embed_cache.pkl
Caching Gene embeddings...


100%|██████████| 5307/5307 [1:46:48<00:00,  1.21s/it]


Saved Gene embedding cache to gene_embed_cache.pkl


In [None]:
import shutil
from google.colab import files

# Path to your embedding directory
embed_dir = "./embeds"

# Output zip file name
zip_name = "nt_embedding_caches.zip"

# Create zip
shutil.make_archive("nt_embedding_caches", 'zip', embed_dir)

# Download zip
files.download(zip_name)


In [56]:
class CachedEmbeddingDataset(torch.utils.data.Dataset):
    def __init__(self, parquet_path, tf_cache_path, gene_cache_path, type="train", train_fraction=0.8, seed=10701):
        df = pd.read_parquet(parquet_path)

        # load caches
        self.tf_cache = pickle.load(open(tf_cache_path, "rb"))
        self.gene_cache = pickle.load(open(gene_cache_path, "rb"))

        # remove entries missing from cache
        df = df[df["tf_name"].isin(self.tf_cache.keys())]
        df = df[df["gene_name"].isin(self.gene_cache.keys())]

        # shuffle + split
        df = df.sample(frac=1.0, random_state=seed)
        n = int(train_fraction * len(df))
        if type == "train":
            self.df = df.iloc[:n].reset_index(drop=True)
        else:
            self.df = df.iloc[n:].reset_index(drop=True)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        tf_emb = self.tf_cache[row["tf_name"]]     # tensor
        gene_emb = self.gene_cache[row["gene_name"]]
        y = torch.tensor(row["expression"], dtype=torch.float32)
        return tf_emb, gene_emb, y


In [57]:
def get_cached_loader(type="train", batch_size=32):
    ds = CachedEmbeddingDataset(
        parquet_path="tf_gene_expression.parquet",
        tf_cache_path="tf_embed_cache.pkl",
        gene_cache_path="gene_embed_cache.pkl",
        type=type
    )
    return DataLoader(ds, batch_size=batch_size, shuffle=(type=="train"))


In [58]:
class NTBiEncoderFast(nn.Module):
    def __init__(self, emb_dim=1280):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(emb_dim * 2, 512),
            nn.ReLU(),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, tf_embs, gene_embs):
        h = torch.cat([tf_embs, gene_embs], dim=-1)
        return self.mlp(h).squeeze(-1)


In [59]:
# Updated Training Loop
from tqdm import tqdm

def train_one_epoch_cached_nt(model, loader, optimizer, mu, sigma, device="cuda"):
    model.train()
    total_loss, N = 0.0, 0
    pbar = tqdm(loader)

    for tf_emb, gene_emb, y in pbar:
        tf_emb = tf_emb.to(device)
        gene_emb = gene_emb.to(device)
        y = y.to(device)

        preds = model(tf_emb, gene_emb)
        loss = weighted_mse_loss(preds, y, mu, sigma)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * len(y)
        N += len(y)
        pbar.set_postfix({"loss": loss.item()})

    return total_loss / N




In [60]:
device = "cuda" if torch.cuda.is_available() else "cpu"

train_loader = get_cached_loader(type="train", batch_size=64)
test_loader = get_cached_loader(type="test", batch_size=64)

model_nt = NTBiEncoderFast(emb_dim=1280).to(device)
optimizer = torch.optim.Adam(model_nt.parameters(), lr=1e-4)

for epoch in range(5):
    loss = train_one_epoch_cached_nt(model_nt, train_loader, optimizer, mu, sigma)
    print("Epoch", epoch, "Loss", loss)


100%|██████████| 14794/14794 [01:54<00:00, 129.64it/s, loss=0.0159]


Epoch 0 Loss 0.04461956790330804


100%|██████████| 14794/14794 [01:52<00:00, 130.96it/s, loss=0.0485]


Epoch 1 Loss 0.04406978431607792


100%|██████████| 14794/14794 [01:52<00:00, 131.34it/s, loss=0.0138]


Epoch 2 Loss 0.04387991112987205


100%|██████████| 14794/14794 [01:52<00:00, 131.66it/s, loss=0.085]


Epoch 3 Loss 0.04357769544650502


100%|██████████| 14794/14794 [01:53<00:00, 130.46it/s, loss=0.0337]

Epoch 4 Loss 0.04340518663419025





In [None]:
def evaluate_nt_cached(model, loader, mu, sigma, device="cuda"):
    model.eval()
    preds_all, y_all = [], []

    with torch.no_grad():
        for tf_emb, gene_emb, y in loader:
            tf_emb = tf_emb.to(device)
            gene_emb = gene_emb.to(device)
            y = y.to(device)

            preds = model(tf_emb, gene_emb)

            preds_all.append(preds.cpu())
            y_all.append(y.cpu())

    preds_all = torch.cat(preds_all)
    y_all = torch.cat(y_all)

    mse = ((preds_all - y_all)**2).mean().item()
    corr = torch.corrcoef(torch.stack([preds_all, y_all]))[0, 1].item()

    # Large-effect subset
    z = (y_all - mu) / sigma
    mask = torch.abs(z) > 1.0

    if mask.sum() > 0:
        mse_big = ((preds_all[mask] - y_all[mask])**2).mean().item()
        corr_big = torch.corrcoef(torch.stack([preds_all[mask], y_all[mask]]))[0, 1].item()
    else:
        mse_big, corr_big = None, None

    return mse, corr, mse_big, corr_big


In [None]:
mse, corr, mse_big, corr_big = evaluate_nt_cached(model_nt, test_loader, mu, sigma, device=device)

print("=== Evaluation Results ===")
print(f"Test MSE:          {mse:.6f}")
print(f"Test Corr:         {corr:.4f}")
print(f"Big-Effect MSE:    {mse_big:.6f}")
print(f"Big-Effect Corr:   {corr_big:.4f}")


In [None]:
torch.save(model_nt.state_dict(), "nt_mlp_model.pt")
print("Model saved as nt_mlp_model.pt")


In [None]:
from google.colab import files
files.download("nt_mlp_model.pt")
