<a href="https://colab.research.google.com/github/navidh86/perturbseq-10701/blob/master/nt_main.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
#!pip install enformer-pytorch

In [3]:
# ONLY FOR COLAB
# !git clone https://github.com/navidh86/perturbseq-10701.git
# %cd ./perturbseq-10701
# !pip install fastparquet tqdm


In [4]:
# !pip install --upgrade git+https://github.com/huggingface/transformers.git

In [5]:
import pandas as pd
import numpy as np
import pickle
import torch
import torch.nn as nn
# from transformers import AutoTokenizer, AutoModelForMaskedLM
from enformer_pytorch import Enformer, seq_indices_to_one_hot

import os
import pickle
from tqdm import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
device

'cuda'

In [7]:
# Get dataloader
from reference_data_alternate import (
    PairPerturbSeqDataset,
    perturbseq_collate,
    get_dataloader
)

train_loader = get_dataloader(type="train", batch_size=4)   # small batch for NT
test_loader  = get_dataloader(type="test", batch_size=4)

print("Train size:", len(train_loader))
print("Test size: ", len(test_loader))


Train size: 236692
Test size:  59174


In [8]:
# Calc summary stats
import numpy as np
from torch.utils.data import DataLoader

train_dataset = PairPerturbSeqDataset(type="train")
loader = DataLoader(train_dataset, batch_size=512, collate_fn=perturbseq_collate)

all_y = []
for _, y in loader:
    all_y.extend(y.numpy())

all_y = np.array(all_y)
mu = all_y.mean()
sigma = all_y.std()

print("mu =", mu)
print("sigma =", sigma)


mu = -0.022736948
sigma = 0.15207927


In [9]:
# define loss fucntion
def weighted_mse_loss(pred, target, mu, sigma, alpha=3.0, threshold=1.0):
    z = (target - mu) / sigma
    weights = torch.where(
        torch.abs(z) > threshold,
        torch.tensor(alpha, device=target.device),
        torch.tensor(1.0, device=target.device)
    )
    mse = (pred - target)**2
    return (weights * mse).sum() / weights.sum()


##NEW

In [11]:
# import torch
# import torch.nn as nn
# from transformers import AutoTokenizer, AutoModelForMaskedLM

class NTEncoderCLS(nn.Module):
    def __init__(self, model_name="InstaDeepAI/nucleotide-transformer-500m-human-ref", device="cuda"):
        super().__init__()
        self.device = device

        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForMaskedLM.from_pretrained(model_name).to(device)
        self.model.eval()

        self.max_len = self.tokenizer.model_max_length

    @torch.no_grad()
    def forward(self, seq: str):
        seq = seq.upper().replace("U", "T")
        embeds = []

        chunks = [seq[i:i+self.max_len] for i in range(0, len(seq), self.max_len)]

        for chunk in chunks:
            tokens = self.tokenizer(
                [chunk],
                return_tensors="pt",
                padding="max_length",
                max_length=self.max_len,
                truncation=True
            ).to(self.device)

            outputs = self.model(
                tokens["input_ids"],
                attention_mask=tokens["attention_mask"],
                output_hidden_states=True
            )

            hidden = outputs.hidden_states[-1].to(self.device)  # (1, L, 1280)

            # ⭐ CLS TOKEN (position 0)
            cls_vec = hidden[:, 0, :].squeeze(0).to(self.device)


            embeds.append(cls_vec)

        return torch.stack(embeds).mean(0)

