In [1]:
import config

import pandas as pd
import numpy as np
import joblib
from collections import Counter

from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import OneHotEncoder, LabelEncoder, OrdinalEncoder
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.feature_selection import SelectFpr, f_classif
from sklearn.utils import resample

In [None]:
# Loading in the mRNA and clinical data:
clinical_df = pd.read_csv("..\ucec_tcga_pan_can_atlas_2018\data_clinical_patient.txt", sep="\t", comment="#", low_memory=False)
clinical_df = clinical_df.set_index('PATIENT_ID')

mrna_df = pd.read_csv("..\ucec_tcga_pan_can_atlas_2018/data_mrna_seq_v2_rsem_zscores_ref_all_samples.txt", sep="\t", comment="#")

# There are 527 patients in the mRNA and 529 patients in the clinical data

# The first 2 columns of the mRNA data are labels (Hugo_Symbol then Entrez_Gene_Id). 
# 13 of the genes do not have Hugo_symbols, so for these I will you the Entrex_Gene_Id as the label.
missing_symbols = mrna_df['Hugo_Symbol'].isnull()
mrna_df.loc[missing_symbols, 'Hugo_Symbol'] = mrna_df.loc[missing_symbols, 'Entrez_Gene_Id'].astype(str)

# There are 7 rows that have both the same Hugo_Symbol and Entrez_Gene_Id but different values for the patients.
# I will rename these rows to have unique labels by appending -1-of-2 and -2-of-2 to the Hugo_Symbol.
# Get value counts
counts = mrna_df['Hugo_Symbol'].value_counts()

# Generate unique labels for duplicates
def label_duplicates(value, index):
    if counts[value] == 1:
        return value  # Keep unique values unchanged
    occurrence = mrna_df.groupby('Hugo_Symbol').cumcount() + 1  # Count occurrences per group
    return f"{value}-{occurrence[index]}-of-{counts[value]}"

# Apply the labeling function
mrna_df['Hugo_Symbol'] = [label_duplicates(value, idx) for idx, value in mrna_df['Hugo_Symbol'].items()]

mrna_df = mrna_df.set_index('Hugo_Symbol')
mrna_df = mrna_df.drop(columns="Entrez_Gene_Id") # removing the label column before I transpose the df
mrna_df= mrna_df.transpose() # now the patients are the index and the genes are the columns
mrna_df.index = [id[:-3] for id in mrna_df.index] # removes extranious -01 so that the patient ids match the clinical data



FileNotFoundError: [Errno 2] No such file or directory: 'ucec_tcga_pan_can_atlas_2018\\data_clinical_patient.txt'

In [None]:
def assign_labels(clinical_df):
    '''given the clinical dataframe, returns the corresposnding labels, 
    assigning 1 for recurrance, 0 for no recurrance, 
    and None if the patient has no recurrence information. 
    Currently uses NEW_TUMOR_EVENT_AFTER_INITIAL_TREATMENT to identify recurrance.
    If NEW_TUMOR_EVENT_AFTER_INITIAL_TREATMENT is NaN, uses DSF_STATUS to save the label.'''
    labels = []
    for _, row in clinical_df.iterrows():
        if row['NEW_TUMOR_EVENT_AFTER_INITIAL_TREATMENT'] == 'Yes':
            labels.append(1)
        elif row['NEW_TUMOR_EVENT_AFTER_INITIAL_TREATMENT'] == 'No':
            labels.append(0)
        elif pd.isna(row['NEW_TUMOR_EVENT_AFTER_INITIAL_TREATMENT']):
            if row['DFS_STATUS'] == '1:Recurred/Progressed':
                labels.append(1)
            elif row['DFS_STATUS'] == '0:DiseaseFree':
                labels.append(0)
            else:
                labels.append(None)
    return pd.Series(labels, index=clinical_df.index)

    


def drop_patients_missing_data(clinical_df, mrna_df, labels):
    '''Drops patients from both dataframes that are not present in the other dataframe. 
    Drops patients who are missing labeling data used to define recurrence.
    Returns the cleaned dataframes and labels.'''
    # Find patient IDs not shared between the two dataframes:
    clinical_not_in_mrna = set(clinical_df.index) - set(mrna_df.index)
    mrna_not_in_clinical = set(mrna_df.index) - set(clinical_df.index)
    # There are 2 patients ('TCGA-EY-A1GJ', 'TCGA-AP-A0LQ') in the clinical data that are not in the mRNA data.
    clinical_df = clinical_df.drop(index=clinical_not_in_mrna)
    mrna_df = mrna_df.drop(index=mrna_not_in_clinical)
    labels = labels.drop(index=clinical_not_in_mrna)
    labels = labels.drop(index=mrna_not_in_clinical)
    assert clinical_df.shape[0] == mrna_df.shape[0] == labels.shape[0], "Dataframes have different number of patients after cleaning"

    # Now drop patients missing labeling data used to define recurrence:
    patients_no_label = labels[labels.isna()].index
    clinical_df = clinical_df.drop(index=patients_no_label)
    mrna_df = mrna_df.drop(index=patients_no_label)
    labels = labels.drop(index=patients_no_label)
    assert not labels.isna().any(), "Found unlabeled patient after cleaning"

    return clinical_df, mrna_df, labels

