In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelEncoder
from sklearn.svm import SVC
from sklearn.model_selection import StratifiedKFold, cross_val_score, cross_val_predict
from sklearn.metrics import roc_auc_score, roc_curve


feature_dir = '../data/cristiano_cfdnas_dhs_pca'
STATS = {
    'ocf': ('{sid}__{dhs}_sorted_ocf.npy', None),
    'lwps': ('{sid}__{dhs}_sorted_lwps.npy', None),
    'ifs': ('{sid}__{dhs}_sorted_ifs.npz', 'ifs_scores'),
    'pfe': ('{sid}__{dhs}_sorted_pfe.npz', 'pfe_scores'),
    'fdi': ('{sid}__{dhs}_sorted_fdi.npz', 'overlapping_fdi_scores'),
}

In [None]:
def load_feature_df(stat: str, feature_dir: str) -> pd.DataFrame:
    parquet_path = os.path.join(feature_dir, f'feature_matrix_{stat}.parquet')
    if not os.path.exists(parquet_path):
        print(f'Missing parquet: {parquet_path}')
        return None

    return pd.read_parquet(parquet_path)


df = load_feature_df('fdi', feature_dir)['pc1'].reset_index()
print(f'Shape before: {df.shape}')
df = df.drop(columns=['Stromal_A'])  # there are 100+ samples
n_unique_nan_samples = len(df.loc[df.isna().any(axis=1), 'sample'].unique())
print(f'There are {n_unique_nan_samples} samples with at least 1 NaN')
df_clean = df.dropna(axis=0, how="any")
print(f'Shape after: {df_clean.shape}')

In [None]:
def load_stat(stat: str, feature_dir: str):
    parquet_path = os.path.join(feature_dir, f'feature_matrix_{stat}.parquet')
    if not os.path.exists(parquet_path):
        print(f'Missing parquet: {parquet_path}')
        return None
    df = pd.read_parquet(parquet_path)['pc1'].reset_index()
    df = df.drop(columns=['Stromal_A'])
    df = df.dropna(axis=0, how="any")
    return df


def merge_stats(dfs: dict):
    df_all = None
    for stat, df in dfs.items():
        feature_cols = [
            c for c in df.columns if c not in ['sample', 'binary', 'disease']
        ]
        renamed = df[['sample', 'binary', 'disease'] + feature_cols].copy()
        renamed = renamed.rename(columns={c: f"{c}_{stat}" for c in feature_cols})
        if df_all is None:
            df_all = renamed
        else:
            df_all = pd.merge(df_all, renamed, on=['sample', 'binary', 'disease'], how='inner')
    return df_all


def evaluate_stat_joint(stat: str, feature_dir: str, dfs: dict, cv_splits: int = 10):
    if stat == 'all':
        df = merge_stats(dfs)
        stat = 'ocf+lwps+ifs+pfe+fdi'
    else:
        df = dfs[stat]
    
    # target labels
    labels = df['binary']
    le = LabelEncoder()
    y = le.fit_transform(labels)

    # feature matrix: all DHS columns
    meta_cols = ['sample', 'binary', 'disease']
    feature_cols = [c for c in df.columns if c not in meta_cols]
    X = df[feature_cols].to_numpy()

    clf = SVC(probability=True, random_state=42)
    cv = StratifiedKFold(n_splits=cv_splits, shuffle=True, random_state=42)
    scores = cross_val_score(clf, X, y, cv=cv, scoring='roc_auc', n_jobs=-1)
    
    y_pred_proba = cross_val_predict(clf, X, y, cv=cv, method="predict_proba")[:, 1]
    auc_proba = roc_auc_score(y, y_pred_proba)

    return {
        'stat': stat,
        'n_features': X.shape[1],
        'auc_mean': float(np.mean(scores)),
        'auc_std': float(np.std(scores)),
        'auc_proba_overall': auc_proba,
        'y_true': y,
        'y_pred_proba': y_pred_proba,
    }
        
results = []
dfs = {}
for stat in STATS:
    df = load_stat(stat, feature_dir)
    if df is None or df.empty:
        continue
    dfs[stat] = df
    
    result = evaluate_stat_joint(stat, feature_dir, dfs)
    if result is not None:
        results.append(result)
        
if dfs:
    result_all = evaluate_stat_joint('all', feature_dir, dfs)
    if result_all is not None:
        results.append(result_all)

if results:
    stats_order = [r['stat'] for r in results]
    auc_means = [r['auc_mean'] for r in results]
    auc_stds = [r['auc_std'] for r in results]
    auc_proba_overall_values = [r['auc_proba_overall'] for r in results]

    plt.figure(figsize=(8, 6))
#     bars = plt.bar(stats_order, auc_means, yerr=auc_stds, capsize=5)
    bars = plt.bar(stats_order, auc_proba_overall_values, capsize=5)
    plt.ylabel('Overall ROC AUC proba (all DHS sites)')
    plt.xlabel('Test statistic')
#     for bar, val in zip(bars, auc_means):
    for bar, val in zip(bars, auc_proba_overall_values):
        plt.text(
            bar.get_x() + bar.get_width() / 2,
            bar.get_height(),
            f"{val:.2f}",
            ha='center',
            va='bottom',
            fontsize=9,
        )
    plt.title('Binary classification using joint DHS PC1 features + SVM')
    plt.tight_layout()
    out_path_classification = os.path.join(feature_dir, 'classification_pc1_auc_by_stat_overall_auc_proba.png')
    plt.savefig(out_path_classification, dpi=200)
    plt.close()
    plt.show()
    
    plt.figure(figsize=(8, 6))

    for r in results:
        fpr, tpr, _ = roc_curve(r['y_true'], r['y_pred_proba'])
        plt.plot(fpr, tpr, label=f"{r['stat']} (AUC={r['auc_overall']:.3f})")

    plt.plot([0, 1], [0, 1], 'k--', lw=1)
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('ROC curves for all statistics (cross-validated predictions)')
    plt.legend()
    plt.tight_layout()
    out_path_classification = os.path.join(feature_dir, 'roc_curves_pc1_all_stats.png')
    plt.savefig(out_path_classification, dpi=200)
    plt.close()