In [1]:
import os
import pandas as pd
from sklearn.model_selection import train_test_split
import joblib
import config
from preprocessing_utils import (
    load_clinical_data,
    load_mrna_data,
    load_mutation_data,
    generate_recurrence_labels,
    drop_patients_missing_data,
    MrnaPreprocessorWrapper,
    ClinicalPreprocessorWrapper,
    MutationPreprocessorWrapper,
)


In [None]:
# === Load data ===
clinical_df = load_clinical_data(config.CLINICAL_DATA_PATH)
mrna_df = load_mrna_data(config.MRNA_DATA_PATH)
mutation_df = load_mutation_data(config.MUTATION_DATA_PATH)
labels = generate_recurrence_labels(
    treatment_file=config.TREATMENT_DATA_PATH,
    status_file=config.STATUS_DATA_PATH,
    clinical_file=config.CLINICAL_DATA_PATH,
)

print("Clinical data shape:", clinical_df.shape)
print("mRNA data shape:", mrna_df.shape)
print("Mutation data shape:", mutation_df.shape)
print("Labels shape:", labels.shape)



Clinical data shape: (529, 37)
mRNA data shape: (527, 20531)
Mutation data shape: (515, 19112)
Labels shape: (529,)


In [5]:
# === Drop patients with missing data ===
clinical_df, mrna_df, mutation_df, labels = drop_patients_missing_data(
    clinical_df, mrna_df, mutation_df, labels
)

print("After dropping missing:")
print("Clinical data shape:", clinical_df.shape)
print("mRNA data shape:", mrna_df.shape)
print("Mutation data shape:", mutation_df.shape)
print("Labels shape:", labels.shape)

# === Ensure consistent ordering of patients ===
full_df = clinical_df.join(mrna_df, how="inner").join(mutation_df, how="inner")
labels = labels.loc[full_df.index]

# === 70/15/15 Split ===
X_train, X_temp, y_train, y_temp = train_test_split(
    full_df, labels, test_size=0.30, random_state=config.SEED, stratify=labels
)
X_val, X_test, y_val, y_test = train_test_split(
    X_temp, y_temp, test_size=0.50, random_state=config.SEED, stratify=y_temp
)

print(f"Train size: {len(X_train)}, Val size: {len(X_val)}, Test size: {len(X_test)}")

# === Separate modalities again ===
clinical_train = X_train[clinical_df.columns]
mrna_train = X_train[mrna_df.columns]
mutation_train = X_train[mutation_df.columns]

clinical_val = X_val[clinical_df.columns]
mrna_val = X_val[mrna_df.columns]
mutation_val = X_val[mutation_df.columns]

clinical_test = X_test[clinical_df.columns]
mrna_test = X_test[mrna_df.columns]
mutation_test = X_test[mutation_df.columns]

# === Initialize and fit each preprocessor on training data ===
clinical_prep = ClinicalPreprocessorWrapper(
    cols_to_remove=config.CLINICAL_COLS_TO_REMOVE,
    categorical_cols=config.CATEGORICAL_COLS,
    max_null_frac=config.CLINICAL_MAX_NULL_FRAC,
    uniform_thresh=config.CLINICAL_UNIFORM_THRESH,
)
mrna_prep = MrnaPreprocessorWrapper(
    max_null_frac=config.MAX_NULL_FRAC,
    uniform_thresh=config.UNIFORM_THRESHOLD,
    corr_thresh=config.CORRELATION_THRESHOLD,
    var_thresh=config.VARIANCE_THRESHOLD,
    re_run_pruning=config.RE_RUN_PRUNING,
    literature_genes=config.LITERATURE_GENES,
    correlated_genes_path=config.CORRELATED_GENES_PATH,
    use_stability_selection=config.USE_STABILITY_SELECTION,
    n_boots=config.N_BOOTS_FPR,
    fpr_alpha=config.FPR_ALPHA,
    stability_threshold=config.STABILITY_THRESHOLD_FPR,
    random_state=config.SEED,
)
mutation_prep = MutationPreprocessorWrapper(
    max_null_frac=config.MUTATION_MAX_NULL_FRAC,
    uniform_thresh=config.MUTATION_UNIFORM_THRESH,
)

# Fit on train
clinical_prep.fit(clinical_train)
mrna_prep.fit(mrna_train)
mutation_prep.fit(mutation_train)

# Transform train, val, test
clinical_train = clinical_prep.transform(clinical_train)
clinical_val = clinical_prep.transform(clinical_val)
clinical_test = clinical_prep.transform(clinical_test)

mrna_train = mrna_prep.transform(mrna_train)
mrna_val = mrna_prep.transform(mrna_val)
mrna_test = mrna_prep.transform(mrna_test)

mutation_train = mutation_prep.transform(mutation_train)
mutation_val = mutation_prep.transform(mutation_val)
mutation_test = mutation_prep.transform(mutation_test)

# === Create output directories ===
for split in ["train", "val", "test"]:
    os.makedirs(f"../data/{split}", exist_ok=True)

# === Save all 12 datasets ===
joblib.dump(clinical_train, "data/train/clinical.joblib")
joblib.dump(mrna_train, "data/train/mrna.joblib")
joblib.dump(mutation_train, "data/train/mutation.joblib")
joblib.dump(y_train, "data/train/labels.joblib")

joblib.dump(clinical_val, "data/val/clinical.joblib")
joblib.dump(mrna_val, "data/val/mrna.joblib")
joblib.dump(mutation_val, "data/val/mutation.joblib")
joblib.dump(y_val, "data/val/labels.joblib")

joblib.dump(clinical_test, "data/test/clinical.joblib")
joblib.dump(mrna_test, "data/test/mrna.joblib")
joblib.dump(mutation_test, "data/test/mutation.joblib")
joblib.dump(y_test, "data/test/labels.joblib")

After dropping missing:
Clinical data shape: (452, 37)
mRNA data shape: (452, 20531)
Mutation data shape: (452, 19112)
Labels shape: (452,)
Train size: 316, Val size: 68, Test size: 68


TypeError: unhashable type: 'DataFrame'