In [None]:
def drop_post_diagnosis_clinical_columns(clinical_df):
    '''Removes all columns in the clinical data that are recurrence indicators or are not available at diagnosis.
    Returns the cleaned clinical dataframe and the labels series.'''
    cols_to_drop = [
    "DAYS_LAST_FOLLOWUP",              # follow-up time after diagnosis (future info)
    "FORM_COMPLETION_DATE",            # administrative metadata, not predictive
    "INFORMED_CONSENT_VERIFIED",       # administrative, no biological meaning
    "NEW_TUMOR_EVENT_AFTER_INITIAL_TREATMENT",  # recurrence event → direct leakage
    "PERSON_NEOPLASM_CANCER_STATUS",   # disease status at follow-up → leakage
    "IN_PANCANPATHWAYS_FREEZE",        # technical/analysis flag, not biological
    "OS_STATUS",                       # overall survival outcome → leakage
    "OS_MONTHS",                       # overall survival time → leakage
    "DSS_STATUS",                      # disease-specific survival outcome → leakage
    "DSS_MONTHS",                      # disease-specific survival time → leakage
    "DFS_STATUS",                      # disease-free survival outcome → leakage
    "DFS_MONTHS",                      # disease-free survival time → leakage
    "PFS_STATUS",                      # progression-free survival outcome → leakage
    "PFS_MONTHS"                       # progression-free survival time → leakage
]
    clinical_df = clinical_df.drop(columns=cols_to_drop)
    return clinical_df  

In [None]:
from sklearn.model_selection import train_test_split

def split_train_test(clinical_df, mrna_df, labels, test_size=0.2, random_state=1):
    """
    Splits clinical and mRNA data into train/test sets using precomputed labels.

    Parameters
    ----------
    clinical_df : pd.DataFrame
        Clinical features (indexed by patient ID).
    mrna_df : pd.DataFrame
        mRNA expression features (indexed by patient ID).
    labels : pd.Series
        Precomputed labels indexed by patient ID.
    test_size : float
        Fraction of patients to hold out for testing.
    random_state : int
        Random seed for reproducibility.

    Returns
    -------
    dict of train/test splits:
        {
            "X_clinical_train", "X_clinical_test",
            "X_mrna_train", "X_mrna_test",
            "y_train", "y_test"
        }
    """

    # Train/test split on patient IDs
    train_ids, test_ids = train_test_split(
        labels.index,
        test_size=test_size,
        stratify=labels,
        random_state=random_state
    )

    # Slice dataframes and labels
    splits = {
        "X_clinical_train": clinical_df.loc[train_ids],
        "X_clinical_test":  clinical_df.loc[test_ids],
        "X_mrna_train":     mrna_df.loc[train_ids],
        "X_mrna_test":      mrna_df.loc[test_ids],
        "y_train":          labels.loc[train_ids],
        "y_test":           labels.loc[test_ids]
    }

    return splits

