<a href="https://colab.research.google.com/github/navidh86/perturbseq-10701/blob/master/baseline_classify.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# ONLY FOR COLAB
!git clone https://github.com/navidh86/perturbseq-10701.git
%cd ./perturbseq-10701
!pip install fastparquet tqdm

Cloning into 'perturbseq-10701'...
remote: Enumerating objects: 201, done.[K
remote: Counting objects: 100% (80/80), done.[K
remote: Compressing objects: 100% (63/63), done.[K
remote: Total 201 (delta 39), reused 45 (delta 16), pack-reused 121 (from 2)[K
Receiving objects: 100% (201/201), 260.57 MiB | 12.74 MiB/s, done.
Resolving deltas: 100% (85/85), done.
Updating files: 100% (53/53), done.
/content/perturbseq-10701
Collecting fastparquet
  Downloading fastparquet-2024.11.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.2 kB)
Downloading fastparquet-2024.11.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.8 MB)
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m1.8/1.8 MB[0m [31m65.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: fastparquet
Successfully installed fastparquet-2024.11.0


In [2]:
# Imports and device
import os
import pickle
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from tqdm import tqdm

from data.reference_data_classification import get_dataloader

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

Device: cuda


In [12]:
# # Create dataloaders (point to data/ paths explicitly)
# train_loader = get_dataloader(
#     parquet_path='data/tf_gene_expression_labeled_v2.parquet',
#     tf_sequences_path='data/tf_sequences.pkl',
#     gene_sequences_path='data/gene_sequences_4000bp.pkl',
#     batch_size=128,
#     type='train',
#     majority_fraction=0.005
# )
# test_loader = get_dataloader(
#     parquet_path='data/tf_gene_expression_labeled_v2.parquet',
#     tf_sequences_path='data/tf_sequences.pkl',
#     gene_sequences_path='data/gene_sequences_4000bp.pkl',
#     batch_size=256,
#     type='test',
#     majority_fraction=0.005
# )

# print('Train size:', len(train_loader.dataset))
# print('Test size :', len(test_loader.dataset))

# ========================
# Load Dataloaders
# ========================
# Create dataloaders (train / val / test)
train_loader = get_dataloader(
    parquet_path='data/tf_gene_expression_labeled_v2.parquet',
    tf_sequences_path='data/tf_sequences.pkl',
    gene_sequences_path='data/gene_sequences_4000bp.pkl',
    batch_size=128,
    type='train',
    majority_fraction=0.005,
)

val_loader = get_dataloader(
    parquet_path='data/tf_gene_expression_labeled_v2.parquet',
    tf_sequences_path='data/tf_sequences.pkl',
    gene_sequences_path='data/gene_sequences_4000bp.pkl',
    batch_size=256,
    type='val',
    majority_fraction=0.005,
)

test_loader = get_dataloader(
    parquet_path='data/tf_gene_expression_labeled_v2.parquet',
    tf_sequences_path='data/tf_sequences.pkl',
    gene_sequences_path='data/gene_sequences_4000bp.pkl',
    batch_size=256,
    type='test',
    majority_fraction=0.005,
)

print('Train size:', len(train_loader.dataset))
print('Val size  :', len(val_loader.dataset))
print('Test size :', len(test_loader.dataset))


Train size: 10845
Val size  : 2324
Test size : 2325


In [18]:
# Build TF/Gene ID maps using ALL (train + val + test)
train_ds = train_loader.dataset
val_ds = val_loader.dataset
test_ds = test_loader.dataset

combined_df = pd.concat([train_ds.df, val_ds.df, test_ds.df]).reset_index(drop=True)

# Unique names
tf_names = combined_df['tf_name'].unique().tolist()
gene_names = combined_df['gene_name'].unique().tolist()

# Mappings
tf_to_id = {n: i for i, n in enumerate(tf_names)}
gene_to_id = {n: i for i, n in enumerate(gene_names)}

num_tfs = len(tf_to_id)
num_genes = len(gene_to_id)

print('Unique TFs (all splits):', num_tfs)
print('Unique Genes (all splits):', num_genes)


Unique TFs (all splits): 223
Unique Genes (all splits): 4539


In [19]:
# Load cached embeddings and prepare name lists
import pickle

# Load cached TF/gene embeddings produced by the embedding notebook
tf_embed_cache = pickle.load(open('embeds/tf_cls.pkl', 'rb'))
gene_embed_cache = pickle.load(open('embeds/gn_cls.pkl', 'rb'))

# Convert any numpy arrays to torch tensors
for k in list(tf_embed_cache.keys()):
    v = tf_embed_cache[k]
    if not isinstance(v, torch.Tensor):
        tf_embed_cache[k] = torch.tensor(v, dtype=torch.float32)
for k in list(gene_embed_cache.keys()):
    v = gene_embed_cache[k]
    if not isinstance(v, torch.Tensor):
        gene_embed_cache[k] = torch.tensor(v, dtype=torch.float32)

# Expose name lists and counts (from the caches to ensure consistent mapping)
tf_names = list(tf_embed_cache.keys())
gene_names = list(gene_embed_cache.keys())
num_tfs = len(tf_names)
num_genes = len(gene_names)
# number of classes is taken from the training split
num_classes = len(train_loader.dataset.df['expression_label'].unique())

print('TFs in cache:', num_tfs)
print('Genes in cache:', num_genes)
print('Num classes:', num_classes)


TFs in cache: 223
Genes in cache: 5307
Num classes: 3


In [20]:
# ID-only model definition
class TFGeneIDModel(nn.Module):
    def __init__(self, num_tfs, num_genes, emb_dim=64, hidden_dim=256, num_classes=3):
        super().__init__()
        self.tf_emb = nn.Embedding(num_tfs, emb_dim)
        self.gene_emb = nn.Embedding(num_genes, emb_dim)
        self.mlp = nn.Sequential(
            nn.Linear(2*emb_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, hidden_dim//2),
            nn.ReLU(),
            nn.Linear(hidden_dim//2, num_classes)
        )

    def forward(self, tf_ids, gene_ids):
        t = self.tf_emb(tf_ids)
        g = self.gene_emb(gene_ids)
        h = torch.cat([t, g], dim=-1)
        return self.mlp(h)

# Instantiate model
model = TFGeneIDModel(num_tfs=num_tfs, num_genes=num_genes, emb_dim=64, hidden_dim=256, num_classes=num_classes).to(device)
print(model)

TFGeneIDModel(
  (tf_emb): Embedding(223, 64)
  (gene_emb): Embedding(5307, 64)
  (mlp): Sequential(
    (0): Linear(in_features=128, out_features=256, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.2, inplace=False)
    (3): Linear(in_features=256, out_features=128, bias=True)
    (4): ReLU()
    (5): Linear(in_features=128, out_features=3, bias=True)
  )
)


In [21]:
# Training / evaluation helpers
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

def prepare_id_batch(batch_x, device=device):
    tf_ids = torch.tensor([tf_to_id[item['tf_name']] for item in batch_x], dtype=torch.long, device=device)
    gene_ids = torch.tensor([gene_to_id[item['gene_name']] for item in batch_x], dtype=torch.long, device=device)
    return tf_ids, gene_ids

def train_one_epoch(model, loader, optimizer):
    model.train()
    total_loss = 0.0
    total_correct = 0
    total = 0
    pbar = tqdm(loader)
    for batch_x, batch_y in pbar:
        batch_y = batch_y.to(device)
        tf_ids, gene_ids = prepare_id_batch(batch_x)

        logits = model(tf_ids, gene_ids)
        loss = loss_fn(logits, batch_y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        preds = logits.argmax(dim=1)
        total_correct += (preds == batch_y).sum().item()
        total += len(batch_y)
        total_loss += loss.item() * len(batch_y)
        pbar.set_postfix({'loss': total_loss/total, 'acc': total_correct/total})

    return total_loss / total, total_correct / total

@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    total_loss = 0.0
    total_correct = 0
    total = 0
    for batch_x, batch_y in loader:
        batch_y = batch_y.to(device)
        tf_ids, gene_ids = prepare_id_batch(batch_x)
        logits = model(tf_ids, gene_ids)
        loss = loss_fn(logits, batch_y)
        preds = logits.argmax(dim=1)
        total_correct += (preds == batch_y).sum().item()
        total += len(batch_y)
        total_loss += loss.item() * len(batch_y)
    return total_loss/total, total_correct/total

In [22]:
# # Train for a few epochs
# num_epochs = 30
# for epoch in range(1, num_epochs+1):
#     train_loss, train_acc = train_one_epoch(model, train_loader, optimizer)
#     val_loss, val_acc = evaluate(model, test_loader)
#     print(f'Epoch {epoch:02d} | Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')

# # Save model
# os.makedirs('models', exist_ok=True)
# torch.save({
#     'model_state_dict': model.state_dict(),
#     'tf_to_id': tf_to_id,
#     'gene_to_id': gene_to_id
# }, 'models/tf_gene_id_baseline.pt')
# print('Saved model to models/tf_gene_id_baseline.pt')

# ================================
# Training Loop WITH VALIDATION
# ================================

num_epochs = 30
best_val_f1 = -1
best_state = None

for epoch in range(1, num_epochs+1):
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer)
    val_loss, val_acc = evaluate(model, val_loader)

    # compute F1 on val set
    _, _, val_f1, _, _ = eval_model_ntstyle(model, val_loader)

    print(f'Epoch {epoch:02d} | '
          f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | '
          f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, Val F1: {val_f1:.4f}')

    # SAVE BEST MODEL (based on val F1)
    if val_f1 > best_val_f1:
        best_val_f1 = val_f1
        best_state = model.state_dict()
        torch.save(best_state, 'models/best_model.pt')
        print(f'üî• Saved new best model at epoch {epoch} (Val F1={val_f1:.4f})')

print("\nTraining complete.")
print("Best Validation F1:", best_val_f1)


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 85/85 [00:00<00:00, 89.35it/s, loss=0.954, acc=0.521]


Epoch 01 | Train Loss: 0.9539, Train Acc: 0.5206 | Val Loss: 0.8606, Val Acc: 0.5757, Val F1: 0.5723
üî• Saved new best model at epoch 1 (Val F1=0.5723)


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 85/85 [00:00<00:00, 98.37it/s, loss=0.785, acc=0.633]


Epoch 02 | Train Loss: 0.7845, Train Acc: 0.6328 | Val Loss: 0.7990, Val Acc: 0.6179, Val F1: 0.6101
üî• Saved new best model at epoch 2 (Val F1=0.6101)


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 85/85 [00:00<00:00, 98.99it/s, loss=0.701, acc=0.681]


Epoch 03 | Train Loss: 0.7015, Train Acc: 0.6811 | Val Loss: 0.7658, Val Acc: 0.6360, Val F1: 0.6351
üî• Saved new best model at epoch 3 (Val F1=0.6351)


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 85/85 [00:00<00:00, 94.86it/s, loss=0.629, acc=0.718]


Epoch 04 | Train Loss: 0.6295, Train Acc: 0.7178 | Val Loss: 0.7756, Val Acc: 0.6489, Val F1: 0.6532
üî• Saved new best model at epoch 4 (Val F1=0.6532)


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 85/85 [00:00<00:00, 89.20it/s, loss=0.569, acc=0.753]


Epoch 05 | Train Loss: 0.5693, Train Acc: 0.7533 | Val Loss: 0.7598, Val Acc: 0.6528, Val F1: 0.6575
üî• Saved new best model at epoch 5 (Val F1=0.6575)


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 85/85 [00:00<00:00, 89.38it/s, loss=0.512, acc=0.785]


Epoch 06 | Train Loss: 0.5120, Train Acc: 0.7845 | Val Loss: 0.7539, Val Acc: 0.6635, Val F1: 0.6653
üî• Saved new best model at epoch 6 (Val F1=0.6653)


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 85/85 [00:00<00:00, 94.79it/s, loss=0.453, acc=0.815]


Epoch 07 | Train Loss: 0.4534, Train Acc: 0.8148 | Val Loss: 0.7634, Val Acc: 0.6678, Val F1: 0.6726
üî• Saved new best model at epoch 7 (Val F1=0.6726)


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 85/85 [00:00<00:00, 86.23it/s, loss=0.406, acc=0.834]


Epoch 08 | Train Loss: 0.4055, Train Acc: 0.8339 | Val Loss: 0.7922, Val Acc: 0.6713, Val F1: 0.6734
üî• Saved new best model at epoch 8 (Val F1=0.6734)


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 85/85 [00:01<00:00, 83.83it/s, loss=0.356, acc=0.855]


Epoch 09 | Train Loss: 0.3557, Train Acc: 0.8547 | Val Loss: 0.7991, Val Acc: 0.6790, Val F1: 0.6822
üî• Saved new best model at epoch 9 (Val F1=0.6822)


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 85/85 [00:00<00:00, 88.30it/s, loss=0.322, acc=0.87]


Epoch 10 | Train Loss: 0.3220, Train Acc: 0.8695 | Val Loss: 0.8452, Val Acc: 0.6769, Val F1: 0.6810


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 85/85 [00:01<00:00, 82.08it/s, loss=0.291, acc=0.884]


Epoch 11 | Train Loss: 0.2911, Train Acc: 0.8835 | Val Loss: 0.8913, Val Acc: 0.6867, Val F1: 0.6901
üî• Saved new best model at epoch 11 (Val F1=0.6901)


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 85/85 [00:01<00:00, 84.82it/s, loss=0.264, acc=0.895]


Epoch 12 | Train Loss: 0.2641, Train Acc: 0.8946 | Val Loss: 0.8966, Val Acc: 0.6880, Val F1: 0.6913
üî• Saved new best model at epoch 12 (Val F1=0.6913)


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 85/85 [00:00<00:00, 89.24it/s, loss=0.228, acc=0.912]


Epoch 13 | Train Loss: 0.2283, Train Acc: 0.9122 | Val Loss: 0.9441, Val Acc: 0.6872, Val F1: 0.6917
üî• Saved new best model at epoch 13 (Val F1=0.6917)


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 85/85 [00:00<00:00, 85.04it/s, loss=0.206, acc=0.918]


Epoch 14 | Train Loss: 0.2059, Train Acc: 0.9178 | Val Loss: 0.9993, Val Acc: 0.6777, Val F1: 0.6813


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 85/85 [00:00<00:00, 90.56it/s, loss=0.181, acc=0.93]


Epoch 15 | Train Loss: 0.1806, Train Acc: 0.9296 | Val Loss: 1.0778, Val Acc: 0.6786, Val F1: 0.6836


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 85/85 [00:00<00:00, 93.69it/s, loss=0.158, acc=0.939]


Epoch 16 | Train Loss: 0.1582, Train Acc: 0.9386 | Val Loss: 1.0928, Val Acc: 0.6850, Val F1: 0.6893


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 85/85 [00:00<00:00, 87.93it/s, loss=0.152, acc=0.943]


Epoch 17 | Train Loss: 0.1522, Train Acc: 0.9428 | Val Loss: 1.1220, Val Acc: 0.6915, Val F1: 0.6954
üî• Saved new best model at epoch 17 (Val F1=0.6954)


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 85/85 [00:01<00:00, 77.26it/s, loss=0.134, acc=0.95]


Epoch 18 | Train Loss: 0.1341, Train Acc: 0.9505 | Val Loss: 1.1655, Val Acc: 0.6919, Val F1: 0.6963
üî• Saved new best model at epoch 18 (Val F1=0.6963)


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 85/85 [00:00<00:00, 92.44it/s, loss=0.119, acc=0.956]


Epoch 19 | Train Loss: 0.1191, Train Acc: 0.9556 | Val Loss: 1.2408, Val Acc: 0.6889, Val F1: 0.6912


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 85/85 [00:01<00:00, 84.18it/s, loss=0.113, acc=0.957]


Epoch 20 | Train Loss: 0.1126, Train Acc: 0.9575 | Val Loss: 1.2672, Val Acc: 0.6898, Val F1: 0.6932


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 85/85 [00:00<00:00, 90.55it/s, loss=0.105, acc=0.963]


Epoch 21 | Train Loss: 0.1047, Train Acc: 0.9632 | Val Loss: 1.3036, Val Acc: 0.6984, Val F1: 0.7036
üî• Saved new best model at epoch 21 (Val F1=0.7036)


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 85/85 [00:00<00:00, 95.34it/s, loss=0.103, acc=0.962]


Epoch 22 | Train Loss: 0.1033, Train Acc: 0.9625 | Val Loss: 1.3074, Val Acc: 0.6971, Val F1: 0.7020


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 85/85 [00:00<00:00, 86.51it/s, loss=0.0874, acc=0.967]


Epoch 23 | Train Loss: 0.0874, Train Acc: 0.9672 | Val Loss: 1.3875, Val Acc: 0.6928, Val F1: 0.6976


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 85/85 [00:00<00:00, 86.71it/s, loss=0.0839, acc=0.969]


Epoch 24 | Train Loss: 0.0839, Train Acc: 0.9693 | Val Loss: 1.4625, Val Acc: 0.6880, Val F1: 0.6936


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 85/85 [00:00<00:00, 88.24it/s, loss=0.0778, acc=0.972]


Epoch 25 | Train Loss: 0.0778, Train Acc: 0.9722 | Val Loss: 1.4757, Val Acc: 0.6910, Val F1: 0.6942


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 85/85 [00:01<00:00, 80.56it/s, loss=0.0715, acc=0.975]


Epoch 26 | Train Loss: 0.0715, Train Acc: 0.9746 | Val Loss: 1.4561, Val Acc: 0.6941, Val F1: 0.6982


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 85/85 [00:01<00:00, 77.65it/s, loss=0.0619, acc=0.978]


Epoch 27 | Train Loss: 0.0619, Train Acc: 0.9781 | Val Loss: 1.5787, Val Acc: 0.6915, Val F1: 0.6958


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 85/85 [00:01<00:00, 82.57it/s, loss=0.0589, acc=0.978]


Epoch 28 | Train Loss: 0.0589, Train Acc: 0.9783 | Val Loss: 1.5913, Val Acc: 0.6954, Val F1: 0.6990


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 85/85 [00:00<00:00, 88.25it/s, loss=0.0642, acc=0.976]


Epoch 29 | Train Loss: 0.0642, Train Acc: 0.9762 | Val Loss: 1.6216, Val Acc: 0.6984, Val F1: 0.7015


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 85/85 [00:00<00:00, 93.11it/s, loss=0.0577, acc=0.979]


Epoch 30 | Train Loss: 0.0577, Train Acc: 0.9792 | Val Loss: 1.6320, Val Acc: 0.6932, Val F1: 0.6978

Training complete.
Best Validation F1: 0.7035648040383767


In [29]:
# Final evaluation (nt_classify-style): compute loss, accuracy, macro-F1, and classification report
import torch.nn as nn
from sklearn.metrics import f1_score, classification_report
import json

loss_fn = nn.CrossEntropyLoss()

@torch.no_grad()
def eval_model_ntstyle(model, loader):
    model.eval()
    total_loss, total_correct, total_samples = 0, 0, 0

    all_preds = []
    all_labels = []

    for batch_x, batch_y in loader:
        batch_y = batch_y.to(device)

        # support both cache-based and id-based models
        try:
            logits = model(batch_x)
        except TypeError:
            tf_ids, gene_ids = prepare_id_batch(batch_x, device=device)
            logits = model(tf_ids, gene_ids)

        loss = loss_fn(logits, batch_y)

        preds = logits.argmax(dim=1)

        total_loss += loss.item() * len(batch_y)
        total_correct += (preds == batch_y).sum().item()
        total_samples += len(batch_y)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(batch_y.cpu().numpy())

    avg_loss = total_loss / total_samples
    accuracy = total_correct / total_samples
    macro_f1 = f1_score(all_labels, all_preds, average="macro")

    return avg_loss, accuracy, macro_f1, all_labels, all_preds


In [30]:
# =======================================
# Final evaluation on TEST using best model
# =======================================

# load best checkpoint
model.load_state_dict(torch.load('models/best_model.pt'))

test_loss, test_acc, test_f1, y_true, y_pred = eval_model_ntstyle(model, test_loader)

print("\n=== FINAL TEST RESULTS (BEST CHECKPOINT) ===")
print("Test Accuracy:", test_acc)
print("Test Macro F1:", test_f1)
print("\nClassification Report:")
print(classification_report(y_true, y_pred, digits=4))

# Save JSON
os.makedirs('results', exist_ok=True)
metrics = {
    'test_loss': float(test_loss),
    'accuracy': float(test_acc),
    'macro_f1': float(test_f1),
    'classification_report': classification_report(y_true, y_pred, digits=4),
}
with open('results/best_model_test_metrics.json', 'w') as f:
    json.dump(metrics, f, indent=2)

print('\nSaved metrics to results/best_model_test_metrics.json')



=== FINAL TEST RESULTS (BEST CHECKPOINT) ===
Test Accuracy: 0.7096774193548387
Test Macro F1: 0.7141396280920671

Classification Report:
              precision    recall  f1-score   support

           0     0.7004    0.6958    0.6981       766
           1     0.8249    0.7538    0.7877       731
           2     0.6317    0.6836    0.6566       828

    accuracy                         0.7097      2325
   macro avg     0.7190    0.7111    0.7141      2325
weighted avg     0.7151    0.7097    0.7115      2325


Saved metrics to results/best_model_test_metrics.json
