In [None]:
# # Notebook 2: Patient-Wise Validation with Group Cross-Validation
#
# ## Goals
# * Understand the critical concept of **Information Leakage** in medical datasets.
# * Learn why standard K-Fold/Stratified K-Fold fail when multiple samples exist per patient.
# * Implement **Patient-Wise (Group) Cross-Validation** using `GroupKFold`.
# * Understand and implement `LeaveOneGroupOut` (LOGO).
# * Demonstrate how `GroupKFold` prevents patient overlap between training and validation folds.
#
# **This is arguably one of the most important validation concepts for medical AI.**

# ## 1. The Problem: Information Leakage from Non-Independent Samples
#
# In many medical datasets, we have multiple data points (samples) originating from the same individual patient or subject. Examples:
# *   Multiple MRI slices from the same patient scan.
# *   Several ECG readings over time from one person.
# *   Multiple pathology slide images from a single biopsy.
# *   Longitudinal data from Electronic Health Records (EHRs) for one patient.
#
# These samples are **not statistically independent**. Samples from the same patient are likely to be more similar to each other than to samples from different patients due to shared underlying biology, genetics, lifestyle, etc.
#
# **What happens if we use standard K-Fold or Stratified K-Fold?**
# These methods split data at the *sample* level. It's highly likely that for a given fold, some samples from Patient A will end up in the training set, while other samples from the *same* Patient A end up in the validation set.
#
# **This is Information Leakage:** The model learns patient-specific features during training. When it evaluates samples from the same patient in the validation set, it performs well not necessarily because it learned a generalizable biological pattern, but because it recognizes the patient it saw in training.
#
# **Result:** The validation score becomes **artificially inflated**, giving a false sense of high performance. The model may fail drastically when deployed on *new, unseen patients*.
#
# **Goal:** We typically want to estimate how well the model generalizes to **new patients**, not just new samples from patients already seen.

# ## 2. The Solution: Group (Patient-Wise) Cross-Validation
#
# Group CV techniques ensure that all samples belonging to the same group (e.g., patient) are kept together within the *same* split (either all in training or all in validation for a given fold).
#
# *   **`GroupKFold`**: Partitions *groups* into K folds. Each group is assigned to exactly one validation fold across the K iterations.
# *   **`LeaveOneGroupOut` (LOGO)**: Iterates through each unique group, holding out that entire group for validation and training on all other groups. (Equivalent to `GroupKFold` where `n_splits` equals the number of unique groups).

# ## 3. Setup and Data
#
# We need the development data including the crucial `groups_dev` array (patient IDs) generated in Notebook 0.

# +
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import (
    KFold, StratifiedKFold, GroupKFold, LeaveOneGroupOut,
    cross_val_score, cross_validate # Use cross_validate for more metrics
)
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, roc_auc_score
import time

# %matplotlib inline
sns.set(style='whitegrid')

# Assume X_dev, y_dev, groups_dev are loaded from Notebook 0
# If you didn't use %store or are running standalone, regenerate or load them here.
RANDOM_STATE = 42
try:
    X_dev.shape # Check if variable exists
    print("Using data loaded from previous notebook.")
    # Ensure groups are integers (required by scikit-learn GroupKFold)
    groups_dev = groups_dev.astype(int)