In [None]:
def encode_clinical_features(X_train, X_test, categorical_cols=None, ordinal_cols=None, ordinal_mappings=None):
    """
    Encode clinical features: one-hot for categorical, ordinal for ordinal columns.

    Parameters
    ----------
    X_train : pd.DataFrame
        Training clinical features.
    X_test : pd.DataFrame
        Test clinical features.
    categorical_cols : list of str
        Columns to one-hot encode.
    ordinal_cols : list of str
        Columns to ordinally encode.
    ordinal_mappings : dict
        Mapping of column name -> list of categories in order for ordinal encoding.
        Example: {'TUMOR_GRADE': ['G1', 'G2', 'G3']}

    Returns
    -------
    X_train_encoded, X_test_encoded : pd.DataFrame, pd.DataFrame
        Encoded training and test clinical dataframes.
    """

    X_train_encoded = X_train.copy()
    X_test_encoded = X_test.copy()

    # --- One-hot encode categorical columns ---
    if categorical_cols:
        ohe = OneHotEncoder(sparse=False, handle_unknown='ignore')
        ohe_train = ohe.fit_transform(X_train_encoded[categorical_cols])
        ohe_test  = ohe.transform(X_test_encoded[categorical_cols])

        ohe_columns = ohe.get_feature_names_out(categorical_cols)
        ohe_train_df = pd.DataFrame(ohe_train, columns=ohe_columns, index=X_train_encoded.index)
        ohe_test_df  = pd.DataFrame(ohe_test, columns=ohe_columns, index=X_test_encoded.index)

        X_train_encoded = pd.concat([X_train_encoded.drop(columns=categorical_cols), ohe_train_df], axis=1)
        X_test_encoded  = pd.concat([X_test_encoded.drop(columns=categorical_cols), ohe_test_df], axis=1)

    # --- Ordinal encode ordinal columns ---
    if ordinal_cols:
        if ordinal_mappings is None:
            raise ValueError("You must provide ordinal_mappings when encoding ordinal columns.")

        for col in ordinal_cols:
            encoder = OrdinalEncoder(categories=[ordinal_mappings[col]])
            X_train_encoded[[col]] = encoder.fit_transform(X_train_encoded[[col]])
            X_test_encoded[[col]]  = encoder.transform(X_test_encoded[[col]])

    return X_train_encoded, X_test_encoded


In [None]:
class ClinicalPreprocessor:
    def __init__(self, cols_to_remove, categorical_cols, max_null_frac=0.3, uniform_thresh=0.99):
        self.cols_to_remove = cols_to_remove
        self.categorical_cols = categorical_cols
        self.max_null_frac = max_null_frac
        self.uniform_thresh = uniform_thresh
        
        # Saved state after fit
        self.removed_cols_ = []
        self.columns_ = None  # final column order
        self.num_fill_values_ = {}
        self.cat_fill_values_ = {}
    
    def _drop_highly_uniform_columns(self, df):
        """Identifies highly uniform columns (> threshold same value)."""
        cols_to_drop = []
        for col in df.columns:
            non_na_values = df[col].dropna()
            if not non_na_values.empty:
                top_freq = non_na_values.value_counts(normalize=True).iloc[0]
                if top_freq > self.uniform_thresh:
                    cols_to_drop.append(col)
        return cols_to_drop
    
    def fit(self, df):
        # --- Step 1. Drop specified columns
        removed = [c for c in self.cols_to_remove if c in df.columns]
        
        # --- Step 2. Drop columns with too many nulls
        thresh = len(df) * (1 - self.max_null_frac)
        high_null_cols = [c for c in df.columns if df[c].isna().sum() > len(df) - thresh]
        removed.extend(high_null_cols)
        
        # --- Step 3. Drop highly uniform columns
        uniform_cols = self._drop_highly_uniform_columns(df)
        removed.extend(uniform_cols)

        # --- Step 4. Drop all identified columns
        df = df.drop(columns=removed, errors="ignore")
        
        # --- Step 5. Fill NaNs
        # Numerical → median
        numeric_cols = df.select_dtypes(include=['number']).columns
        self.num_fill_values_ = df[numeric_cols].median()
        df[numeric_cols] = df[numeric_cols].fillna(self.num_fill_values_)
        
        # Categorical → mode
        cat_cols = [c for c in self.categorical_cols if c in df.columns]
        self.cat_fill_values_ = {c: df[c].mode().iloc[0] for c in cat_cols if not df[c].dropna().empty}
        for c, mode_val in self.cat_fill_values_.items():
            df[c] = df[c].fillna(mode_val)
        
        # --- Step 6. One-hot encode categorical
        df_enc = pd.get_dummies(df, columns=cat_cols, drop_first=True)
        
        # Save results
        self.removed_cols_ = removed
        self.columns_ = df_enc.columns.tolist()
        
        return self
    
    def transform(self, df):
        # Drop removed cols
        df = df.drop(columns=[c for c in self.removed_cols_ if c in df.columns], errors="ignore")
        
        # --- Fill NaNs using training fill values
        numeric_cols = df.select_dtypes(include=['number']).columns
        for c in numeric_cols:
            if c in self.num_fill_values_:
                df[c] = df[c].fillna(self.num_fill_values_[c])
        
        cat_cols = [c for c in self.categorical_cols if c in df.columns]
        for c in cat_cols:
            if c in self.cat_fill_values_:
                df[c] = df[c].fillna(self.cat_fill_values_[c])
        
        # One-hot encode
        df_enc = pd.get_dummies(df, columns=cat_cols, drop_first=True)
        
        # Reindex to training columns (fill missing with 0)
        df_enc = df_enc.reindex(columns=self.columns_, fill_value=0)
        
        return df_enc

