<a href="https://colab.research.google.com/github/navidh86/perturbseq-10701/blob/master/baseline_classify.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# # ONLY FOR COLAB
# !git clone https://github.com/navidh86/perturbseq-10701.git
# %cd ./perturbseq-10701
# !pip install fastparquet tqdm

In [1]:
# Imports and device
import os
import pickle
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from tqdm import tqdm
import numpy as np

from data.reference_data_classification import get_dataloader

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

Device: cuda


In [2]:
# # Create dataloaders (point to data/ paths explicitly)
# train_loader = get_dataloader(
#     parquet_path='data/tf_gene_expression_labeled.parquet',
#     tf_sequences_path='data/tf_sequences.pkl',
#     gene_sequences_path='data/gene_sequences_4000bp.pkl',
#     batch_size=128,
#     type='train',
#     majority_fraction=0.01
# )
# test_loader = get_dataloader(
#     parquet_path='data/tf_gene_expression_labeled.parquet',
#     tf_sequences_path='data/tf_sequences.pkl',
#     gene_sequences_path='data/gene_sequences_4000bp.pkl',
#     batch_size=256,
#     type='test',
#     majority_fraction=0.01
# )

# print('Train size:', len(train_loader.dataset))
# print('Test size :', len(test_loader.dataset))

# Create dataloaders (point to data/ paths explicitly)
train_loader = get_dataloader(
    parquet_path='data/tf_gene_expression_labeled_v2.parquet',
    tf_sequences_path='data/tf_sequences.pkl',
    gene_sequences_path='data/gene_sequences_4000bp.pkl',
    batch_size=128,
    type='train',
    majority_fraction=0.005
)

validation_loader = get_dataloader(
    parquet_path='data/tf_gene_expression_labeled_v2.parquet',
    tf_sequences_path='data/tf_sequences.pkl',
    gene_sequences_path='data/gene_sequences_4000bp.pkl',
    batch_size=256,
    type='val',
    majority_fraction=0.005
)

test_loader = get_dataloader(
    parquet_path='data/tf_gene_expression_labeled_v2.parquet',
    tf_sequences_path='data/tf_sequences.pkl',
    gene_sequences_path='data/gene_sequences_4000bp.pkl',
    batch_size=256,
    type='test',
    majority_fraction=0.005
)

print('Train size:', len(train_loader.dataset))
print('Validation size :', len(validation_loader.dataset))
print('Test size :', len(test_loader.dataset))

Train size: 10845
Validation size : 2324
Test size : 2325


In [3]:
train_ds = train_loader.dataset
validation_ds = validation_loader.dataset
test_ds = test_loader.dataset

combined_df = pd.concat([train_ds.df, validation_ds.df, test_ds.df]).reset_index(drop=True)

# unique names from combined set
tf_names = combined_df['tf_name'].unique().tolist()
gene_names = combined_df['gene_name'].unique().tolist()

# create mappings
tf_to_id = {n: i for i, n in enumerate(tf_names)}
gene_to_id = {n: i for i, n in enumerate(gene_names)}

num_tfs = len(tf_to_id)
num_genes = len(gene_to_id)
# Use classes from training split
num_classes = len(train_ds.df['expression_label'].unique())

print('Unique TFs (combined):', num_tfs)
print('Unique Genes (combined):', num_genes)
print('Num classes:', num_classes)

Unique TFs (combined): 223
Unique Genes (combined): 4539
Num classes: 3


In [4]:
# load NT sequence pca embeddings
tf_embed_cache_nt = pickle.load(open("./embeds/tf_cls_pca.pkl", "rb"))
gene_embed_cache_nt = pickle.load(open("./embeds/gn_cls_pca.pkl", "rb"))

# Load enformer pca embeddings
tf_embed_cache = pickle.load(open("./embeds/tf_enformer_alternate_pca.pkl", "rb"))
gene_embed_cache = pickle.load(open("./embeds/gn_enformer_alternate_pca.pkl", "rb"))

# # ensure everything is torch tensors
# for k in tf_embed_cache:
#     if not isinstance(tf_embed_cache[k], torch.Tensor):
#         tf_embed_cache[k] = torch.tensor(tf_embed_cache[k], dtype=torch.float32)

