# Predict Intervention Responsiveness

In [None]:
# Standard libraries
import os
import csv
import time
import warnings
from copy import deepcopy
import joblib 
from itertools import product, combinations, chain

# Data manipulation
import numpy as np
import pandas as pd

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns
import shap

# Statistical analysis
from scipy.stats import norm, mannwhitneyu

# Machine learning models
from sklearn.ensemble import RandomForestClassifier
from xgboost import XGBClassifier

# Preprocessing and model evaluation
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.metrics import (
    accuracy_score, roc_auc_score, f1_score, recall_score,
    precision_recall_curve, confusion_matrix, auc, precision_score,
    balanced_accuracy_score, matthews_corrcoef
)

# Model interpretation
from sklearn.inspection import PartialDependenceDisplay

# Progress bar
from tqdm import tqdm

# Suppress warnings
warnings.filterwarnings('ignore')


In [None]:
# SEED = 123
SEED = 321

In [None]:
np.random.seed(SEED)

## Pre-Processing

### Define threshold for responsiveness

Indicate change threshold that qualifies a participant as responsive vs non-responsive

In [None]:
# DEFINE RESPONSIVENESS
# avg reduction in drinking occasions between active and control weeks
def_response_drink_occasions = -1

### Load data

In [None]:
output_dir = "../../results"

data_study1 = pd.read_csv('../../SHINE/osf_study1.csv')
data_study2 = pd.read_csv('../../SHINE/osf_study2.csv')

# Study 1 baseline data (train/val input)
b1_alcohol_self = pd.read_csv('../../SHINE/final_buckets/alcoholself_bucket280225.csv', index_col=0)
b2_group_subjective = pd.read_csv('../../SHINE/final_buckets/subjective_grouperceptions_280225.csv', index_col=0)
b3_group_sociometric = pd.read_csv('../../src/responsiveness/data_social.csv')
b4_brain = pd.read_csv('../../SHINE/final_buckets/brain_bucket_280225.csv', index_col=0)
b5_demographic = pd.read_csv('../../SHINE/final_buckets/demographic_bucket280225.csv', index_col=0)
b6_psychometric = pd.read_csv('../../SHINE/final_buckets/psychometrics_bucket280225.csv', index_col=0)

# Study 2 subjective data (test input)
b2_group_subjective_study2 = pd.read_csv('/Users/fmagdalena/Documents/GitHub/shine-network-analysis/SHINE/final_buckets/subjective_grouperceptions_test.csv')

# Study 1 & 2 drinking/responsiveness data (output -> prediction target)

if def_response_drink_occasions == -1:
    responsive_study1 = pd.read_csv('../../SHINE/final_buckets/responsiveness_study1.csv', index_col=0).reset_index()
# elif def_response_drink_occasions == -0.5:
#     responsive_study1 = pd.read_csv('../../SHINE/final_buckets/responsiveness_study1_-0.5.csv', index_col=0).reset_index()
# elif def_response_drink_occasions == -2:
#     responsive_study1 = pd.read_csv('../../SHINE/final_buckets/responsiveness_study1_-2.csv', index_col=0).reset_index()

responsive_study2 = pd.read_csv('../../SHINE/final_buckets/responsiveness_study2.csv', index_col=0).reset_index()

In [None]:
data_study1_control = data_study1[data_study1.condition == 'control']
data_study2_control = data_study2[data_study2.condition == 'control']

len(data_study1_control)
len(data_study2_control)

In [None]:
# Check for duplicates within each DataFrame
duplicates_study1 = responsive_study1['id'].duplicated().any()
duplicates_study2 = responsive_study2['id'].duplicated().any()

print(f"Study 1 has duplicates: {duplicates_study1}")
print(f"Study 2 has duplicates: {duplicates_study2}")

# Check for overlapping IDs between the two studies
ids_study1 = set(responsive_study1['id'])
ids_study2 = set(responsive_study2['id'])
overlap = ids_study1.intersection(ids_study2)

print(f"Number of overlapping IDs: {len(overlap)}")
if overlap:
    print(f"Overlapping IDs: {overlap}")


In [None]:
EXCLUDE_VARS = [
    'group', 'condition', 'active',
    'control', 'difference_drinks_occasions']

responsive_study1.drop(columns=EXCLUDE_VARS, inplace=True)
responsive_study2.drop(columns=EXCLUDE_VARS, inplace=True)

In [None]:
responsive_study2.head()

### Merge Baseline and Target Data

In [None]:
# Training datasets -> Study 1
b1_alcohol_self_response = pd.merge(b1_alcohol_self, responsive_study1, on='id', how='inner')
b2_group_subjective_response = pd.merge(b2_group_subjective, responsive_study1, on='id', how='inner')
b2_group_subjective_response_old = pd.merge(responsive_study1, responsive_study1, on='id', how='inner')
b3_group_sociometric_response = pd.merge(b3_group_sociometric, responsive_study1, on='id', how='inner')
b4_brain_response = pd.merge(b4_brain, responsive_study1, on='id', how='inner')
b5_demographic_response = pd.merge(b5_demographic, responsive_study1, on='id', how='inner')
b6_psychometric_response = pd.merge(b6_psychometric, responsive_study1, on='id', how='inner')