In [None]:
# class MrnaPreprocessor:
#     def __init__(self, max_null_frac=0.3, uniform_thresh=0.99, corr_thresh=0.9, var_thresh=1e-5, literature_genes=set()):
#         self.max_null_frac = max_null_frac
#         self.uniform_thresh = uniform_thresh
#         self.corr_thresh = corr_thresh
#         self.var_thresh = var_thresh
#         self.literature_genes = literature_genes

#         # Saved state after fit
#         self.removed_cols_ = []
#         self.medians_ = {}
#         self.columns_ = None

#     def _drop_highly_uniform_columns(self, df):
#         """Identify and drop highly uniform columns (> threshold)."""
#         cols_to_drop = []
#         for col in df.columns:
#             non_na_values = df[col].dropna()
#             if not non_na_values.empty:
#                 top_freq = non_na_values.value_counts(normalize=True).iloc[0]
#                 if top_freq > self.uniform_thresh:
#                     cols_to_drop.append(col)
#         return df.drop(columns=cols_to_drop), cols_to_drop

#     def _prune_correlated_features(self, df):
#         """Prune correlated features above correlation threshold."""
#         corr_matrix = df.corr().abs()
#         np.fill_diagonal(corr_matrix.values, 0)

#         high_corr_map = {
#             gene: set(corr_matrix.index[corr_matrix.loc[gene] >= self.corr_thresh])
#             for gene in corr_matrix.columns
#         }

#         genes_to_keep = set(corr_matrix.columns)
#         genes_to_remove = set()

#         while True:
#             correlated_genes = {g: nbrs for g, nbrs in high_corr_map.items() if nbrs & genes_to_keep}
#             if not correlated_genes:
#                 break

#             degrees = {g: len(nbrs & genes_to_keep) for g, nbrs in correlated_genes.items() if g in genes_to_keep}
#             if not degrees:
#                 break

#             worst_gene = max(degrees, key=lambda g: degrees[g])

#             if worst_gene in self.literature_genes:
#                 neighbors = correlated_genes[worst_gene] & genes_to_keep
#                 non_lit_neighbors = [n for n in neighbors if n not in self.literature_genes]
#                 if non_lit_neighbors:
#                     worst_gene = min(non_lit_neighbors, key=lambda n: df[n].var())
#                 else:
#                     break
#             else:
#                 ties = [g for g, d in degrees.items() if d == degrees[worst_gene]]
#                 if len(ties) > 1:
#                     worst_gene = min(ties, key=lambda g: df[g].var())
            
#             genes_to_remove.add(worst_gene)
#             genes_to_keep.remove(worst_gene)

#         return df[list(genes_to_keep)], genes_to_remove


#     def fit(self, df):
#         removed = []

#         # Step 1. Drop columns with too many nulls
#         high_null_cols = [c for c in df.columns if df[c].isna().sum() > len(df) * self.max_null_frac]
#         removed.extend(high_null_cols)
#         df_temp = df.drop(columns=high_null_cols, errors="ignore")
#         print(f"Dropped {len(high_null_cols)} columns with >{self.max_null_frac*100}% nulls")

#         # Step 2. Drop highly uniform columns
#         df_temp, uniform_cols = self._drop_highly_uniform_columns(df_temp)
#         removed.extend(uniform_cols)
#         print(f"Dropped {len(uniform_cols)} highly uniform columns")

#         # Step 3. Fill NaNs with median
#         self.medians_ = df_temp.median().to_dict()
#         df_temp = df_temp.fillna(self.medians_)

#         # Step 4. Variance filter
#         low_var_cols = [c for c in df_temp.columns if df_temp[c].var() < self.var_thresh]
#         df_temp = df_temp.drop(columns=low_var_cols, errors="ignore")
#         removed.extend(low_var_cols)
#         print(f"Dropped {len(low_var_cols)} low variance columns (<{self.var_thresh})")

#         # CHECKING TO SEE IF PRUNING IS HELPING OR HURTING
#         # # Step 5. Prune correlated features
#         # df_temp, correlated_genes = self._prune_correlated_features(df_temp)
#         # removed.extend(correlated_genes)
#         # print(f"Dropped {len(correlated_genes)} correlated genes (>{self.corr_thresh} correlation)")

#         # Save final state
#         self.removed_cols_ = list(set(removed))
#         self.columns_ = df_temp.columns.tolist()

#         return self

#     def transform(self, df):
#         # Drop known removed cols
#         df = df.drop(columns=[c for c in self.removed_cols_ if c in df.columns])

#         # Fill NaNs with median
#         df = df.fillna(self.medians_)

