<a href="https://colab.research.google.com/github/navidh86/perturbseq-10701/blob/master/baseline_classify.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [9]:
# # ONLY FOR COLAB
# !git clone https://github.com/navidh86/perturbseq-10701.git
# %cd ./perturbseq-10701
# !pip install fastparquet tqdm

In [10]:
# Imports and device
import os
import pickle
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from tqdm import tqdm

from data.reference_data_classification import get_dataloader

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

Device: cuda


In [11]:
# # Create dataloaders (point to data/ paths explicitly)
# train_loader = get_dataloader(
#     parquet_path='data/tf_gene_expression_labeled_v2.parquet',
#     tf_sequences_path='data/tf_sequences.pkl',
#     gene_sequences_path='data/gene_sequences_4000bp.pkl',
#     batch_size=128,
#     type='train',
#     majority_fraction=0.005
# )
# test_loader = get_dataloader(
#     parquet_path='data/tf_gene_expression_labeled_v2.parquet',
#     tf_sequences_path='data/tf_sequences.pkl',
#     gene_sequences_path='data/gene_sequences_4000bp.pkl',
#     batch_size=256,
#     type='test',
#     majority_fraction=0.005
# )

# print('Train size:', len(train_loader.dataset))
# print('Test size :', len(test_loader.dataset))

# ========================
# Load Dataloaders
# ========================
# Create dataloaders (train / val / test)
train_loader = get_dataloader(
    parquet_path='data/tf_gene_expression_labeled_v2.parquet',
    tf_sequences_path='data/tf_sequences.pkl',
    gene_sequences_path='data/gene_sequences_4000bp.pkl',
    batch_size=128,
    type='train',
    majority_fraction=0.005,
)

val_loader = get_dataloader(
    parquet_path='data/tf_gene_expression_labeled_v2.parquet',
    tf_sequences_path='data/tf_sequences.pkl',
    gene_sequences_path='data/gene_sequences_4000bp.pkl',
    batch_size=256,
    type='val',
    majority_fraction=0.005,
)

test_loader = get_dataloader(
    parquet_path='data/tf_gene_expression_labeled_v2.parquet',
    tf_sequences_path='data/tf_sequences.pkl',
    gene_sequences_path='data/gene_sequences_4000bp.pkl',
    batch_size=256,
    type='test',
    majority_fraction=0.005,
)

print('Train size:', len(train_loader.dataset))
print('Val size  :', len(val_loader.dataset))
print('Test size :', len(test_loader.dataset))


Train size: 10845
Val size  : 2324
Test size : 2325


In [12]:
# Build TF/Gene ID maps using ALL (train + val + test)
train_ds = train_loader.dataset
val_ds = val_loader.dataset
test_ds = test_loader.dataset

combined_df = pd.concat([train_ds.df, val_ds.df, test_ds.df]).reset_index(drop=True)

# Unique names
tf_names = combined_df['tf_name'].unique().tolist()
gene_names = combined_df['gene_name'].unique().tolist()

# Mappings
tf_to_id = {n: i for i, n in enumerate(tf_names)}
gene_to_id = {n: i for i, n in enumerate(gene_names)}

num_tfs = len(tf_to_id)
num_genes = len(gene_to_id)

print('Unique TFs (all splits):', num_tfs)
print('Unique Genes (all splits):', num_genes)


Unique TFs (all splits): 223
Unique Genes (all splits): 4539


In [13]:
# Load cached embeddings and prepare name lists
import pickle

# Load cached TF/gene embeddings produced by the embedding notebook
tf_embed_cache = pickle.load(open('embeds/tf_cls.pkl', 'rb'))
gene_embed_cache = pickle.load(open('embeds/gn_cls.pkl', 'rb'))

# Convert any numpy arrays to torch tensors
for k in list(tf_embed_cache.keys()):
    v = tf_embed_cache[k]
    if not isinstance(v, torch.Tensor):
        tf_embed_cache[k] = torch.tensor(v, dtype=torch.float32)
for k in list(gene_embed_cache.keys()):
    v = gene_embed_cache[k]
    if not isinstance(v, torch.Tensor):
        gene_embed_cache[k] = torch.tensor(v, dtype=torch.float32)

# Expose name lists and counts (from the caches to ensure consistent mapping)
tf_names = list(tf_embed_cache.keys())
gene_names = list(gene_embed_cache.keys())
num_tfs = len(tf_names)
num_genes = len(gene_names)
# number of classes is taken from the training split
num_classes = len(train_loader.dataset.df['expression_label'].unique())