print(f'Total IDs Study 1: {len(b1_alcohol_self_response)}')
print(f'Responsive IDs Study 1: {b1_alcohol_self_response[b1_alcohol_self_response["responsive"] == 1]["id"].nunique()}')
print('----------')
# Testing dataset -> Study 2
b2_group_subjective_test = pd.merge(b2_group_subjective_study2, responsive_study2, on='id', how='inner')
print(f'Total IDs Study 2: {len(b2_group_subjective_test)}')
print(f'Responsive IDs Study 2: {b2_group_subjective_test[b2_group_subjective_test["responsive"] == 1]["id"].nunique()}')

In [None]:
dataframes = {
    'alc_self': b1_alcohol_self_response,
    'group_sub': b2_group_subjective_response,
    'group_socio': b3_group_sociometric_response,
    'brain': b4_brain_response,
    'demo': b5_demographic_response,
    'psych': b6_psychometric_response
}

for key, df in dataframes.items():
    print(f"Missing values in '{key}':")
    print(df.isna().sum())
    print()  # for spacing between outputs

# Feature Selection

## Find highly correlated features within buckets
Find redundancy in features if they are highly correlated

In [None]:
dataframes = {
    'alc_self': b1_alcohol_self_response,
    'group_sub': b2_group_subjective_response,
    'group_socio': b3_group_sociometric_response,
    'brain': b4_brain_response,
    'demo': b5_demographic_response,
    'psych': b6_psychometric_response
}

In [None]:
b2_group_subjective_response.columns

In [None]:
TARGET_VAR = 'responsive'

In [None]:
def find_highly_correlated_features(dataframes, threshold=0.8):
    """
    Identifies pairs of highly correlated features in each dataframe.
    :param dataframes: dict of {name: dataframe}
    :param threshold: correlation threshold to consider as "high"
    :return: dict of {name: list of correlated feature pairs}
    """
    correlated_features = {}
    for name, df in dataframes.items():
        # Exclude COMMON_VARS from the correlation computation
        columns_to_correlate = [col for col in df.columns if col != TARGET_VAR and col !='id']
        
        # Compute correlation matrix only for selected columns
        corr_matrix = df[columns_to_correlate].corr().abs()
        
        # Select the upper triangle of the correlation matrix
        upper_triangle = corr_matrix.where(np.triu(np.ones(corr_matrix.shape), k=1).astype(bool))
        
        # Find pairs of features with correlation above the threshold
        correlated_pairs = [
            (col, idx, upper_triangle.loc[idx, col])
            for col in upper_triangle.columns
            for idx in upper_triangle.index
            if upper_triangle.loc[idx, col] > threshold
        ]
        
        # Store results for the current dataframe
        correlated_features[name] = correlated_pairs

    return correlated_features

In [None]:
correlated_features = find_highly_correlated_features(dataframes, threshold=0.8)

# Display results
for name, pairs in correlated_features.items():
    print(f"\n{name} - Highly Correlated Features:")
    for col1, col2, corr_value in pairs:
        print(f"  {col1} ↔ {col2} : Correlation = {corr_value:.2f}")


Remove highly correlated features:

In [None]:
# Choice is made manually 

dataframes['brain'].drop(columns=['reward', 'ROI_alc_react_v_rest_neurosynth_cogcontrol', 'ROI_alc_react_v_rest_neurosynth_craving', \
                                  'ROI_alc_react_v_rest_neurosynth_emoreg'], inplace=True)

dataframes['group_socio'].drop(columns=['leaders_deg_in', 'goToBad_deg_in'], inplace=True)

dataframes['psych'].drop(columns=['ACS_focus', 'DERS_strategies', 'BIS_attention_total'], inplace=True)

In [None]:
# Check that all within-category correlations are gone
correlated_features = find_highly_correlated_features(dataframes, threshold=0.8)

for name, pairs in correlated_features.items():
    print(f"\n{name} - Highly Correlated Features:")
    for col1, col2, corr_value in pairs:
        print(f"  {col1} ↔ {col2} : Correlation = {corr_value:.2f}")


In [None]:
# Number of features per category
{key: df.shape[1] for key, df in dataframes.items()}

## Significance tests: Features
### Mann-Whitney U Tests

Hypothesis test for non-normally distributed data to check which of the remaining features show the most (significant) difference between the two groups (responsive vs non-responsive).

In [None]:
dataframes = {
    'alc_self': b1_alcohol_self_response,
    'group_sub': b2_group_subjective_response,
    'group_socio': b3_group_sociometric_response,
    'brain': b4_brain_response,
    'demo': b5_demographic_response,
    'psych': b6_psychometric_response
}

In [None]:
def perform_mann_whitney_u(df, target_var, exclude_vars):
    results = {}
    for col in df.columns:
        if col not in exclude_vars and col != target_var:
            try:
                df[col] = pd.to_numeric(df[col], errors='coerce')
                group1 = df[df[target_var] == 0][col]
                group2 = df[df[target_var] == 1][col]
                stat, p_value = mannwhitneyu(group1, group2, alternative='two-sided')
                results[col] = {'U_statistic': stat, 'p_value': p_value}
            except Exception as e:
                results[col] = {'error': str(e)}
    return results

In [None]:
mwu_results = {}
for name, df in dataframes.items():
    if name != 'demo' and TARGET_VAR in df.columns:
        mwu_results[name] = perform_mann_whitney_u(df, TARGET_VAR, 'id')

# Output summary
for name, results in mwu_results.items():
    print(f"\n{name} DataFrame Mann-Whitney U Test Results (p-value < 0.05):")
    
    # Ensure only variables with p-values < 0.05 are retained
    significant_results = {}
    for var, stats in results.items():
        if isinstance(stats, dict) and 'p_value' in stats and stats['p_value'] < 0.05:
            significant_results[var] = stats  
    
    if significant_results:
        df_significant = pd.DataFrame(significant_results).T  
        print(df_significant)
    else:
        print("No significant results (p-value < 0.05) found.")
    print("---------------")