#         # Check column alignment
#         missing = set(self.columns_) - set(df.columns)
#         extra = set(df.columns) - set(self.columns_)
#         if missing or extra:
#             raise ValueError(
#                 f"Column mismatch! Missing: {missing}, Extra: {extra}, "
#                 f"{len(missing)} missing, {len(extra)} extra"
#             )

#         # Reorder df to match training column order
#         df = df[self.columns_]

#         return df


In [None]:
# def stability_feature_selection(X, y, n_boots=100, alpha=0.05, stability_threshold=0.8, random_state=42):
#     """
#     Stability-based feature selection using bootstrapped univariate tests.

#     Parameters
#     ----------
#     X : pd.DataFrame
#         Feature matrix (e.g. mRNA expression).
#     y : pd.Series
#         Labels.
#     n_boots : int
#         Number of bootstrap samples.
#     alpha : float
#         Significance level for SelectFpr.
#     stability_threshold : float
#         Minimum selection frequency to keep a feature.
#     random_state : int
#         Reproducibility seed.

#     Returns
#     -------
#     selected_features : list
#         List of stable feature names.
#     selection_freq : pd.Series
#         Selection frequency for each feature.
#     """
#     np.random.seed(random_state)
#     feature_counts = pd.Series(0, index=X.columns)

#     for i in range(n_boots):
#         X_boot, y_boot = resample(
#             X, y,
#             stratify=y,
#             n_samples=len(y),
#             replace=True,
#             random_state=random_state+i
#         )
#         selector = SelectFpr(score_func=f_classif, alpha=alpha)
#         selector.fit(X_boot, y_boot)
#         selected = X_boot.columns[selector.get_support()]
#         feature_counts[selected] += 1

#     selection_freq = feature_counts / n_boots
#     selected_features = selection_freq[selection_freq >= stability_threshold].index.tolist()

#     print(f"{len(selected_features)} features selected (out of {X.shape[1]})")
#     return selected_features, selection_freq