# for k in gene_embed_cache:
#     if not isinstance(gene_embed_cache[k], torch.Tensor):
#         gene_embed_cache[k] = torch.tensor(gene_embed_cache[k], dtype=torch.float32)
first_tf_nt = next(iter(tf_embed_cache_nt.values()))
first_gene_nt = next(iter(gene_embed_cache_nt.values()))
print("TF NT emb dim:", first_tf_nt.shape)
print("Gene NT emb dim:", first_gene_nt.shape)

first_tf = next(iter(tf_embed_cache.values()))
first_gene = next(iter(gene_embed_cache.values()))
print("TF enformer emb dim:", first_tf.shape)
print("Gene enformer emb dim:", first_gene.shape)



TF NT emb dim: torch.Size([110])
Gene NT emb dim: torch.Size([528])
TF enformer emb dim: torch.Size([18])
Gene enformer emb dim: torch.Size([26])


In [5]:
def one_hot(index, num_classes):
    v = torch.zeros(num_classes, dtype=torch.float32)
    v[index] = 1.0
    return v

In [58]:
def prepare_combined(tf_name, gene_name):
    tf_id = tf_to_id[tf_name]
    gene_id = gene_to_id[gene_name]
    
    # tf_one_hot = one_hot(tf_id, num_tfs)
    # gene_one_hot = one_hot(gene_id, num_genes)

    tf_embed_nt = tf_embed_cache_nt[tf_name]
    gene_embed_nt = gene_embed_cache_nt[gene_name]
    
    tf_embed = tf_embed_cache[tf_name]
    gene_embed = gene_embed_cache[gene_name]

    # interaction = torch.nn.functional.pad(tf_embed, (0, len(gene_embed)-len(tf_embed))) * gene_embed
    
    # combined_vector = torch.cat([tf_one_hot, gene_one_hot, tf_embed, gene_embed, interaction], dim=0)
    # combined_vector = torch.cat([tf_one_hot, gene_one_hot, tf_embed, gene_embed], dim=0)
    # combined_vector = torch.cat([tf_embed, gene_embed, interaction], dim=0)
    combined_vector = torch.cat([tf_embed_nt, gene_embed_nt, tf_embed, gene_embed], dim=0)
    # combined_vector = torch.cat([torch.tensor([tf_id], dtype=torch.float32), torch.tensor([gene_id], dtype=torch.float32), tf_embed_nt, gene_embed_nt, tf_embed, gene_embed], dim=0)

    return combined_vector

def prepare_combined_dataset(loader):
    X_list = []
    y_list = []
    
    for batch_x, batch_y in loader:
        for i in range(len(batch_x)):
            item = batch_x[i]
            combined_vector = prepare_combined(item['tf_name'], item['gene_name'])
            X_list.append(combined_vector.numpy())
        
        y_list.extend(batch_y.numpy())
    
    X = np.array(X_list)
    y = np.array(y_list)
    return X, y

In [59]:
# Prepare training data
print("Preparing training data with combined one-hot and embeddings...")
X_train, y_train = prepare_combined_dataset(train_loader)
print(f"Train shape after combined: X={X_train.shape}, y={y_train.shape}")

# Prepare validation dataq
X_val, y_val = prepare_combined_dataset(validation_loader)
print(f"Validation shape after combined: X={X_val.shape}, y={y_val.shape}")

# Prepare test data
X_test, y_test = prepare_combined_dataset(test_loader)
print(f"Test shape after combined: X={X_test.shape}, y={y_test.shape}")

Preparing training data with combined one-hot and embeddings...
Train shape after combined: X=(10845, 682), y=(10845,)
Validation shape after combined: X=(2324, 682), y=(2324,)
Test shape after combined: X=(2325, 682), y=(2325,)


In [60]:
print(X_train[0])