# ML Models

## Train / Test Splits

In [None]:

def prepare_features_and_targets(df, test_set=0, resampling=None):
    if TARGET_VAR not in df.columns:
        raise ValueError(f"Target variable '{TARGET_VAR}' not found in dataframe.")

    # Extract target variable and drop excluded columns
    targets = df[TARGET_VAR]
    features = df[[col for col in df.columns if col != TARGET_VAR and col != 'id']]
    features = features.drop(columns=[TARGET_VAR], errors='ignore')

    # Split into training and test sets (STRATIFIED)
    if test_set:
        X_train, X_test, Y_train, Y_test = train_test_split(
            features, targets, test_size=test_set, stratify=targets
        )
    else: 
        X_train = features
        Y_train = targets
        X_test = []
        Y_test = []

    # Median imputation for 'income_numeric' if it contains NA values
    if 'income_numeric' in X_train.columns:
        if X_train['income_numeric'].isna().any():
            X_train['income_numeric'].fillna(X_train['income_numeric'].median(), inplace=True)
        if isinstance(X_test, pd.DataFrame) and 'income_numeric' in X_test.columns and X_test['income_numeric'].isna().any():
            X_test['income_numeric'].fillna(X_test['income_numeric'].median(), inplace=True)

    if 'IAS_mean' in X_train.columns:
        if X_train['IAS_mean'].isna().any():
            X_train['IAS_mean'].fillna(X_train['IAS_mean'].median(), inplace=True)
        if isinstance(X_test, pd.DataFrame) and 'IAS_mean' in X_test.columns and X_test['IAS_mean'].isna().any():
            X_test['IAS_mean'].fillna(X_test['IAS_mean'].median(), inplace=True)

    # TODO: Handle all other missingness here
    return X_train, Y_train, X_test, Y_test

## Random Forest

In [None]:
def random_forest_kfold_grid_search(
    X, Y, param_grid, k=5, CV_reps=1, eval_metric=['auc'], model_choice_metric='auc', 
    res_dir=".", model_type='rf', combo='alcohol'
):

    # Generate all parameter combinations
    param_combinations = list(product(*param_grid.values()))
    param_names = list(param_grid.keys())

    # Initialize variables to store the best model and scores
    best_model = None
    best_scores = None
    best_params = None
    best_model_choice_value = -np.inf  # Track the best model based on the chosen metric

    kf = StratifiedKFold(n_splits=k, shuffle=True)

    for params in param_combinations:
        current_params = dict(zip(param_names, params))

        # Store all fold results
        all_folds_metrics = {metric: [] for metric in eval_metric}

        for train_index, test_index in kf.split(X, Y):  # k-fold cv split

            X_train, X_test = X.iloc[train_index], X.iloc[test_index]
            Y_train, Y_test = Y.iloc[train_index], Y.iloc[test_index]

            rep_metrics = {metric: [] for metric in eval_metric}  # Reset for each fold

            for _ in range(CV_reps):  # Repeat that split j times

                # Initialize the model with the current parameters
                if model_type == 'rf':
                    model = RandomForestClassifier(
                        n_estimators=current_params.get("n_estimators", 100),
                        max_depth=current_params.get("max_depth"),
                        min_samples_split=current_params.get("min_samples_split", 2),
                        min_samples_leaf=current_params.get("min_samples_leaf", 1),
                        class_weight="balanced"
                    )
                elif model_type == 'xgb':
                    model = XGBClassifier(
                        n_estimators=current_params.get("n_estimators", 100),
                        max_depth=current_params.get("max_depth", 6),
                        learning_rate=current_params.get("learning_rate", 0.1),
                        min_child_weight=current_params.get("min_child_weight", 1),
                        gamma=current_params.get("gamma", 0),
                        subsample=current_params.get("subsample", 1),
                        colsample_bytree=current_params.get("colsample_bytree", 1),
                        scale_pos_weight=current_params.get("scale_pos_weight", 1),
                        use_label_encoder=False,
                        eval_metric="logloss"
                    )

                model.fit(X_train, Y_train)
                Y_pred = model.predict(X_test)
                Y_prob = model.predict_proba(X_test)[:, 1] if hasattr(model, 'predict_proba') else None

                if 'auc' in eval_metric and Y_prob is not None:
                    rep_metrics['auc'].append(roc_auc_score(Y_test, Y_prob))
                if 'f1' in eval_metric:
                    rep_metrics['f1'].append(f1_score(Y_test, Y_pred))
                if 'accuracy' in eval_metric:
                    rep_metrics['accuracy'].append(accuracy_score(Y_test, Y_pred))
                if 'specificity' in eval_metric or 'sensitivity' in eval_metric:
                    tn, fp, fn, tp = confusion_matrix(Y_test, Y_pred).ravel()
                    specificity = tn / (tn + fp) if (tn + fp) > 0 else np.nan
                    sensitivity = tp / (tp + fn) if (tp + fn) > 0 else np.nan
                    if 'specificity' in eval_metric:
                        rep_metrics['specificity'].append(specificity)
                    if 'sensitivity' in eval_metric:
                        rep_metrics['sensitivity'].append(sensitivity)
                if 'mcc' in eval_metric:
                    rep_metrics['mcc'].append(matthews_corrcoef(Y_test, Y_pred))
                if 'balancedAcc' in eval_metric:
                    rep_metrics['balancedAcc'].append(balanced_accuracy_score(Y_test, Y_pred))
                if 'pr_auc' in eval_metric and Y_prob is not None:
                    precision, recall, _ = precision_recall_curve(Y_test, Y_prob)
                    pr_auc = auc(recall, precision)
                    rep_metrics['pr_auc'].append(pr_auc)

            # Compute median scores per fold and store results
            fold_median_metrics = {metric: np.mean(values) for metric, values in rep_metrics.items()}
            for metric in eval_metric:
                all_folds_metrics[metric].append(fold_median_metrics[metric])

        # Compute final median scores over all folds
        median_rep_metrics = {metric: np.mean(values) for metric, values in all_folds_metrics.items()}

        # Select best model based on median of model_choice_metric
        if median_rep_metrics[model_choice_metric] > best_model_choice_value:
            best_model_choice_value = median_rep_metrics[model_choice_metric]
            best_model = model
            best_params = current_params
            best_scores = median_rep_metrics  # Store median scores for all metrics

    return best_model, best_scores, best_params