In [None]:
class MrnaPreprocessor:
    def __init__(self,
                 max_null_frac=0.3,
                 uniform_thresh=0.99,
                 corr_thresh=0.9,
                 var_thresh=1e-5,
                 literature_genes=set(),
                 use_stability_selection=True,
                 n_boots=100,
                 fpr_alpha=0.05,
                 stability_threshold=0.8,
                 random_state=42):

        self.max_null_frac = max_null_frac
        self.uniform_thresh = uniform_thresh
        self.corr_thresh = corr_thresh
        self.var_thresh = var_thresh
        self.literature_genes = literature_genes

        # Stability selection params
        self.use_stability_selection = use_stability_selection
        self.n_boots = n_boots
        self.fpr_alpha = fpr_alpha
        self.stability_threshold = stability_threshold
        self.random_state = random_state

        # Saved state after fit
        self.removed_cols_ = []
        self.medians_ = {}
        self.columns_ = None
        self.selection_freq_ = None

    def _drop_highly_uniform_columns(self, df):
        """Identify and drop highly uniform columns (> threshold)."""
        cols_to_drop = []
        for col in df.columns:
            non_na_values = df[col].dropna()
            if not non_na_values.empty:
                top_freq = non_na_values.value_counts(normalize=True).iloc[0]
                if top_freq > self.uniform_thresh:
                    cols_to_drop.append(col)
        return df.drop(columns=cols_to_drop), cols_to_drop

    def _prune_correlated_features(self, df):
        """Prune correlated features above correlation threshold."""
        corr_matrix = df.corr().abs()
        np.fill_diagonal(corr_matrix.values, 0)

        high_corr_map = {
            gene: set(corr_matrix.index[corr_matrix.loc[gene] >= self.corr_thresh])
            for gene in corr_matrix.columns
        }

        genes_to_keep = set(corr_matrix.columns)
        genes_to_remove = set()

        while True:
            correlated_genes = {g: nbrs for g, nbrs in high_corr_map.items() if nbrs & genes_to_keep}
            if not correlated_genes:
                break

            degrees = {g: len(nbrs & genes_to_keep) for g, nbrs in correlated_genes.items() if g in genes_to_keep}
            if not degrees:
                break

            worst_gene = max(degrees, key=lambda g: degrees[g])

            if worst_gene in self.literature_genes:
                neighbors = correlated_genes[worst_gene] & genes_to_keep
                non_lit_neighbors = [n for n in neighbors if n not in self.literature_genes]
                if non_lit_neighbors:
                    worst_gene = min(non_lit_neighbors, key=lambda n: df[n].var())
                else:
                    break
            else:
                ties = [g for g, d in degrees.items() if d == degrees[worst_gene]]
                if len(ties) > 1:
                    worst_gene = min(ties, key=lambda g: df[g].var())
            
            genes_to_remove.add(worst_gene)
            genes_to_keep.remove(worst_gene)

        return df[list(genes_to_keep)], genes_to_remove

    def _stability_feature_selection(self, X, y):
        """Bootstrap stability-based feature selection (Jessie’s approach)."""
        np.random.seed(self.random_state)
        feature_counts = pd.Series(0, index=X.columns)

        for i in range(self.n_boots):
            X_boot, y_boot = resample(
                X, y,
                stratify=y,
                n_samples=len(y),
                replace=True,
                random_state=self.random_state+i
            )
            selector = SelectFpr(score_func=f_classif, alpha=self.fpr_alpha)
            selector.fit(X_boot, y_boot)
            selected = X_boot.columns[selector.get_support()]
            feature_counts[selected] += 1

        selection_freq = feature_counts / self.n_boots
        selected_features = selection_freq[selection_freq >= self.stability_threshold].index.tolist()

        print(f"Stability selection: kept {len(selected_features)} / {X.shape[1]} features "
              f"({self.stability_threshold*100:.0f}% stability threshold)")

        self.selection_freq_ = selection_freq
        return X[selected_features], list(set(X.columns) - set(selected_features))

    def fit(self, df, y=None):
        removed = []

        # Step 1. Drop columns with too many nulls
        high_null_cols = [c for c in df.columns if df[c].isna().sum() > len(df) * self.max_null_frac]
        removed.extend(high_null_cols)
        df_temp = df.drop(columns=high_null_cols, errors="ignore")
        print(f"Dropped {len(high_null_cols)} columns with >{self.max_null_frac*100}% nulls")

        # Step 2. Drop highly uniform columns
        df_temp, uniform_cols = self._drop_highly_uniform_columns(df_temp)
        removed.extend(uniform_cols)
        print(f"Dropped {len(uniform_cols)} highly uniform columns")

        # Step 3. Fill NaNs with median
        self.medians_ = df_temp.median().to_dict()
        df_temp = df_temp.fillna(self.medians_)

        # Step 4. Variance filter
        low_var_cols = [c for c in df_temp.columns if df_temp[c].var() < self.var_thresh]
        df_temp = df_temp.drop(columns=low_var_cols, errors="ignore")
        removed.extend(low_var_cols)
        print(f"Dropped {len(low_var_cols)} low variance columns (<{self.var_thresh})")

        # Step 5. Prune correlated features
        df_temp, correlated_genes = self._prune_correlated_features(df_temp)
        removed.extend(correlated_genes)
        print(f"Dropped {len(correlated_genes)} correlated genes (>{self.corr_thresh} correlation)")

        # Step 6. Stability-based selection
        if self.use_stability_selection:
            if y is None:
                raise ValueError("y labels required for stability-based feature selection")
            df_temp, dropped_stability = self._stability_feature_selection(df_temp, y)
            removed.extend(dropped_stability)

        # Save final state
        self.removed_cols_ = list(set(removed))
        self.columns_ = df_temp.columns.tolist()

        return self

    def transform(self, df):
        # Drop known removed cols
        df = df.drop(columns=[c for c in self.removed_cols_ if c in df.columns], errors="ignore")

        # Fill NaNs with median
        df = df.fillna(self.medians_)

        # Check column alignment
        missing = set(self.columns_) - set(df.columns)
        extra = set(df.columns) - set(self.columns_)
        if missing or extra:
            raise ValueError(
                f"Column mismatch! Missing: {missing}, Extra: {extra}, "
                f"{len(missing)} missing, {len(extra)} extra"
            )

        # Reorder df to match training column order
        df = df[self.columns_]

        return df


In [None]:
labels = assign_labels(clinical_df)
clinical_df, mrna_df, labels = drop_patients_missing_data(clinical_df, mrna_df, labels)
clinical_df = drop_post_diagnosis_clinical_columns(clinical_df)
splits = split_train_test(clinical_df, mrna_df, labels, test_size=0.2, random_state=42)

clinical_train = splits["X_clinical_train"]
clinical_test  = splits["X_clinical_test"]
mrna_train     = splits["X_mrna_train"]
mrna_test      = splits["X_mrna_test"]
y_train          = splits["y_train"]
y_test           = splits["y_test"]

