In [None]:
import pandas as pd
import numpy as np

In [8]:
from reference_data_alternate import (
    PairPerturbSeqDataset,
    perturbseq_collate,
    get_dataloader
)

from torch.utils.data import DataLoader

train_loader = get_dataloader(type="train", batch_size=4)
batch_inputs, batch_labels = next(iter(train_loader))

print("Inputs example:", batch_inputs[0])
print("Label example:", batch_labels[0])


Inputs example: ['GCGGCCGCGCGTGGTGGGGGAGGAGGGACCGGCGGCGCCCACGTGGCCTCCGCGGGCCCCGCCAGAGCCTGCGCCCGGGCCCTGACCGCACCTCTCGCCCCGCAGGACCATGGCCAACCTGGAGCGCACCTTCATCGCCATCAAGCCGGACGGCGTGCAGCGCGGCCTGGTGGGCGAGATCATCAAGCGCTTCGAGCAGAAGGGATTCCGCCTCGTGGCCATGAAGTTCCTCCGGGCCTCTGAAGAACACCTGAAGCAGCACTACATTGACCTGAAAGACCGACCATTCTTCCCTGGGCTGGTGAAGTACATGAACTCAGGGCCGGTTGTGGCCATGGTCTGGGAGGGGCTGAACGTGGTGAAGACAGGCCGAGTGATGCTTGGGGAGACCAATCCAGCAGATTCAAAGCCAGGCACCATTCGTGGGGACTTCTGCATTCAGGTTGGCAGGAACATCATTCATGGCAGTGATTCAGTAAAAAGTGCTGAAAAAGAAATCAGCCTATGGTTTAAGCCTGAAGAACTGGTTGACTACAAGTCTTGTGCTCATGACTGGGTCTATGAATAAGAGGTGGACACAACAGCAGTCTCCTTCAGCACGGCGTGGTGTGTCCCTGGACACAGCTCTTCATTCCATTGACTTAGAGGCAACAGGATTGATCATTCTTTTATAGAGCATATTTGCCAATAAAGCTTTTGGAAGCCGGAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA', 'TTAAGACTTCAGGCAGACAGATATGAACACATGAAGGGAGTAAAAACTCCAACCTCTGGACCACTTCCTGGAGTGCTCCTAACCCAGACTCCCACAGGGCTTCCTCCCATACTTCATTTAGGTCTTTGGCTCAAATGTTACCTCTTTAGAACTGCCTCAAGACTGGGGCAAGAGGAGCCCCAGTCTAGGGCTCTGAGCACCTGATTCTCTTCCGT

In [9]:
import numpy as np

train_dataset = PairPerturbSeqDataset(type="train")
all_y = []

for _, y in DataLoader(train_dataset, batch_size=512):
    all_y.extend(y.numpy())

all_y = np.array(all_y)
mu = all_y.mean()
sigma = all_y.std()

print("Training mean:", mu)
print("Training std:", sigma)


Training mean: -0.022736955
Training std: 0.15207928


In [10]:
import torch
import torch.nn as nn

def weighted_mse_loss(pred, target, mu, sigma, alpha=3.0, threshold=1.0):
    mse = (pred - target)**2
    z = (target - mu) / sigma
    weights = torch.where(torch.abs(z) > threshold,
                          torch.tensor(alpha, device=target.device),
                          torch.tensor(1.0, device=target.device))
    return (weights * mse).sum() / weights.sum()


In [11]:
import pickle

tf_seqs = pickle.load(open("tf_sequences.pkl", "rb"))
gene_seqs = pickle.load(open("gene_sequences_4000bp.pkl", "rb"))

tf_to_id = {tf: i for i, tf in enumerate(tf_seqs.keys())}
gene_to_id = {gene: i for i, gene in enumerate(gene_seqs.keys())}

num_tfs = len(tf_to_id)
num_genes = len(gene_to_id)


In [12]:
class TFGeneIDModel(nn.Module):
    def __init__(self, num_tfs, num_genes, emb_dim=32):
        super().__init__()
        self.tf_emb = nn.Embedding(num_tfs, emb_dim)
        self.gene_emb = nn.Embedding(num_genes, emb_dim)
        self.mlp = nn.Sequential(
            nn.Linear(2*emb_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, tf_ids, gene_ids):
        t = self.tf_emb(tf_ids)
        g = self.gene_emb(gene_ids)
        h = torch.cat([t, g], dim=-1)
        return self.mlp(h).squeeze(-1)


In [13]:
dna_map = {"A":0, "C":1, "G":2, "T":3, "N":0}

def encode_seq(seq):
    return torch.tensor([dna_map.get(ch, 0) for ch in seq], dtype=torch.long)


In [14]:
class SimpleSeqEncoder(nn.Module):
    def __init__(self, vocab=4, emb_dim=16, hidden=64):
        super().__init__()
        self.emb = nn.Embedding(vocab, emb_dim)
        self.conv = nn.Conv1d(emb_dim, hidden, kernel_size=7, padding=3)
        self.pool = nn.AdaptiveMaxPool1d(1)

    def forward(self, seq_batch):
        # seq_batch: list of 1D LongTensors (variable length)
        padded = nn.utils.rnn.pad_sequence(seq_batch, batch_first=True)
        x = self.emb(padded)                     # (B, L, emb)
        x = x.permute(0,2,1)                     # (B, emb, L)
        x = torch.relu(self.conv(x))             # (B, hidden, L)
        x = self.pool(x).squeeze(-1)             # (B, hidden)
        return x

class SeqBiEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = SimpleSeqEncoder()
        self.mlp = nn.Sequential(
            nn.Linear(128, 64),   # because encoder outputs 64-d each
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, tf_seqs, gene_seqs):
        tf_h = self.encoder(tf_seqs)
        gene_h = self.encoder(gene_seqs)
        h = torch.cat([tf_h, gene_h], dim=-1)
        return self.mlp(h).squeeze(-1)


In [None]:
# Old prepare_id_batch function

def prepare_id_batch(batch_inputs):
    tf_ids = []
    gene_ids = []
    for tf_seq, gene_seq in batch_inputs:
        # keys are exactly those in your parquet
        tf_name = tf_seq   # since the dataset returns raw string seq, get name via reverse lookup
        gene_name = gene_seq
        tf_ids.append(tf_to_id[tf_name])
        gene_ids.append(gene_to_id[gene_name])
    return torch.tensor(tf_ids), torch.tensor(gene_ids)


In [None]:
# Old training loop for sequence model

import torch.optim as optim

# model = SeqBiEncoder().cuda()
model = SeqBiEncoder()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(3):
    model.train()
    for batch_inputs, batch_labels in train_loader:
        
        # encode sequences
        tf_seqs = [encode_seq(tf) for tf, gene in batch_inputs]
        gene_seqs = [encode_seq(gene) for tf, gene in batch_inputs]

        # tf_seqs = [t.cuda() for t in tf_seqs]
        # gene_seqs = [g.cuda() for g in gene_seqs]
        # y = batch_labels.cuda()
        y= batch_labels

        preds = model(tf_seqs, gene_seqs)

        loss = weighted_mse_loss(preds, y, mu, sigma)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print("Epoch", epoch, "Loss:", loss.item())


In [None]:
# New loop with option for ID only model

import torch
import torch.nn as nn
import torch.optim as optim

def train_one_epoch(model, loader, mu, sigma, optimizer, use_sequences=True, device="cuda"):
    model.train()
    total_loss = 0.0
    count = 0

    for batch_x, batch_y in loader:
        batch_y = batch_y.to(device)

        # ------------------------
        # 1. Prepare TF and gene IDs
        # ------------------------
        tf_ids = torch.tensor([tf_to_id[x["tf_name"]] for x in batch_x], dtype=torch.long).to(device)
        gene_ids = torch.tensor([gene_to_id[x["gene_name"]] for x in batch_x], dtype=torch.long).to(device)

        # ------------------------
        # 2. Prepare sequences (optional)
        # ------------------------
        if use_sequences:
            tf_seqs = [encode_seq(x["tf_seq"]).to(device) for x in batch_x]
            gene_seqs = [encode_seq(x["gene_seq"]).to(device) for x in batch_x]
        else:
            tf_seqs, gene_seqs = None, None

        # ------------------------
        # 3. Forward pass
        # ------------------------
        if use_sequences:
            preds = model(tf_seqs, gene_seqs)          # sequence model
        else:
            preds = model(tf_ids, gene_ids)            # ID baseline model

        # ------------------------
        # 4. Weighted MSE Loss
        # ------------------------
        z = (batch_y - mu) / sigma
        weights = torch.where(torch.abs(z) > 1.0, torch.tensor(3.0, device=device), torch.tensor(1.0, device=device))
        loss = ((preds - batch_y) ** 2 * weights).sum() / weights.sum()

        # ------------------------
        # 5. Backprop
        # ------------------------
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * len(batch_y)
        count += len(batch_y)

    return total_loss / count


In [None]:
# ID only model training, baseline

model = TFGeneIDModel(num_tfs, num_genes).to("cuda")
optimizer = optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(5):
    loss = train_one_epoch(model, train_loader, mu, sigma, optimizer,
                           use_sequences=False, device="cuda")
    print(f"Epoch {epoch} | Loss = {loss:.4f}")


In [None]:
# Sequence-based model training

model = SeqBiEncoder().to("cuda")
optimizer = optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(5):
    loss = train_one_epoch(model, train_loader, mu, sigma, optimizer,
                           use_sequences=True, device="cuda")
    print(f"Epoch {epoch} | Loss = {loss:.4f}")