In [None]:
def save_metrics_to_csv(results_dict, results_dir, filename):

    os.makedirs(results_dir, exist_ok=True)


    # Define the output file path
    file_path = os.path.join(results_dir, filename)
    
    # Extract all metric names
    all_metrics = set()
    for metrics in results_dict.values():
        all_metrics.update(metrics.keys())
    
    # Sort metrics for consistency
    all_metrics = sorted(all_metrics)

    # Open CSV file for writing
    with open(file_path, mode="w", newline="") as file:
        writer = csv.writer(file)
        
        # Write header
        header = ["run", "group"] + all_metrics
        writer.writerow(header)

        # Write data
        for group, metrics in results_dict.items():
            num_runs = len(next(iter(metrics.values())))  # Get number of runs from first metric
            for run_idx in range(num_runs):
                row = [run_idx, str(group)]  # Start with run index and group name
                for metric in all_metrics:
                    value = metrics.get(metric, [np.nan] * num_runs)[run_idx]  # Handle missing values
                    row.append(value)
                writer.writerow(row)

In [None]:
def compute_test_metrics(Y_test, test_predictions, proba_predictions, test_scores):
    Y_test_flat = Y_test.ravel()
    
    test_scores['auc'].append(roc_auc_score(Y_test_flat, proba_predictions))
    test_scores['f1'].append(f1_score(Y_test_flat, test_predictions))
    test_scores['accuracy'].append(accuracy_score(Y_test_flat, test_predictions))

    tn, fp, fn, tp = confusion_matrix(Y_test_flat, test_predictions).ravel()
    test_scores['specificity'].append(tn / (tn + fp) if (tn + fp) > 0 else np.nan)
    test_scores['sensitivity'].append(tp / (tp + fn) if (tp + fn) > 0 else np.nan)
    test_scores['PPV'].append(tp / (tp + fp) if (tp + fp) > 0 else np.nan)
    test_scores['NPV'].append(tn / (tn + fn) if (tn + fn) > 0 else np.nan)
    test_scores['MCC'].append(matthews_corrcoef(Y_test_flat, test_predictions))
    test_scores['balancedAcc'].append(balanced_accuracy_score(Y_test_flat, test_predictions))

    precision, recall, _ = precision_recall_curve(Y_test_flat, proba_predictions)
    pr_auc = auc(recall, precision)
    test_scores['pr_auc'].append(pr_auc)

    test_scores['tn'].append(tn)
    test_scores['fp'].append(fp)
    test_scores['fn'].append(fn)
    test_scores['tp'].append(tp)

    return test_scores

In [None]:
def flatten_score_dict(score_dict, res_dir=None, filename=None):
    rows = []
    for combination, metrics in score_dict.items():
        row = {"Combination": combination}
        for metric, values in metrics.items():
            row[f"{metric}_mean"] = values["mean"]
            row[f"{metric}_CI_lower"] = values["95%_CI"][0]
            row[f"{metric}_CI_upper"] = values["95%_CI"][1]
        rows.append(row)

    df = pd.DataFrame(rows)
    df_comb = pd.DataFrame(df["Combination"].tolist(), columns=[f"Factor_{i+1}" for i in range(df["Combination"].map(len).max())])
    df = pd.concat([df_comb, df.drop(columns="Combination")], axis=1)

    if res_dir and filename:
        df.to_csv(f"{res_dir}/{filename}", index=False)

    return df

