<a href="https://colab.research.google.com/github/navidh86/perturbseq-10701/blob/master/test_dataset_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import pandas as pd
import numpy as np

In [2]:
# ONLY FOR COLAB
!git clone https://github.com/navidh86/perturbseq-10701.git
%cd ./perturbseq-10701
!pip install fastparquet tqdm


Cloning into 'perturbseq-10701'...
remote: Enumerating objects: 65, done.[K
remote: Counting objects: 100% (56/56), done.[K
remote: Compressing objects: 100% (49/49), done.[K
remote: Total 65 (delta 18), reused 31 (delta 5), pack-reused 9 (from 1)[K
Receiving objects: 100% (65/65), 91.05 MiB | 36.58 MiB/s, done.
Resolving deltas: 100% (19/19), done.
/content/perturbseq-10701
Collecting fastparquet
  Downloading fastparquet-2024.11.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.2 kB)
Downloading fastparquet-2024.11.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m22.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: fastparquet
Successfully installed fastparquet-2024.11.0


In [3]:
# Imports & Config

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

import numpy as np
from tqdm import tqdm
import pickle
import pandas as pd

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using", device)


Using cuda


In [4]:
# Dataloader import

from reference_data_alternate import (
    PairPerturbSeqDataset,
    perturbseq_collate,
    get_dataloader
)


In [5]:
# Load loaders

train_loader = get_dataloader(type="train", batch_size=32)
test_loader  = get_dataloader(type="test", batch_size=32)


In [6]:
# Sample data

batch_x, batch_y = next(iter(train_loader))
print(batch_x[0])
print(batch_y[0])


{'tf_name': 'MYBL1', 'tf_seq': 'GCAGAACTGCTAGCTGCGGGGGAGAGGGCAGGGGTCGGGCGCCTGTGGCGGAGCCGGGCTGGGGCCAGGGCAGGGAGGCTGACAAGCGGCGGGAGAAGCCGGCGGAGGGCGGGATCGCGCCTCCTGACATGTTGGGGGTATCCCTGGCCGGGCCGGGCCGGGGCTAAGAGCGGCGCTGCGGGCCGGGGTCGGGGTCGGGTCGCGGTCCGCCCCCGCTGTCCCTCCGTCCTGCCCTGTCGAGGACGTGCGTTCCGCACTCGGCCGCCTCCAGAGGGAGCGAGGGAAGCGGCTAGAGGATCGGGGAGAAGGAGCATTCGCCGGAGGCTGGAGGAGGCTGACCCGCGTCCCCGCCCAGCCTGCTCCTATGCGGTACTTGAAGGATGGCGAAGAGGTCGCGCAGTGAGGATGAGGATGATGACCTTCAGTATGCCGATCATGATTATGAAGTACCACAACAAAAAGGACTGAAGAAACTCTGGAACAGAGTAAAATGGACAAGGGACGAGGATGATAAATTAAAGAAGTTGGTTGAACAACATGGAACTGATGATTGGACTCTAATTGCTAGTCATCTTCAAAATCGCTCTGATTTTCAGTGCCAGCATCGATGGCAGAAAGTTTTAAATCCTGAATTGATAAAGGGTCCTTGGACTAAAGAAGAAGATCAGAGGGTTATTGAATTAGTTCAGAAATATGGGCCAAAAAGATGGTCTTTAATTGCAAAACATTTAAAAGGAAGAATAGGCAAGCAGTGTAGAGAAAGATGGCATAATCATCTGAATCCTGAGGTAAAGAAATCTTCCTGGACAGAAGAGGAGGACAGGATCATCTATGAAGCACATAAGCGGTTGGGAAATCGTTGGGCAGAAATTGCCAAACTACTTCCAGGAAGGACTGATAATTCTATCAAAAATCATTGGAATTCTACTATGCGAAGAAAAGTGGAACAGGAGGGCTATTTACAAGAT

In [7]:
# Load TF & Gene Name Vocabularies

tf_seqs = pickle.load(open("tf_sequences.pkl", "rb"))
gene_seqs = pickle.load(open("gene_sequences_4000bp.pkl", "rb"))

tf_to_id = {name: i for i, name in enumerate(tf_seqs.keys())}
gene_to_id = {name: i for i, name in enumerate(gene_seqs.keys())}

num_tfs = len(tf_to_id)
num_genes = len(gene_to_id)

print(num_tfs, "TFs,", num_genes, "genes")


223 TFs, 5307 genes


In [8]:
# Compute Dataset Statistic

train_dataset = PairPerturbSeqDataset(type="train")

all_y = []
loader = DataLoader(train_dataset, batch_size=512, collate_fn=perturbseq_collate)

for _, y in loader:
    all_y.extend(y.numpy())

all_y = np.array(all_y)
mu, sigma = all_y.mean(), all_y.std()

print("Mean:", mu, "Std:", sigma)


Mean: -0.022736955 Std: 0.15207928


In [15]:
# Weighted mse loss define
def weighted_mse_loss(pred, target, mu, sigma, alpha=3.0, threshold=1.0):
    z = (target - mu) / sigma
    weights = torch.where(torch.abs(z) > threshold,
                          torch.tensor(alpha, device=target.device),
                          torch.tensor(1.0, device=target.device))
    mse = (pred - target)**2
    return (weights * mse).sum() / weights.sum()


In [9]:
# DNA Tokenizer

dna_map = {"A":0, "C":1, "G":2, "T":3, "N":0}

def encode_seq(seq):
    return torch.tensor([dna_map.get(ch, 0) for ch in seq], dtype=torch.long)


In [10]:
# TF+Gene Identity Baseline

class TFGeneIDModel(nn.Module):
    def __init__(self, num_tfs, num_genes, emb_dim=32):
        super().__init__()
        self.tf_emb = nn.Embedding(num_tfs, emb_dim)
        self.gene_emb = nn.Embedding(num_genes, emb_dim)
        self.mlp = nn.Sequential(
            nn.Linear(2*emb_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, tf_ids, gene_ids):
        t = self.tf_emb(tf_ids)
        g = self.gene_emb(gene_ids)
        h = torch.cat([t, g], dim=-1)
        return self.mlp(h).squeeze(-1)


In [11]:
# Sequence Bi-Encoder (Simple Prototype)

class SimpleSeqEncoder(nn.Module):
    def __init__(self, vocab=4, emb_dim=16, hidden=64):
        super().__init__()
        self.emb = nn.Embedding(vocab, emb_dim)
        self.conv = nn.Conv1d(emb_dim, hidden, kernel_size=7, padding=3)
        self.pool = nn.AdaptiveMaxPool1d(1)

    def forward(self, seq_batch):
        padded = nn.utils.rnn.pad_sequence(seq_batch, batch_first=True)
        x = self.emb(padded)                     # (B, L, emb)
        x = x.transpose(1, 2)                    # (B, emb, L)
        x = torch.relu(self.conv(x))            # (B, hidden, L)
        x = self.pool(x).squeeze(-1)            # (B, hidden)
        return x

class SeqBiEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = SimpleSeqEncoder()
        self.mlp = nn.Sequential(
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, tf_seqs, gene_seqs):
        tf_h = self.encoder(tf_seqs)
        gene_h = self.encoder(gene_seqs)
        h = torch.cat([tf_h, gene_h], dim=-1)
        return self.mlp(h).squeeze(-1)



In [12]:
# Training Loop (Unified for ID & Seq Models)

def train_one_epoch(model, loader, optimizer, mu, sigma,
                    use_sequences=True, device="cuda"):
    model.train()
    total_loss, N = 0.0, 0

    for batch_x, batch_y in loader:
        batch_y = batch_y.to(device)

        # ----- Prepare ID inputs -----
        tf_ids = torch.tensor([tf_to_id[x["tf_name"]] for x in batch_x]).long().to(device)
        gene_ids = torch.tensor([gene_to_id[x["gene_name"]] for x in batch_x]).long().to(device)

        # ----- Prepare sequences -----
        if use_sequences:
            tf_seqs = [encode_seq(x["tf_seq"]).to(device) for x in batch_x]
            gene_seqs = [encode_seq(x["gene_seq"]).to(device) for x in batch_x]
        else:
            tf_seqs, gene_seqs = None, None

        # ----- Forward pass -----
        if use_sequences:
            preds = model(tf_seqs, gene_seqs)
        else:
            preds = model(tf_ids, gene_ids)

        loss = weighted_mse_loss(preds, batch_y, mu, sigma)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * len(batch_y)
        N += len(batch_y)

    return total_loss / N


In [13]:
# Evaluate

def evaluate(model, loader, mu, sigma, use_sequences=True, device="cuda"):
    model.eval()
    preds_all, y_all = [], []

    with torch.no_grad():
        for batch_x, batch_y in loader:
            batch_y = batch_y.to(device)

            tf_ids = torch.tensor([tf_to_id[x["tf_name"]] for x in batch_x]).long().to(device)
            gene_ids = torch.tensor([gene_to_id[x["gene_name"]] for x in batch_x]).long().to(device)

            if use_sequences:
                tf_seqs = [encode_seq(x["tf_seq"]).to(device) for x in batch_x]
                gene_seqs = [encode_seq(x["gene_seq"]).to(device) for x in batch_x]
                preds = model(tf_seqs, gene_seqs)
            else:
                preds = model(tf_ids, gene_ids)

            preds_all.append(preds.cpu())
            y_all.append(batch_y.cpu())

    preds_all = torch.cat(preds_all)
    y_all = torch.cat(y_all)

    mse = ((preds_all - y_all)**2).mean().item()
    corr = torch.corrcoef(torch.stack([preds_all, y_all]))[0,1].item()

    # Large-effect subset
    z = (y_all - mu) / sigma
    mask = torch.abs(z) > 1.0

    if mask.sum() > 0:
        mse_big = ((preds_all[mask] - y_all[mask])**2).mean().item()
        corr_big = torch.corrcoef(torch.stack([preds_all[mask], y_all[mask]]))[0,1].item()
    else:
        mse_big, corr_big = None, None

    return mse, corr, mse_big, corr_big


In [16]:
# Train ID baseline

model_id = TFGeneIDModel(num_tfs, num_genes).to(device)
optimizer = optim.Adam(model_id.parameters(), lr=1e-3)

for epoch in range(5):
    loss = train_one_epoch(model_id, train_loader, optimizer, mu, sigma,
                           use_sequences=False)
    print(f"[ID Baseline] Epoch {epoch} | Loss={loss:.4f}")

mse, corr, big_mse, big_corr = evaluate(model_id, test_loader, mu, sigma,
                                        use_sequences=False)
print("Test: MSE =", mse, "Corr =", corr)
print("Large-effect: MSE =", big_mse, "Corr =", big_corr)


[ID Baseline] Epoch 0 | Loss=0.0438
[ID Baseline] Epoch 1 | Loss=0.0405
[ID Baseline] Epoch 2 | Loss=0.0382
[ID Baseline] Epoch 3 | Loss=0.0377
[ID Baseline] Epoch 4 | Loss=0.0372
Test: MSE = 0.019195543602108955 Corr = 0.33441659808158875
Large-effect: MSE = 0.07673805952072144 Corr = 0.4397140443325043


In [17]:
# Train Seq-Bi-encoder

model_seq = SeqBiEncoder().to(device)
optimizer = optim.Adam(model_seq.parameters(), lr=1e-3)

for epoch in range(5):
    loss = train_one_epoch(model_seq, train_loader, optimizer, mu, sigma,
                           use_sequences=True)
    print(f"[Seq Bi-Encoder] Epoch {epoch} | Loss={loss:.4f}")

mse, corr, big_mse, big_corr = evaluate(model_seq, test_loader, mu, sigma,
                                        use_sequences=True)
print("Test: MSE =", mse, "Corr =", corr)
print("Large-effect: MSE =", big_mse, "Corr =", big_corr)


[Seq Bi-Encoder] Epoch 0 | Loss=0.0442
[Seq Bi-Encoder] Epoch 1 | Loss=0.0432
[Seq Bi-Encoder] Epoch 2 | Loss=0.0431
[Seq Bi-Encoder] Epoch 3 | Loss=0.0438
[Seq Bi-Encoder] Epoch 4 | Loss=0.0437
Test: MSE = 0.020954959094524384 Corr = 0.069491446018219
Large-effect: MSE = 0.093931645154953 Corr = 0.1222829595208168
