<a href="https://colab.research.google.com/github/navidh86/perturbseq-10701/blob/master/Combine_classify.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [21]:
# ONLY FOR COLAB
!git clone https://github.com/navidh86/perturbseq-10701.git
%cd ./perturbseq-10701
!pip install fastparquet tqdm scikit-learn



Cloning into 'perturbseq-10701'...
remote: Enumerating objects: 198, done.[K
remote: Counting objects: 100% (77/77), done.[K
remote: Compressing objects: 100% (60/60), done.[K
remote: Total 198 (delta 37), reused 45 (delta 16), pack-reused 121 (from 2)[K
Receiving objects: 100% (198/198), 260.56 MiB | 23.32 MiB/s, done.
Resolving deltas: 100% (83/83), done.
Updating files: 100% (52/52), done.
/content/perturbseq-10701/perturbseq-10701


In [22]:

import os
import pickle
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from sklearn.metrics import f1_score, classification_report

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

Using device: cuda


In [23]:
from data.reference_data_classification import (
    PairPerturbSeqDataset,
    perturbseq_collate_2,
    get_dataloader,
)

train_loader = get_dataloader(
    parquet_path="data/tf_gene_expression_labeled_v2.parquet",
    tf_sequences_path="data/tf_sequences.pkl",
    gene_sequences_path="data/gene_sequences_4000bp.pkl",
    batch_size=64,
    type="train",
    majority_fraction=0.005
)

test_loader = get_dataloader(
    parquet_path="data/tf_gene_expression_labeled_v2.parquet",
    tf_sequences_path="data/tf_sequences.pkl",
    gene_sequences_path="data/gene_sequences_4000bp.pkl",
    batch_size=64,
    type="test",
    majority_fraction=0.005
)

print("Train batches:", len(train_loader))
print("Test batches: ", len(test_loader))


Train batches: 170
Test batches:  37


In [24]:
# load NT sequence embeddings
tf_embed_cache = pickle.load(open("./embeds/tf_cls.pkl", "rb"))
gene_embed_cache = pickle.load(open("./embeds/gn_cls.pkl", "rb"))

# ensure everything is torch tensors
for k in tf_embed_cache:
    if not isinstance(tf_embed_cache[k], torch.Tensor):
        tf_embed_cache[k] = torch.tensor(tf_embed_cache[k], dtype=torch.float32)

for k in gene_embed_cache:
    if not isinstance(gene_embed_cache[k], torch.Tensor):
        gene_embed_cache[k] = torch.tensor(gene_embed_cache[k], dtype=torch.float32)

first_tf = next(iter(tf_embed_cache.values()))
first_gene = next(iter(gene_embed_cache.values()))
print("TF emb dim:", first_tf.shape)
print("Gene emb dim:", first_gene.shape)


TF emb dim: torch.Size([1280])
Gene emb dim: torch.Size([1280])


In [25]:
train_ds = train_loader.dataset
test_ds = test_loader.dataset

# union of names from train + test
tf_names = sorted(set(train_ds.df["tf_name"].tolist() + test_ds.df["tf_name"].tolist()))
gene_names = sorted(set(train_ds.df["gene_name"].tolist() + test_ds.df["gene_name"].tolist()))

tf_id_map = {name: idx for idx, name in enumerate(tf_names)}
gene_id_map = {name: idx for idx, name in enumerate(gene_names)}

print("Num TF IDs:", len(tf_id_map))
print("Num gene IDs:", len(gene_id_map))

# quick sanity check: names in map also have embeddings
missing_tf = [n for n in tf_names if n not in tf_embed_cache]
missing_gene = [n for n in gene_names if n not in gene_embed_cache]
print("Missing TF in cache:", len(missing_tf))
print("Missing gene in cache:", len(missing_gene))


Num TF IDs: 223
Num gene IDs: 4305
Missing TF in cache: 0
Missing gene in cache: 0


In [26]:
def one_hot(index, num_classes):
    v = torch.zeros(num_classes, dtype=torch.float32)
    v[index] = 1.0
    return v


In [27]:
class TFgeneHybridMLP(nn.Module):
    def __init__(self, tf_embed_cache, gene_embed_cache, tf_id_map, gene_id_map,
                 hidden_dim=1024, num_classes=3):
        super().__init__()

        self.tf_cache = tf_embed_cache
        self.gene_cache = gene_embed_cache
        self.tf_id_map = tf_id_map
        self.gene_id_map = gene_id_map

        self.num_tfs = len(tf_id_map)
        self.num_genes = len(gene_id_map)

        seq_dim = next(iter(tf_embed_cache.values())).shape[0]

        # FINAL INPUT SIZE:
        # TF_seq + Gene_seq + interaction + TF_onehot + Gene_onehot
        in_dim = seq_dim*3 + self.num_tfs + self.num_genes

        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),

            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),

            nn.Linear(hidden_dim, hidden_dim//2),
            nn.ReLU(),

            nn.Linear(hidden_dim//2, 128),
            nn.ReLU(),

            nn.Linear(128, num_classes)
        )

    def forward(self, batch_x):
        X_list = []

        for item in batch_x:
            tf_name = item["tf_name"]
            gene_name = item["gene_name"]

            tf_seq = self.tf_cache[tf_name]
            gene_seq = self.gene_cache[gene_name]
            interaction = tf_seq * gene_seq

            tf_onehot = torch.zeros(self.num_tfs)
            gene_onehot = torch.zeros(self.num_genes)

            tf_onehot[self.tf_id_map[tf_name]] = 1.0
            gene_onehot[self.gene_id_map[gene_name]] = 1.0

            vec = torch.cat([tf_seq, gene_seq, interaction, tf_onehot, gene_onehot])
            X_list.append(vec)

        X = torch.stack(X_list).to(device)
        return self.net(X)


In [28]:
loss_fn = nn.CrossEntropyLoss()

def train_one_epoch(model, loader, optimizer):
    model.train()
    total_loss, total_correct, total_samples = 0.0, 0, 0

    for batch_x, batch_y in loader:
        batch_y = batch_y.to(device)

        logits = model(batch_x)
        loss = loss_fn(logits, batch_y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        preds = logits.argmax(dim=1)
        total_loss += loss.item() * len(batch_y)
        total_correct += (preds == batch_y).sum().item()
        total_samples += len(batch_y)

    return total_loss / total_samples, total_correct / total_samples


@torch.no_grad()
def eval_model(model, loader):
    model.eval()
    total_loss, total_correct, total_samples = 0.0, 0, 0
    all_preds, all_labels = [], []

    for batch_x, batch_y in loader:
        batch_y = batch_y.to(device)

        logits = model(batch_x)
        loss = loss_fn(logits, batch_y)

        preds = logits.argmax(dim=1)

        total_loss += loss.item() * len(batch_y)
        total_correct += (preds == batch_y).sum().item()
        total_samples += len(batch_y)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(batch_y.cpu().numpy())

    avg_loss = total_loss / total_samples
    acc = total_correct / total_samples
    macro_f1 = f1_score(all_labels, all_preds, average="macro")

    return avg_loss, acc, macro_f1, all_labels, all_preds


In [33]:
model = TFgeneHybridMLP(
    tf_embed_cache=tf_embed_cache,
    gene_embed_cache=gene_embed_cache,
    tf_id_map=tf_id_map,
    gene_id_map=gene_id_map,
    hidden_dim=1024,
    num_classes=3,
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

for epoch in range(1, 31):
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer)
    test_loss, test_acc, test_f1, _, _ = eval_model(model, test_loader)

    print(
        f"Epoch {epoch:02d} | "
        f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | "
        f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}, Test F1: {test_f1:.4f}"
    )


Epoch 01 | Train Loss: 1.0990, Train Acc: 0.3485 | Test Loss: 1.1003, Test Acc: 0.3144, Test F1: 0.1595
Epoch 02 | Train Loss: 1.0973, Train Acc: 0.3528 | Test Loss: 1.0974, Test Acc: 0.3561, Test F1: 0.1751
Epoch 03 | Train Loss: 1.0978, Train Acc: 0.3525 | Test Loss: 1.0964, Test Acc: 0.3561, Test F1: 0.1751
Epoch 04 | Train Loss: 1.0961, Train Acc: 0.3596 | Test Loss: 1.0948, Test Acc: 0.3561, Test F1: 0.1751
Epoch 05 | Train Loss: 1.0755, Train Acc: 0.3922 | Test Loss: 1.0491, Test Acc: 0.3910, Test F1: 0.2860
Epoch 06 | Train Loss: 0.9767, Train Acc: 0.4814 | Test Loss: 0.8816, Test Acc: 0.5510, Test F1: 0.5288
Epoch 07 | Train Loss: 0.8477, Train Acc: 0.5554 | Test Loss: 0.8809, Test Acc: 0.5497, Test F1: 0.5575
Epoch 08 | Train Loss: 0.8248, Train Acc: 0.5782 | Test Loss: 0.7854, Test Acc: 0.6086, Test F1: 0.6102
Epoch 09 | Train Loss: 0.7698, Train Acc: 0.6119 | Test Loss: 0.8774, Test Acc: 0.5690, Test F1: 0.5765
Epoch 10 | Train Loss: 0.7290, Train Acc: 0.6358 | Test Loss: 0.

In [34]:
test_loss, test_acc, test_f1, y_true, y_pred = eval_model(model, test_loader)

print("Final Test Accuracy:", test_acc)
print("Final Test Macro F1:", test_f1)
print("\nClassification Report:")
print(classification_report(y_true, y_pred, digits=4))


Final Test Accuracy: 0.730752688172043
Final Test Macro F1: 0.7302464646452856

Classification Report:
              precision    recall  f1-score   support

           0     0.6681    0.8068    0.7309       766
           1     0.8311    0.8618    0.8462       731
           2     0.7025    0.5447    0.6136       828

    accuracy                         0.7308      2325
   macro avg     0.7339    0.7378    0.7302      2325
weighted avg     0.7316    0.7308    0.7254      2325