In [None]:
def plot_shap_summary_with_percentages(all_shap_values, all_test_data, res_dir, combo):
    # Mapping of original to preferred variable names
    name_mapping = {
        "avg_alcmost": "Peer Perception: Drinking Amount",
        "groupAtt_alc": "Peer Attitudes: Alcohol",
        "avg_alcmost_freq": "Peer Perception: Drinking Frequency",
        "alc_norm_5_r": "Perceived Peer Pressure",
        "groupAtt_binge": "Peer Attitudes: Binges"
    }

    # Combine SHAP values and test data
    final_shap_values = np.vstack(all_shap_values)
    final_test_data = pd.concat(all_test_data, ignore_index=True)

    # Compute relative importance
    mean_abs_shap = np.abs(final_shap_values).mean(axis=0)
    rel_importance = 100 * mean_abs_shap / mean_abs_shap.sum()

    # Plot SHAP summary without showing
    plt.figure()
    shap.summary_plot(final_shap_values, final_test_data, show=False, cmap='winter')

    # Get current axis and y-tick labels
    ax = plt.gca()
    feature_names = [tick.get_text() for tick in ax.get_yticklabels()]

    # Map to preferred names if available
    mapped_feature_names = [name_mapping.get(name, name) for name in feature_names]

    # Use Index.get_loc instead of list
    col_index = final_test_data.columns
    feature_order = [col_index.get_loc(name) for name in feature_names]

    # Add percentage values to labels
    percent_labels = [f"{mapped_name} ({rel_importance[i]:.1f}%)"
                      for mapped_name, i in zip(mapped_feature_names, feature_order)]
    ax.set_yticklabels(percent_labels, fontsize=10)

    # Save updated plot
    plt.tight_layout()
    plt.savefig(f"{res_dir}/{combo}_shap_summary_plot_with_percentages.png", dpi=300, bbox_inches="tight")
    plt.close()

    # Return top 2 most important features (by mean absolute SHAP)
    top2_indices = np.argsort(mean_abs_shap)[-2:][::-1]
    top2_features = final_test_data.columns[top2_indices].tolist()
    return top2_features


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.inspection import PartialDependenceDisplay
from scipy.interpolate import interp1d