cols_to_remove = [
    "CANCER_TYPE_ACRONYM",
    "OTHER_PATIENT_ID",
    "SEX",
    "AJCC_PATHOLOGIC_TUMOR_STAGE",
    "DAYS_TO_INITIAL_PATHOLOGIC_DIAGNOSIS",
    "HISTORY_NEOADJUVANT_TRTYN",
    "PATH_M_STAGE",
    "ICD_O_3_SITE", # removed because is the same as ICD_10
    "ICD_O_3_"
]
categorical_cols = ['SUBTYPE', 
                    'ETHNICITY', 
                    "ICD_10", 
                    "ICD_O_3_HISTOLOGY", 
                    "PRIOR_DX", 
                    "RACE", 
                    "RADIATION_THERAPY", 
                    "GENETIC_ANCESTRY_LABEL"
]

clinical_preproc = ClinicalPreprocessor(
    cols_to_remove=cols_to_remove,
    categorical_cols=categorical_cols,
    max_null_frac=config.MAX_NULL_VALS,
    uniform_thresh=config.UNIFORM_THRESHOLD
)

# Fit on train
clinical_train = clinical_preproc.fit(clinical_train).transform(clinical_train)

# Apply same preprocessing to test
clinical_test = clinical_preproc.transform(clinical_test)

# See which columns were dropped
print("Removed columns:", clinical_preproc.removed_cols_)

# Initialize preprocessor with config values
mrna_preproc = MrnaPreprocessor(
    max_null_frac=config.MAX_NULL_VALS,
    uniform_thresh=config.UNIFORM_THRESHOLD,
    corr_thresh=config.CORRELATION_THRESHOLD,
    var_thresh=config.VARIANCE_THRESHOLD,
    literature_genes=config.LITERATURE_GENES,
    n_boots=config.N_BOOTSTRAPS,
    fpr_alpha=config.FPR_ALPHA,
    stability_threshold=config.STABILITY_THRESHOLD,
    random_state=42
)

# Fit on training set, then transform train + test
mrna_train = mrna_preproc.fit(mrna_train, y_train).transform(mrna_train)
mrna_test  = mrna_preproc.transform(mrna_test)


Removed columns: ['CANCER_TYPE_ACRONYM', 'OTHER_PATIENT_ID', 'SEX', 'AJCC_PATHOLOGIC_TUMOR_STAGE', 'DAYS_TO_INITIAL_PATHOLOGIC_DIAGNOSIS', 'HISTORY_NEOADJUVANT_TRTYN', 'PATH_M_STAGE', 'ICD_O_3_SITE', 'AJCC_PATHOLOGIC_TUMOR_STAGE', 'ETHNICITY', 'PATH_M_STAGE', 'PATH_N_STAGE', 'PATH_T_STAGE', 'PRIMARY_LYMPH_NODE_PRESENTATION_ASSESSMENT', 'CANCER_TYPE_ACRONYM', 'SEX', 'DAYS_TO_INITIAL_PATHOLOGIC_DIAGNOSIS', 'HISTORY_NEOADJUVANT_TRTYN']
Dropped 3024 columns with >25.0% nulls
Dropped 0 highly uniform columns
Dropped 0 low variance columns (<1e-05)
Stability selection: kept 113 / 17507 features (98% stability threshold)


In [None]:
print("Removed mRNA columns:", len(mrna_preproc.removed_cols_))
print("Final mRNA columns:", len(mrna_preproc.columns_))

X_train = clinical_train.join(mrna_train, how="inner")
X_test = clinical_test.join(mrna_test, how="inner")

old_X_train = joblib.load(config.X_TRAIN_PATH)
print("Length of X_test:", len(old_X_train.columns))

print(set(X_train.columns) - set(old_X_train.columns))
print(set(old_X_train.columns) - set(X_train.columns))

joblib.dump(X_train, "new_data/X_train_no_prune.pkl")
joblib.dump(X_test, "new_data/X_test_no_prune.pkl")
joblib.dump(y_train, "new_data/y_train_no_prune.pkl")
joblib.dump(y_test, "new_data/y_test_no_prune.pkl")

