<a href="https://colab.research.google.com/github/navidh86/perturbseq-10701/blob/master/baseline_classify.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# ONLY FOR COLAB
!git clone https://github.com/navidh86/perturbseq-10701.git
%cd ./perturbseq-10701
!pip install fastparquet tqdm

Cloning into 'perturbseq-10701'...
remote: Enumerating objects: 184, done.[K
remote: Counting objects: 100% (63/63), done.[K
remote: Compressing objects: 100% (51/51), done.[K
remote: Total 184 (delta 29), reused 36 (delta 11), pack-reused 121 (from 2)[K
Receiving objects: 100% (184/184), 260.54 MiB | 15.34 MiB/s, done.
Resolving deltas: 100% (75/75), done.
Updating files: 100% (52/52), done.
/content/perturbseq-10701
Collecting fastparquet
  Downloading fastparquet-2024.11.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.2 kB)
Downloading fastparquet-2024.11.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m71.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: fastparquet
Successfully installed fastparquet-2024.11.0


In [2]:
# Imports and device
import os
import pickle
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from tqdm import tqdm

from data.reference_data_classification import get_dataloader

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

Device: cuda


In [3]:
# # Create dataloaders (point to data/ paths explicitly)
# train_loader = get_dataloader(
#     parquet_path='data/tf_gene_expression_labeled.parquet',
#     tf_sequences_path='data/tf_sequences.pkl',
#     gene_sequences_path='data/gene_sequences_4000bp.pkl',
#     batch_size=128,
#     type='train',
#     majority_fraction=0.01
# )
# test_loader = get_dataloader(
#     parquet_path='data/tf_gene_expression_labeled.parquet',
#     tf_sequences_path='data/tf_sequences.pkl',
#     gene_sequences_path='data/gene_sequences_4000bp.pkl',
#     batch_size=256,
#     type='test',
#     majority_fraction=0.01
# )

# print('Train size:', len(train_loader.dataset))
# print('Test size :', len(test_loader.dataset))

# Create dataloaders (point to data/ paths explicitly)
train_loader = get_dataloader(
    parquet_path='data/tf_gene_expression_labeled_v2.parquet',
    tf_sequences_path='data/tf_sequences.pkl',
    gene_sequences_path='data/gene_sequences_4000bp.pkl',
    batch_size=128,
    type='train',
    majority_fraction=0.005
)
test_loader = get_dataloader(
    parquet_path='data/tf_gene_expression_labeled_v2.parquet',
    tf_sequences_path='data/tf_sequences.pkl',
    gene_sequences_path='data/gene_sequences_4000bp.pkl',
    batch_size=256,
    type='test',
    majority_fraction=0.005
)

print('Train size:', len(train_loader.dataset))
print('Test size :', len(test_loader.dataset))

Train size: 12395
Test size : 3099


In [4]:
train_ds = train_loader.dataset
test_ds = test_loader.dataset
combined_df = pd.concat([train_ds.df, test_ds.df]).reset_index(drop=True)

# unique names from combined set
tf_names = combined_df['tf_name'].unique().tolist()
gene_names = combined_df['gene_name'].unique().tolist()

# create mappings
tf_to_id = {n: i for i, n in enumerate(tf_names)}
gene_to_id = {n: i for i, n in enumerate(gene_names)}

num_tfs = len(tf_to_id)
num_genes = len(gene_to_id)
# Use classes from training split
num_classes = len(train_ds.df['expression_label'].unique())

print('Unique TFs (combined):', num_tfs)
print('Unique Genes (combined):', num_genes)
print('Num classes:', num_classes)

Unique TFs (combined): 223
Unique Genes (combined): 4539
Num classes: 3


In [5]:
# Load cached embeddings and prepare name lists
import pickle

# Load cached TF/gene embeddings produced by the embedding notebook
tf_embed_cache = pickle.load(open('embeds/tf_cls.pkl', 'rb'))
gene_embed_cache = pickle.load(open('embeds/gn_cls.pkl', 'rb'))

# Convert any numpy arrays to torch tensors
for k in list(tf_embed_cache.keys()):
    v = tf_embed_cache[k]
    if not isinstance(v, torch.Tensor):
        tf_embed_cache[k] = torch.tensor(v, dtype=torch.float32)
for k in list(gene_embed_cache.keys()):
    v = gene_embed_cache[k]
    if not isinstance(v, torch.Tensor):
        gene_embed_cache[k] = torch.tensor(v, dtype=torch.float32)

# Expose name lists and counts (from the caches to ensure consistent mapping)
tf_names = list(tf_embed_cache.keys())
gene_names = list(gene_embed_cache.keys())
num_tfs = len(tf_names)
num_genes = len(gene_names)
# number of classes is taken from the training split
num_classes = len(train_loader.dataset.df['expression_label'].unique())

print('TFs in cache:', num_tfs)
print('Genes in cache:', num_genes)
print('Num classes:', num_classes)


TFs in cache: 223
Genes in cache: 5307
Num classes: 3


In [6]:
# ID-only model definition
class TFGeneIDModel(nn.Module):
    def __init__(self, num_tfs, num_genes, emb_dim=64, hidden_dim=256, num_classes=3):
        super().__init__()
        self.tf_emb = nn.Embedding(num_tfs, emb_dim)
        self.gene_emb = nn.Embedding(num_genes, emb_dim)
        self.mlp = nn.Sequential(
            nn.Linear(2*emb_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, hidden_dim//2),
            nn.ReLU(),
            nn.Linear(hidden_dim//2, num_classes)
        )

    def forward(self, tf_ids, gene_ids):
        t = self.tf_emb(tf_ids)
        g = self.gene_emb(gene_ids)
        h = torch.cat([t, g], dim=-1)
        return self.mlp(h)

# Instantiate model
model = TFGeneIDModel(num_tfs=num_tfs, num_genes=num_genes, emb_dim=64, hidden_dim=256, num_classes=num_classes).to(device)
print(model)

TFGeneIDModel(
  (tf_emb): Embedding(223, 64)
  (gene_emb): Embedding(5307, 64)
  (mlp): Sequential(
    (0): Linear(in_features=128, out_features=256, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.2, inplace=False)
    (3): Linear(in_features=256, out_features=128, bias=True)
    (4): ReLU()
    (5): Linear(in_features=128, out_features=3, bias=True)
  )
)


In [7]:
# Training / evaluation helpers
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

def prepare_id_batch(batch_x, device=device):
    tf_ids = torch.tensor([tf_to_id[item['tf_name']] for item in batch_x], dtype=torch.long, device=device)
    gene_ids = torch.tensor([gene_to_id[item['gene_name']] for item in batch_x], dtype=torch.long, device=device)
    return tf_ids, gene_ids

def train_one_epoch(model, loader, optimizer):
    model.train()
    total_loss = 0.0
    total_correct = 0
    total = 0
    pbar = tqdm(loader)
    for batch_x, batch_y in pbar:
        batch_y = batch_y.to(device)
        tf_ids, gene_ids = prepare_id_batch(batch_x)

        logits = model(tf_ids, gene_ids)
        loss = loss_fn(logits, batch_y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        preds = logits.argmax(dim=1)
        total_correct += (preds == batch_y).sum().item()
        total += len(batch_y)
        total_loss += loss.item() * len(batch_y)
        pbar.set_postfix({'loss': total_loss/total, 'acc': total_correct/total})

    return total_loss / total, total_correct / total

@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    total_loss = 0.0
    total_correct = 0
    total = 0
    for batch_x, batch_y in loader:
        batch_y = batch_y.to(device)
        tf_ids, gene_ids = prepare_id_batch(batch_x)
        logits = model(tf_ids, gene_ids)
        loss = loss_fn(logits, batch_y)
        preds = logits.argmax(dim=1)
        total_correct += (preds == batch_y).sum().item()
        total += len(batch_y)
        total_loss += loss.item() * len(batch_y)
    return total_loss/total, total_correct/total

In [8]:
# Train for a few epochs
num_epochs = 5
for epoch in range(1, num_epochs+1):
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer)
    val_loss, val_acc = evaluate(model, test_loader)
    print(f'Epoch {epoch:02d} | Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')

# Save model
os.makedirs('models', exist_ok=True)
torch.save({
    'model_state_dict': model.state_dict(),
    'tf_to_id': tf_to_id,
    'gene_to_id': gene_to_id
}, 'models/tf_gene_id_baseline.pt')
print('Saved model to models/tf_gene_id_baseline.pt')

100%|██████████| 97/97 [00:01<00:00, 51.32it/s, loss=0.931, acc=0.529]


Epoch 01 | Train Loss: 0.9310, Train Acc: 0.5293 | Val Loss: 0.8341, Val Acc: 0.5941


100%|██████████| 97/97 [00:01<00:00, 92.66it/s, loss=0.762, acc=0.637]


Epoch 02 | Train Loss: 0.7623, Train Acc: 0.6368 | Val Loss: 0.7690, Val Acc: 0.6286


100%|██████████| 97/97 [00:01<00:00, 85.99it/s, loss=0.687, acc=0.681]


Epoch 03 | Train Loss: 0.6873, Train Acc: 0.6814 | Val Loss: 0.7358, Val Acc: 0.6563


100%|██████████| 97/97 [00:01<00:00, 83.15it/s, loss=0.615, acc=0.725]


Epoch 04 | Train Loss: 0.6149, Train Acc: 0.7250 | Val Loss: 0.7264, Val Acc: 0.6751


100%|██████████| 97/97 [00:01<00:00, 78.06it/s, loss=0.546, acc=0.759]


Epoch 05 | Train Loss: 0.5459, Train Acc: 0.7589 | Val Loss: 0.7191, Val Acc: 0.6828
Saved model to models/tf_gene_id_baseline.pt


In [10]:
# Final evaluation (nt_classify-style): compute loss, accuracy, macro-F1, and classification report
import torch.nn as nn
from sklearn.metrics import f1_score, classification_report
import json

loss_fn = nn.CrossEntropyLoss()

@torch.no_grad()
def eval_model_ntstyle(model, loader):
    model.eval()
    total_loss, total_correct, total_samples = 0, 0, 0

    all_preds = []
    all_labels = []

    for batch_x, batch_y in loader:
        batch_y = batch_y.to(device)

        # support both cache-based and id-based models
        try:
            logits = model(batch_x)
        except TypeError:
            tf_ids, gene_ids = prepare_id_batch(batch_x, device=device)
            logits = model(tf_ids, gene_ids)

        loss = loss_fn(logits, batch_y)

        preds = logits.argmax(dim=1)

        total_loss += loss.item() * len(batch_y)
        total_correct += (preds == batch_y).sum().item()
        total_samples += len(batch_y)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(batch_y.cpu().numpy())

    avg_loss = total_loss / total_samples
    accuracy = total_correct / total_samples
    macro_f1 = f1_score(all_labels, all_preds, average="macro")

    return avg_loss, accuracy, macro_f1, all_labels, all_preds

# Run final eval and print same outputs as nt_classify
test_loss, test_acc, test_f1, y_true, y_pred = eval_model_ntstyle(model, test_loader)

print("Final Test Accuracy:", test_acc)
print("Final Test Macro F1:", test_f1)
print("\nClassification Report:")
print(classification_report(y_true, y_pred, digits=4))

# Save same metrics JSON for reproducibility
os.makedirs('results', exist_ok=True)
metrics = {
    'test_loss': float(test_loss),
    'accuracy': float(test_acc),
    'macro_f1': float(test_f1),
    'classification_report': classification_report(y_true, y_pred, digits=4),
}
with open('results/baseline_classify_metrics_ntstyle.json', 'w') as f:
    json.dump(metrics, f, indent=2)
print('\nSaved metrics to results/baseline_classify_metrics_ntstyle.json')

Final Test Accuracy: 0.6828009035172636
Final Test Macro F1: 0.686565429925103

Classification Report:
              precision    recall  f1-score   support

           0     0.6539    0.6513    0.6526      1021
           1     0.8042    0.8183    0.8112       974
           2     0.5995    0.5924    0.5959      1104

    accuracy                         0.6828      3099
   macro avg     0.6859    0.6873    0.6866      3099
weighted avg     0.6817    0.6828    0.6822      3099


Saved metrics to results/baseline_classify_metrics_ntstyle.json
