<a href="https://colab.research.google.com/github/navidh86/perturbseq-10701/blob/master/Combine_classify.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# ONLY FOR COLAB
!git clone https://github.com/navidh86/perturbseq-10701.git
%cd ./perturbseq-10701
!pip install fastparquet tqdm scikit-learn



Cloning into 'perturbseq-10701'...
remote: Enumerating objects: 220, done.[K
remote: Counting objects: 100% (99/99), done.[K
remote: Compressing objects: 100% (80/80), done.[K
remote: Total 220 (delta 47), reused 53 (delta 17), pack-reused 121 (from 2)[K
Receiving objects: 100% (220/220), 260.63 MiB | 16.64 MiB/s, done.
Resolving deltas: 100% (93/93), done.
Updating files: 100% (57/57), done.
/content/perturbseq-10701
Collecting fastparquet
  Downloading fastparquet-2024.11.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.2 kB)
Downloading fastparquet-2024.11.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m17.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: fastparquet
Successfully installed fastparquet-2024.11.0


In [2]:

import os
import pickle
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from sklearn.metrics import f1_score, classification_report

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

Using device: cuda


In [3]:
from data.reference_data_classification import (
    PairPerturbSeqDataset,
    perturbseq_collate_2,
    get_dataloader,
)

train_loader = get_dataloader(
    parquet_path="data/tf_gene_expression_labeled_v2.parquet",
    tf_sequences_path="data/tf_sequences.pkl",
    gene_sequences_path="data/gene_sequences_4000bp.pkl",
    batch_size=128,
    type="train",
    majority_fraction=0.005,
)

val_loader = get_dataloader(
    parquet_path="data/tf_gene_expression_labeled_v2.parquet",
    tf_sequences_path="data/tf_sequences.pkl",
    gene_sequences_path="data/gene_sequences_4000bp.pkl",
    batch_size=256,
    type="val",
    majority_fraction=0.005,
)

test_loader = get_dataloader(
    parquet_path="data/tf_gene_expression_labeled_v2.parquet",
    tf_sequences_path="data/tf_sequences.pkl",
    gene_sequences_path="data/gene_sequences_4000bp.pkl",
    batch_size=256,
    type="test",
    majority_fraction=0.005,
)

print("Train batches:", len(train_loader.dataset))
print("Val batches:", len(val_loader.dataset))
print("Test batches:", len(test_loader.dataset))



Train batches: 10845
Val batches: 2324
Test batches: 2325


In [4]:
# load NT sequence embeddings
tf_embed_cache = pickle.load(open("./embeds/tf_cls.pkl", "rb"))
gene_embed_cache = pickle.load(open("./embeds/gn_cls.pkl", "rb"))

# ensure everything is torch tensors
for k in tf_embed_cache:
    if not isinstance(tf_embed_cache[k], torch.Tensor):
        tf_embed_cache[k] = torch.tensor(tf_embed_cache[k], dtype=torch.float32)

for k in gene_embed_cache:
    if not isinstance(gene_embed_cache[k], torch.Tensor):
        gene_embed_cache[k] = torch.tensor(gene_embed_cache[k], dtype=torch.float32)

first_tf = next(iter(tf_embed_cache.values()))
first_gene = next(iter(gene_embed_cache.values()))
print("TF emb dim:", first_tf.shape)
print("Gene emb dim:", first_gene.shape)


TF emb dim: torch.Size([1280])
Gene emb dim: torch.Size([1280])


In [5]:
train_ds = train_loader.dataset
val_ds = val_loader.dataset
test_ds = test_loader.dataset

# Use all TF/gene names across ALL splits
combined_df = pd.concat([train_ds.df, val_ds.df, test_ds.df]).reset_index(drop=True)

tf_names = sorted(combined_df["tf_name"].unique().tolist())
gene_names = sorted(combined_df["gene_name"].unique().tolist())

tf_id_map = {name: idx for idx, name in enumerate(tf_names)}
gene_id_map = {name: idx for idx, name in enumerate(gene_names)}

print("Num TF IDs:", len(tf_id_map))
print("Num Gene IDs:", len(gene_id_map))

# check if each TF/gene has an embedding
missing_tf = [n for n in tf_names if n not in tf_embed_cache]
missing_gene = [n for n in gene_names if n not in gene_embed_cache]
print("Missing TF in cache:", len(missing_tf))
print("Missing gene in cache:", len(missing_gene))


Num TF IDs: 223
Num Gene IDs: 4539
Missing TF in cache: 0
Missing gene in cache: 0


In [6]:
def one_hot(index, num_classes):
    v = torch.zeros(num_classes, dtype=torch.float32)
    v[index] = 1.0
    return v


In [7]:
class TFgeneHybridMLP(nn.Module):
    def __init__(self, tf_embed_cache, gene_embed_cache, tf_id_map, gene_id_map,
                 hidden_dim=1024, num_classes=3):
        super().__init__()

        self.tf_cache = tf_embed_cache
        self.gene_cache = gene_embed_cache
        self.tf_id_map = tf_id_map
        self.gene_id_map = gene_id_map

        self.num_tfs = len(tf_id_map)
        self.num_genes = len(gene_id_map)

        seq_dim = next(iter(tf_embed_cache.values())).shape[0]

        # FINAL INPUT SIZE:
        # TF_seq + Gene_seq + interaction + TF_onehot + Gene_onehot
        in_dim = seq_dim*3 + self.num_tfs + self.num_genes

        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),

            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),

            nn.Linear(hidden_dim, hidden_dim//2),
            nn.ReLU(),

            nn.Linear(hidden_dim//2, 128),
            nn.ReLU(),

            nn.Linear(128, num_classes)
        )

    def forward(self, batch_x):
        X_list = []

        for item in batch_x:
            tf_name = item["tf_name"]
            gene_name = item["gene_name"]

            tf_seq = self.tf_cache[tf_name]
            gene_seq = self.gene_cache[gene_name]
            interaction = tf_seq * gene_seq

            tf_onehot = torch.zeros(self.num_tfs)
            gene_onehot = torch.zeros(self.num_genes)

            tf_onehot[self.tf_id_map[tf_name]] = 1.0
            gene_onehot[self.gene_id_map[gene_name]] = 1.0

            vec = torch.cat([tf_seq, gene_seq, interaction, tf_onehot, gene_onehot])
            X_list.append(vec)

        X = torch.stack(X_list).to(device)
        return self.net(X)


In [8]:
loss_fn = nn.CrossEntropyLoss()

def train_one_epoch(model, loader, optimizer):
    model.train()
    total_loss, total_correct, total_samples = 0.0, 0, 0

    for batch_x, batch_y in loader:
        batch_y = batch_y.to(device)

        logits = model(batch_x)
        loss = loss_fn(logits, batch_y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        preds = logits.argmax(dim=1)
        total_loss += loss.item() * len(batch_y)
        total_correct += (preds == batch_y).sum().item()
        total_samples += len(batch_y)

    return total_loss / total_samples, total_correct / total_samples


@torch.no_grad()
def eval_model(model, loader):
    model.eval()
    total_loss, total_correct, total_samples = 0.0, 0, 0
    all_preds, all_labels = [], []

    for batch_x, batch_y in loader:
        batch_y = batch_y.to(device)

        logits = model(batch_x)
        loss = loss_fn(logits, batch_y)

        preds = logits.argmax(dim=1)

        total_loss += loss.item() * len(batch_y)
        total_correct += (preds == batch_y).sum().item()
        total_samples += len(batch_y)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(batch_y.cpu().numpy())

    avg_loss = total_loss / total_samples
    acc = total_correct / total_samples
    macro_f1 = f1_score(all_labels, all_preds, average="macro")

    return avg_loss, acc, macro_f1, all_labels, all_preds


In [9]:
model = TFgeneHybridMLP(
    tf_embed_cache=tf_embed_cache,
    gene_embed_cache=gene_embed_cache,
    tf_id_map=tf_id_map,
    gene_id_map=gene_id_map,
    hidden_dim=1024,
    num_classes=3,
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

best_val_f1 = -1
best_state = None

for epoch in range(1, 31):
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer)

    # validation
    val_loss, val_acc, val_f1, _, _ = eval_model(model, val_loader)

    print(
        f"Epoch {epoch:02d} | "
        f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | "
        f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, Val F1: {val_f1:.4f}"
    )

    # save best model
    if val_f1 > best_val_f1:
        best_val_f1 = val_f1
        best_state = model.state_dict()
        torch.save(best_state, "best_hybrid_model.pt")
        print(f"*** Saved best model at epoch {epoch} (Val F1={val_f1:.4f})")



Epoch 01 | Train Loss: 1.0992, Train Acc: 0.3526 | Val Loss: 1.0973, Val Acc: 0.3563, Val F1: 0.1751
*** Saved best model at epoch 1 (Val F1=0.1751)
Epoch 02 | Train Loss: 1.0976, Train Acc: 0.3546 | Val Loss: 1.0968, Val Acc: 0.3563, Val F1: 0.1751
Epoch 03 | Train Loss: 1.0967, Train Acc: 0.3553 | Val Loss: 1.0956, Val Acc: 0.3563, Val F1: 0.1751
Epoch 04 | Train Loss: 1.0964, Train Acc: 0.3511 | Val Loss: 1.0935, Val Acc: 0.3563, Val F1: 0.1751
Epoch 05 | Train Loss: 1.0829, Train Acc: 0.3848 | Val Loss: 1.1146, Val Acc: 0.3141, Val F1: 0.1594
Epoch 06 | Train Loss: 1.0263, Train Acc: 0.4444 | Val Loss: 0.9666, Val Acc: 0.4888, Val F1: 0.4483
*** Saved best model at epoch 6 (Val F1=0.4483)
Epoch 07 | Train Loss: 0.9025, Train Acc: 0.5317 | Val Loss: 1.1735, Val Acc: 0.4127, Val F1: 0.3336
Epoch 08 | Train Loss: 0.8469, Train Acc: 0.5599 | Val Loss: 0.7937, Val Acc: 0.5899, Val F1: 0.5918
*** Saved best model at epoch 8 (Val F1=0.5918)
Epoch 09 | Train Loss: 0.7714, Train Acc: 0.6048

In [10]:
# Load best model
model.load_state_dict(torch.load("best_hybrid_model.pt"))

test_loss, test_acc, test_f1, y_true, y_pred = eval_model(model, test_loader)

print("\n=== FINAL TEST RESULTS ===")
print("Test Loss:", test_loss)
print("Test Accuracy:", test_acc)
print("Test Macro F1:", test_f1)
print("\nClassification Report:")
print(classification_report(y_true, y_pred, digits=4))



=== FINAL TEST RESULTS ===
Test Loss: 0.6436375150372905
Test Accuracy: 0.7255913978494624
Test Macro F1: 0.7257591843778003

Classification Report:
              precision    recall  f1-score   support

           0     0.7644    0.6227    0.6863       766
           1     0.7849    0.9083    0.8421       731
           2     0.6386    0.6594    0.6488       828

    accuracy                         0.7256      2325
   macro avg     0.7293    0.7302    0.7258      2325
weighted avg     0.7260    0.7256    0.7220      2325