print('TFs in cache:', num_tfs)
print('Genes in cache:', num_genes)
print('Num classes:', num_classes)


TFs in cache: 223
Genes in cache: 5307
Num classes: 3


In [14]:
# ID-only model definition
class TFGeneIDModel(nn.Module):
    def __init__(self, num_tfs, num_genes, emb_dim=64, hidden_dim=256, num_classes=3):
        super().__init__()
        self.tf_emb = nn.Embedding(num_tfs, emb_dim)
        self.gene_emb = nn.Embedding(num_genes, emb_dim)
        self.mlp = nn.Sequential(
            nn.Linear(2*emb_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, hidden_dim//2),
            nn.ReLU(),
            nn.Linear(hidden_dim//2, num_classes)
        )

    def forward(self, tf_ids, gene_ids):
        t = self.tf_emb(tf_ids)
        g = self.gene_emb(gene_ids)
        h = torch.cat([t, g], dim=-1)
        return self.mlp(h)

# Instantiate model
model = TFGeneIDModel(num_tfs=num_tfs, num_genes=num_genes, emb_dim=64, hidden_dim=256, num_classes=num_classes).to(device)
print(model)

TFGeneIDModel(
  (tf_emb): Embedding(223, 64)
  (gene_emb): Embedding(5307, 64)
  (mlp): Sequential(
    (0): Linear(in_features=128, out_features=256, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.2, inplace=False)
    (3): Linear(in_features=256, out_features=128, bias=True)
    (4): ReLU()
    (5): Linear(in_features=128, out_features=3, bias=True)
  )
)


In [15]:
# Training / evaluation helpers
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

def prepare_id_batch(batch_x, device=device):
    tf_ids = torch.tensor([tf_to_id[item['tf_name']] for item in batch_x], dtype=torch.long, device=device)
    gene_ids = torch.tensor([gene_to_id[item['gene_name']] for item in batch_x], dtype=torch.long, device=device)
    return tf_ids, gene_ids

def train_one_epoch(model, loader, optimizer):
    model.train()
    total_loss = 0.0
    total_correct = 0
    total = 0
    pbar = tqdm(loader)
    for batch_x, batch_y in pbar:
        batch_y = batch_y.to(device)
        tf_ids, gene_ids = prepare_id_batch(batch_x)

        logits = model(tf_ids, gene_ids)
        loss = loss_fn(logits, batch_y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        preds = logits.argmax(dim=1)
        total_correct += (preds == batch_y).sum().item()
        total += len(batch_y)
        total_loss += loss.item() * len(batch_y)
        pbar.set_postfix({'loss': total_loss/total, 'acc': total_correct/total})

    return total_loss / total, total_correct / total

@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    total_loss = 0.0
    total_correct = 0
    total = 0
    for batch_x, batch_y in loader:
        batch_y = batch_y.to(device)
        tf_ids, gene_ids = prepare_id_batch(batch_x)
        logits = model(tf_ids, gene_ids)
        loss = loss_fn(logits, batch_y)
        preds = logits.argmax(dim=1)
        total_correct += (preds == batch_y).sum().item()
        total += len(batch_y)
        total_loss += loss.item() * len(batch_y)
    return total_loss/total, total_correct/total

In [16]:
# Final evaluation (nt_classify-style): compute loss, accuracy, macro-F1, and classification report
import torch.nn as nn
from sklearn.metrics import f1_score, classification_report
import json

loss_fn = nn.CrossEntropyLoss()

@torch.no_grad()
def eval_model_ntstyle(model, loader):
    model.eval()
    total_loss, total_correct, total_samples = 0, 0, 0

    all_preds = []
    all_labels = []

    for batch_x, batch_y in loader:
        batch_y = batch_y.to(device)

        # support both cache-based and id-based models
        try:
            logits = model(batch_x)
        except TypeError:
            tf_ids, gene_ids = prepare_id_batch(batch_x, device=device)
            logits = model(tf_ids, gene_ids)

        loss = loss_fn(logits, batch_y)

        preds = logits.argmax(dim=1)

        total_loss += loss.item() * len(batch_y)
        total_correct += (preds == batch_y).sum().item()
        total_samples += len(batch_y)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(batch_y.cpu().numpy())

    avg_loss = total_loss / total_samples
    accuracy = total_correct / total_samples
    macro_f1 = f1_score(all_labels, all_preds, average="macro")

    return avg_loss, accuracy, macro_f1, all_labels, all_preds


In [17]:
# # Train for a few epochs
# num_epochs = 30
# for epoch in range(1, num_epochs+1):
#     train_loss, train_acc = train_one_epoch(model, train_loader, optimizer)
#     val_loss, val_acc = evaluate(model, test_loader)
#     print(f'Epoch {epoch:02d} | Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')

# # Save model
# os.makedirs('models', exist_ok=True)
# torch.save({
#     'model_state_dict': model.state_dict(),
#     'tf_to_id': tf_to_id,
#     'gene_to_id': gene_to_id
# }, 'models/tf_gene_id_baseline.pt')
# print('Saved model to models/tf_gene_id_baseline.pt')

# ================================
# Training Loop WITH VALIDATION
# ================================

num_epochs = 30
best_val_f1 = -1
best_state = None

for epoch in range(1, num_epochs+1):
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer)
    val_loss, val_acc = evaluate(model, val_loader)

    # compute F1 on train set
    _, _, train_f1, _, _ = eval_model_ntstyle(model, train_loader)

    # compute F1 on val set
    _, _, val_f1, _, _ = eval_model_ntstyle(model, val_loader)

    print(f'Epoch {epoch:02d} | '
          f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} Train F1: {train_f1:.4f} | '
          f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, Val F1: {val_f1:.4f}')

    # SAVE BEST MODEL (based on val F1)
    if val_f1 > best_val_f1:
        best_val_f1 = val_f1
        best_state = model.state_dict()
        torch.save(best_state, 'models/best_model.pt')
        print(f'***Saved new best model at epoch {epoch} (Val F1={val_f1:.4f})')

print("\nTraining complete.")
print("Best Validation F1:", best_val_f1)


100%|██████████| 85/85 [00:01<00:00, 50.35it/s, loss=0.954, acc=0.519]


Epoch 01 | Train Loss: 0.9543, Train Acc: 0.5192 Train F1: 0.6334 | Val Loss: 0.8542, Val Acc: 0.5869, Val F1: 0.5793
***Saved new best model at epoch 1 (Val F1=0.5793)


100%|██████████| 85/85 [00:01<00:00, 55.98it/s, loss=0.784, acc=0.635]


Epoch 02 | Train Loss: 0.7841, Train Acc: 0.6349 Train F1: 0.6949 | Val Loss: 0.7824, Val Acc: 0.6205, Val F1: 0.6157
***Saved new best model at epoch 2 (Val F1=0.6157)


100%|██████████| 85/85 [00:01<00:00, 54.46it/s, loss=0.699, acc=0.686]


Epoch 03 | Train Loss: 0.6992, Train Acc: 0.6857 Train F1: 0.7378 | Val Loss: 0.7627, Val Acc: 0.6321, Val F1: 0.6327
***Saved new best model at epoch 3 (Val F1=0.6327)


100%|██████████| 85/85 [00:01<00:00, 57.12it/s, loss=0.624, acc=0.726]


Epoch 04 | Train Loss: 0.6239, Train Acc: 0.7261 Train F1: 0.7888 | Val Loss: 0.7478, Val Acc: 0.6446, Val F1: 0.6491
***Saved new best model at epoch 4 (Val F1=0.6491)


100%|██████████| 85/85 [00:01<00:00, 62.27it/s, loss=0.553, acc=0.764]


Epoch 05 | Train Loss: 0.5532, Train Acc: 0.7639 Train F1: 0.8332 | Val Loss: 0.7423, Val Acc: 0.6652, Val F1: 0.6702
***Saved new best model at epoch 5 (Val F1=0.6702)


100%|██████████| 85/85 [00:01<00:00, 59.63it/s, loss=0.503, acc=0.786]


Epoch 06 | Train Loss: 0.5029, Train Acc: 0.7864 Train F1: 0.8570 | Val Loss: 0.7491, Val Acc: 0.6670, Val F1: 0.6704
***Saved new best model at epoch 6 (Val F1=0.6704)


100%|██████████| 85/85 [00:01<00:00, 56.96it/s, loss=0.44, acc=0.814] 


Epoch 07 | Train Loss: 0.4398, Train Acc: 0.8142 Train F1: 0.8952 | Val Loss: 0.7670, Val Acc: 0.6807, Val F1: 0.6851
***Saved new best model at epoch 7 (Val F1=0.6851)


100%|██████████| 85/85 [00:01<00:00, 61.79it/s, loss=0.384, acc=0.844]


Epoch 08 | Train Loss: 0.3843, Train Acc: 0.8437 Train F1: 0.9060 | Val Loss: 0.7930, Val Acc: 0.6760, Val F1: 0.6776


100%|██████████| 85/85 [00:01<00:00, 64.91it/s, loss=0.343, acc=0.864]


Epoch 09 | Train Loss: 0.3433, Train Acc: 0.8644 Train F1: 0.9348 | Val Loss: 0.8263, Val Acc: 0.6833, Val F1: 0.6877
***Saved new best model at epoch 9 (Val F1=0.6877)


100%|██████████| 85/85 [00:01<00:00, 56.40it/s, loss=0.301, acc=0.881]


Epoch 10 | Train Loss: 0.3011, Train Acc: 0.8814 Train F1: 0.9424 | Val Loss: 0.8909, Val Acc: 0.6889, Val F1: 0.6944
***Saved new best model at epoch 10 (Val F1=0.6944)


100%|██████████| 85/85 [00:01<00:00, 54.23it/s, loss=0.274, acc=0.891]


Epoch 11 | Train Loss: 0.2739, Train Acc: 0.8908 Train F1: 0.9578 | Val Loss: 0.9085, Val Acc: 0.6863, Val F1: 0.6907


100%|██████████| 85/85 [00:01<00:00, 56.44it/s, loss=0.241, acc=0.904]


Epoch 12 | Train Loss: 0.2410, Train Acc: 0.9042 Train F1: 0.9626 | Val Loss: 0.9720, Val Acc: 0.6846, Val F1: 0.6896


100%|██████████| 85/85 [00:01<00:00, 60.77it/s, loss=0.219, acc=0.914]


Epoch 13 | Train Loss: 0.2187, Train Acc: 0.9137 Train F1: 0.9667 | Val Loss: 1.0116, Val Acc: 0.6923, Val F1: 0.6965
***Saved new best model at epoch 13 (Val F1=0.6965)


100%|██████████| 85/85 [00:01<00:00, 57.57it/s, loss=0.189, acc=0.923]


Epoch 14 | Train Loss: 0.1888, Train Acc: 0.9231 Train F1: 0.9768 | Val Loss: 1.0551, Val Acc: 0.6928, Val F1: 0.6974
***Saved new best model at epoch 14 (Val F1=0.6974)


100%|██████████| 85/85 [00:01<00:00, 54.37it/s, loss=0.175, acc=0.932]


Epoch 15 | Train Loss: 0.1752, Train Acc: 0.9322 Train F1: 0.9817 | Val Loss: 1.0910, Val Acc: 0.6842, Val F1: 0.6896


100%|██████████| 85/85 [00:01<00:00, 59.13it/s, loss=0.161, acc=0.939]


Epoch 16 | Train Loss: 0.1607, Train Acc: 0.9392 Train F1: 0.9806 | Val Loss: 1.1580, Val Acc: 0.6799, Val F1: 0.6840


100%|██████████| 85/85 [00:01<00:00, 59.67it/s, loss=0.145, acc=0.942]


Epoch 17 | Train Loss: 0.1450, Train Acc: 0.9424 Train F1: 0.9881 | Val Loss: 1.1951, Val Acc: 0.6889, Val F1: 0.6937


100%|██████████| 85/85 [00:01<00:00, 55.92it/s, loss=0.126, acc=0.95] 


Epoch 18 | Train Loss: 0.1262, Train Acc: 0.9498 Train F1: 0.9916 | Val Loss: 1.2741, Val Acc: 0.6872, Val F1: 0.6928


100%|██████████| 85/85 [00:01<00:00, 55.56it/s, loss=0.115, acc=0.957]


Epoch 19 | Train Loss: 0.1153, Train Acc: 0.9570 Train F1: 0.9934 | Val Loss: 1.2745, Val Acc: 0.6936, Val F1: 0.6968


100%|██████████| 85/85 [00:01<00:00, 63.59it/s, loss=0.106, acc=0.961] 


Epoch 20 | Train Loss: 0.1059, Train Acc: 0.9609 Train F1: 0.9945 | Val Loss: 1.3625, Val Acc: 0.6898, Val F1: 0.6943


100%|██████████| 85/85 [00:01<00:00, 57.19it/s, loss=0.103, acc=0.961] 


Epoch 21 | Train Loss: 0.1035, Train Acc: 0.9612 Train F1: 0.9943 | Val Loss: 1.3569, Val Acc: 0.6919, Val F1: 0.6963


100%|██████████| 85/85 [00:01<00:00, 55.77it/s, loss=0.0932, acc=0.964]


Epoch 22 | Train Loss: 0.0932, Train Acc: 0.9637 Train F1: 0.9966 | Val Loss: 1.4121, Val Acc: 0.6954, Val F1: 0.6996
***Saved new best model at epoch 22 (Val F1=0.6996)


100%|██████████| 85/85 [00:01<00:00, 57.08it/s, loss=0.087, acc=0.968] 


Epoch 23 | Train Loss: 0.0870, Train Acc: 0.9675 Train F1: 0.9963 | Val Loss: 1.4403, Val Acc: 0.6941, Val F1: 0.6981


100%|██████████| 85/85 [00:01<00:00, 56.87it/s, loss=0.0783, acc=0.971]


Epoch 24 | Train Loss: 0.0783, Train Acc: 0.9711 Train F1: 0.9957 | Val Loss: 1.4797, Val Acc: 0.6807, Val F1: 0.6862


100%|██████████| 85/85 [00:01<00:00, 59.63it/s, loss=0.0715, acc=0.972]


Epoch 25 | Train Loss: 0.0715, Train Acc: 0.9722 Train F1: 0.9966 | Val Loss: 1.5037, Val Acc: 0.6992, Val F1: 0.7027
***Saved new best model at epoch 25 (Val F1=0.7027)


100%|██████████| 85/85 [00:01<00:00, 57.38it/s, loss=0.0753, acc=0.972]


Epoch 26 | Train Loss: 0.0753, Train Acc: 0.9721 Train F1: 0.9968 | Val Loss: 1.5727, Val Acc: 0.6880, Val F1: 0.6929


100%|██████████| 85/85 [00:01<00:00, 56.78it/s, loss=0.0703, acc=0.973]


Epoch 27 | Train Loss: 0.0703, Train Acc: 0.9731 Train F1: 0.9984 | Val Loss: 1.5918, Val Acc: 0.6958, Val F1: 0.6998


100%|██████████| 85/85 [00:01<00:00, 64.54it/s, loss=0.0623, acc=0.978]


Epoch 28 | Train Loss: 0.0623, Train Acc: 0.9782 Train F1: 0.9989 | Val Loss: 1.6435, Val Acc: 0.6949, Val F1: 0.6991


100%|██████████| 85/85 [00:01<00:00, 56.73it/s, loss=0.0534, acc=0.98] 


Epoch 29 | Train Loss: 0.0534, Train Acc: 0.9797 Train F1: 0.9990 | Val Loss: 1.6911, Val Acc: 0.6936, Val F1: 0.6987


100%|██████████| 85/85 [00:01<00:00, 56.07it/s, loss=0.0508, acc=0.982]


Epoch 30 | Train Loss: 0.0508, Train Acc: 0.9816 Train F1: 0.9987 | Val Loss: 1.7693, Val Acc: 0.6910, Val F1: 0.6945

Training complete.
Best Validation F1: 0.7026722653077618


In [18]:
# =======================================
# Final evaluation on TEST using best model
# =======================================

# load best checkpoint
model.load_state_dict(torch.load('models/best_model.pt'))

test_loss, test_acc, test_f1, y_true, y_pred = eval_model_ntstyle(model, test_loader)

print("\n=== FINAL TEST RESULTS (BEST CHECKPOINT) ===")
print("Test Accuracy:", test_acc)
print("Test Macro F1:", test_f1)
print("\nClassification Report:")
print(classification_report(y_true, y_pred, digits=4))

# Save JSON
os.makedirs('results', exist_ok=True)
metrics = {
    'test_loss': float(test_loss),
    'accuracy': float(test_acc),
    'macro_f1': float(test_f1),
    'classification_report': classification_report(y_true, y_pred, digits=4),
}
with open('results/best_model_test_metrics.json', 'w') as f:
    json.dump(metrics, f, indent=2)

print('\nSaved metrics to results/best_model_test_metrics.json')



=== FINAL TEST RESULTS (BEST CHECKPOINT) ===
Test Accuracy: 0.7010752688172043
Test Macro F1: 0.7033612089520026

Classification Report:
              precision    recall  f1-score   support

           0     0.6826    0.7076    0.6949       766
           1     0.7860    0.8140    0.7997       731
           2     0.6370    0.5954    0.6155       828

    accuracy                         0.7011      2325
   macro avg     0.7019    0.7056    0.7034      2325
weighted avg     0.6989    0.7011    0.6996      2325


Saved metrics to results/best_model_test_metrics.json
