<a href="https://colab.research.google.com/github/navidh86/perturbseq-10701/blob/master/baseline_classify.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# ONLY FOR COLAB
!git clone https://github.com/navidh86/perturbseq-10701.git
%cd ./perturbseq-10701
!pip install fastparquet tqdm

Cloning into 'perturbseq-10701'...
remote: Enumerating objects: 201, done.[K
remote: Counting objects: 100% (80/80), done.[K
remote: Compressing objects: 100% (63/63), done.[K
remote: Total 201 (delta 39), reused 45 (delta 16), pack-reused 121 (from 2)[K
Receiving objects: 100% (201/201), 260.57 MiB | 12.74 MiB/s, done.
Resolving deltas: 100% (85/85), done.
Updating files: 100% (53/53), done.
/content/perturbseq-10701
Collecting fastparquet
  Downloading fastparquet-2024.11.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.2 kB)
Downloading fastparquet-2024.11.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m65.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: fastparquet
Successfully installed fastparquet-2024.11.0


In [2]:
# Imports and device
import os
import pickle
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from tqdm import tqdm

from data.reference_data_classification import get_dataloader

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

Device: cuda


In [3]:
# # Create dataloaders (point to data/ paths explicitly)
# train_loader = get_dataloader(
#     parquet_path='data/tf_gene_expression_labeled.parquet',
#     tf_sequences_path='data/tf_sequences.pkl',
#     gene_sequences_path='data/gene_sequences_4000bp.pkl',
#     batch_size=128,
#     type='train',
#     majority_fraction=0.01
# )
# test_loader = get_dataloader(
#     parquet_path='data/tf_gene_expression_labeled.parquet',
#     tf_sequences_path='data/tf_sequences.pkl',
#     gene_sequences_path='data/gene_sequences_4000bp.pkl',
#     batch_size=256,
#     type='test',
#     majority_fraction=0.01
# )

# print('Train size:', len(train_loader.dataset))
# print('Test size :', len(test_loader.dataset))

# Create dataloaders (point to data/ paths explicitly)
train_loader = get_dataloader(
    parquet_path='data/tf_gene_expression_labeled_v2.parquet',
    tf_sequences_path='data/tf_sequences.pkl',
    gene_sequences_path='data/gene_sequences_4000bp.pkl',
    batch_size=128,
    type='train',
    majority_fraction=0.005
)
test_loader = get_dataloader(
    parquet_path='data/tf_gene_expression_labeled_v2.parquet',
    tf_sequences_path='data/tf_sequences.pkl',
    gene_sequences_path='data/gene_sequences_4000bp.pkl',
    batch_size=256,
    type='test',
    majority_fraction=0.005
)

print('Train size:', len(train_loader.dataset))
print('Test size :', len(test_loader.dataset))

Train size: 10845
Test size : 2325


In [4]:
train_ds = train_loader.dataset
test_ds = test_loader.dataset
combined_df = pd.concat([train_ds.df, test_ds.df]).reset_index(drop=True)

# unique names from combined set
tf_names = combined_df['tf_name'].unique().tolist()
gene_names = combined_df['gene_name'].unique().tolist()

# create mappings
tf_to_id = {n: i for i, n in enumerate(tf_names)}
gene_to_id = {n: i for i, n in enumerate(gene_names)}

num_tfs = len(tf_to_id)
num_genes = len(gene_to_id)
# Use classes from training split
num_classes = len(train_ds.df['expression_label'].unique())

print('Unique TFs (combined):', num_tfs)
print('Unique Genes (combined):', num_genes)
print('Num classes:', num_classes)

Unique TFs (combined): 223
Unique Genes (combined): 4305
Num classes: 3


In [5]:
# Load cached embeddings and prepare name lists
import pickle

# Load cached TF/gene embeddings produced by the embedding notebook
tf_embed_cache = pickle.load(open('embeds/tf_cls.pkl', 'rb'))
gene_embed_cache = pickle.load(open('embeds/gn_cls.pkl', 'rb'))

# Convert any numpy arrays to torch tensors
for k in list(tf_embed_cache.keys()):
    v = tf_embed_cache[k]
    if not isinstance(v, torch.Tensor):
        tf_embed_cache[k] = torch.tensor(v, dtype=torch.float32)
for k in list(gene_embed_cache.keys()):
    v = gene_embed_cache[k]
    if not isinstance(v, torch.Tensor):
        gene_embed_cache[k] = torch.tensor(v, dtype=torch.float32)

# Expose name lists and counts (from the caches to ensure consistent mapping)
tf_names = list(tf_embed_cache.keys())
gene_names = list(gene_embed_cache.keys())
num_tfs = len(tf_names)
num_genes = len(gene_names)
# number of classes is taken from the training split
num_classes = len(train_loader.dataset.df['expression_label'].unique())

print('TFs in cache:', num_tfs)
print('Genes in cache:', num_genes)
print('Num classes:', num_classes)


TFs in cache: 223
Genes in cache: 5307
Num classes: 3


In [6]:
# ID-only model definition
class TFGeneIDModel(nn.Module):
    def __init__(self, num_tfs, num_genes, emb_dim=64, hidden_dim=256, num_classes=3):
        super().__init__()
        self.tf_emb = nn.Embedding(num_tfs, emb_dim)
        self.gene_emb = nn.Embedding(num_genes, emb_dim)
        self.mlp = nn.Sequential(
            nn.Linear(2*emb_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, hidden_dim//2),
            nn.ReLU(),
            nn.Linear(hidden_dim//2, num_classes)
        )

    def forward(self, tf_ids, gene_ids):
        t = self.tf_emb(tf_ids)
        g = self.gene_emb(gene_ids)
        h = torch.cat([t, g], dim=-1)
        return self.mlp(h)

# Instantiate model
model = TFGeneIDModel(num_tfs=num_tfs, num_genes=num_genes, emb_dim=64, hidden_dim=256, num_classes=num_classes).to(device)
print(model)

TFGeneIDModel(
  (tf_emb): Embedding(223, 64)
  (gene_emb): Embedding(5307, 64)
  (mlp): Sequential(
    (0): Linear(in_features=128, out_features=256, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.2, inplace=False)
    (3): Linear(in_features=256, out_features=128, bias=True)
    (4): ReLU()
    (5): Linear(in_features=128, out_features=3, bias=True)
  )
)


In [7]:
# Training / evaluation helpers
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

def prepare_id_batch(batch_x, device=device):
    tf_ids = torch.tensor([tf_to_id[item['tf_name']] for item in batch_x], dtype=torch.long, device=device)
    gene_ids = torch.tensor([gene_to_id[item['gene_name']] for item in batch_x], dtype=torch.long, device=device)
    return tf_ids, gene_ids

def train_one_epoch(model, loader, optimizer):
    model.train()
    total_loss = 0.0
    total_correct = 0
    total = 0
    pbar = tqdm(loader)
    for batch_x, batch_y in pbar:
        batch_y = batch_y.to(device)
        tf_ids, gene_ids = prepare_id_batch(batch_x)

        logits = model(tf_ids, gene_ids)
        loss = loss_fn(logits, batch_y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        preds = logits.argmax(dim=1)
        total_correct += (preds == batch_y).sum().item()
        total += len(batch_y)
        total_loss += loss.item() * len(batch_y)
        pbar.set_postfix({'loss': total_loss/total, 'acc': total_correct/total})

    return total_loss / total, total_correct / total

@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    total_loss = 0.0
    total_correct = 0
    total = 0
    for batch_x, batch_y in loader:
        batch_y = batch_y.to(device)
        tf_ids, gene_ids = prepare_id_batch(batch_x)
        logits = model(tf_ids, gene_ids)
        loss = loss_fn(logits, batch_y)
        preds = logits.argmax(dim=1)
        total_correct += (preds == batch_y).sum().item()
        total += len(batch_y)
        total_loss += loss.item() * len(batch_y)
    return total_loss/total, total_correct/total

In [8]:
# Train for a few epochs
num_epochs = 30
for epoch in range(1, num_epochs+1):
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer)
    val_loss, val_acc = evaluate(model, test_loader)
    print(f'Epoch {epoch:02d} | Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')

# Save model
os.makedirs('models', exist_ok=True)
torch.save({
    'model_state_dict': model.state_dict(),
    'tf_to_id': tf_to_id,
    'gene_to_id': gene_to_id
}, 'models/tf_gene_id_baseline.pt')
print('Saved model to models/tf_gene_id_baseline.pt')

100%|██████████| 85/85 [00:01<00:00, 53.02it/s, loss=0.964, acc=0.509]


Epoch 01 | Train Loss: 0.9643, Train Acc: 0.5087 | Val Loss: 0.8636, Val Acc: 0.5712


100%|██████████| 85/85 [00:00<00:00, 87.76it/s, loss=0.789, acc=0.627]


Epoch 02 | Train Loss: 0.7893, Train Acc: 0.6275 | Val Loss: 0.7875, Val Acc: 0.6284


100%|██████████| 85/85 [00:01<00:00, 79.78it/s, loss=0.706, acc=0.677]


Epoch 03 | Train Loss: 0.7057, Train Acc: 0.6768 | Val Loss: 0.7632, Val Acc: 0.6353


100%|██████████| 85/85 [00:01<00:00, 79.16it/s, loss=0.64, acc=0.715]


Epoch 04 | Train Loss: 0.6404, Train Acc: 0.7145 | Val Loss: 0.7562, Val Acc: 0.6477


100%|██████████| 85/85 [00:00<00:00, 89.74it/s, loss=0.57, acc=0.75]


Epoch 05 | Train Loss: 0.5704, Train Acc: 0.7502 | Val Loss: 0.7402, Val Acc: 0.6675


100%|██████████| 85/85 [00:00<00:00, 89.13it/s, loss=0.506, acc=0.784]


Epoch 06 | Train Loss: 0.5062, Train Acc: 0.7841 | Val Loss: 0.7705, Val Acc: 0.6723


100%|██████████| 85/85 [00:00<00:00, 87.37it/s, loss=0.451, acc=0.809]


Epoch 07 | Train Loss: 0.4514, Train Acc: 0.8094 | Val Loss: 0.7779, Val Acc: 0.6757


100%|██████████| 85/85 [00:00<00:00, 95.42it/s, loss=0.405, acc=0.836]


Epoch 08 | Train Loss: 0.4046, Train Acc: 0.8357 | Val Loss: 0.7867, Val Acc: 0.6839


100%|██████████| 85/85 [00:00<00:00, 91.40it/s, loss=0.351, acc=0.859]


Epoch 09 | Train Loss: 0.3512, Train Acc: 0.8590 | Val Loss: 0.8354, Val Acc: 0.6839


100%|██████████| 85/85 [00:00<00:00, 96.13it/s, loss=0.314, acc=0.872]


Epoch 10 | Train Loss: 0.3142, Train Acc: 0.8717 | Val Loss: 0.8800, Val Acc: 0.6834


100%|██████████| 85/85 [00:00<00:00, 95.43it/s, loss=0.292, acc=0.88]


Epoch 11 | Train Loss: 0.2923, Train Acc: 0.8803 | Val Loss: 0.8874, Val Acc: 0.6938


100%|██████████| 85/85 [00:00<00:00, 94.49it/s, loss=0.255, acc=0.899]


Epoch 12 | Train Loss: 0.2550, Train Acc: 0.8995 | Val Loss: 0.9664, Val Acc: 0.6826


100%|██████████| 85/85 [00:01<00:00, 79.24it/s, loss=0.234, acc=0.906]


Epoch 13 | Train Loss: 0.2336, Train Acc: 0.9062 | Val Loss: 0.9848, Val Acc: 0.6865


100%|██████████| 85/85 [00:01<00:00, 79.98it/s, loss=0.196, acc=0.922]


Epoch 14 | Train Loss: 0.1958, Train Acc: 0.9221 | Val Loss: 1.0193, Val Acc: 0.6890


100%|██████████| 85/85 [00:01<00:00, 79.74it/s, loss=0.181, acc=0.931]


Epoch 15 | Train Loss: 0.1812, Train Acc: 0.9314 | Val Loss: 1.0681, Val Acc: 0.6877


100%|██████████| 85/85 [00:00<00:00, 92.61it/s, loss=0.165, acc=0.938]


Epoch 16 | Train Loss: 0.1653, Train Acc: 0.9382 | Val Loss: 1.1293, Val Acc: 0.6951


100%|██████████| 85/85 [00:00<00:00, 90.65it/s, loss=0.145, acc=0.945]


Epoch 17 | Train Loss: 0.1454, Train Acc: 0.9449 | Val Loss: 1.1634, Val Acc: 0.6895


100%|██████████| 85/85 [00:00<00:00, 90.30it/s, loss=0.14, acc=0.946]


Epoch 18 | Train Loss: 0.1398, Train Acc: 0.9462 | Val Loss: 1.2269, Val Acc: 0.6873


100%|██████████| 85/85 [00:00<00:00, 95.07it/s, loss=0.125, acc=0.953]


Epoch 19 | Train Loss: 0.1253, Train Acc: 0.9528 | Val Loss: 1.2610, Val Acc: 0.6942


100%|██████████| 85/85 [00:01<00:00, 81.91it/s, loss=0.112, acc=0.958]


Epoch 20 | Train Loss: 0.1117, Train Acc: 0.9583 | Val Loss: 1.2940, Val Acc: 0.6946


100%|██████████| 85/85 [00:00<00:00, 94.09it/s, loss=0.0989, acc=0.963]


Epoch 21 | Train Loss: 0.0989, Train Acc: 0.9626 | Val Loss: 1.3681, Val Acc: 0.6882


100%|██████████| 85/85 [00:00<00:00, 89.27it/s, loss=0.0946, acc=0.965]


Epoch 22 | Train Loss: 0.0946, Train Acc: 0.9647 | Val Loss: 1.3569, Val Acc: 0.6951


100%|██████████| 85/85 [00:00<00:00, 93.05it/s, loss=0.0844, acc=0.97]


Epoch 23 | Train Loss: 0.0844, Train Acc: 0.9700 | Val Loss: 1.4341, Val Acc: 0.6933


100%|██████████| 85/85 [00:00<00:00, 85.89it/s, loss=0.0765, acc=0.972]


Epoch 24 | Train Loss: 0.0765, Train Acc: 0.9719 | Val Loss: 1.4742, Val Acc: 0.6972


100%|██████████| 85/85 [00:00<00:00, 91.15it/s, loss=0.0808, acc=0.971]


Epoch 25 | Train Loss: 0.0808, Train Acc: 0.9708 | Val Loss: 1.4711, Val Acc: 0.6925


100%|██████████| 85/85 [00:01<00:00, 84.35it/s, loss=0.0673, acc=0.976]


Epoch 26 | Train Loss: 0.0673, Train Acc: 0.9762 | Val Loss: 1.5427, Val Acc: 0.7002


100%|██████████| 85/85 [00:00<00:00, 87.29it/s, loss=0.0686, acc=0.975]


Epoch 27 | Train Loss: 0.0686, Train Acc: 0.9747 | Val Loss: 1.5450, Val Acc: 0.6920


100%|██████████| 85/85 [00:00<00:00, 90.52it/s, loss=0.0575, acc=0.978]


Epoch 28 | Train Loss: 0.0575, Train Acc: 0.9782 | Val Loss: 1.5907, Val Acc: 0.6933


100%|██████████| 85/85 [00:00<00:00, 90.62it/s, loss=0.0567, acc=0.978]


Epoch 29 | Train Loss: 0.0567, Train Acc: 0.9782 | Val Loss: 1.6480, Val Acc: 0.6959


100%|██████████| 85/85 [00:00<00:00, 87.88it/s, loss=0.0573, acc=0.98]


Epoch 30 | Train Loss: 0.0573, Train Acc: 0.9805 | Val Loss: 1.6543, Val Acc: 0.6920
Saved model to models/tf_gene_id_baseline.pt


In [9]:
# Final evaluation (nt_classify-style): compute loss, accuracy, macro-F1, and classification report
import torch.nn as nn
from sklearn.metrics import f1_score, classification_report
import json

loss_fn = nn.CrossEntropyLoss()

@torch.no_grad()
def eval_model_ntstyle(model, loader):
    model.eval()
    total_loss, total_correct, total_samples = 0, 0, 0

    all_preds = []
    all_labels = []

    for batch_x, batch_y in loader:
        batch_y = batch_y.to(device)

        # support both cache-based and id-based models
        try:
            logits = model(batch_x)
        except TypeError:
            tf_ids, gene_ids = prepare_id_batch(batch_x, device=device)
            logits = model(tf_ids, gene_ids)

        loss = loss_fn(logits, batch_y)

        preds = logits.argmax(dim=1)

        total_loss += loss.item() * len(batch_y)
        total_correct += (preds == batch_y).sum().item()
        total_samples += len(batch_y)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(batch_y.cpu().numpy())

    avg_loss = total_loss / total_samples
    accuracy = total_correct / total_samples
    macro_f1 = f1_score(all_labels, all_preds, average="macro")

    return avg_loss, accuracy, macro_f1, all_labels, all_preds

# Run final eval and print same outputs as nt_classify
test_loss, test_acc, test_f1, y_true, y_pred = eval_model_ntstyle(model, test_loader)

print("Final Test Accuracy:", test_acc)
print("Final Test Macro F1:", test_f1)
print("\nClassification Report:")
print(classification_report(y_true, y_pred, digits=4))

# Save same metrics JSON for reproducibility
os.makedirs('results', exist_ok=True)
metrics = {
    'test_loss': float(test_loss),
    'accuracy': float(test_acc),
    'macro_f1': float(test_f1),
    'classification_report': classification_report(y_true, y_pred, digits=4),
}
with open('results/baseline_classify_metrics_ntstyle.json', 'w') as f:
    json.dump(metrics, f, indent=2)
print('\nSaved metrics to results/baseline_classify_metrics_ntstyle.json')

Final Test Accuracy: 0.6920430107526881
Final Test Macro F1: 0.6960908410748012

Classification Report:
              precision    recall  f1-score   support

           0     0.6902    0.6514    0.6702       766
           1     0.7966    0.7715    0.7839       731
           2     0.6107    0.6594    0.6341       828

    accuracy                         0.6920      2325
   macro avg     0.6992    0.6941    0.6961      2325
weighted avg     0.6954    0.6920    0.6931      2325


Saved metrics to results/baseline_classify_metrics_ntstyle.json
