In [9]:
import os
import sys
import glob
import pickle
import pandas as pd
import numpy as np
from sklearn.decomposition import TruncatedSVD
from sklearn.linear_model import LogisticRegression
from sklearn.utils import resample
from autogluon.tabular import TabularDataset, TabularPredictor
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier, AdaBoostClassifier
from sklearn.metrics import accuracy_score,f1_score,precision_score,recall_score,balanced_accuracy_score,roc_auc_score,confusion_matrix,roc_curve
from perpetual import PerpetualBooster
from tabpfn import TabPFNClassifier
import seaborn as sns
import random
from sklearn.model_selection import RepeatedKFold
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.feature_selection import SelectKBest,f_classif
from sklearn.model_selection import StratifiedKFold
from collections import Counter

In [10]:
sys.path.append("../../Utils")
from loaders import HNSCCFeatureHandler

METADATA_PATH = "../../Supplementary_Tables/ST1/RAW_HNSCC_METADATA_NEW_v10.csv"
VALID_IDS_PATH = "../../Utils/Lists/cv_ids.txt"
HOLD_IDS_PATH = '../../Utils/Lists/holdout_ids.txt'

def load_feature(path,idx,val,zscore=True,batch=True):
    f=HNSCCFeatureHandler(METADATA_PATH,VALID_IDS_PATH,HOLD_IDS_PATH)
    f.load_feature_to_dataframe(path,idx,val)
    f.normalize_zscore()
    f.merge_feature_metadata()
    if batch: f.batch_correct()
    return f
    
metrics = []
mds=load_feature("/projects/b1198/epifluidlab/ravi/0401/headneck/data/hg38_frag_filtered/*.hg38.frag.interval_mds.tsv",0,4)
ids=list(mds.institute1_ids)+list(mds.institute2_ids)+list(mds.institute3_ids)+list(mds.institute4_ids)+list(mds.institute5_ids)+list(mds.institute6_ids)

random.seed(42)
np.random.seed(42)

def prepare_train_test(data, train_ids, test_ids):
    def build_df(ids):
        feats = data.get_raw_features(data.get_subset(ids))
        meta = data.get_metadata_col("Treatment Response", data.get_subset(ids))
        pno = data.get_metadata_col("Patient Number", data.get_subset(ids))
        tov = data.get_metadata_col("Type of Visit", data.get_subset(ids))
        return pd.concat([feats, meta, pno, tov], axis=1)
    
    train_df = build_df(train_ids)
    test_df = build_df(test_ids)
    
    hold_df = pd.concat([
        data.get_raw_features(data.hold_data),
        data.get_metadata_col("Patient Number", data.hold_data),
        data.get_metadata_col("Type of Visit", data.hold_data),
        data.get_metadata_col("Treatment Response", data.hold_data)
    ], axis=1)

    return train_df, test_df, hold_df

In [11]:
def compute_fold_metrics(y_true, y_prob, y_pred, fold=None):
    m = {
        "accuracy":      accuracy_score(y_true, y_pred),
        "f1":            f1_score(y_true, y_pred, zero_division=0),
        "precision":     precision_score(y_true, y_pred, zero_division=0),
        "recall":        recall_score(y_true, y_pred, zero_division=0),
        "balanced_ac":   balanced_accuracy_score(y_true, y_pred),
    }

    tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[0, 1]).ravel()
    m.update({
        "tn": tn,
        "fp": fp,
        "fn": fn,
        "tp": tp
    })

    if len(np.unique(y_true)) == 2:
        m["roc_auc"] = roc_auc_score(y_true, y_prob)
        fpr, tpr, thresholds = roc_curve(y_true, y_prob)
        m["roc_curve"] = {
            "fpr": fpr.tolist(),
            "tpr": tpr.tolist(),
            "thresholds": thresholds.tolist()
        }
    if fold.endswith("Patient"):
        m["patients"]=list(y_prob.index)
        m["probabilities"]=list(y_prob)
    else:
        m["patients"] = None
        m["probabilities"] = None
    if fold is not None:
        m["fold"] = fold
    return m

def print_classification_metrics(df):
    y_true = df["Treatment Response"] == "Responder"
    y_pred = df["Predicted Treatment Response"] == "Responder"
    y_score = df["Final Weighted Prediction"]

    metrics = {
        "accuracy": accuracy_score(y_true, y_pred),
        "f1": f1_score(y_true, y_pred),
        "precision": precision_score(y_true, y_pred),
        "recall": recall_score(y_true, y_pred),
        "balanced_ac": balanced_accuracy_score(y_true, y_pred),
        "roc_auc": roc_auc_score(y_true, y_score)
    }
    for k, v in metrics.items():
        print(f"{k}\t{v:.3f}")

