<a href="https://colab.research.google.com/github/navidh86/perturbseq-10701/blob/master/nt_classify.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# # ONLY FOR COLAB
# !git clone https://github.com/navidh86/perturbseq-10701.git
# %cd ./perturbseq-10701
# !pip install fastparquet tqdm


In [2]:

# !pip install --upgrade git+https://github.com/huggingface/transformers.git


In [1]:
import pandas as pd
import numpy as np
import pickle
import torch
import torch.nn as nn
# from transformers import AutoTokenizer, AutoModelForMaskedLM
from enformer_pytorch import Enformer, seq_indices_to_one_hot

import os
import pickle
from tqdm import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device

'cuda'

In [3]:
from reference_data_classification import (
    PairPerturbSeqDataset,
    perturbseq_collate_2,
    get_dataloader
)

train_loader = get_dataloader(
    parquet_path="tf_gene_expression_labeled_v2.parquet",
    tf_sequences_path="tf_sequences.pkl",
    gene_sequences_path="gene_sequences_4000bp.pkl",
    batch_size=32,
    type="train",
    majority_fraction=0.005
)

test_loader = get_dataloader(
    parquet_path="tf_gene_expression_labeled_v2.parquet",
    tf_sequences_path="tf_sequences.pkl",
    gene_sequences_path="gene_sequences_4000bp.pkl",
    batch_size=32,
    type="test",
    majority_fraction=0.005
)

print("Train size:", len(train_loader.dataset))
print("Test size: ", len(test_loader.dataset))

Train size: 11051
Test size:  2764


In [6]:
train_loader.dataset.df['expression_label'].value_counts()

expression_label
2    4415
1    4010
0    2626
Name: count, dtype: int64

In [7]:
import pickle
import torch

# Load cached embeddings
tf_embed_cache = pickle.load(open("./embeds/tf_enformer_alternate.pkl", "rb"))
gene_embed_cache = pickle.load(open("./embeds/gn_enformer_alternate.pkl", "rb"))

# Convert all embeddings to torch tensors (if stored as numpy)
for k in tf_embed_cache:
    if not isinstance(tf_embed_cache[k], torch.Tensor):
        tf_embed_cache[k] = torch.tensor(tf_embed_cache[k], dtype=torch.float32)

for k in gene_embed_cache:
    if not isinstance(gene_embed_cache[k], torch.Tensor):
        gene_embed_cache[k] = torch.tensor(gene_embed_cache[k], dtype=torch.float32)

# Inspect shapes
first_tf = next(iter(tf_embed_cache.values()))
first_gene = next(iter(gene_embed_cache.values()))

print("TF embedding count:", len(tf_embed_cache))
print("Gene embedding count:", len(gene_embed_cache))
print("TF embedding dim:", first_tf.shape)
print("Gene embedding dim:", first_gene.shape)


TF embedding count: 223
Gene embedding count: 5307
TF embedding dim: torch.Size([5313])
Gene embedding dim: torch.Size([5313])


In [8]:
class InteractionMLP(nn.Module):
    def __init__(self, tf_embed_cache, gene_embed_cache, hidden_dim=1024, num_classes=3):
        super().__init__()

        self.tf_embed_cache = tf_embed_cache
        self.gene_embed_cache = gene_embed_cache

        tf_dim = next(iter(tf_embed_cache.values())).shape[0]
        gene_dim = next(iter(gene_embed_cache.values())).shape[0]

        in_dim = tf_dim + gene_dim   # 1280 + 1280

        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),

            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),

            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),

            nn.Linear(hidden_dim // 2, 128),
            nn.ReLU(),

            nn.Linear(128, num_classes)   # logits for 3 classes
        )

    def forward(self, batch_x):
        """
        batch_x = list of dicts:
        [
           {"tf_name": ..., "gene_name": ...},
           ...
        ]
        """
        vectors = []

        for item in batch_x:
            tf_vec = self.tf_embed_cache[item["tf_name"]]
            gene_vec = self.gene_embed_cache[item["gene_name"]]
            pair_vec = torch.cat([tf_vec, gene_vec], dim=-1)
            vectors.append(pair_vec)

        X = torch.stack(vectors).to(device)
        return self.net(X)   # shape (batch, 3)


