<a href="https://colab.research.google.com/github/navidh86/perturbseq-10701/blob/master/baseline_classify.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# # ONLY FOR COLAB
# !git clone https://github.com/navidh86/perturbseq-10701.git
# %cd ./perturbseq-10701
# !pip install fastparquet tqdm

In [2]:
# Imports and device
import os
import pickle
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from tqdm import tqdm

from data.reference_data_classification import get_dataloader

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

Device: cuda


In [3]:
# # Create dataloaders (point to data/ paths explicitly)
# train_loader = get_dataloader(
#     parquet_path='data/tf_gene_expression_labeled.parquet',
#     tf_sequences_path='data/tf_sequences.pkl',
#     gene_sequences_path='data/gene_sequences_4000bp.pkl',
#     batch_size=128,
#     type='train',
#     majority_fraction=0.01
# )
# test_loader = get_dataloader(
#     parquet_path='data/tf_gene_expression_labeled.parquet',
#     tf_sequences_path='data/tf_sequences.pkl',
#     gene_sequences_path='data/gene_sequences_4000bp.pkl',
#     batch_size=256,
#     type='test',
#     majority_fraction=0.01
# )

# print('Train size:', len(train_loader.dataset))
# print('Test size :', len(test_loader.dataset))

# Create dataloaders (point to data/ paths explicitly)
train_loader = get_dataloader(
    parquet_path='data/tf_gene_expression_labeled_v2.parquet',
    tf_sequences_path='data/tf_sequences.pkl',
    gene_sequences_path='data/gene_sequences_4000bp.pkl',
    batch_size=128,
    type='train',
    majority_fraction=0.005
)

validation_loader = get_dataloader(
    parquet_path='data/tf_gene_expression_labeled_v2.parquet',
    tf_sequences_path='data/tf_sequences.pkl',
    gene_sequences_path='data/gene_sequences_4000bp.pkl',
    batch_size=256,
    type='val',
    majority_fraction=0.005
)

test_loader = get_dataloader(
    parquet_path='data/tf_gene_expression_labeled_v2.parquet',
    tf_sequences_path='data/tf_sequences.pkl',
    gene_sequences_path='data/gene_sequences_4000bp.pkl',
    batch_size=256,
    type='test',
    majority_fraction=0.005
)

print('Train size:', len(train_loader.dataset))
print('Validation size :', len(validation_loader.dataset))
print('Test size :', len(test_loader.dataset))

Train size: 10845
Validation size : 2324
Test size : 2325


In [4]:
train_ds = train_loader.dataset
validation_ds = validation_loader.dataset
test_ds = test_loader.dataset

combined_df = pd.concat([train_ds.df, validation_ds.df, test_ds.df]).reset_index(drop=True)

# unique names from combined set
tf_names = combined_df['tf_name'].unique().tolist()
gene_names = combined_df['gene_name'].unique().tolist()

# create mappings
tf_to_id = {n: i for i, n in enumerate(tf_names)}
gene_to_id = {n: i for i, n in enumerate(gene_names)}

num_tfs = len(tf_to_id)
num_genes = len(gene_to_id)
# Use classes from training split
num_classes = len(train_ds.df['expression_label'].unique())

print('Unique TFs (combined):', num_tfs)
print('Unique Genes (combined):', num_genes)
print('Num classes:', num_classes)

Unique TFs (combined): 223
Unique Genes (combined): 4539
Num classes: 3


In [5]:
def one_hot(index, num_classes):
    v = torch.zeros(num_classes, dtype=torch.float32)
    v[index] = 1.0
    return v

In [6]:
import numpy as np

def prepare_one_hot(tf_name, gene_name):
    tf_id = tf_to_id[tf_name]
    gene_id = gene_to_id[gene_name]

    tf_one_hot = one_hot(tf_id, num_tfs)
    gene_one_hot = one_hot(gene_id, num_genes)

    combined_vector = torch.cat([tf_one_hot, gene_one_hot], dim=0)

    return combined_vector

def prepare_combined_dataset(loader):
    X_list = []
    y_list = []
    
    for batch_x, batch_y in loader:
        for i in range(len(batch_x)):
            item = batch_x[i]
            combined_vector = prepare_one_hot(item['tf_name'], item['gene_name'])
            X_list.append(combined_vector.numpy())
        
        y_list.extend(batch_y.numpy())
    
    X = np.array(X_list)
    y = np.array(y_list)
    return X, y

In [7]:
# Prepare training data
print("Preparing training data with combined one-hot and embeddings...")
X_train, y_train = prepare_combined_dataset(train_loader)
print(f"Train shape after combined: X={X_train.shape}, y={y_train.shape}")

# Prepare validation dataq
X_val, y_val = prepare_combined_dataset(validation_loader)
print(f"Validation shape after combined: X={X_val.shape}, y={y_val.shape}")

# Prepare test data
X_test, y_test = prepare_combined_dataset(test_loader)
print(f"Test shape after combined: X={X_test.shape}, y={y_test.shape}")

Preparing training data with combined one-hot and embeddings...
Train shape after combined: X=(10845, 4762), y=(10845,)
Validation shape after combined: X=(2324, 4762), y=(2324,)
Test shape after combined: X=(2325, 4762), y=(2325,)


In [8]:
print(X_train[0])

[0. 0. 0. ... 0. 0. 0.]


In [9]:
# Random Forest Classification
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import f1_score, classification_report, accuracy_score

import json
import pickle
import os

In [10]:
# Train Random Forest
print("\nTraining Random Forest...")
rf_model = RandomForestClassifier(
    n_estimators=200,        # number of trees
    max_depth=20,            # maximum depth of trees
    min_samples_split=20,     # minimum samples to split a node
    min_samples_leaf=5,      # minimum samples in leaf
    max_features='sqrt',
    random_state=10701,
    n_jobs=-1,               # use all CPU cores
    verbose=0,
    class_weight='balanced'  # handle class imbalance
)

rf_model.fit(X_train, y_train)
print("Training complete!")


Training Random Forest...
Training complete!


In [11]:
# Evaluate
for name, X, y in [('Train', X_train, y_train), 
                    ('Val', X_val, y_val), 
                    ('Test', X_test, y_test)]:
    y_pred = rf_model.predict(X)
    acc = accuracy_score(y, y_pred)
    f1 = f1_score(y, y_pred, average='macro')
    print(f"\n=== {name} Set ===")
    print(f"Accuracy: {acc:.4f}")
    print(f"Macro F1: {f1:.4f}")
    print("Classification Report:")
    print(classification_report(y, y_pred))


=== Train Set ===
Accuracy: 0.6155
Macro F1: 0.5918
Classification Report:
              precision    recall  f1-score   support

           0       0.61      0.65      0.63      3573
           1       0.60      0.94      0.73      3409
           2       0.68      0.30      0.42      3863

    accuracy                           0.62     10845
   macro avg       0.63      0.63      0.59     10845
weighted avg       0.63      0.62      0.59     10845


=== Val Set ===
Accuracy: 0.5895
Macro F1: 0.5586
Classification Report:
              precision    recall  f1-score   support

           0       0.58      0.63      0.61       766
           1       0.58      0.94      0.72       730
           2       0.63      0.24      0.35       828

    accuracy                           0.59      2324
   macro avg       0.60      0.60      0.56      2324
weighted avg       0.60      0.59      0.55      2324


=== Test Set ===
Accuracy: 0.5931
Macro F1: 0.5649
Classification Report:
             