def plot_metrics(metrics, title=''):
    fpr = metrics['roc_curve']['fpr']
    tpr = metrics['roc_curve']['tpr']
    roc_auc = metrics['roc_auc']
    cm = np.array([[metrics['tn'], metrics['fp']],
                   [metrics['fn'], metrics['tp']]])
    
    plt.figure(figsize=(10, 4), dpi=1000)
    
    plt.subplot(1, 2, 1)
    plt.plot(fpr, tpr, color='mediumpurple', linewidth=2)
    plt.plot([0, 1], [0, 1], color='black', alpha=0.3, linestyle='--', linewidth=2)
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.text(0.55, 0.05, f'ROC AUC: {roc_auc:.2f}', fontsize=12, color='red')
    plt.xlim(-0.03, 1)
    plt.ylim(0, 1.03)
    
    plt.subplot(1, 2, 2)
    sns.heatmap(cm, annot=True, fmt='d', cmap='Purples', cbar=False,
                xticklabels=['Non-Responder', 'Responder'],
                yticklabels=['Non-Responder', 'Responder'])
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')

    plt.tight_layout()
    plt.savefig(f'{title}.pdf', bbox_inches='tight')

def adjust_coefficients_mean(group):
    prob = group['mds'].mean()
    resp = group['Treatment Response'].iloc[0]
    return pd.Series({
        'Final Weighted Prediction': prob,
        'Treatment Response': resp,
        'Predicted Treatment Response': 'Responder' if prob >= 0.5 else 'Non-Responder'
    })

def adjust_coefficients_priority(group):
    visit_priority = {'Screen': 0, 'Day 0': 1, 'Adj Wk 1': 2}
    group = group.copy()
    group['Visit Priority'] = group['Type of Visit'].map(visit_priority)
    most_recent = group.loc[group['Visit Priority'].idxmax()]

    prob = most_recent['mds']
    resp = most_recent['Treatment Response']
    return pd.Series({
        'Final Weighted Prediction': prob,
        'Treatment Response': resp,
        'Predicted Treatment Response': 'Responder' if prob >= 0.5 else 'Non-Responder'
    })

def plot_roc(df, title, fname):
    fpr, tpr, _ = roc_curve(df['Treatment Response'] == 'Responder', df['Final Weighted Prediction'])
    auc_score = roc_auc_score(df['Treatment Response'] == 'Responder', df['Final Weighted Prediction'])
    plt.figure(figsize=(2, 2), dpi=300)
    plt.plot(fpr, tpr, color="mediumpurple")
    plt.plot([0, 1], [0, 1], linestyle='--', color='black', alpha=0.3)
    plt.text(0.4, 0.05, f"ROC AUC: {auc_score:.2f}", fontsize=6)
    plt.title(title, fontsize=6)
    plt.xlabel("False Positive Rate", fontsize=6)
    plt.ylabel("True Positive Rate", fontsize=6)
    plt.xticks(fontsize=6)
    plt.yticks(fontsize=6)
    plt.gca().set_aspect('equal')
    plt.tight_layout()
    plt.savefig(fname)
    plt.show()

def plot_cm(df, title, fname):
    cm = confusion_matrix(df['Treatment Response'], df['Predicted Treatment Response'], labels=['Responder', 'Non-Responder'])
    plt.figure(figsize=(2.5, 2.5), dpi=300)
    sns.heatmap(cm, annot=True, fmt='d', cmap='Purples', cbar=False,
                xticklabels=['Responder', 'Non-Responder'], yticklabels=['Responder', 'Non-Responder'],
                annot_kws={"fontsize": 6})
    plt.title(title, fontsize=6)
    plt.xlabel('Predicted Label', fontsize=6)
    plt.ylabel('True Label', fontsize=6)
    plt.xticks(fontsize=6)
    plt.yticks(fontsize=6)
    plt.tight_layout()
    plt.savefig(fname)
    plt.show()

In [12]:
def train_model(
    df, response_col="Treatment Response",
    k_features=6,
    base_model_cls=None, base_model_kwargs=None,
    random_state=42
):
    from sklearn.decomposition import TruncatedSVD
    from sklearn.utils import resample
    from sklearn.model_selection import StratifiedKFold

    class_counts = df[response_col].value_counts()
    minority_class = class_counts.idxmin()
    majority_class = class_counts.idxmax()

    df_minority = df[df[response_col] == minority_class]
    df_majority = df[df[response_col] == majority_class]

    df_majority_downsampled = resample(
        df_majority,
        replace=False,
        n_samples=len(df_minority),
        random_state=random_state
    )

    df_balanced = pd.concat([df_minority, df_majority_downsampled]).sample(frac=1, random_state=random_state)

    X = df_balanced.drop(columns=[response_col, "Patient Number", "Type of Visit"])
    y = (df_balanced[response_col] == "Responder").astype(int).to_numpy()

    final_svd = TruncatedSVD(n_components=k_features, random_state=random_state).fit(X)
    final_model = base_model_cls(**(base_model_kwargs or {})).fit(final_svd.transform(X), y)

    return final_svd, final_model

