<a href="https://colab.research.google.com/github/navidh86/perturbseq-10701/blob/master/nt_classify.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# ONLY FOR COLAB
!git clone https://github.com/navidh86/perturbseq-10701.git
%cd ./perturbseq-10701
!pip install fastparquet tqdm


Cloning into 'perturbseq-10701'...
remote: Enumerating objects: 141, done.[K
remote: Counting objects: 100% (20/20), done.[K
remote: Compressing objects: 100% (14/14), done.[K
remote: Total 141 (delta 10), reused 14 (delta 6), pack-reused 121 (from 2)[K
Receiving objects: 100% (141/141), 252.09 MiB | 22.77 MiB/s, done.
Resolving deltas: 100% (56/56), done.
Updating files: 100% (43/43), done.
/content/perturbseq-10701
Collecting fastparquet
  Downloading fastparquet-2024.11.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.2 kB)
Downloading fastparquet-2024.11.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m20.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: fastparquet
Successfully installed fastparquet-2024.11.0


In [2]:

!pip install --upgrade git+https://github.com/huggingface/transformers.git


Collecting git+https://github.com/huggingface/transformers.git
  Cloning https://github.com/huggingface/transformers.git to /tmp/pip-req-build-g5ry4yxe
  Running command git clone --filter=blob:none --quiet https://github.com/huggingface/transformers.git /tmp/pip-req-build-g5ry4yxe
  Resolved https://github.com/huggingface/transformers.git to commit bc7a268fed343ab22446ec86115cf2727b38a5eb
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting huggingface-hub<2.0,>=1.0.0 (from transformers==5.0.0.dev0)
  Downloading huggingface_hub-1.1.7-py3-none-any.whl.metadata (13 kB)
Downloading huggingface_hub-1.1.7-py3-none-any.whl (516 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m516.2/516.2 kB[0m [31m8.5 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: transformers
  Building wheel for transformers (pyproject.toml) ... 

In [3]:
import pandas as pd
import numpy as np
import pickle
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForMaskedLM

import os
import pickle
from tqdm import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"

In [10]:
from reference_data_classification import (
    PairPerturbSeqDataset,
    perturbseq_collate_2,
    get_dataloader
)

train_loader = get_dataloader(
    parquet_path="tf_gene_expression_labeled.parquet",
    tf_sequences_path="tf_sequences.pkl",
    gene_sequences_path="gene_sequences_4000bp.pkl",
    batch_size=32,
    type="train",
    majority_fraction=0.01
)

test_loader = get_dataloader(
    parquet_path="tf_gene_expression_labeled.parquet",
    tf_sequences_path="tf_sequences.pkl",
    gene_sequences_path="gene_sequences_4000bp.pkl",
    batch_size=32,
    type="test",
    majority_fraction=0.01
)

print("Train size:", len(train_loader.dataset))
print("Test size: ", len(test_loader.dataset))

Train size: 23427
Test size:  5858


In [15]:
import pickle
import torch

# Load cached embeddings
tf_embed_cache = pickle.load(open("./embeds/tf_cls.pkl", "rb"))
gene_embed_cache = pickle.load(open("./embeds/gn_cls.pkl", "rb"))

# Convert all embeddings to torch tensors (if stored as numpy)
for k in tf_embed_cache:
    if not isinstance(tf_embed_cache[k], torch.Tensor):
        tf_embed_cache[k] = torch.tensor(tf_embed_cache[k], dtype=torch.float32)

for k in gene_embed_cache:
    if not isinstance(gene_embed_cache[k], torch.Tensor):
        gene_embed_cache[k] = torch.tensor(gene_embed_cache[k], dtype=torch.float32)

# Inspect shapes
first_tf = next(iter(tf_embed_cache.values()))
first_gene = next(iter(gene_embed_cache.values()))

print("TF embedding count:", len(tf_embed_cache))
print("Gene embedding count:", len(gene_embed_cache))
print("TF embedding dim:", first_tf.shape)
print("Gene embedding dim:", first_gene.shape)


TF embedding count: 223
Gene embedding count: 5307
TF embedding dim: torch.Size([1280])
Gene embedding dim: torch.Size([1280])


In [24]:
class InteractionMLP(nn.Module):
    def __init__(self, tf_embed_cache, gene_embed_cache, hidden_dim=1024, num_classes=3):
        super().__init__()

        self.tf_embed_cache = tf_embed_cache
        self.gene_embed_cache = gene_embed_cache

        tf_dim = next(iter(tf_embed_cache.values())).shape[0]
        gene_dim = next(iter(gene_embed_cache.values())).shape[0]

        in_dim = tf_dim + gene_dim   # 1280 + 1280

        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),

            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),

            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),

            nn.Linear(hidden_dim // 2, 128),
            nn.ReLU(),

            nn.Linear(128, num_classes)   # logits for 3 classes
        )

    def forward(self, batch_x):
        """
        batch_x = list of dicts:
        [
           {"tf_name": ..., "gene_name": ...},
           ...
        ]
        """
        vectors = []

        for item in batch_x:
            tf_vec = self.tf_embed_cache[item["tf_name"]]
            gene_vec = self.gene_embed_cache[item["gene_name"]]
            pair_vec = torch.cat([tf_vec, gene_vec], dim=-1)
            vectors.append(pair_vec)

        X = torch.stack(vectors).to(device)
        return self.net(X)   # shape (batch, 3)


In [None]:
class InteractionMLP2(nn.Module):
    def __init__(self, tf_embed_cache, gene_embed_cache, hidden_dim=1024, num_classes=3):
        super().__init__()

        self.tf_cache = tf_embed_cache
        self.gene_cache = gene_embed_cache

        tf_dim = next(iter(tf_embed_cache.values())).shape[0]
        gene_dim = next(iter(gene_embed_cache.values())).shape[0]

        in_dim = tf_dim + gene_dim + tf_dim  # concat + product

        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),

            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),

            nn.Linear(hidden_dim, hidden_dim//2),
            nn.ReLU(),

            nn.Linear(hidden_dim//2, 128),
            nn.ReLU(),

            nn.Linear(128, num_classes)
        )

    def forward(self, batch_x):
        vecs = []
        for item in batch_x:
            tf = self.tf_cache[item["tf_name"]]
            gene = self.gene_cache[item["gene_name"]]
            inter = tf * gene  # IMPORTANT
            vecs.append(torch.cat([tf, gene, inter], dim=-1))
        x = torch.stack(vecs).to(device)
        return self.net(x)


In [25]:
loss_fn = nn.CrossEntropyLoss()

def train_one_epoch(model, loader, optimizer):
    model.train()
    total_loss, total_correct, total_samples = 0, 0, 0

    for batch_x, batch_y in loader:
        batch_y = batch_y.to(device)

        logits = model(batch_x)  # shape (B, 3)
        loss = loss_fn(logits, batch_y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        preds = logits.argmax(dim=1)
        total_loss += loss.item() * len(batch_y)
        total_correct += (preds == batch_y).sum().item()
        total_samples += len(batch_y)

    return total_loss/total_samples, total_correct/total_samples


In [26]:
import torch.nn as nn
from sklearn.metrics import f1_score, classification_report

loss_fn = nn.CrossEntropyLoss()

@torch.no_grad()
def eval_model(model, loader):
    model.eval()
    total_loss, total_correct, total_samples = 0, 0, 0

    all_preds = []
    all_labels = []

    for batch_x, batch_y in loader:
        batch_y = batch_y.to(device)

        logits = model(batch_x)
        loss = loss_fn(logits, batch_y)

        preds = logits.argmax(dim=1)

        total_loss += loss.item() * len(batch_y)
        total_correct += (preds == batch_y).sum().item()
        total_samples += len(batch_y)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(batch_y.cpu().numpy())

    avg_loss = total_loss / total_samples
    accuracy = total_correct / total_samples
    macro_f1 = f1_score(all_labels, all_preds, average="macro")

    return avg_loss, accuracy, macro_f1, all_labels, all_preds


In [27]:
model = InteractionMLP2(
    tf_embed_cache=tf_embed_cache,
    gene_embed_cache=gene_embed_cache
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)


In [28]:
for epoch in range(1, 11):
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer)
    test_loss, test_acc, test_f1, _, _ = eval_model(model, test_loader)

    print(f"Epoch {epoch:02d} | "
          f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | "
          f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}, Test F1: {test_f1:.4f}")


Epoch 01 | Train Loss: 1.0910, Train Acc: 0.3945 | Test Loss: 1.0885, Test Acc: 0.3981, Test F1: 0.1898
Epoch 02 | Train Loss: 1.0888, Train Acc: 0.3981 | Test Loss: 1.0869, Test Acc: 0.3981, Test F1: 0.1898
Epoch 03 | Train Loss: 1.0860, Train Acc: 0.3981 | Test Loss: 1.0894, Test Acc: 0.3981, Test F1: 0.1898
Epoch 04 | Train Loss: 1.0733, Train Acc: 0.4158 | Test Loss: 1.0697, Test Acc: 0.4133, Test F1: 0.2264
Epoch 05 | Train Loss: 1.0424, Train Acc: 0.4570 | Test Loss: 1.0315, Test Acc: 0.4850, Test F1: 0.3682
Epoch 06 | Train Loss: 1.0361, Train Acc: 0.4465 | Test Loss: 1.0137, Test Acc: 0.4843, Test F1: 0.3804
Epoch 07 | Train Loss: 1.0206, Train Acc: 0.4568 | Test Loss: 1.0482, Test Acc: 0.4423, Test F1: 0.4095
Epoch 08 | Train Loss: 1.0141, Train Acc: 0.4569 | Test Loss: 0.9785, Test Acc: 0.5205, Test F1: 0.4852
Epoch 09 | Train Loss: 1.0018, Train Acc: 0.4728 | Test Loss: 0.9577, Test Acc: 0.5159, Test F1: 0.3947
Epoch 10 | Train Loss: 0.9985, Train Acc: 0.4712 | Test Loss: 0.

In [29]:
test_loss, test_acc, test_f1, y_true, y_pred = eval_model(model, test_loader)

print("Final Test Accuracy:", test_acc)
print("Final Test Macro F1:", test_f1)
print("\nClassification Report:")
print(classification_report(y_true, y_pred, digits=4))


Final Test Accuracy: 0.49231819733697507
Final Test Macro F1: 0.38852434226512056

Classification Report:
              precision    recall  f1-score   support

           0     0.0000    0.0000    0.0000      1654
           1     0.7096    0.5952    0.6474      2332
           2     0.3834    0.7991    0.5182      1872

    accuracy                         0.4923      5858
   macro avg     0.3643    0.4648    0.3885      5858
weighted avg     0.4050    0.4923    0.4233      5858



  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