except NameError:
    print("Generating synthetic data for standalone execution...")
    from sklearn.datasets import make_classification
    N_SAMPLES_DEV = 400
    N_FEATURES = 20
    N_CLASSES = 2
    N_PATIENTS_DEV = 80
    IMBALANCE = 0.8
    X_dev, y_dev = make_classification(
        n_samples=N_SAMPLES_DEV, n_features=N_FEATURES, n_informative=10, n_redundant=5, n_repeated=0,
        n_classes=N_CLASSES, n_clusters_per_class=2, weights=[IMBALANCE, 1.0 - IMBALANCE],
        flip_y=0.05, class_sep=0.8, random_state=RANDOM_STATE
    )
    samples_per_patient = N_SAMPLES_DEV // N_PATIENTS_DEV
    groups_dev = np.repeat(np.arange(N_PATIENTS_DEV), samples_per_patient)
    remaining_samples = N_SAMPLES_DEV % N_PATIENTS_DEV
    if remaining_samples > 0:
        groups_dev = np.concatenate([groups_dev, np.random.choice(N_PATIENTS_DEV, remaining_samples)])
    np.random.seed(RANDOM_STATE)
    np.random.shuffle(groups_dev)
    groups_dev = groups_dev.astype(int)
    print(f"Generated X_dev shape: {X_dev.shape}, y_dev shape: {y_dev.shape}, groups_dev shape: {groups_dev.shape}")


# Define the model
model = LogisticRegression(solver='liblinear', random_state=RANDOM_STATE)

# Calculate number of unique groups (patients)
n_unique_groups = len(np.unique(groups_dev))
print(f"\nNumber of unique patients in Dev set: {n_unique_groups}")
# -

# ## 4. Applying GroupKFold
#
# `GroupKFold` requires the `groups` array to be passed during splitting. `n_splits` cannot be greater than the number of unique groups. It does *not* shuffle by default; the assignment of groups to folds is deterministic based on their first appearance.

# +
# Ensure n_splits is not greater than the number of unique groups
N_SPLITS_GROUP = 5
if N_SPLITS_GROUP > n_unique_groups:
    N_SPLITS_GROUP = n_unique_groups
    print(f"Warning: n_splits reduced to {N_SPLITS_GROUP} (number of unique groups)")

gkf = GroupKFold(n_splits=N_SPLITS_GROUP)

print(f"\n--- Running {N_SPLITS_GROUP}-Fold Group Cross-Validation ---")

# Using cross_validate to get multiple scores easily
scoring_metrics = ['accuracy', 'roc_auc']

start_time = time.time()
# **Crucially, pass the 'groups' array to cross_validate**
gkf_results = cross_validate(
    model,
    X_dev,
    y_dev,
    groups=groups_dev, # THIS IS THE KEY DIFFERENCE
    cv=gkf,
    scoring=scoring_metrics,
    n_jobs=-1,
    return_train_score=False # Don't need train scores for this demo
)
gkf_time = time.time() - start_time

print("\nGroupKFold Results:")
print(f"  Fold Test Accuracies: {gkf_results['test_accuracy']}")
print(f"  Mean Test Accuracy:   {gkf_results['test_accuracy'].mean():.4f}")
print(f"  Std Test Accuracy:    {gkf_results['test_accuracy'].std():.4f}")
print(f"\n  Fold Test AUCs:       {gkf_results['test_roc_auc']}")
print(f"  Mean Test AUC:        {gkf_results['test_roc_auc'].mean():.4f}")
print(f"  Std Test AUC:         {gkf_results['test_roc_auc'].std():.4f}")
print(f"\nTime taken: {gkf_time:.2f} seconds")

# Compare with StratifiedKFold (often overly optimistic if groups exist)
print("\n--- For Comparison: Running Stratified K-Fold (potential leakage) ---")
skf = StratifiedKFold(n_splits=N_SPLITS_GROUP, shuffle=True, random_state=RANDOM_STATE)
skf_results = cross_validate(
    model, X_dev, y_dev, # No groups argument here!
    cv=skf, scoring=scoring_metrics, n_jobs=-1
)
print("\nStratifiedKFold Results (POTENTIALLY INFLATED):")
print(f"  Mean Test Accuracy:   {skf_results['test_accuracy'].mean():.4f} (+/- {skf_results['test_accuracy'].std():.4f})")
print(f"  Mean Test AUC:        {skf_results['test_roc_auc'].mean():.4f} (+/- {skf_results['test_roc_auc'].std():.4f})")
# -