[ 6.69465363e-01  3.95498425e-02  2.41943657e-01 -2.85729796e-01
 -5.06737866e-02 -1.48269283e-02 -4.06667106e-02 -7.05079511e-02
 -1.47939818e-02 -4.10747016e-03 -1.16001561e-01  2.87545789e-02
 -2.61678211e-02 -2.15113144e-02  3.66010629e-02 -3.81143913e-02
 -2.38849372e-02  5.63825704e-02 -4.25549522e-02 -1.56994462e-02
 -3.58185992e-02  1.84632558e-02 -1.27990888e-02 -6.21355325e-03
  5.14769647e-03  1.04746275e-01 -2.38462053e-02  1.82160903e-02
  2.49444861e-02  3.09170596e-02 -4.00513746e-02  7.89891370e-03
  2.41348799e-02 -3.57803442e-02 -4.98031154e-02 -4.72851330e-03
 -2.12176163e-02 -1.58248432e-02 -1.54785812e-02  1.76582914e-02
 -1.54541340e-02  5.84011711e-03  2.08526626e-02  8.65996722e-03
 -5.56244329e-02 -2.60078534e-02 -2.06190685e-04 -9.99410916e-03
  2.06296537e-02  2.28822324e-02  3.16553339e-02  4.01316443e-03
  2.64919326e-02  1.72106270e-02  3.58469808e-03 -1.85258891e-02
  6.22777827e-03  8.77312664e-03  2.85055228e-02  1.38870850e-02
  1.46972132e-03 -2.40831

In [61]:
#!pip install xgboost

In [62]:
# XGBoost and metrics
from xgboost import XGBClassifier
from sklearn.metrics import f1_score, classification_report, accuracy_score
import json

In [75]:
# Cell 10 - Train XGBoost with one-hot encoded features
print("\nTraining XGBoost with combined encoding...")
# the following gets train f1 of 0.96, test f1 of 0.75  
# xgb_model = XGBClassifier(
#     n_estimators=200,            # Further reduce (was 100)
#     max_depth=10,                # Even shallower (was 6)
#     learning_rate=0.1,         # Slower (was 0.05)
#     subsample=0.8,              # More aggressive (was 0.6)
#     colsample_bytree=0.5,       # More aggressive (was 0.6)
#     colsample_bylevel=0.5,      # Keep aggressive
#     colsample_bynode=0.5,       # Keep aggressive
#     min_child_weight=10,        # Increase (was 10)
#     gamma=1.0,                  # Increase (was 1.0)
#     reg_alpha=3.0,              # Increase L1 (was 2.0)
#     reg_lambda=3.0,             # Increase L2 (was 5.0)
#     random_state=10701,
#     n_jobs=-1,
#     tree_method='hist',
#     eval_metric='mlogloss',
#     early_stopping_rounds=30,
#     verbosity=0
# )

xgb_model = XGBClassifier(
    n_estimators=200,            # More trees (was 200)
    max_depth=7,                # Deeper (was 10)
    learning_rate=0.05,          # Slower learning (was 0.1)
    subsample=0.85,              # Less aggressive (was 0.8)
    colsample_bytree=0.7,        # Less aggressive (was 0.5)
    colsample_bylevel=0.7,       # Less aggressive (was 0.5)
    colsample_bynode=0.7,        # Less aggressive (was 0.5)
    min_child_weight=5,          # Lower (was 10)
    gamma=1,                   # Lower (was 1.5)
    reg_alpha=3.0,               # Lower L1 (was 4.0)
    reg_lambda=3.0,              # Lower L2 (was 4.0)
    random_state=10701,
    n_jobs=-1,
    tree_method='hist',
    eval_metric='mlogloss',
    early_stopping_rounds=30,
    verbosity=1
)

xgb_model.fit(
    X_train, y_train,
    eval_set=[(X_val, y_val)],
    verbose=10
)


Training XGBoost with combined encoding...
[0]	validation_0-mlogloss:1.08028
[10]	validation_0-mlogloss:0.95616
[20]	validation_0-mlogloss:0.88243
[30]	validation_0-mlogloss:0.83189
[40]	validation_0-mlogloss:0.79384
[50]	validation_0-mlogloss:0.76076
[60]	validation_0-mlogloss:0.73676
[70]	validation_0-mlogloss:0.71654
[80]	validation_0-mlogloss:0.69903
[90]	validation_0-mlogloss:0.68490
[100]	validation_0-mlogloss:0.67335
[110]	validation_0-mlogloss:0.66393
[120]	validation_0-mlogloss:0.65524
[130]	validation_0-mlogloss:0.64742
[140]	validation_0-mlogloss:0.64096
[150]	validation_0-mlogloss:0.63553
[160]	validation_0-mlogloss:0.63035
[170]	validation_0-mlogloss:0.62628
[180]	validation_0-mlogloss:0.62280
[190]	validation_0-mlogloss:0.62052
[199]	validation_0-mlogloss:0.61775


0,1,2
,objective,'multi:softprob'
,base_score,
,booster,
,callbacks,
,colsample_bylevel,0.7
,colsample_bynode,0.7
,colsample_bytree,0.7
,device,
,early_stopping_rounds,30
,enable_categorical,False


In [76]:
# Evaluate on all sets
for name, X, y in [('Train', X_train, y_train), 
                    ('Val', X_val, y_val), 
                    ('Test', X_test, y_test)]:
    y_pred = xgb_model.predict(X)
    acc = accuracy_score(y, y_pred)
    f1 = f1_score(y, y_pred, average='macro')
    print(f"\n=== {name} Set ===")
    print(f"Accuracy: {acc:.4f}")
    print(f"Macro F1: {f1:.4f}")
    print("Classification Report:")
    print(classification_report(y, y_pred, digits=4))


=== Train Set ===
Accuracy: 0.9472
Macro F1: 0.9477
Classification Report:
              precision    recall  f1-score   support

           0     0.9485    0.9423    0.9454      3573
           1     0.9622    0.9630    0.9626      3409
           2     0.9328    0.9376    0.9352      3863

    accuracy                         0.9472     10845
   macro avg     0.9478    0.9477    0.9477     10845
weighted avg     0.9472    0.9472    0.9472     10845


=== Val Set ===
Accuracy: 0.7410
Macro F1: 0.7435
Classification Report:
              precision    recall  f1-score   support

           0     0.6969    0.7232    0.7098       766
           1     0.8461    0.8808    0.8631       730
           2     0.6827    0.6341    0.6575       828

    accuracy                         0.7410      2324
   macro avg     0.7419    0.7460    0.7435      2324
weighted avg     0.7387    0.7410    0.7393      2324


=== Test Set ===
Accuracy: 0.7514
Macro F1: 0.7533
Classification Report:
             

# Grid search

In [11]:
# Manual Grid Search with tqdm
from sklearn.model_selection import cross_val_score
from itertools import product
import numpy as np

param_grid = {
    'n_estimators': [150, 250],
    'max_depth': [8, 12],
    'learning_rate': [0.05, 0.15],
    'subsample': [0.7, 0.85],
    'colsample_bytree': [0.7, 0.85],
    'reg_lambda': [1.0, 2.0]
    # Removed min_child_weight and reg_alpha to reduce combinations
}

# Generate all combinations
keys = list(param_grid.keys())
values = list(param_grid.values())
param_combinations = [dict(zip(keys, v)) for v in product(*values)]

print(f"Total combinations to test: {len(param_combinations)}")

# Manual grid search with progress bar
best_score = -np.inf
best_params = None
results = []

for params in tqdm(param_combinations, desc="Grid Search"):
    model = XGBClassifier(
        **params,
        random_state=10701,
        n_jobs=-1,
        tree_method='hist',
        eval_metric='mlogloss',
        verbosity=0
    )
    
    # Cross-validation
    scores = cross_val_score(
        model, X_train, y_train,
        cv=3,
        scoring='f1_macro',
        n_jobs=1  # Important: don't parallelize CV
    )
    
    mean_score = scores.mean()
    std_score = scores.std()
    
    results.append({
        'params': params,
        'mean_f1': mean_score,
        'std_f1': std_score
    })
    
    if mean_score > best_score:
        best_score = mean_score
        best_params = params
        print(f"  New best! F1={mean_score:.4f} (±{std_score:.4f})")

print("\n✓ Grid Search Complete!")
print(f"Best CV F1: {best_score:.4f}")
print(f"Best parameters: {best_params}")

# Train final model with best params
print("\nTraining final model with best parameters...")
xgb_model = XGBClassifier(
    **best_params,
    random_state=10701,
    n_jobs=-1,
    tree_method='hist',
    eval_metric='mlogloss',
    verbosity=0
)

xgb_model.fit(X_train, y_train)

# Evaluate
for name, X, y in [('Train', X_train, y_train), 
                    ('Val', X_val, y_val), 
                    ('Test', X_test, y_test)]:
    y_pred = xgb_model.predict(X)
    acc = accuracy_score(y, y_pred)
    f1 = f1_score(y, y_pred, average='macro')
    print(f"\n=== {name} Set ===")
    print(f"Accuracy: {acc:.4f}, Macro F1: {f1:.4f}")

# Save results
results_df = pd.DataFrame(results)
results_df = results_df.sort_values('mean_f1', ascending=False)
results_df.to_csv('results/grid_search_results.csv', index=False)
print("\nTop 10 parameter combinations saved to results/grid_search_results.csv")
print(results_df.head(10))

Total combinations to test: 64


Grid Search:   2%|▏         | 1/64 [00:16<17:23, 16.56s/it]

  New best! F1=0.5999 (±0.0139)


Grid Search:   5%|▍         | 3/64 [00:57<19:47, 19.47s/it]

  New best! F1=0.6020 (±0.0134)


Grid Search:  14%|█▍        | 9/64 [02:36<14:44, 16.09s/it]

  New best! F1=0.6551 (±0.0063)


Grid Search:  17%|█▋        | 11/64 [03:04<13:19, 15.09s/it]

  New best! F1=0.6576 (±0.0076)


Grid Search:  39%|███▉      | 25/64 [07:07<11:46, 18.12s/it]

  New best! F1=0.6730 (±0.0080)


Grid Search:  42%|████▏     | 27/64 [07:41<10:47, 17.50s/it]

  New best! F1=0.6742 (±0.0055)


Grid Search:  64%|██████▍   | 41/64 [12:52<09:23, 24.49s/it]

  New best! F1=0.6763 (±0.0053)


Grid Search:  73%|███████▎  | 47/64 [15:10<06:29, 22.89s/it]

  New best! F1=0.6765 (±0.0052)


Grid Search:  89%|████████▉ | 57/64 [19:46<03:08, 26.86s/it]

  New best! F1=0.6857 (±0.0053)


Grid Search:  92%|█████████▏| 59/64 [20:36<02:08, 25.77s/it]

  New best! F1=0.6866 (±0.0062)


Grid Search:  95%|█████████▌| 61/64 [21:25<01:15, 25.13s/it]

  New best! F1=0.6874 (±0.0066)


Grid Search:  98%|█████████▊| 63/64 [22:14<00:24, 24.87s/it]

  New best! F1=0.6882 (±0.0061)


Grid Search: 100%|██████████| 64/64 [22:39<00:00, 21.24s/it]



✓ Grid Search Complete!
Best CV F1: 0.6882
Best parameters: {'n_estimators': 250, 'max_depth': 12, 'learning_rate': 0.15, 'subsample': 0.85, 'colsample_bytree': 0.85, 'reg_lambda': 1.0}

Training final model with best parameters...

=== Train Set ===
Accuracy: 0.7517, Macro F1: 0.7516

=== Val Set ===
Accuracy: 0.6936, Macro F1: 0.6904

=== Test Set ===
Accuracy: 0.7191, Macro F1: 0.7161

Top 10 parameter combinations saved to results/grid_search_results.csv
                                               params   mean_f1    std_f1
62  {'n_estimators': 250, 'max_depth': 12, 'learni...  0.688155  0.006108
60  {'n_estimators': 250, 'max_depth': 12, 'learni...  0.687379  0.006591
58  {'n_estimators': 250, 'max_depth': 12, 'learni...  0.686594  0.006224
56  {'n_estimators': 250, 'max_depth': 12, 'learni...  0.685724  0.005347
63  {'n_estimators': 250, 'max_depth': 12, 'learni...  0.681179  0.007222
61  {'n_estimators': 250, 'max_depth': 12, 'learni...  0.680978  0.007099
57  {'n_estimators