Removed mRNA columns: 20418
Final mRNA columns: 113
Length of X_test: 17529
{'ICD_O_3_HISTOLOGY_8255/3', 'ICD_O_3_HISTOLOGY_8382/3', 'ICD_O_3_HISTOLOGY_8460/3', 'ICD_O_3_HISTOLOGY_8310/3', 'ICD_O_3_HISTOLOGY_8441/3', 'SUBTYPE_UCEC_POLE', 'ICD_O_3_HISTOLOGY_8461/3', 'ICD_O_3_HISTOLOGY_8380/3', 'SUBTYPE_UCEC_MSI', 'SUBTYPE_UCEC_CN_LOW'}
{'SAMD12', 'SEC24A', 'GIPR', 'TNF', 'HIC1', 'C9orf3', 'CHADL', 'RTEL1', 'HOXD3', 'HNMT', 'VGF', 'TMEM198', 'SFTPA1', 'NSFL1C', 'TMEM216', 'MTTP', 'RBM3', 'ROR2', 'TCP11L2', 'DQX1', 'NCRNA00201', 'URGCP', 'PGRMC2', 'NAV3', 'SLC5A1', 'TNNC2', 'KBTBD8', 'KIAA1671', 'DCAF8', 'CDK13', 'SCAND1', 'C17orf65', 'FAM200A', 'DPH2', 'LINC00896', 'NXPH4', 'XIRP1', 'RAB41', 'WDR81', 'PRKACB', 'DPY19L2P4', 'C12orf10', 'OTUD6B', 'PIF1', 'PLA2G2F', 'HDHD1A', 'CAND2', 'OPRK1', 'SPDYE3', '155060', 'DHX9', 'CABP1', 'COASY', 'SLC10A4', 'TRIP6', 'AGTRAP', 'C7orf42', 'FBLN5', 'MTFMT', 'C14orf129', 'VDR', 'TMEM215', 'C22orf27', 'ZNF860', 'FBXO22-AS1', 'KRTAP5-9', 'RGS18', 'KRT7',

['new_data/y_test_no_prune.pkl']

In [None]:
def compare_pipelines(X_train_old, X_test_old, X_train_new, X_test_new, top_n=10):
    report = {}

    # 1. Column check
    cols_old = list(X_train_old.columns)
    cols_new = list(X_train_new.columns)
    missing = set(cols_old) - set(cols_new)
    extra = set(cols_new) - set(cols_old)
    report["column_check"] = {
        "missing_in_new": missing,
        "extra_in_new": extra,
        "order_diff": (cols_old != cols_new)
    }

    # 2. Summary stats difference
    def summarize(df):
        return pd.DataFrame({
            "median": df.median(),
            "mean": df.mean(),
            "std": df.std()
        })

    old_train_stats = summarize(X_train_old)
    new_train_stats = summarize(X_train_new)

    # Compute absolute diffs
    diffs = (new_train_stats - old_train_stats).abs()
    diffs["max_change"] = diffs.max(axis=1)

    # 3. Biggest shifts
    top_shifts = diffs.sort_values("max_change", ascending=False).head(top_n)

    report["largest_shifts"] = top_shifts

    # 4. Distribution summary
    print("\n")
    report["train_shape"] = (X_train_old.shape, X_train_new.shape)
    report["test_shape"] = (X_test_old.shape, X_test_new.shape)

    return report

report = compare_pipelines(joblib.load(config.X_TRAIN_PATH), joblib.load(config.X_TEST_PATH), X_train, X_test)
print(report)



{'column_check': {'missing_in_new': {'SAMD12', 'SEC24A', 'GIPR', 'TNF', 'HIC1', 'C9orf3', 'CHADL', 'RTEL1', 'HOXD3', 'HNMT', 'VGF', 'TMEM198', 'SFTPA1', 'NSFL1C', 'TMEM216', 'MTTP', 'RBM3', 'ROR2', 'TCP11L2', 'DQX1', 'NCRNA00201', 'URGCP', 'PGRMC2', 'NAV3', 'SLC5A1', 'TNNC2', 'KBTBD8', 'KIAA1671', 'DCAF8', 'CDK13', 'SCAND1', 'C17orf65', 'FAM200A', 'DPH2', 'LINC00896', 'NXPH4', 'XIRP1', 'RAB41', 'WDR81', 'PRKACB', 'DPY19L2P4', 'C12orf10', 'OTUD6B', 'PIF1', 'PLA2G2F', 'HDHD1A', 'CAND2', 'OPRK1', 'SPDYE3', '155060', 'DHX9', 'CABP1', 'COASY', 'SLC10A4', 'TRIP6', 'AGTRAP', 'C7orf42', 'FBLN5', 'MTFMT', 'C14orf129', 'VDR', 'TMEM215', 'C22orf27', 'ZNF860', 'FBXO22-AS1', 'KRTAP5-9', 'RGS18', 'KRT7', 'SLC38A3', 'MBD5', 'ZNF410', 'CHD9', 'TBX20', 'ZNF263', 'BAT2L1', 'CCNL1', 'DSCR6', 'KDELR2', 'GDI2', 'DTX4', 'PSMG3', 'GTF2H2C', 'HSPA4', 'RBM45', 'RTN4R', 'ZNF354A', 'PIK3IP1', 'CLEC4G', 'FGR', 'CCDC155', 'NUDT16L1', 'FBXL15', 'MRPL34', 'CHST10', 'PLEKHO1', 'YME1L1', 'RACGAP1P', 'NUP93', 'TIA1',