In [9]:
class InteractionMLP2(nn.Module):
    def __init__(self, tf_embed_cache, gene_embed_cache, hidden_dim=1024, num_classes=3):
        super().__init__()

        self.tf_cache = tf_embed_cache
        self.gene_cache = gene_embed_cache

        tf_dim = next(iter(tf_embed_cache.values())).shape[0]
        gene_dim = next(iter(gene_embed_cache.values())).shape[0]

        in_dim = tf_dim + gene_dim + tf_dim  # concat + product

        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),

            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),

            nn.Linear(hidden_dim, hidden_dim//2),
            nn.ReLU(),

            nn.Linear(hidden_dim//2, 128),
            nn.ReLU(),

            nn.Linear(128, num_classes)
        )

    def forward(self, batch_x):
        vecs = []
        for item in batch_x:
            tf = self.tf_cache[item["tf_name"]]
            gene = self.gene_cache[item["gene_name"]]
            inter = tf * gene  # IMPORTANT
            vecs.append(torch.cat([tf, gene, inter], dim=-1))
        x = torch.stack(vecs).to(device)
        return self.net(x)


In [11]:
loss_fn = nn.CrossEntropyLoss()

def train_one_epoch(model, loader, optimizer):
    model.train()
    total_loss, total_correct, total_samples = 0, 0, 0

    for batch_x, batch_y in loader:
        batch_y = batch_y.to(device)

        logits = model(batch_x)  # shape (B, 3)
        loss = loss_fn(logits, batch_y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        preds = logits.argmax(dim=1)
        total_loss += loss.item() * len(batch_y)
        total_correct += (preds == batch_y).sum().item()
        total_samples += len(batch_y)

    return total_loss/total_samples, total_correct/total_samples


In [12]:
import torch.nn as nn
from sklearn.metrics import f1_score, classification_report

loss_fn = nn.CrossEntropyLoss()

@torch.no_grad()
def eval_model(model, loader):
    model.eval()
    total_loss, total_correct, total_samples = 0, 0, 0

    all_preds = []
    all_labels = []

    for batch_x, batch_y in loader:
        batch_y = batch_y.to(device)

        logits = model(batch_x)
        loss = loss_fn(logits, batch_y)

        preds = logits.argmax(dim=1)

        total_loss += loss.item() * len(batch_y)
        total_correct += (preds == batch_y).sum().item()
        total_samples += len(batch_y)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(batch_y.cpu().numpy())

    avg_loss = total_loss / total_samples
    accuracy = total_correct / total_samples
    macro_f1 = f1_score(all_labels, all_preds, average="macro")

    return avg_loss, accuracy, macro_f1, all_labels, all_preds


In [13]:
model = InteractionMLP(
    tf_embed_cache=tf_embed_cache,
    gene_embed_cache=gene_embed_cache
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)


In [14]:
for epoch in range(1, 31):
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer)
    test_loss, test_acc, test_f1, _, _ = eval_model(model, test_loader)

    print(f"Epoch {epoch:02d} | "
          f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | "
          f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}, Test F1: {test_f1:.4f}")


Epoch 01 | Train Loss: 1.0792, Train Acc: 0.3983 | Test Loss: 1.0753, Test Acc: 0.4092, Test F1: 0.3083
Epoch 02 | Train Loss: 1.0748, Train Acc: 0.3972 | Test Loss: 1.0743, Test Acc: 0.3994, Test F1: 0.1903
Epoch 03 | Train Loss: 1.0711, Train Acc: 0.4002 | Test Loss: 1.0677, Test Acc: 0.4016, Test F1: 0.1969
Epoch 04 | Train Loss: 1.0636, Train Acc: 0.4129 | Test Loss: 1.0558, Test Acc: 0.4334, Test F1: 0.3464
Epoch 05 | Train Loss: 1.0454, Train Acc: 0.4370 | Test Loss: 1.0251, Test Acc: 0.4645, Test F1: 0.4266
Epoch 06 | Train Loss: 1.0278, Train Acc: 0.4567 | Test Loss: 1.0049, Test Acc: 0.4674, Test F1: 0.4281
Epoch 07 | Train Loss: 1.0268, Train Acc: 0.4496 | Test Loss: 1.0160, Test Acc: 0.4555, Test F1: 0.4047
Epoch 08 | Train Loss: 1.0130, Train Acc: 0.4647 | Test Loss: 1.0123, Test Acc: 0.4928, Test F1: 0.4073
Epoch 09 | Train Loss: 1.0051, Train Acc: 0.4749 | Test Loss: 0.9877, Test Acc: 0.4993, Test F1: 0.4549
Epoch 10 | Train Loss: 0.9970, Train Acc: 0.4831 | Test Loss: 0.

In [16]:
test_loss, test_acc, test_f1, y_true, y_pred = eval_model(model, test_loader)

print("Final Test Accuracy:", test_acc)
print("Final Test Macro F1:", test_f1)
print("\nClassification Report:")
print(classification_report(y_true, y_pred, digits=4))


Final Test Accuracy: 0.5289435600578871
Final Test Macro F1: 0.44182153567152554

Classification Report:
              precision    recall  f1-score   support

           0     0.6224    0.0928    0.1616       657
           1     0.5120    0.8285    0.6329      1003
           2     0.5465    0.5163    0.5310      1104

    accuracy                         0.5289      2764
   macro avg     0.5603    0.4792    0.4418      2764
weighted avg     0.5520    0.5289    0.4802      2764