# **Observation:** Notice if the `GroupKFold` scores (especially AUC) are noticeably lower or have a different standard deviation compared to the `StratifiedKFold` scores. If they are, it strongly suggests that standard CV was suffering from information leakage, and the `GroupKFold` results provide a more realistic estimate of performance on *new patients*.

# ## 5. Demonstrating Group Separation Manually
#
# Let's iterate through the `GroupKFold` splits manually to explicitly show that patient IDs do not overlap between training and validation sets in any fold.

# +
print("\n--- Manual Group K-Fold Loop: Verifying Patient Separation ---")
fold_counter = 1
patient_overlap_found = False

for train_index, val_index in gkf.split(X_dev, y_dev, groups=groups_dev):
    print(f"\nFold {fold_counter}:")
    X_train_fold, X_val_fold = X_dev[train_index], X_dev[val_index]
    y_train_fold, y_val_fold = y_dev[train_index], y_dev[val_index]
    groups_train_fold, groups_val_fold = groups_dev[train_index], groups_dev[val_index]

    print(f"  Train size: {len(X_train_fold)}, Val size: {len(X_val_fold)}")

    # Get unique patient IDs in this fold's train and validation sets
    train_patients = set(groups_train_fold)
    val_patients = set(groups_val_fold)
    print(f"  Unique patients in Train set: {len(train_patients)}")
    print(f"  Unique patients in Val set:   {len(val_patients)}")

    # THE CRITICAL CHECK: Ensure no patient is in both sets
    common_patients_in_fold = train_patients.intersection(val_patients)
    print(f"  Patients common to Train and Val in this fold: {len(common_patients_in_fold)}")

    if len(common_patients_in_fold) > 0:
        print(f"  WARNING: Patient overlap detected in Fold {fold_counter}!")
        patient_overlap_found = True
        # break # Stop if overlap found, indicates an issue

    fold_counter += 1

if not patient_overlap_found:
    print("\nSuccess: No patient overlap was detected between train/validation sets in any fold using GroupKFold.")
else:
    print("\nError: Patient overlap detected. Check implementation or group assignments.")
# -

# ## 6. LeaveOneGroupOut (LOGO)
#
# This validates the model on each patient group individually. It's computationally more expensive if you have many groups (patients).

# +
logo = LeaveOneGroupOut()
n_splits_logo = logo.get_n_splits(X_dev, y_dev, groups=groups_dev)
print(f"\n--- Running Leave-One-Group-Out Cross-Validation ({n_splits_logo} Splits) ---")
print(f"(This corresponds to {n_unique_groups} unique patients)")

# This can be slow if n_unique_groups is large
if n_splits_logo <= 100: # Demo limit
    start_time = time.time()
    logo_results = cross_validate(
        model,
        X_dev,
        y_dev,
        groups=groups_dev, # Pass groups!
        cv=logo,
        scoring=scoring_metrics,
        n_jobs=-1
    )
    logo_time = time.time() - start_time

    print("\nLOGO Results:")
    print(f"  Mean Test Accuracy:   {logo_results['test_accuracy'].mean():.4f} (+/- {logo_results['test_accuracy'].std():.4f})")
    print(f"  Mean Test AUC:        {logo_results['test_roc_auc'].mean():.4f} (+/- {logo_results['test_roc_auc'].std():.4f})")
    print(f"\nTime taken: {logo_time:.2f} seconds")
else:
    print(f"\nSkipping LOGO execution as number of groups ({n_splits_logo}) is large.")

# -

# ## 7. Conclusion
#
# **If your data has multiple samples per patient (or other grouping factor), using `GroupKFold` (or related group methods) is essential for obtaining a realistic and trustworthy estimate of your model's generalization performance to new, unseen subjects.** Standard K-Fold or Stratified K-Fold will likely produce overly optimistic results due to information leakage. Always check your data structure and apply patient-wise validation when appropriate in medical AI tasks.