class AttentionPool(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.att = nn.Linear(dim, 1)

    def forward(self, token_embs, mask):
        # token_embs: (L, D)
        scores = self.att(token_embs).squeeze(-1)     # (L)
        scores = scores.masked_fill(mask == 0, -1e9)  # ignore PAD tokens
        weights = torch.softmax(scores, dim=0).unsqueeze(-1)
        return (weights * token_embs).sum(0)


In [10]:
# Replace the encoder classes with Enformer encoder
class EnformerEncoder(nn.Module):
    def __init__(self, device="cuda"):
        super().__init__()
        self.device = device
        
        # Load pretrained Enformer
        self.model = Enformer.from_pretrained('EleutherAI/enformer-official-rough').to(device)
        self.model.eval()
        
        # Enformer expects sequences of length 196608 (divisible by 128)
        self.target_length = 196608
        
    def _prepare_sequence(self, seq: str):
        """Convert DNA sequence to one-hot encoded tensor."""
        seq = seq.upper().replace('U', 'T')
        
        # Pad or truncate to target length
        if len(seq) < self.target_length:
            # Pad with N (will be ignored)
            seq = seq + 'N' * (self.target_length - len(seq))
        else:
            # Take center region
            start = (len(seq) - self.target_length) // 2
            seq = seq[start:start + self.target_length]
        
        # Convert to indices: A=0, C=1, G=2, T=3, N=4
        seq_map = {'A': 0, 'C': 1, 'G': 2, 'T': 3, 'N': 4}
        indices = torch.tensor([seq_map.get(c, 4) for c in seq], dtype=torch.long)
        
        # One-hot encode (ignore N tokens)
        one_hot = seq_indices_to_one_hot(indices.unsqueeze(0))  # (1, target_length, 4)
        
        return one_hot.to(self.device)
    
    @torch.no_grad()
    def forward(self, seq: str):
        """Encode sequence using Enformer and return mean pooled embedding."""
        one_hot = self._prepare_sequence(seq)
        
        # Get Enformer output
        # Returns dict with 'human' and 'mouse' predictions
        # Each has shape (1, 896, 5313) - 896 bins, 5313 tracks
        output = self.model(one_hot)
        
        # Use human predictions and pool over bins
        human_pred = output['human']  # (1, 896, 5313)
        
        # Mean pool over spatial dimension (896 bins)
        embedding = human_pred.mean(dim=1).squeeze(0)  # (5313,)
        
        return embedding

In [12]:
class NTEncoderAttention(nn.Module):
    def __init__(self, model_name="InstaDeepAI/nucleotide-transformer-500m-human-ref", device="cuda"):
        super().__init__()
        self.device = device

        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForMaskedLM.from_pretrained(model_name).to(device)
        self.model.eval()


        self.pool = AttentionPool(1280).to(device)
        self.max_len = self.tokenizer.model_max_length

    @torch.no_grad()
    def forward(self, seq: str):
        seq = seq.upper().replace("U", "T")
        embeds = []

        chunks = [seq[i:i+self.max_len] for i in range(0, len(seq), self.max_len)]

        for chunk in chunks:
            tokens = self.tokenizer(
                [chunk],
                return_tensors="pt",
                padding="max_length",
                max_length=self.max_len,
                truncation=True
            ).to(self.device)

            outputs = self.model(
                tokens["input_ids"],
                attention_mask=tokens["attention_mask"],
                output_hidden_states=True
            )

            hidden = outputs.hidden_states[-1].squeeze(0)  # (L, 1280), on CUDA
            mask = tokens["attention_mask"].squeeze(0)     # (L), on CUDA

            # ⭐ Now pooling is also on CUDA
            att_vec = self.pool(hidden, mask).to(self.device)

            embeds.append(att_vec)

        return torch.stack(embeds).mean(0)


In [13]:
# encoder_mean = NTEncoderMean(device=device)
# encoder_cls  = NTEncoderCLS(device=device)
# encoder_att  = NTEncoderAttention(device=device)


In [11]:
#Function to Save embeddings
# import pickle
# from tqdm import tqdm

# def cache_tf_embeddings(encoder, tf_seq_dict, save_path="tf_embed_cache.pkl"):
#     cache = {}
#     print("Caching TF embeddings...")
#     for tf_name, seq in tqdm(tf_seq_dict.items()):
#         cache[tf_name] = encoder(seq).cpu()
#     pickle.dump(cache, open(save_path, "wb"))
#     print("Saved TF embedding cache to", save_path)


# def cache_gene_embeddings(encoder, gene_seq_dict, save_path="gene_embed_cache.pkl"):
#     cache = {}
#     print("Caching Gene embeddings...")
#     for gene_name, seq in tqdm(gene_seq_dict.items()):
#         cache[gene_name] = encoder(seq).cpu()
#     pickle.dump(cache, open(save_path, "wb"))
#     print("Saved Gene embedding cache to", save_path)

#Function to Save embeddings


def ensure_dir(path):
    if path != "" and not os.path.exists(path):
        os.makedirs(path, exist_ok=True)

def cache_tf_embeddings(encoder, tf_seq_dict, save_path="./embeds/tf_embed_cache_nt_cls.pkl"):
    # FIX: ensure the directory of save_path exists
    ensure_dir(os.path.dirname(save_path))

    cache = {}
    print("Caching TF embeddings...")

    for tf_name, seq in tqdm(tf_seq_dict.items()):
        emb = encoder(seq)
        if hasattr(emb, "cpu"):
            emb = emb.cpu()
        cache[tf_name] = emb

    with open(save_path, "wb") as f:
        pickle.dump(cache, f)

    print(f"Saved TF embedding cache to: {os.path.abspath(save_path)}")


def cache_gene_embeddings(encoder, gene_seq_dict, save_path="./embeds/gene_embed_cache_nt_cls.pkl"):
    # FIX: ensure the directory exists
    ensure_dir(os.path.dirname(save_path))

    cache = {}
    print("Caching gene embeddings...")

    for gene_name, seq in tqdm(gene_seq_dict.items()):
        emb = encoder(seq)
        if hasattr(emb, "cpu"):
            emb = emb.cpu()
        cache[gene_name] = emb

    with open(save_path, "wb") as f:
        pickle.dump(cache, f)

    print(f"Saved Gene embedding cache to: {os.path.abspath(save_path)}")


In [12]:
# Generate and Save embeddings
# import torch
# device = "cuda" if torch.cuda.is_available() else "cpu"

# Load your sequence dictionaries (the same ones the dataloader uses)
tf_seq_dict = pickle.load(open("tf_sequences.pkl", "rb"))
gene_seq_dict = pickle.load(open("gene_sequences_4000bp.pkl", "rb"))

# Initialize encoder
# encoder_mean = NTEncoderMean(device=device)
# encoder_cls  = NTEncoderCLS(device=device)
# encoder_att  = NTEncoderAttention(device=device)
encoder_enformer = EnformerEncoder(device=device)

# Cache embeddings (takes 3–20 minutes total depending on sizes)
# cache_tf_embeddings(encoder, tf_seq_dict)
# cache_gene_embeddings(encoder, gene_seq_dict)
# cache_tf_embeddings(encoder_att, tf_seq_dict, save_path="./embeds/tf_attn.pkl")
# cache_gene_embeddings(encoder_att, gene_seq_dict, save_path="./embeds/gn_attn.pkl")

# cache_tf_embeddings(encoder_cls, tf_seq_dict, save_path="./embeds/tf_cls.pkl")
cache_tf_embeddings(encoder_enformer, tf_seq_dict, save_path="./embeds/tf_enformer.pkl")
# cache_gene_embeddings(encoder_cls, gene_seq_dict, save_path="./embeds/gn_cls.pkl")
cache_gene_embeddings(encoder_enformer, gene_seq_dict, save_path="./embeds/gn_enformer.pkl")


To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`
Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


Caching TF embeddings...


100%|██████████| 223/223 [02:19<00:00,  1.60it/s]


Saved TF embedding cache to: c:\Users\navid\Desktop\CMU\Academic\ML\Project\perturbseq\embeds\tf_enformer.pkl
Caching gene embeddings...


100%|██████████| 5307/5307 [54:21<00:00,  1.63it/s]


Saved Gene embedding cache to: c:\Users\navid\Desktop\CMU\Academic\ML\Project\perturbseq\embeds\gn_enformer.pkl


In [13]:
# #  Download embeddings
# import shutil
# from google.colab import files

# # Path to your embedding directory
# embed_dir = "./embeds"

# # Output zip file name
# zip_name = "nt_embedding_caches.zip"

# # Create zip
# shutil.make_archive("nt_embedding_caches", 'zip', embed_dir)

# # Download zip
# files.download(zip_name)


In [14]:
# New dataset/dataloader with embeddings
class CachedEmbeddingDataset(torch.utils.data.Dataset):
    def __init__(self, parquet_path, tf_cache_path, gene_cache_path, type="train", train_fraction=0.8, seed=10701):
        df = pd.read_parquet(parquet_path)

        # load caches
        self.tf_cache = pickle.load(open(tf_cache_path, "rb"))
        self.gene_cache = pickle.load(open(gene_cache_path, "rb"))

        # remove entries missing from cache
        df = df[df["tf_name"].isin(self.tf_cache.keys())]
        df = df[df["gene_name"].isin(self.gene_cache.keys())]

        # shuffle + split
        df = df.sample(frac=1.0, random_state=seed)
        n = int(train_fraction * len(df))
        if type == "train":
            self.df = df.iloc[:n].reset_index(drop=True)
        else:
            self.df = df.iloc[n:].reset_index(drop=True)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        tf_emb = self.tf_cache[row["tf_name"]]     # tensor
        gene_emb = self.gene_cache[row["gene_name"]]
        y = torch.tensor(row["expression"], dtype=torch.float32)
        return tf_emb, gene_emb, y


In [15]:
def get_cached_loader(type="train", batch_size=32):
    ds = CachedEmbeddingDataset(
        parquet_path="tf_gene_expression.parquet",
        # tf_cache_path="tf_embed_cache.pkl",
        # gene_cache_path="gene_embed_cache.pkl",
        # tf_cache_path="./embeds/tf_attn.pkl",
        # gene_cache_path="./embeds/gn_attn.pkl",
        # tf_cache_path="./embeds/tf_cls.pkl",
        # gene_cache_path="./embeds/gn_cls.pkl",
        tf_cache_path="./embeds/tf_enformer.pkl",
        gene_cache_path="./embeds/gn_enformer.pkl",
        type=type
    )
    return DataLoader(ds, batch_size=batch_size, shuffle=(type=="train"))

# def get_cached_loader(type="train", batch_size=32):
#     ds = CachedEmbeddingDataset(
#         parquet_path="tf_gene_expression.parquet",
#         tf_cache_path="./embeds/tf_embed_cache.pkl",
#         gene_cache_path="./embeds/gene_embed_cache.pkl",
#         type=type
#     )
#     return DataLoader(ds, batch_size=batch_size, shuffle=(type=="train"))



In [19]:
# # MLP OLD
# class NTBiEncoderFast(nn.Module):
#     def __init__(self, emb_dim=1280):
#         super().__init__()
#         self.mlp = nn.Sequential(
#             nn.Linear(emb_dim * 2, 512),
#             nn.ReLU(),
#             nn.Linear(512, 128),
#             nn.ReLU(),
#             nn.Linear(128, 1)
#         )

#     def forward(self, tf_embs, gene_embs):
#         h = torch.cat([tf_embs, gene_embs], dim=-1)
#         return self.mlp(h).squeeze(-1)


In [16]:
# MLP new
class InteractionMLP(nn.Module):
    def __init__(self, emb_dim=1280):
        super().__init__()

        # TF emb (1280) + gene emb (1280) + interaction (1280)
        in_dim = emb_dim * 3

        self.net = nn.Sequential(
            nn.Linear(in_dim, 2048),
            nn.ReLU(),
            nn.Dropout(0.2),

            nn.Linear(2048, 1024),
            nn.ReLU(),
            nn.Dropout(0.2),

            nn.Linear(1024, 512),
            nn.ReLU(),

            nn.Linear(512, 128),
            nn.ReLU(),

            nn.Linear(128, 1)
        )

    def forward(self, tf_emb, gene_emb):
        interaction = tf_emb * gene_emb
        h = torch.cat([tf_emb, gene_emb, interaction], dim=-1)
        return self.net(h).squeeze(-1)


In [None]:
# Updated Training Loop
from tqdm import tqdm

def train_one_epoch_cached_enformer(model, loader, optimizer, mu, sigma, device="cuda"):
    model.train()
    total_loss, N = 0.0, 0
    pbar = tqdm(loader)

    for tf_emb, gene_emb, y in pbar:
        tf_emb = tf_emb.to(device)
        gene_emb = gene_emb.to(device)
        y = y.to(device)

        preds = model(tf_emb, gene_emb)
        loss = weighted_mse_loss(preds, y, mu, sigma)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * len(y)
        N += len(y)
        pbar.set_postfix({"loss": loss.item()})

    return total_loss / N




In [None]:
train_loader = get_cached_loader(type="train", batch_size=64)
test_loader = get_cached_loader(type="test", batch_size=64)

# model_nt = NTBiEncoderFast(emb_dim=1280).to(device)
# model_nt = InteractionMLP(emb_dim=1280).to(device)
model_enformer = InteractionMLP(emb_dim=5313).to(device) # for enformer
optimizer = torch.optim.Adam(model_enformer.parameters(), lr=1e-4)

for epoch in range(10):
    loss = train_one_epoch_cached_enformer(model_enformer, train_loader, optimizer, mu, sigma)
    print("Epoch", epoch, "Loss", loss)


100%|██████████| 14794/14794 [05:58<00:00, 41.21it/s, loss=0.0297] 


Epoch 0 Loss 0.04464342606947554


100%|██████████| 14794/14794 [06:14<00:00, 39.52it/s, loss=0.015]  


Epoch 1 Loss 0.04469939464102886


100%|██████████| 14794/14794 [06:18<00:00, 39.04it/s, loss=0.0504] 


Epoch 2 Loss 0.044399045377206965


100%|██████████| 14794/14794 [06:17<00:00, 39.16it/s, loss=0.1]    


Epoch 3 Loss 0.04436302752354335


100%|██████████| 14794/14794 [06:37<00:00, 37.22it/s, loss=0.0732] 


Epoch 4 Loss 0.044479404794626996


100%|██████████| 14794/14794 [06:26<00:00, 38.26it/s, loss=0.0227] 


Epoch 5 Loss 0.044042329071691644


100%|██████████| 14794/14794 [06:23<00:00, 38.57it/s, loss=0.0211] 


Epoch 6 Loss 0.04433271750442492


100%|██████████| 14794/14794 [06:50<00:00, 36.04it/s, loss=0.0272] 


Epoch 7 Loss 0.044423117650708964


100%|██████████| 14794/14794 [06:16<00:00, 39.27it/s, loss=0.00598]


Epoch 8 Loss 0.044464552987611754


100%|██████████| 14794/14794 [06:21<00:00, 38.78it/s, loss=0.0173] 

Epoch 9 Loss 0.04436450114990348





In [19]:
def evaluate_enformer_cached(model, loader, mu, sigma, device="cuda"):
    model.eval()
    preds_all, y_all = [], []

    with torch.no_grad():
        for tf_emb, gene_emb, y in loader:
            tf_emb = tf_emb.to(device)
            gene_emb = gene_emb.to(device)
            y = y.to(device)

            preds = model(tf_emb, gene_emb)

            preds_all.append(preds.cpu())
            y_all.append(y.cpu())

    preds_all = torch.cat(preds_all)
    y_all = torch.cat(y_all)

    mse = ((preds_all - y_all)**2).mean().item()
    corr = torch.corrcoef(torch.stack([preds_all, y_all]))[0, 1].item()

    # Large-effect subset
    z = (y_all - mu) / sigma
    mask = torch.abs(z) > 1.0

    if mask.sum() > 0:
        mse_big = ((preds_all[mask] - y_all[mask])**2).mean().item()
        corr_big = torch.corrcoef(torch.stack([preds_all[mask], y_all[mask]]))[0, 1].item()
    else:
        mse_big, corr_big = None, None

    return mse, corr, mse_big, corr_big


In [20]:
mse, corr, mse_big, corr_big = evaluate_enformer_cached(model_enformer, test_loader, mu, sigma, device=device)

print("=== Evaluation Results ===")
print(f"Test MSE:          {mse:.6f}")
print(f"Test Corr:         {corr:.4f}")
print(f"Big-Effect MSE:    {mse_big:.6f}")
print(f"Big-Effect Corr:   {corr_big:.4f}")


=== Evaluation Results ===
Test MSE:          0.020895
Test Corr:         0.0005
Big-Effect MSE:    0.094879
Big-Effect Corr:   0.0010


In [None]:
torch.save(model_enformer.state_dict(), "models/enformer_int_model.pt")
print("Model saved as enformer_int_model.pt")


Model saved as enformer_int_model.pt


In [None]:
# from google.colab import files
# files.download("nt_cls_int_model.pt")


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>