In [4]:
import pandas as pd
import numpy as np
from sklearn.model_selection import StratifiedKFold, GridSearchCV
from sklearn.metrics import (
    precision_recall_curve, auc, roc_auc_score, classification_report,
    make_scorer, f1_score, recall_score, precision_score
)
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.neighbors import KNeighborsClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from imblearn.over_sampling import SMOTE
import matplotlib.pyplot as plt
import seaborn as sns
import warnings

warnings.filterwarnings("ignore")
sns.set(style='darkgrid')

# Load data
data = pd.read_csv('no_transformer_data.csv')
X = data.drop(columns=['Outcome'])
y = data['Outcome']

# Define models and hyperparameter grids
models = {
    'KNN': (
        KNeighborsClassifier(),
        {
            'model__n_neighbors': [3, 5, 7, 9, 11, 15, 21],
            'model__weights': ['uniform', 'distance'],
            'model__metric': ['euclidean', 'manhattan']
        }
    ),
    'SVM': (
        SVC(probability=True, class_weight='balanced', random_state=42),
        {
            'model__C': [0.1, 1, 10],
            'model__kernel': ['linear']
        }
    ),
    'LogReg': (
        LogisticRegression(class_weight='balanced', max_iter=1000, random_state=42),
        {
            'model__C': [0.01, 0.1, 1, 10, 100],
            'model__penalty': ['l1', 'l2'],
            'model__solver': ['liblinear']
        }
    ),
    'RandomForest': (
        RandomForestClassifier(class_weight='balanced', random_state=42),
        {
            'model__n_estimators': [100, 200],
            'model__max_depth': [10, 20, None],
            'model__min_samples_split': [2, 5],
            'model__min_samples_leaf': [1, 3]
        }
    )
}

# Scoring
scoring = {
    'f1': make_scorer(f1_score, pos_label=1),
    'recall': make_scorer(recall_score, pos_label=1),
    'precision': make_scorer(precision_score, pos_label=1),
    'roc_auc': make_scorer(roc_auc_score, needs_proba=True)
}

# Outer cross-validation
outer_cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
pr_curves = {}

# Run outer CV for each model
for model_name, (estimator, param_grid) in models.items():
    print(f"\n--- {model_name} ---")
    precision_list, recall_list, auc_list = [], [], []

    for fold, (train_idx, test_idx) in enumerate(outer_cv.split(X, y), 1):
        X_train_raw, X_test = X.iloc[train_idx], X.iloc[test_idx]
        y_train_raw, y_test = y.iloc[train_idx], y.iloc[test_idx]

        # Apply SMOTE only for KNN
        if model_name == 'KNN':
            smote = SMOTE(random_state=fold)
            X_train, y_train = smote.fit_resample(X_train_raw, y_train_raw)
        else:
            X_train, y_train = X_train_raw, y_train_raw

        pipeline = Pipeline([
            ('scaler', StandardScaler()),
            ('model', estimator)
        ])

        grid_search = GridSearchCV(
            pipeline, param_grid, cv=3, scoring=scoring,
            refit='recall', n_jobs=-1
        )
        grid_search.fit(X_train, y_train)
        best_model = grid_search.best_estimator_
        y_pred_proba = best_model.predict_proba(X_test)[:, 1]

        precision, recall, _ = precision_recall_curve(y_test, y_pred_proba)
        pr_auc = auc(recall, precision)

        precision_list.append(precision)
        recall_list.append(recall)
        auc_list.append(pr_auc)

    pr_curves[model_name] = (precision_list, recall_list, auc_list)

# Plot separate PR curves per model
for model_name, (precisions, recalls, aucs) in pr_curves.items():
    plt.figure(figsize=(8, 6))
    for i in range(len(precisions)):
        plt.plot(recalls[i], precisions[i], label=f'Fold {i+1} (AUC = {aucs[i]:.2f})')

    plt.xlabel('Recall (Sensitivity)')
    plt.ylabel('Precision')
    plt.title(f'Precision–Recall Curve — {model_name} (5-fold CV)')
    plt.legend(loc='lower left')
    plt.grid(True)
    plt.tight_layout()

    plt.savefig(f'../../output/PR_{model_name}.png', dpi=300)
    plt.close()



--- KNN ---

--- SVM ---

--- LogReg ---

--- RandomForest ---