def plot_pdp_across_runs(best_model, res_dir, all_test_data, feature_names=None, interaction_pair=None, colors=None, title=None):
    """
    Plots PDPs with mean and std across multiple test sets for each feature.
    Optionally adds an interaction plot.

    Parameters:
        best_model: trained model
        all_test_data: list of pd.DataFrames used for PDP evaluation
        feature_names: list of features to plot (default: all features in data)
        interaction_pair: tuple of two features to plot interaction PDP
        colors: optional color list
    """
    if colors is None:
        colors = ["#22223B", "#4A4E69", "#9A8C98", "#C9ADA7", "#F2E9E4"]

    final_test_data = pd.concat(all_test_data, ignore_index=True)

    if feature_names is None:
        feature_names = final_test_data.columns.tolist()
    
    # Optional: map feature names to preferred display names
    name_mapping = {
        "avg_alcmost": "Peer Drinking Amount",
        "groupAtt_alc": "Peer Attitudes: Alcohol",
        "avg_alcmost_freq": "Peer Drinking Frequency",
        "alc_norm_5_r": "Perceived Peer Pressure",
        "groupAtt_binge": "Peer Attitudes: Binges"
    }
    display_names = [name_mapping.get(name, name) for name in feature_names]


    num_features = len(feature_names)
    num_plots = num_features + (1 if interaction_pair else 0)
    num_cols = 3
    num_rows = -(-num_plots // num_cols)

    fig, axes = plt.subplots(num_rows, num_cols, figsize=(num_cols * 5, num_rows * 4))
    axes = axes.flatten()

    for idx, feature_name in enumerate(feature_names):
        pdp_values = []
        feature_values_list = []

        for dat in all_test_data:
            ax_dummy = plt.figure().add_subplot()
            ax_dummy.set_visible(False)
            pdp_display = PartialDependenceDisplay.from_estimator(best_model, dat, [feature_name], ax=ax_dummy)
            plt.close(ax_dummy.figure)

            pdp_x = pdp_display.lines_[0][0].get_xdata()
            pdp_y = pdp_display.lines_[0][0].get_ydata()
            pdp_values.append(pdp_y)
            feature_values_list.append(pdp_x)

        common_feature_values = np.linspace(min(map(min, feature_values_list)),
                                            max(map(max, feature_values_list)), num=100)

        interpolated_pdp_values = []
        for i in range(len(pdp_values)):
            f_interp = interp1d(feature_values_list[i], pdp_values[i], kind="linear", fill_value="extrapolate")
            interpolated_pdp_values.append(f_interp(common_feature_values))

        pdp_values = np.array(interpolated_pdp_values)
        pdp_mean = np.mean(pdp_values, axis=0)
        pdp_std = np.std(pdp_values, axis=0)

        ax = axes[idx]
        ax.plot(common_feature_values, pdp_mean, label="Mean PDP", color=colors[0], lw=2)
        ax.fill_between(common_feature_values, pdp_mean - pdp_std, pdp_mean + pdp_std,
                        color=colors[2], alpha=0.5, label="Std Dev")
        ax.set_ylabel("Predicted Value")            
        ax.set_title(f"PDP for {display_names[idx]}")
        ax.legend()

    if interaction_pair:
        ax = axes[num_features]
        PartialDependenceDisplay.from_estimator(best_model, final_test_data,
                                                [interaction_pair], ax=ax)
        name_a = name_mapping.get(interaction_pair[0], interaction_pair[0])
        name_b = name_mapping.get(interaction_pair[1], interaction_pair[1])
        ax.set_title(f"Interaction: {name_a} & {name_b}")


    for i in range(num_plots, len(axes)):
        fig.delaxes(axes[i])

    plt.tight_layout()

    # Build filename
    interaction_suffix = f"_{interaction_pair[0]}_{interaction_pair[1]}" if interaction_pair else ""
    if not title:
        filename = f"{res_dir}/pdp_plots{interaction_suffix}.png"
    else:
        filename = f"{res_dir}/pdp_plots{title}.png"
    # Save figure
    plt.savefig(filename, dpi=300, bbox_inches="tight")
    plt.close()


In [None]:
shap.initjs()

def run_rf_train_test(dataframes, param_grid, eval_metrics, outer_reps=50, k=5, CV_reps=5, model_choice_metric='f1', 
                      res_dir=f"./results/", model_type='xgb', test_set=0.3, resampling=None, permutation=False):

    timestamp = int(time.time())
    res_dir = f"{res_dir}/{timestamp}_{model_type}_outer{outer_reps}_cvrep{CV_reps}_k{k}_{model_choice_metric}_testsize{test_set}_resampling{resampling}_perm{permutation}/"
    os.makedirs(res_dir, exist_ok=True)
    
    keys = list(dataframes.keys())

    # combine data categories
    combinations_keys = list(chain.from_iterable(combinations(keys, r) for r in range(1, 3)))
    combo_validation_scores = {}
    combo_test_scores = {}
    best_models = {} 
    best_shap_vals = {}
    best_paramses = {}

    all_val_scores = {}
    all_test_scores = {}
    all_models_sub = []

    for combo in tqdm(combinations_keys):
        validation_scores = {metric: [] for metric in eval_metrics}
        test_scores = {metric: [] for metric in eval_metrics}
        merged_df = dataframes[combo[0]].copy()
        top_models_group_sub = []
        
        for key in combo[1:]:
            merged_df = merged_df.merge(dataframes[key].copy(), how='inner', on=['id', TARGET_VAR])
        if TARGET_VAR not in merged_df.columns:
            raise ValueError(f"Target variable '{TARGET_VAR}' not found in merged dataframe for combo: {combo}")
    
        all_shap_values = []
        all_test_data = []
        best_overall_score = -np.inf 
        best_model_for_combo = None
        best_params_for_combo = None
        best_shap_for_combo = None

        for _ in range(outer_reps): # i repetitions of train/test

            # Prepare train/test split for this i (random & stratified)
            X_data, Y_data, X_test, Y_test = prepare_features_and_targets(merged_df.copy(), test_set=test_set, resampling=resampling)

            # Shuffle labels for permutation tests
            if permutation:
                Y_data = Y_data.sample(frac=1, random_state=None).reset_index(drop=True)
                Y_test = Y_test.sample(frac=1, random_state=None).reset_index(drop=True)

            else:
                best_model, best_scores, best_params = random_forest_kfold_grid_search(X_data, Y_data, 
                                                                                    param_grid, k=k, 
                                                                                    CV_reps=CV_reps, 
                                                                                    eval_metric=eval_metrics,
                                                                                    model_choice_metric=model_choice_metric,
                                                                                    res_dir=res_dir,
                                                                                    model_type=model_type,
                                                                                    combo=combo)
            # Collect metrics
            for metric, score in best_scores.items():
                validation_scores[metric].append(score)

            # Retrain the best model on the full training dataset and evaluate on the test set
            best_model.fit(X_data, Y_data)
            test_predictions = best_model.predict(X_test)
            proba_predictions = best_model.predict_proba(X_test)[:, 1]

            explainer = shap.TreeExplainer(best_model)
            shap_values = explainer.shap_values(X_test) 
            shap_values = shap_values[:, :, 1]

            # Append SHAP values and test data for later aggregation
            all_shap_values.append(shap_values)
            all_test_data.append(pd.DataFrame(X_test))

            if best_scores[model_choice_metric] > best_overall_score:
                best_overall_score = best_scores[model_choice_metric]
                best_model_for_combo = best_model
                best_params_for_combo = best_params
                best_shap_for_combo = shap_values  # Store SHAP values if needed

            if combo == ('group_sub',):
                top_models_group_sub.append((best_scores[model_choice_metric], deepcopy(best_model)))

            # Calculate and append metrics for the test set
            test_scores = compute_test_metrics(Y_test, test_predictions, proba_predictions, test_scores)

        # Keep track of the best model based on the model_choice_metric
        if combo not in best_models or best_scores[model_choice_metric] > combo_validation_scores[combo][model_choice_metric]['mean']:
            best_models[combo] = best_model_for_combo
            joblib.dump(best_model_for_combo, f"{res_dir}/model_{'_'.join(combo)}.joblib")

            best_shap_vals[combo] = best_shap_for_combo
            best_paramses[combo] = best_params_for_combo

            # Save top 10 models for group_sub combo
            if combo == ('group_sub',):
                top_models_group_sub = locals().get("top_models_group_sub", [])
                top_models_group_sub.append((best_overall_score, deepcopy(best_model_for_combo)))

                # Sort and save top 10 by score
                top_models_group_sub.sort(key=lambda x: x[0], reverse=True)
                top10 = top_models_group_sub[:10]

                subdir = os.path.join(res_dir, "top10_group_sub_models")
                os.makedirs(subdir, exist_ok=True)

                for i, (score, model) in enumerate(top10):
                    joblib.dump(model, f"{subdir}/model_rank{i+1}_score{score:.4f}.joblib")

                # Store back in locals so it's not overwritten each time
                locals()["top_models_group_sub"] = top_models_group_sub

        top2_features = plot_shap_summary_with_percentages(all_shap_values, all_test_data, res_dir, combo)

        plot_pdp_across_runs(
            best_model=best_model_for_combo,
            res_dir=res_dir,
            all_test_data=all_test_data,
            interaction_pair=tuple(top2_features)
        )

        # Calculate mean and 95% CI for validation scores
        z = norm.ppf(0.975)  # 95% confidence level
        final_validation_scores = {}
        for metric, scores in validation_scores.items():
            mean_score = np.mean(scores)
            std_error = np.std(scores, ddof=1) / np.sqrt(len(scores))
            ci_lower = mean_score - z * std_error
            ci_upper = mean_score + z * std_error
            final_validation_scores[metric] = {
                'mean': mean_score,
                '95%_CI': (ci_lower, ci_upper)
            }
        combo_validation_scores[combo] = final_validation_scores
        all_val_scores[combo] = validation_scores
        save_metrics_to_csv(all_val_scores, res_dir, 'all_val_scores.csv')

        # Calculate mean and 95% CI for test scores
        final_test_scores = {}
        for metric, scores in test_scores.items():
            mean_score = np.mean(scores)
            std_error = np.std(scores, ddof=1) / np.sqrt(len(scores))
            ci_lower = mean_score - z * std_error
            ci_upper = mean_score + z * std_error
            final_test_scores[metric] = {
                'mean': mean_score,
                '95%_CI': (ci_lower, ci_upper)
            }
        combo_test_scores[combo] = final_test_scores
        all_test_scores[combo] = test_scores
        save_metrics_to_csv(all_test_scores, res_dir, 'all_test_scores.csv')

        # For validation scores
        df_val = flatten_score_dict(combo_validation_scores, res_dir=res_dir, filename="validation_scores.csv")
        # For test scores
        df_test = flatten_score_dict(combo_test_scores, res_dir=res_dir, filename="test_scores.csv")
        
    return 

In [None]:
def test_oos(test_df, res_dir, best_model, best_params, plot=True):
    
    # Drop common variables
    test_features = test_df.drop(columns=[TARGET_VAR] + ['id'])
    test_labels = test_df[TARGET_VAR]
    
    # Make predictions
    y_pred_proba = best_model.predict_proba(test_features)[:, 1]  # Probabilities for the positive class
    y_pred = best_model.predict(test_features)
    
    # Compute metrics
    tn, fp, fn, tp = confusion_matrix(test_labels, y_pred).ravel()
    
    scores = {
        'auc': roc_auc_score(test_labels, y_pred_proba),
        'f1': f1_score(test_labels, y_pred),
        'accuracy': accuracy_score(test_labels, y_pred),
        'specificity': tn / (tn + fp) if (tn + fp) > 0 else np.nan,
        'sensitivity': recall_score(test_labels, y_pred),
        'PPV': precision_score(test_labels, y_pred),
        'NPV': tn / (tn + fn) if (tn + fn) > 0 else np.nan,
        'MCC': matthews_corrcoef(test_labels, y_pred),
        'balancedAcc': balanced_accuracy_score(test_labels, y_pred),
        'pr_auc': roc_auc_score(test_labels, y_pred_proba),
        'tn': tn,
        'fn': fn,
        'tp': tp,
        'fp': fp
    }
    
    colors = ["#22223B", "#4A4E69", "#9A8C98", "#C9ADA7", "#F2E9E4"]

    # SHAP feature importance
    explainer = shap.TreeExplainer(best_model)
    shap_values = explainer.shap_values(test_features)
    shap_values = shap_values[:, :, 1]  # Extract SHAP values for positive class
    
    # Compute mean absolute SHAP values for importance ranking
    shap_importance = np.abs(shap_values).mean(axis=0)
    feature_importance = pd.DataFrame({'feature': test_features.columns, 'importance': shap_importance})
    feature_importance = feature_importance.sort_values(by='importance', ascending=False)
    
    if plot:
        plot_shap_summary_with_percentages(
            all_shap_values=[shap_values], 
            all_test_data=[pd.DataFrame(test_features)], 
            res_dir=res_dir,  # or pass a dynamic path
            combo="test_oos"
        )
        plot_pdp_across_runs(
            best_model=best_model,
            res_dir=res_dir,
            all_test_data=[pd.DataFrame(test_features)],
            interaction_pair=("avg_alcmost_freq", "avg_alcmost"),
            title="study_2"
        )
            
    return scores, best_params


## Run Analysis

In [None]:
dataframes = {
    'demo': b5_demographic_response,
    'alc_self': b1_alcohol_self_response,
    'psych': b6_psychometric_response,
    'group_sub': b2_group_subjective_response,
    'group_socio': b3_group_sociometric_response,
    'brain': b4_brain_response,
}

In [None]:
param_grid = {
    "n_estimators": [50],
    "max_depth": [3, 5],
    "min_samples_split": [2, 4, 8],
    "min_samples_leaf": [2, 3, 5]
}

eval_metrics = ['auc', 'f1', 'accuracy', 'specificity', 'sensitivity', 'PPV', 'NPV', 'MCC', 'balancedAcc', 'pr_auc', 'tn', 'fn', 'tp', 'fp']

#### 3-fold CV

##### Normal Run

In [None]:
run_rf_train_test(
    dataframes=dataframes,
    param_grid=param_grid,
    eval_metrics=eval_metrics,
    outer_reps=100,
    k=3,
    CV_reps=5,
    model_choice_metric='auc',
    res_dir="./results_finalbuckets/",
    model_type='rf',
    test_set=0.3,
    resampling=None,
    permutation=False
)

In [None]:
# # sensitivities
# run_rf_train_test(
#     dataframes=dataframes,
#     param_grid=param_grid,
#     eval_metrics=eval_metrics,
#     outer_reps=100,
#     k=5,
#     CV_reps=5,
#     model_choice_metric='auc',
#     res_dir="./results_finalbuckets/",
#     model_type='rf',
#     test_set=0.3,
#     resampling=None,
#     permutation=False
# )

In [None]:
# run_rf_train_test(
#     dataframes=dataframes,
#     param_grid=param_grid,
#     eval_metrics=eval_metrics,
#     outer_reps=100,
#     k=3,
#     CV_reps=5,
#     model_choice_metric='auc',
#     res_dir="./results_finalbuckets/",
#     model_type='rf',
#     test_set=0.4,
#     resampling=None,
#     permutation=False
# )

##### Permutation Test

In [None]:
# run_rf_train_test(
#     dataframes=dataframes,
#     param_grid=param_grid,
#     eval_metrics=eval_metrics,
#     outer_reps=100,
#     k=3,
#     CV_reps=5,
#     model_choice_metric='auc',
#     res_dir="./results_finalbuckets/",
#     model_type='rf',
#     test_set=0.3,
#     resampling=None,
#     permutation=True
# )

##### Out-of-sample Test

In [None]:
# Out-of sample testing
res_dir = './results_finalbuckets/1747686299_rf_outer100_cvrep5_k3_auc_testsize0.3_resamplingNone_permFalse'
loaded_model = joblib.load("./results_finalbuckets/1747686299_rf_outer100_cvrep5_k3_auc_testsize0.3_resamplingNone_permFalse/model_group_sub.joblib")
scores, best_params = test_oos(b2_group_subjective_test, res_dir, loaded_model, None, plot=True)
scores

In [None]:
import os
import joblib
import pandas as pd
import numpy as np
import scipy.stats as st
from sklearn.utils import resample
from tqdm import tqdm
import glob

def evaluate_top_models(res_dir, test_df, target_col='responsive', group_name="('group_sub',)", 
                        top_n=10, n_iterations=10, desired_positive_rate=0.24, plot=False):
    """
    Loads top N models from a result directory, resamples the test set with a desired positive rate, 
    and evaluates each model over multiple iterations.

    Returns:
        summary_df: DataFrame with mean and 95% CI for each metric.
        all_scores_df: Raw scores from each resampling.
    """
    # Load test metrics to identify top models
    df_scores = pd.read_csv(os.path.join(res_dir, "all_test_scores.csv"))
    group_df = df_scores[df_scores['group'] == group_name]

    # Select top models based on AUC
    top_indices = group_df['auc'].nlargest(top_n).index.tolist()

    # Load corresponding models from saved files
    top_models = []
    model_dir = os.path.join(res_dir, "top10_group_sub_models")

    for i in range(1, top_n + 1):
        pattern = os.path.join(model_dir, f"model_rank{i}_*.joblib")
        matched_files = glob.glob(pattern)
        if matched_files:
            top_models.append(joblib.load(matched_files[0]))

    if len(top_models) == 0:
        raise ValueError(f"No models found")

    # Drop missing values
    test_df = test_df.dropna()

    # Split positives and negatives
    positive_cases = test_df[test_df[target_col] == 1]
    negative_cases = test_df[test_df[target_col] == 0]

    all_scores = []

    for model in tqdm(top_models, desc="Evaluating top models"):
        for _ in range(n_iterations):
            # Stratified resampling
            total_samples = len(test_df)
            n_pos = int(total_samples * desired_positive_rate)
            n_neg = total_samples - n_pos

            pos_sample = resample(positive_cases, replace=True, n_samples=n_pos, random_state=None)
            neg_sample = resample(negative_cases, replace=False, n_samples=n_neg, random_state=None)

            balanced_df = pd.concat([pos_sample, neg_sample]).sample(frac=1).reset_index(drop=True)

            scores, _ = test_oos(balanced_df, res_dir, model, [], plot=plot)
            all_scores.append(scores)

    scores_df = pd.DataFrame(all_scores)

    # Compute summary
    mean_scores = scores_df.mean()
    ci_lower, ci_upper = st.t.interval(0.95, df=len(scores_df)-1, loc=mean_scores, scale=scores_df.sem())

    summary_df = pd.DataFrame({
        'Mean': mean_scores,
        '95% CI Lower': ci_lower,
        '95% CI Upper': ci_upper
    })

    return summary_df, scores_df


In [None]:
np.random.seed(SEED)
summary_df, all_scores_df = evaluate_top_models(
    res_dir="/Users/fmagdalena/Documents/GitHub/shine-network-analysis/src/responsiveness/results_finalbuckets/1747686299_rf_outer100_cvrep5_k3_auc_testsize0.3_resamplingNone_permFalse",
    test_df=b2_group_subjective_test,
    top_n=1,
    n_iterations=100,
    desired_positive_rate=0.23,
    plot=False
)

In [None]:
summary_df

### Sensitivity Analyses

In [None]:
# # sensitivity 
# run_rf_train_test(
#     dataframes=dataframes,
#     param_grid=param_grid,
#     eval_metrics=eval_metrics,
#     outer_reps=100,
#     k=3,
#     CV_reps=5,
#     model_choice_metric='auc',
#     res_dir="./results_finalbuckets/",
#     model_type='rf',
#     test_set=0.4,
#     resampling=None,
#     permutation=False
# )

# run_rf_train_test(
#     dataframes=dataframes,
#     param_grid=param_grid,
#     eval_metrics=eval_metrics,
#     outer_reps=100,
#     k=5,
#     CV_reps=5,
#     model_choice_metric='auc',
#     res_dir="./results_finalbuckets/",
#     model_type='rf',
#     test_set=0.3,
#     resampling=None,
#     permutation=False
# )