<a href="https://colab.research.google.com/github/navidh86/perturbseq-10701/blob/master/baseline_classify.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# ONLY FOR COLAB
!git clone https://github.com/navidh86/perturbseq-10701.git
%cd ./perturbseq-10701
!pip install fastparquet tqdm

Cloning into 'perturbseq-10701'...
remote: Enumerating objects: 201, done.[K
remote: Counting objects: 100% (80/80), done.[K
remote: Compressing objects: 100% (63/63), done.[K
remote: Total 201 (delta 39), reused 45 (delta 16), pack-reused 121 (from 2)[K
Receiving objects: 100% (201/201), 260.57 MiB | 12.74 MiB/s, done.
Resolving deltas: 100% (85/85), done.
Updating files: 100% (53/53), done.
/content/perturbseq-10701
Collecting fastparquet
  Downloading fastparquet-2024.11.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.2 kB)
Downloading fastparquet-2024.11.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m65.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: fastparquet
Successfully installed fastparquet-2024.11.0


In [2]:
# Imports and device
import os
import pickle
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from tqdm import tqdm

from data.reference_data_classification import get_dataloader

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

Device: cuda


In [12]:
# # Create dataloaders (point to data/ paths explicitly)
# train_loader = get_dataloader(
#     parquet_path='data/tf_gene_expression_labeled_v2.parquet',
#     tf_sequences_path='data/tf_sequences.pkl',
#     gene_sequences_path='data/gene_sequences_4000bp.pkl',
#     batch_size=128,
#     type='train',
#     majority_fraction=0.005
# )
# test_loader = get_dataloader(
#     parquet_path='data/tf_gene_expression_labeled_v2.parquet',
#     tf_sequences_path='data/tf_sequences.pkl',
#     gene_sequences_path='data/gene_sequences_4000bp.pkl',
#     batch_size=256,
#     type='test',
#     majority_fraction=0.005
# )

# print('Train size:', len(train_loader.dataset))
# print('Test size :', len(test_loader.dataset))

# ========================
# Load Dataloaders
# ========================
# Create dataloaders (train / val / test)
train_loader = get_dataloader(
    parquet_path='data/tf_gene_expression_labeled_v2.parquet',
    tf_sequences_path='data/tf_sequences.pkl',
    gene_sequences_path='data/gene_sequences_4000bp.pkl',
    batch_size=128,
    type='train',
    majority_fraction=0.005,
)

val_loader = get_dataloader(
    parquet_path='data/tf_gene_expression_labeled_v2.parquet',
    tf_sequences_path='data/tf_sequences.pkl',
    gene_sequences_path='data/gene_sequences_4000bp.pkl',
    batch_size=256,
    type='val',
    majority_fraction=0.005,
)

test_loader = get_dataloader(
    parquet_path='data/tf_gene_expression_labeled_v2.parquet',
    tf_sequences_path='data/tf_sequences.pkl',
    gene_sequences_path='data/gene_sequences_4000bp.pkl',
    batch_size=256,
    type='test',
    majority_fraction=0.005,
)

print('Train size:', len(train_loader.dataset))
print('Val size  :', len(val_loader.dataset))
print('Test size :', len(test_loader.dataset))


Train size: 10845
Val size  : 2324
Test size : 2325


In [18]:
# Build TF/Gene ID maps using ALL (train + val + test)
train_ds = train_loader.dataset
val_ds = val_loader.dataset
test_ds = test_loader.dataset

combined_df = pd.concat([train_ds.df, val_ds.df, test_ds.df]).reset_index(drop=True)

# Unique names
tf_names = combined_df['tf_name'].unique().tolist()
gene_names = combined_df['gene_name'].unique().tolist()

# Mappings
tf_to_id = {n: i for i, n in enumerate(tf_names)}
gene_to_id = {n: i for i, n in enumerate(gene_names)}

num_tfs = len(tf_to_id)
num_genes = len(gene_to_id)

print('Unique TFs (all splits):', num_tfs)
print('Unique Genes (all splits):', num_genes)


Unique TFs (all splits): 223
Unique Genes (all splits): 4539


In [19]:
# Load cached embeddings and prepare name lists
import pickle

# Load cached TF/gene embeddings produced by the embedding notebook
tf_embed_cache = pickle.load(open('embeds/tf_cls.pkl', 'rb'))
gene_embed_cache = pickle.load(open('embeds/gn_cls.pkl', 'rb'))

# Convert any numpy arrays to torch tensors
for k in list(tf_embed_cache.keys()):
    v = tf_embed_cache[k]
    if not isinstance(v, torch.Tensor):
        tf_embed_cache[k] = torch.tensor(v, dtype=torch.float32)
for k in list(gene_embed_cache.keys()):
    v = gene_embed_cache[k]
    if not isinstance(v, torch.Tensor):
        gene_embed_cache[k] = torch.tensor(v, dtype=torch.float32)

# Expose name lists and counts (from the caches to ensure consistent mapping)
tf_names = list(tf_embed_cache.keys())
gene_names = list(gene_embed_cache.keys())
num_tfs = len(tf_names)
num_genes = len(gene_names)
# number of classes is taken from the training split
num_classes = len(train_loader.dataset.df['expression_label'].unique())

print('TFs in cache:', num_tfs)
print('Genes in cache:', num_genes)
print('Num classes:', num_classes)


TFs in cache: 223
Genes in cache: 5307
Num classes: 3


In [20]:
# ID-only model definition
class TFGeneIDModel(nn.Module):
    def __init__(self, num_tfs, num_genes, emb_dim=64, hidden_dim=256, num_classes=3):
        super().__init__()
        self.tf_emb = nn.Embedding(num_tfs, emb_dim)
        self.gene_emb = nn.Embedding(num_genes, emb_dim)
        self.mlp = nn.Sequential(
            nn.Linear(2*emb_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, hidden_dim//2),
            nn.ReLU(),
            nn.Linear(hidden_dim//2, num_classes)
        )

    def forward(self, tf_ids, gene_ids):
        t = self.tf_emb(tf_ids)
        g = self.gene_emb(gene_ids)
        h = torch.cat([t, g], dim=-1)
        return self.mlp(h)

# Instantiate model
model = TFGeneIDModel(num_tfs=num_tfs, num_genes=num_genes, emb_dim=64, hidden_dim=256, num_classes=num_classes).to(device)
print(model)

TFGeneIDModel(
  (tf_emb): Embedding(223, 64)
  (gene_emb): Embedding(5307, 64)
  (mlp): Sequential(
    (0): Linear(in_features=128, out_features=256, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.2, inplace=False)
    (3): Linear(in_features=256, out_features=128, bias=True)
    (4): ReLU()
    (5): Linear(in_features=128, out_features=3, bias=True)
  )
)


In [21]:
# Training / evaluation helpers
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

def prepare_id_batch(batch_x, device=device):
    tf_ids = torch.tensor([tf_to_id[item['tf_name']] for item in batch_x], dtype=torch.long, device=device)
    gene_ids = torch.tensor([gene_to_id[item['gene_name']] for item in batch_x], dtype=torch.long, device=device)
    return tf_ids, gene_ids

def train_one_epoch(model, loader, optimizer):
    model.train()
    total_loss = 0.0
    total_correct = 0
    total = 0
    pbar = tqdm(loader)
    for batch_x, batch_y in pbar:
        batch_y = batch_y.to(device)
        tf_ids, gene_ids = prepare_id_batch(batch_x)

        logits = model(tf_ids, gene_ids)
        loss = loss_fn(logits, batch_y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        preds = logits.argmax(dim=1)
        total_correct += (preds == batch_y).sum().item()
        total += len(batch_y)
        total_loss += loss.item() * len(batch_y)
        pbar.set_postfix({'loss': total_loss/total, 'acc': total_correct/total})

    return total_loss / total, total_correct / total

@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    total_loss = 0.0
    total_correct = 0
    total = 0
    for batch_x, batch_y in loader:
        batch_y = batch_y.to(device)
        tf_ids, gene_ids = prepare_id_batch(batch_x)
        logits = model(tf_ids, gene_ids)
        loss = loss_fn(logits, batch_y)
        preds = logits.argmax(dim=1)
        total_correct += (preds == batch_y).sum().item()
        total += len(batch_y)
        total_loss += loss.item() * len(batch_y)
    return total_loss/total, total_correct/total

In [31]:
# # Train for a few epochs
# num_epochs = 30
# for epoch in range(1, num_epochs+1):
#     train_loss, train_acc = train_one_epoch(model, train_loader, optimizer)
#     val_loss, val_acc = evaluate(model, test_loader)
#     print(f'Epoch {epoch:02d} | Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')

# # Save model
# os.makedirs('models', exist_ok=True)
# torch.save({
#     'model_state_dict': model.state_dict(),
#     'tf_to_id': tf_to_id,
#     'gene_to_id': gene_to_id
# }, 'models/tf_gene_id_baseline.pt')
# print('Saved model to models/tf_gene_id_baseline.pt')

# ================================
# Training Loop WITH VALIDATION
# ================================

num_epochs = 30
best_val_f1 = -1
best_state = None

for epoch in range(1, num_epochs+1):
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer)
    val_loss, val_acc = evaluate(model, val_loader)

    # compute F1 on val set
    _, _, val_f1, _, _ = eval_model_ntstyle(model, val_loader)

    print(f'Epoch {epoch:02d} | '
          f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | '
          f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, Val F1: {val_f1:.4f}')

    # SAVE BEST MODEL (based on val F1)
    if val_f1 > best_val_f1:
        best_val_f1 = val_f1
        best_state = model.state_dict()
        torch.save(best_state, 'models/best_model.pt')
        print(f'***Saved new best model at epoch {epoch} (Val F1={val_f1:.4f})')

print("\nTraining complete.")
print("Best Validation F1:", best_val_f1)


100%|██████████| 85/85 [00:01<00:00, 79.16it/s, loss=0.0937, acc=0.966]


Epoch 01 | Train Loss: 0.0937, Train Acc: 0.9659 | Val Loss: 1.3444, Val Acc: 0.6876, Val F1: 0.6914
***Saved new best model at epoch 1 (Val F1=0.6914)


100%|██████████| 85/85 [00:01<00:00, 83.05it/s, loss=0.0914, acc=0.967]


Epoch 02 | Train Loss: 0.0914, Train Acc: 0.9670 | Val Loss: 1.3744, Val Acc: 0.6992, Val F1: 0.7034
***Saved new best model at epoch 2 (Val F1=0.7034)


100%|██████████| 85/85 [00:00<00:00, 92.20it/s, loss=0.0854, acc=0.97]


Epoch 03 | Train Loss: 0.0854, Train Acc: 0.9698 | Val Loss: 1.4290, Val Acc: 0.6872, Val F1: 0.6926


100%|██████████| 85/85 [00:00<00:00, 94.82it/s, loss=0.0867, acc=0.968]


Epoch 04 | Train Loss: 0.0867, Train Acc: 0.9678 | Val Loss: 1.4030, Val Acc: 0.6898, Val F1: 0.6931


100%|██████████| 85/85 [00:00<00:00, 90.44it/s, loss=0.075, acc=0.974]


Epoch 05 | Train Loss: 0.0750, Train Acc: 0.9736 | Val Loss: 1.4610, Val Acc: 0.6928, Val F1: 0.6970


100%|██████████| 85/85 [00:00<00:00, 85.47it/s, loss=0.0652, acc=0.978]


Epoch 06 | Train Loss: 0.0652, Train Acc: 0.9779 | Val Loss: 1.5627, Val Acc: 0.6833, Val F1: 0.6891


100%|██████████| 85/85 [00:00<00:00, 85.10it/s, loss=0.0672, acc=0.976]


Epoch 07 | Train Loss: 0.0672, Train Acc: 0.9762 | Val Loss: 1.5553, Val Acc: 0.6971, Val F1: 0.7009


100%|██████████| 85/85 [00:00<00:00, 91.26it/s, loss=0.0686, acc=0.976]


Epoch 08 | Train Loss: 0.0686, Train Acc: 0.9765 | Val Loss: 1.5308, Val Acc: 0.6945, Val F1: 0.6977


100%|██████████| 85/85 [00:01<00:00, 77.43it/s, loss=0.0531, acc=0.98]


Epoch 09 | Train Loss: 0.0531, Train Acc: 0.9804 | Val Loss: 1.6374, Val Acc: 0.6971, Val F1: 0.7009


100%|██████████| 85/85 [00:01<00:00, 84.34it/s, loss=0.057, acc=0.979]


Epoch 10 | Train Loss: 0.0570, Train Acc: 0.9793 | Val Loss: 1.6334, Val Acc: 0.6975, Val F1: 0.7018


100%|██████████| 85/85 [00:00<00:00, 91.76it/s, loss=0.0508, acc=0.982]


Epoch 11 | Train Loss: 0.0508, Train Acc: 0.9817 | Val Loss: 1.6559, Val Acc: 0.6988, Val F1: 0.7025


100%|██████████| 85/85 [00:01<00:00, 80.38it/s, loss=0.0497, acc=0.983]


Epoch 12 | Train Loss: 0.0497, Train Acc: 0.9830 | Val Loss: 1.7597, Val Acc: 0.6984, Val F1: 0.7031


100%|██████████| 85/85 [00:01<00:00, 77.39it/s, loss=0.0468, acc=0.983]


Epoch 13 | Train Loss: 0.0468, Train Acc: 0.9828 | Val Loss: 1.7381, Val Acc: 0.6932, Val F1: 0.6971


100%|██████████| 85/85 [00:00<00:00, 88.87it/s, loss=0.0425, acc=0.986]


Epoch 14 | Train Loss: 0.0425, Train Acc: 0.9858 | Val Loss: 1.8075, Val Acc: 0.6988, Val F1: 0.7034
***Saved new best model at epoch 14 (Val F1=0.7034)


100%|██████████| 85/85 [00:00<00:00, 93.28it/s, loss=0.0386, acc=0.987]


Epoch 15 | Train Loss: 0.0386, Train Acc: 0.9867 | Val Loss: 1.7882, Val Acc: 0.6889, Val F1: 0.6935


100%|██████████| 85/85 [00:01<00:00, 84.90it/s, loss=0.0429, acc=0.984]


Epoch 16 | Train Loss: 0.0429, Train Acc: 0.9839 | Val Loss: 1.8584, Val Acc: 0.6954, Val F1: 0.7003


100%|██████████| 85/85 [00:00<00:00, 90.72it/s, loss=0.0427, acc=0.985]


Epoch 17 | Train Loss: 0.0427, Train Acc: 0.9845 | Val Loss: 1.8409, Val Acc: 0.6975, Val F1: 0.7010


100%|██████████| 85/85 [00:00<00:00, 94.02it/s, loss=0.0408, acc=0.985]


Epoch 18 | Train Loss: 0.0408, Train Acc: 0.9853 | Val Loss: 1.8396, Val Acc: 0.6971, Val F1: 0.7020


100%|██████████| 85/85 [00:01<00:00, 82.95it/s, loss=0.0381, acc=0.987]


Epoch 19 | Train Loss: 0.0381, Train Acc: 0.9874 | Val Loss: 1.8223, Val Acc: 0.6992, Val F1: 0.7037
***Saved new best model at epoch 19 (Val F1=0.7037)


100%|██████████| 85/85 [00:00<00:00, 85.70it/s, loss=0.0339, acc=0.989]


Epoch 20 | Train Loss: 0.0339, Train Acc: 0.9895 | Val Loss: 1.8756, Val Acc: 0.7014, Val F1: 0.7058
***Saved new best model at epoch 20 (Val F1=0.7058)


100%|██████████| 85/85 [00:00<00:00, 87.84it/s, loss=0.0342, acc=0.988]


Epoch 21 | Train Loss: 0.0342, Train Acc: 0.9881 | Val Loss: 1.9143, Val Acc: 0.7044, Val F1: 0.7080
***Saved new best model at epoch 21 (Val F1=0.7080)


100%|██████████| 85/85 [00:00<00:00, 92.93it/s, loss=0.0281, acc=0.99]


Epoch 22 | Train Loss: 0.0281, Train Acc: 0.9901 | Val Loss: 1.9637, Val Acc: 0.7057, Val F1: 0.7103
***Saved new best model at epoch 22 (Val F1=0.7103)


100%|██████████| 85/85 [00:00<00:00, 92.99it/s, loss=0.0275, acc=0.991]


Epoch 23 | Train Loss: 0.0275, Train Acc: 0.9914 | Val Loss: 2.0147, Val Acc: 0.6988, Val F1: 0.7032


100%|██████████| 85/85 [00:00<00:00, 93.41it/s, loss=0.0317, acc=0.989]


Epoch 24 | Train Loss: 0.0317, Train Acc: 0.9890 | Val Loss: 1.9795, Val Acc: 0.7065, Val F1: 0.7100


100%|██████████| 85/85 [00:00<00:00, 87.47it/s, loss=0.0339, acc=0.988]


Epoch 25 | Train Loss: 0.0339, Train Acc: 0.9879 | Val Loss: 1.9820, Val Acc: 0.7052, Val F1: 0.7099


100%|██████████| 85/85 [00:00<00:00, 87.69it/s, loss=0.0287, acc=0.991]


Epoch 26 | Train Loss: 0.0287, Train Acc: 0.9907 | Val Loss: 2.0011, Val Acc: 0.7044, Val F1: 0.7086


100%|██████████| 85/85 [00:00<00:00, 87.80it/s, loss=0.0253, acc=0.992]


Epoch 27 | Train Loss: 0.0253, Train Acc: 0.9916 | Val Loss: 2.0635, Val Acc: 0.6979, Val F1: 0.7035


100%|██████████| 85/85 [00:01<00:00, 74.76it/s, loss=0.0301, acc=0.989]


Epoch 28 | Train Loss: 0.0301, Train Acc: 0.9893 | Val Loss: 2.0641, Val Acc: 0.7040, Val F1: 0.7077


100%|██████████| 85/85 [00:01<00:00, 82.21it/s, loss=0.0242, acc=0.992]


Epoch 29 | Train Loss: 0.0242, Train Acc: 0.9923 | Val Loss: 2.0923, Val Acc: 0.6966, Val F1: 0.7014


100%|██████████| 85/85 [00:00<00:00, 89.33it/s, loss=0.0228, acc=0.992]


Epoch 30 | Train Loss: 0.0228, Train Acc: 0.9920 | Val Loss: 2.1581, Val Acc: 0.7009, Val F1: 0.7055

Training complete.
Best Validation F1: 0.7102865762778435


In [32]:
# Final evaluation (nt_classify-style): compute loss, accuracy, macro-F1, and classification report
import torch.nn as nn
from sklearn.metrics import f1_score, classification_report
import json

loss_fn = nn.CrossEntropyLoss()

@torch.no_grad()
def eval_model_ntstyle(model, loader):
    model.eval()
    total_loss, total_correct, total_samples = 0, 0, 0

    all_preds = []
    all_labels = []

    for batch_x, batch_y in loader:
        batch_y = batch_y.to(device)

        # support both cache-based and id-based models
        try:
            logits = model(batch_x)
        except TypeError:
            tf_ids, gene_ids = prepare_id_batch(batch_x, device=device)
            logits = model(tf_ids, gene_ids)

        loss = loss_fn(logits, batch_y)

        preds = logits.argmax(dim=1)

        total_loss += loss.item() * len(batch_y)
        total_correct += (preds == batch_y).sum().item()
        total_samples += len(batch_y)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(batch_y.cpu().numpy())

    avg_loss = total_loss / total_samples
    accuracy = total_correct / total_samples
    macro_f1 = f1_score(all_labels, all_preds, average="macro")

    return avg_loss, accuracy, macro_f1, all_labels, all_preds


In [33]:
# =======================================
# Final evaluation on TEST using best model
# =======================================

# load best checkpoint
model.load_state_dict(torch.load('models/best_model.pt'))

test_loss, test_acc, test_f1, y_true, y_pred = eval_model_ntstyle(model, test_loader)

print("\n=== FINAL TEST RESULTS (BEST CHECKPOINT) ===")
print("Test Accuracy:", test_acc)
print("Test Macro F1:", test_f1)
print("\nClassification Report:")
print(classification_report(y_true, y_pred, digits=4))

# Save JSON
os.makedirs('results', exist_ok=True)
metrics = {
    'test_loss': float(test_loss),
    'accuracy': float(test_acc),
    'macro_f1': float(test_f1),
    'classification_report': classification_report(y_true, y_pred, digits=4),
}
with open('results/best_model_test_metrics.json', 'w') as f:
    json.dump(metrics, f, indent=2)

print('\nSaved metrics to results/best_model_test_metrics.json')



=== FINAL TEST RESULTS (BEST CHECKPOINT) ===
Test Accuracy: 0.7161290322580646
Test Macro F1: 0.7203459767234563

Classification Report:
              precision    recall  f1-score   support

           0     0.7152    0.6984    0.7067       766
           1     0.8232    0.7579    0.7892       731
           2     0.6372    0.6957    0.6651       828

    accuracy                         0.7161      2325
   macro avg     0.7252    0.7173    0.7203      2325
weighted avg     0.7214    0.7161    0.7178      2325


Saved metrics to results/best_model_test_metrics.json
