In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.svm import SVR
from sklearn.model_selection import cross_val_score, cross_val_predict, KFold
from sklearn.metrics import r2_score, mean_absolute_error


# Load each individual statistic's feature matrix
def load_stat(stat: str, feature_dir: str):
    parquet_path = os.path.join(feature_dir, f'feature_matrix_{stat}.parquet')
    if not os.path.exists(parquet_path):
        print(f'Missing parquet: {parquet_path}')
        return None

    df = pd.read_parquet(parquet_path)['pc1'].reset_index()
    # Drop any unwanted columns if present
    if 'Stromal_A' in df.columns:
        df = df.drop(columns=['Stromal_A'])
    df = df.dropna(axis=0, how="any")
    return df


# Merge features across multiple statistics
def merge_stats(dfs: dict):
    df_all = None
    for stat, df in dfs.items():
        feature_cols = [
            c for c in df.columns if c not in ['sample', 'target', 'disease']
        ]
        renamed = df[['sample', 'target', 'disease'] + feature_cols].copy()
        renamed = renamed.rename(columns={c: f"{c}_{stat}" for c in feature_cols})
        if df_all is None:
            df_all = renamed
        else:
            df_all = pd.merge(
                df_all,
                renamed,
                on=['sample', 'target', 'disease'],
                how='inner',
            )
    return df_all


# Evaluate one or multiple statistics using SVR regression
def evaluate_stat_joint_regression(
    stat: str, feature_dir: str, dfs: dict, cv_splits: int = 10
):
    if stat == 'all':
        df = merge_stats(dfs)
        stat = 'ocf+lwps+ifs+pfe+fdi'
    else:
        df = dfs[stat]

    # Continuous target column (must exist in your parquet)
    y = df['target'].to_numpy().astype(float)

    # Feature matrix
    meta_cols = ['sample', 'target', 'disease']
    feature_cols = [c for c in df.columns if c not in meta_cols]
    X = df[feature_cols].to_numpy()

    # Regression model
    reg = SVR(kernel='rbf')  # can adjust kernel, C, epsilon, etc.
    cv = KFold(n_splits=cv_splits, shuffle=True, random_state=42)

    # Cross-validated metrics
    r2_scores = cross_val_score(reg, X, y, cv=cv, scoring='r2', n_jobs=-1)
    y_pred = cross_val_predict(reg, X, y, cv=cv)
    r2_overall = r2_score(y, y_pred)
    mae_overall = mean_absolute_error(y, y_pred)

    return {
        'stat': stat,
        'n_features': X.shape[1],
        'r2_mean': float(np.mean(r2_scores)),
        'r2_std': float(np.std(r2_scores)),
        'r2_overall': r2_overall,
        'mae_overall': mae_overall,
        'y_true': y,
        'y_pred': y_pred,
    }



feature_dir = '../data/cristiano_cfdnas_dhs_pca'
STATS = ['ocf', 'lwps', 'ifs', 'pfe', 'fdi']

results = []
dfs = {}

for stat in STATS:
    df = load_stat(stat, feature_dir)
    if df is None or df.empty:
        continue
    dfs[stat] = df

    result = evaluate_stat_joint_regression(stat, feature_dir, dfs)
    if result is not None:
        results.append(result)

if dfs:
    result_all = evaluate_stat_joint_regression('all', feature_dir, dfs)
    if result_all is not None:
        results.append(result_all)


if results:
    stats_order = [r['stat'] for r in results]
    r2_means = [r['r2_mean'] for r in results]
    r2_stds = [r['r2_std'] for r in results]
    mae_overall_values = [r['mae_overall'] for r in results]

    # bar plot: Overall r-squared
    plt.figure(figsize=(8, 6))
    bars = plt.bar(stats_order, r2_means, yerr=r2_stds, capsize=5)
    plt.ylabel('Cross-validated R² (mean ± std)')
    plt.xlabel('Test statistic')
    plt.title('Regression Performance using Joint DHS PC1 Features + SVR')

    for bar, val in zip(bars, r2_means):
        plt.text(
            bar.get_x() + bar.get_width() / 2,
            bar.get_height(),
            f"{val:.2f}",
            ha='center',
            va='bottom',
            fontsize=9,
        )

    plt.tight_layout()
    out_path_r2 = os.path.join(feature_dir, 'regression_pc1_r2_by_stat.png')
    plt.savefig(out_path_r2, dpi=200)
    plt.close()
    plt.show()

    # scatter: True vs Predicted
    n_stats = len(results)
    fig, axes = plt.subplots(1, n_stats, figsize=(5 * n_stats, 5), sharey=True)
    if n_stats == 1:
        axes = [axes]

    for ax, r in zip(axes, results):
        sns.scatterplot(x=r['y_true'], y=r['y_pred'], ax=ax, s=20, alpha=0.7)
        ax.plot(
            [min(r['y_true']), max(r['y_true'])],
            [min(r['y_true']), max(r['y_true'])],
            'k--',
            lw=1,
        )
        ax.set_xlabel('True Values')
        ax.set_ylabel('Predicted Values')
        ax.set_title(
            f"{r['stat']} (R²={r['r2_overall']:.2f}, MAE={r['mae_overall']:.2f})"
        )

    plt.tight_layout()
    out_path_scatter = os.path.join(
        feature_dir, 'regression_true_vs_pred_by_stat.png'
    )
    plt.savefig(out_path_scatter, dpi=200)
    plt.close()
    plt.show()

    # residual distributions
    plt.figure(figsize=(8, 6))
    for r in results:
        residuals = r['y_true'] - r['y_pred']
        sns.kdeplot(
            residuals,
            label=f"{r['stat']} (std={np.std(residuals):.2f})",
            fill=True,
            alpha=0.3,
        )

    plt.xlabel('Residual (True - Predicted)')
    plt.ylabel('Density')
    plt.title('Residual Distributions Across All Statistics')
    plt.legend()
    plt.tight_layout()
    out_path_resid = os.path.join(feature_dir, 'regression_residuals_all_stats.png')
    plt.savefig(out_path_resid, dpi=200)
    plt.close()
    plt.show()