def eval_model(
    df_test, svd, model, name,
    response_col="Treatment Response"
):
    X = df_test.drop(columns=[response_col, "Patient Number", "Type of Visit"])
    X_red = svd.transform(X)

    visits = ["Overall"]
    probs = {}

    for visit in visits:
        mask = slice(None) if visit == "Overall" else (df_test["Type of Visit"] == visit)
        fold = "Overall" if visit == "Overall" else visit.replace(" ", "_")
        Xv = X_red[df_test.index.get_indexer(df_test.index[mask])]
        yv = (df_test.loc[mask, response_col] == "Responder").astype(int)
        p = model.predict_proba(Xv)[:, 1]
        yp = (p >= 0.5).astype(int)
        probs[visit] = pd.Series(p, index=df_test.index[mask])
    test_eval = pd.DataFrame(probs["Overall"], columns=["mds"])
    test_eval_w_information = pd.concat([test_eval, df_test['Treatment Response'], df_test['Patient Number'], df_test['Type of Visit']], axis=1).groupby('Patient Number').apply(adjust_coefficients_priority)
    metrics.append(compute_fold_metrics(test_eval_w_information['Treatment Response']=="Responder", test_eval_w_information["Final Weighted Prediction"], test_eval_w_information['Predicted Treatment Response']=="Responder", fold=f"MDS_{name}_Patient")) 

In [13]:
from itertools import combinations

institute_ids = [
    list(mds.institute1_idsa),
    list(mds.institute2_idsa),
    list(mds.institute3_idsa),
    list(mds.institute4_idsa),
    list(mds.institute5_idsa),
    list(mds.institute6_idsa),
]

for test_combo in combinations(range(6), 1):
    print(test_combo)
    def load_feature(path, idx, val, zscore=True, batch=True):
        f = HNSCCFeatureHandler(METADATA_PATH, VALID_IDS_PATH, HOLD_IDS_PATH)
        f.load_feature_to_dataframe(path, idx, val)
        f.normalize_zscore()
        f.merge_feature_metadata()
        all_data = pd.concat([f.data, f.hold_data])
        f.hold_data = all_data
        f.data      = all_data
        if batch: f.batch_correct()
        return f

    mds = load_feature("/projects/b1198/epifluidlab/ravi/0401/headneck/data/hg38_frag_filtered/*.hg38.frag.interval_mds.tsv", 0, 4)
    train_ids = sum([institute_ids[i] for i in range(6) if i not in test_combo], [])
    test_ids  = sum([institute_ids[i] for i in test_combo], [])
    mds_train, mds_test, _ = prepare_train_test(mds, train_ids, test_ids)
    assert set(mds_train.index).isdisjoint(mds_test.index), "Train and test sets overlap!"

    mds_svd, mds_m = train_model(
        mds_train,
        k_features=6,
        base_model_cls=TabPFNClassifier,
        random_state=42
    )
    eval_model(mds_test, mds_svd, mds_m, f"Institute_Hold_Out_{test_combo[0]}")

(0,)
(1,)
(2,)
(3,)
(4,)
(5,)


In [14]:
def generate_random_lists(num_lists, seed=42):
    random.seed(seed)
    return [random.sample(range(1, 69), 10) for _ in range(num_lists)]

train_ids = (
    list(mds.institute1_idsa) +
    list(mds.institute2_idsa) +
    list(mds.institute3_idsa) +
    list(mds.institute4_idsa) +
    list(mds.institute5_idsa) +
    list(mds.institute6_idsa)
)

test_ids = []

lists = generate_random_lists(100)
for i in tqdm(range(100), desc="Repeats"):
    patient_list = lists[i]

    def load_feature(path, idx, val, zscore=True, batch=True):
        f = HNSCCFeatureHandler(METADATA_PATH, VALID_IDS_PATH, HOLD_IDS_PATH)
        f.load_feature_to_dataframe(path, idx, val)
        f.normalize_zscore()
        f.merge_feature_metadata()
        some_data = pd.concat([f.data, f.hold_data])
        f.hold_data = some_data[some_data["Patient Number"].isin(patient_list)]
        f.data      = some_data[~some_data["Patient Number"].isin(patient_list)]
        if batch: f.batch_correct()
        return f

    mds = load_feature("/projects/b1198/epifluidlab/ravi/0401/headneck/data/hg38_frag_filtered/*.hg38.frag.interval_mds.tsv", 0, 4)

    mds_train, _, mds_hold = prepare_train_test(mds, train_ids, test_ids)
    assert set(mds_train.index).isdisjoint(mds_hold.index), "Train and hold sets overlap!"
    
    mds_svd, mds_m = train_model(
        mds_train,
        k_features=6,
        base_model_cls=TabPFNClassifier,
        random_state=42
    )
    eval_model(mds_hold, mds_svd, mds_m, f"Pre_Hold_Out_{i}")

Repeats: 100%|██████████| 100/100 [05:10<00:00,  3.10s/it]


In [15]:
pd.DataFrame(metrics)[["patients","probabilities","fold"]].to_csv('SURV_PREDICTIONS.csv', index=False)