In [1]:
print("IT IS RUNNING!")

"""
Neural Network temporal cluster-start prediction — V2 

Pipeline ::::

1) Load MULTIPLE CSVs (data_0.csv ... data_4.csv), tag each row with its file of origin (source_file),
   and concatenate into one event-level dataframe.
2) For each patient (subject_id):
   a) Identify "clusters" of dense measurement activity (runs of calendar days with small day gaps).
   b) Construct index days using UPDATED rule:
      - pick ONLY ONE index day per quiet spell (first eligible day after a sparse gap)
      - skip days that fall inside clusters
      - drop index days that fall within N days before a future cluster start
   c) Build UPDATED features at each index:
      - value features for 4 codes (SBP, HR, glucose, HbA1c): last, days since last, window median/IQR/count
      - richer dynamics: delta(last - window_median), delta(last - previous), last relative to patient baseline IQR
      - cadence/behaviour features: measurement rhythm + cluster history
   d) Assign label:
      - positive if a cluster START occurs within [label_start_days .. label_end_days] after the index time
3) Add missingness indicator columns for all features before imputation.
4) Leave-one-CSV-out cross-validation:
   - fold = one held-out source_file
   - impute (median) + scale (StandardScaler) using TRAIN fold only
   - oversample minority class in TRAIN fold only
   - fit MLPClassifier(1 hidden layer), predict_proba on the held-out fold
5) Pool out-of-fold predictions and print one combined evaluation.

Improvements from old model (V1) to new model(V2) ::

V1 (older FIXED MLP):
- Label window: 8..90 days after index
- Index-day construction: could create multiple indices during sparse periods
- Simpler features: last/days_since_last + basic window stats
- Minimal behavioural/cadence context
- No explicit ambiguous-negative removal near cluster starts

V2 (this script):
- Label window updated to 7..30 days (DEFAULT_LABEL_START_DAYS..DEFAULT_LABEL_END_DAYS)
- New ambiguous-negative buffer: drop indices within 7 days before any future cluster start
- New index policy: One index per quiet spell after sparse gap (reduces correlated samples)
- Expanded feature set: deltas + patient baseline-relative features + cadence/cluster-history features
- Adds missingness indicator features prior to imputation


"""

import argparse
from pathlib import Path
import math

import numpy as np
import pandas as pd

from sklearn.metrics import (
    roc_auc_score,
    average_precision_score,
    brier_score_loss,
    classification_report,
    precision_recall_curve,
)
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler
from sklearn.neural_network import MLPClassifier



# Default input CSVs (expects the same schema in each file)
DEFAULT_CSV_PATHS = [
    "/Users/colleenohare/Desktop/Bioinformatics_MSC/RETFOUND/Chris_Sainsbury/Inspect_Dataset/inspect_data_csv/data_0.csv",
    "/Users/colleenohare/Desktop/Bioinformatics_MSC/RETFOUND/Chris_Sainsbury/Inspect_Dataset/inspect_data_csv/data_1.csv",
    "/Users/colleenohare/Desktop/Bioinformatics_MSC/RETFOUND/Chris_Sainsbury/Inspect_Dataset/inspect_data_csv/data_2.csv",
    "/Users/colleenohare/Desktop/Bioinformatics_MSC/RETFOUND/Chris_Sainsbury/Inspect_Dataset/inspect_data_csv/data_3.csv",
    "/Users/colleenohare/Desktop/Bioinformatics_MSC/RETFOUND/Chris_Sainsbury/Inspect_Dataset/inspect_data_csv/data_4.csv",
]

# Feature lookback window for summary statistics
DEFAULT_LOOKBACK_DAYS = 90

# Cluster logic:
# - sparse_gap_days: minimum quiet gap required before a dense run can be considered a "new cluster"
# - dense_gap_days: max allowed gap between consecutive observation days within a cluster
# - min_cluster_days: minimum number of unique activity days required for a cluster
DEFAULT_SPARSE_GAP_DAYS = 6
DEFAULT_DENSE_GAP_DAYS = 2
DEFAULT_MIN_CLUSTER_DAYS = 4

# Restrict clustering and measurement features to this table name if present
DEFAULT_CLUSTER_FROM_TABLE = "measurement"

# Updated label horizon
DEFAULT_LABEL_START_DAYS = 7
DEFAULT_LABEL_END_DAYS = 30

# UPDATED: drop ambiguous negatives that are too close to a future cluster start
DEFAULT_NEGATIVE_BUFFER_BEFORE_CLUSTER_DAYS = 7

# Optional: cap number of negatives per subject (helps extreme imbalance); 0 = OFF
DEFAULT_MAX_NEG_PER_SUBJECT = 0

# Clinical codes used as features
CODE_SBP = "LOINC/8480-6"          # systolic BP
CODE_HR = "LOINC/8867-4"           # heart rate
CODE_GLUCOSE = "SNOMED/271649006"  # glucose
CODE_HBA1C = "SNOMED/271650006"    # HbA1c




# Oversampling is used because sklearn MLPClassifier doesnt accept sample_weight.
# IMPORTANT: oversampling happens ONLY in TRAIN folds (never in test fold). 
def oversample_minority(X, y, random_state=42):
    """
    Randomly oversample minority class to match majority count.
    X: numpy array (n_samples, n_features)
    y: numpy array (n_samples,)
    """
    rng = np.random.default_rng(random_state)
    y = y.astype(int)

    idx_pos = np.where(y == 1)[0]
    idx_neg = np.where(y == 0)[0]

    # If only one class exists, nothing we can resample
    if len(idx_pos) == 0 or len(idx_neg) == 0:
        return X, y

    # Identify minority / majority
    if len(idx_pos) < len(idx_neg):
        idx_min, idx_maj = idx_pos, idx_neg
    else:
        idx_min, idx_maj = idx_neg, idx_pos

    # Add enough samples of the minority to match the majority size
    n_to_add = len(idx_maj) - len(idx_min)
    if n_to_add <= 0:
        return X, y

    extra_idx = rng.choice(idx_min, size=n_to_add, replace=True)
    X_bal = np.vstack([X, X[extra_idx]])
    y_bal = np.concatenate([y, y[extra_idx]])

    # Shuffle so duplicates are not in one contiguous block
    perm = rng.permutation(len(y_bal))
    return X_bal[perm], y_bal[perm]


# Optional: cap negatives per subject to reduce extreme imbalance / correlated negatives.
# NOTE: This is applied AFTER index/feature building, BEFORE CV splitting.
def cap_negatives_per_subject(data: pd.DataFrame, max_neg_per_subject: int, seed: int = 42) -> pd.DataFrame:
    if max_neg_per_subject is None or max_neg_per_subject <= 0:
        return data

    rng = np.random.default_rng(seed)
    kept_rows = []
    for sid, g in data.groupby("subject_id", sort=False):
        pos = g[g["label"] == 1]
        neg = g[g["label"] == 0]
        if len(neg) > max_neg_per_subject:
            idx = rng.choice(neg.index.to_numpy(), size=max_neg_per_subject, replace=False)
            neg = neg.loc[idx]
        kept_rows.append(pd.concat([pos, neg], axis=0))

    return (
        pd.concat(kept_rows, ignore_index=True)
        .sort_values(["subject_id", "index_time"])
        .reset_index(drop=True)
    )



# Parses mixed-format datetime strings in a pandas Series into datetimes.
def parse_mixed_datetime(series: pd.Series) -> pd.Series:
    s = series.astype("string")
    mask_slash = s.str.contains("/", na=False)

    parsed_slash = pd.to_datetime(
        s.where(mask_slash),
        format="%d/%m/%Y %H:%M",
        errors="coerce",
    )
    parsed_other = pd.to_datetime(
        s.where(~mask_slash),
        errors="coerce",
        dayfirst=False,
    )
    return parsed_slash.fillna(parsed_other)


# Loads a CSV, parses time + numeric_value, and sorts records by subject and time.
# Also checks required columns exist.
def load_csv_simple(path: str) -> pd.DataFrame:
    print("Loading CSV from", path)
    df = pd.read_csv(path, sep=",", encoding="utf-8-sig", low_memory=False)

    if "time" not in df.columns:
        raise ValueError(f"'time' column not found in {path}")
    if "subject_id" not in df.columns:
        raise ValueError(f"'subject_id' column not found in {path}")

    df["time"] = parse_mixed_datetime(df["time"])

    # numeric_value may be missing for some event types → coerce safely
    if "numeric_value" in df.columns:
        df["numeric_value"] = pd.to_numeric(df["numeric_value"], errors="coerce")
    else:
        df["numeric_value"] = np.nan

    # Stable sort for reproducibility
    df = df.sort_values(["subject_id", "time"], kind="mergesort").reset_index(drop=True)
    return df



# Identifies clusters of dense measurement activity at day granularity.
# Cluster definition:
# - collapse timestamps to unique calendar days
# - dense run = consecutive days where day gaps <= dense_gap_days
# - accept run as cluster if:
#    (1) run_length >= min_cluster_days
#    (2) preceding gap >= sparse_gap_days (or it's the first run)
def find_clusters_for_patient(
    patient_df: pd.DataFrame,
    sparse_gap_days: int,
    dense_gap_days: int,
    min_cluster_days: int,
    cluster_from_table: str | None,
):
    if cluster_from_table is not None and "table" in patient_df.columns:
        chosen = patient_df[
            (patient_df["table"] == cluster_from_table) & (~patient_df["time"].isna())
        ].copy()
        if chosen.empty:
            chosen = patient_df.dropna(subset=["time"]).copy()
    else:
        chosen = patient_df.dropna(subset=["time"]).copy()

    # No valid timestamps → no clusters
    if chosen.empty:
        return []

    # Work at day granularity
    chosen["day"] = chosen["time"].dt.normalize()
    unique_days = sorted(chosen["day"].unique().tolist())

    clusters = []
    if len(unique_days) < min_cluster_days:
        return clusters

    # Gaps between adjacent observed days
    day_diffs = []
    for i in range(1, len(unique_days)):
        day_diffs.append((unique_days[i] - unique_days[i - 1]).days)

    # Scan dense runs
    run_start_index = 0
    for i, gap in enumerate(day_diffs, start=1):
        if gap <= dense_gap_days:
            continue

        run_end_index = i - 1
        run_length = run_end_index - run_start_index + 1

        if run_length >= min_cluster_days:
            preceding_gap = math.inf if run_start_index == 0 else day_diffs[run_start_index - 1]
            if preceding_gap >= sparse_gap_days:
                start_day = unique_days[run_start_index]
                end_day = unique_days[run_end_index]
                clusters.append((start_day.normalize(), end_day.normalize(), run_length))

        run_start_index = i

    # Tail run
    run_end_index = len(unique_days) - 1
    run_length = run_end_index - run_start_index + 1
    if run_length >= min_cluster_days:
        preceding_gap = math.inf if run_start_index == 0 else day_diffs[run_start_index - 1]
        if preceding_gap >= sparse_gap_days:
            start_day = unique_days[run_start_index]
            end_day = unique_days[run_end_index]
            clusters.append((start_day.normalize(), end_day.normalize(), run_length))

    return clusters


# Convenience: expand cluster spans into a set of calendar days (inclusive).
def build_cluster_day_set(clusters):
    return {
        pd.Timestamp(d).normalize()
        for start_day, end_day, _ in clusters
        for d in pd.date_range(start_day, end_day, freq="D")
    }


# Convenience: list of cluster start days (day-normalized).
def build_cluster_start_days(clusters):
    return sorted([pd.Timestamp(s).normalize() for (s, _, _) in clusters])


# UPDATED: treat index days that fall shortly BEFORE a future cluster start as ambiguous negatives.
# If a cluster begins within buffer_days AFTER index_day, we drop that index.
def is_too_close_to_future_cluster(index_day, cluster_start_days, buffer_days: int) -> bool:
    if buffer_days <= 0 or not cluster_start_days:
        return False
    for s in cluster_start_days:
        delta = (s - index_day).days
        if 0 < delta <= buffer_days:
            return True
    return False


# UPDATED index construction (V2):
# - ONE index day per quiet spell after sparse gap
# - skip cluster days
# - drop indices too close to future cluster starts (ambiguous-negative buffer)
# - store (subject_id, index_time, index_day, source_file)
def build_index_days_one_per_quiet_spell(
    patient_df: pd.DataFrame,
    clusters,
    sparse_gap_days: int,
    negative_buffer_before_cluster_days: int = 0,
):
    df = patient_df.dropna(subset=["time"]).copy()
    if df.empty:
        return []

    df["time"] = pd.to_datetime(df["time"], errors="coerce")
    df = df.dropna(subset=["time"])
    df["day"] = df["time"].dt.normalize()

    cluster_day_set = build_cluster_day_set(clusters)
    cluster_start_days = build_cluster_start_days(clusters)

    day_rows = df[["day"]].drop_duplicates().sort_values("day").reset_index(drop=True)

    index_list = []
    prev_day = None
    sid = patient_df["subject_id"].iloc[0]

    # "eligible_open" means we haven't yet taken an index for the current quiet spell
    eligible_open = True

    for i in range(len(day_rows)):
        current_day = day_rows.loc[i, "day"]
        in_cluster = current_day in cluster_day_set
        gap_days = math.inf if prev_day is None else (current_day - prev_day).days

        # A new quiet spell begins whenever we observe a sufficiently long gap
        if gap_days >= sparse_gap_days:
            eligible_open = True

        # Take the FIRST non-cluster day in that new quiet spell
        if eligible_open and (not in_cluster):
            # Drop ambiguous negatives close to future cluster starts (buffer)
            if is_too_close_to_future_cluster(
                current_day,
                cluster_start_days,
                negative_buffer_before_cluster_days
            ):
                prev_day = current_day
                continue

            first_time = df.loc[df["day"] == current_day, "time"].min()

            # Preserve source_file (used for leave-one-CSV-out CV)
            if "source_file" in df.columns:
                src_candidates = df.loc[(df["day"] == current_day) & (df["time"] == first_time), "source_file"]
                src_file = src_candidates.iloc[0] if len(src_candidates) > 0 else "unknown"
            else:
                src_file = "unknown"

            index_list.append((sid, first_time, current_day, src_file))
            eligible_open = False  # close this quiet spell; no more indices until next sparse gap

        prev_day = current_day

    return index_list



# Robust IQR helper:
# - returns NaN if too few points
def robust_iqr(x: np.ndarray) -> float:
    if len(x) < 4:
        return np.nan
    q1 = float(np.nanquantile(x, 0.25))
    q3 = float(np.nanquantile(x, 0.75))
    return q3 - q1


# Patient baselines (computed once per patient):
# - baseline median and IQR per code across all patient measurement history
# Used to normalise "last value" relative to the patient’s typical range.
def build_patient_baselines(meas_df: pd.DataFrame, codes):
    out = {}
    for code_val in codes:
        vals = meas_df.loc[meas_df["code"] == code_val, "numeric_value"].dropna().to_numpy(dtype=float)
        if len(vals) == 0:
            out[code_val] = (np.nan, np.nan)
        else:
            out[code_val] = (float(np.nanmedian(vals)), robust_iqr(vals))
    return out


# UPDATED value features per code:
# - last value + days since last
# - lookback window median/IQR/count
# - delta(last - window_median)
# - delta(last - previous value)
# - normalised last relative to patient baseline IQR
def build_features_for_index(
    patient_df: pd.DataFrame,
    index_time: pd.Timestamp,
    lookback_days: int,
    patient_baselines: dict,
):
    result = {}

    meas_df = patient_df.copy()
    if "table" in meas_df.columns:
        meas_df = meas_df[meas_df["table"] == "measurement"].copy()

    lookback_start = index_time - pd.Timedelta(days=lookback_days)
    codes = [CODE_SBP, CODE_HR, CODE_GLUCOSE, CODE_HBA1C]

    for code_val in codes:
        col_prefix = code_val.replace("/", "_")
        code_df = meas_df[(meas_df["code"] == code_val) & (~meas_df["time"].isna())].copy()

        # History up to index time
        hist_df = code_df[code_df["time"] <= index_time].sort_values("time")

        if hist_df.empty:
            last_val = np.nan
            last_days_since = np.nan
            prev_val = np.nan
        else:
            last_val = hist_df.iloc[-1]["numeric_value"]
            last_days_since = (index_time - hist_df.iloc[-1]["time"]).days
            prev_val = hist_df.iloc[-2]["numeric_value"] if len(hist_df) >= 2 else np.nan

        # Lookback window stats
        window_df = code_df[(code_df["time"] >= lookback_start) & (code_df["time"] <= index_time)]
        window_vals = window_df["numeric_value"].dropna().to_numpy(dtype=float)

        if len(window_vals) == 0:
            window_median = np.nan
            window_iqr = np.nan
            window_count = 0
        else:
            window_median = float(np.nanmedian(window_vals))
            window_count = int(len(window_vals))
            window_iqr = robust_iqr(window_vals)

        # Deltas
        delta_last_median = (
            np.nan if (np.isnan(last_val) or np.isnan(window_median)) else float(last_val - window_median)
        )
        delta_last_prev = (
            np.nan if (np.isnan(last_val) or np.isnan(prev_val)) else float(last_val - prev_val)
        )

        # Baseline-relative normalisation
        pat_med, pat_iqr = patient_baselines.get(code_val, (np.nan, np.nan))
        if np.isnan(last_val) or np.isnan(pat_med) or np.isnan(pat_iqr) or pat_iqr == 0:
            last_rel = np.nan
        else:
            last_rel = float((last_val - pat_med) / pat_iqr)

        # Store features
        result[f"{col_prefix}_last"] = last_val
        result[f"{col_prefix}_days_since_last"] = last_days_since
        result[f"{col_prefix}_median_{lookback_days}d"] = window_median
        result[f"{col_prefix}_iqr_{lookback_days}d"] = window_iqr
        result[f"{col_prefix}_count_{lookback_days}d"] = window_count
        result[f"{col_prefix}_delta_last_median_{lookback_days}d"] = delta_last_median
        result[f"{col_prefix}_delta_last_prev"] = delta_last_prev
        result[f"{col_prefix}_last_rel_patient_iqr"] = last_rel

    return result


# NEW cadence/behaviour features:
# - measurement rhythm: days since last measurement day, last gap between measurement days
# - activity counts: number of distinct measurement days in last 7/30/90 days
# - cluster history: total clusters, days since last cluster start/end (before index day)
def build_cadence_features(
    patient_df: pd.DataFrame,
    index_time: pd.Timestamp,
    clusters,
):
    out = {}

    df = patient_df.dropna(subset=["time"]).copy()
    if df.empty:
        out["any_days_since_last_meas"] = np.nan
        out["any_last_gap_days"] = np.nan
        for w in (7, 30, 90):
            out[f"any_count_days_{w}d"] = 0
        out["cluster_count_total"] = 0
        out["days_since_last_cluster_start"] = np.nan
        out["days_since_last_cluster_end"] = np.nan
        return out

    # Restrict cadence metrics to measurement table if present
    if "table" in df.columns:
        df = df[df["table"] == "measurement"].copy()

    df["time"] = pd.to_datetime(df["time"], errors="coerce")
    df = df.dropna(subset=["time"])
    if df.empty:
        out["any_days_since_last_meas"] = np.nan
        out["any_last_gap_days"] = np.nan
        for w in (7, 30, 90):
            out[f"any_count_days_{w}d"] = 0
        out["cluster_count_total"] = int(len(clusters)) if clusters else 0
        out["days_since_last_cluster_start"] = np.nan
        out["days_since_last_cluster_end"] = np.nan
        return out

    df["day"] = df["time"].dt.normalize()

    # Measurement days up to index time
    days_upto = df.loc[df["time"] <= index_time, "day"].dropna().unique()
    days_upto = sorted(pd.to_datetime(days_upto))

    if len(days_upto) == 0:
        out["any_days_since_last_meas"] = np.nan
        out["any_last_gap_days"] = np.nan
    else:
        last_day = days_upto[-1]
        out["any_days_since_last_meas"] = float((index_time.normalize() - last_day).days)
        out["any_last_gap_days"] = float((days_upto[-1] - days_upto[-2]).days) if len(days_upto) >= 2 else np.nan

    # Distinct measurement-day counts in short/medium/long windows
    idx_day = index_time.normalize()
    for w in (7, 30, 90):
        start = idx_day - pd.Timedelta(days=w)
        mask = (df["day"] >= start) & (df["day"] <= idx_day)
        out[f"any_count_days_{w}d"] = int(df.loc[mask, "day"].nunique())

    # Cluster history features
    out["cluster_count_total"] = int(len(clusters)) if clusters else 0

    last_start = np.nan
    last_end = np.nan
    if clusters:
        starts = [s for (s, e, _) in clusters if pd.Timestamp(s) <= idx_day]
        ends = [e for (s, e, _) in clusters if pd.Timestamp(e) <= idx_day]
        if starts:
            last_start = float((idx_day - max(starts)).days)
        if ends:
            last_end = float((idx_day - max(ends)).days)

    out["days_since_last_cluster_start"] = last_start
    out["days_since_last_cluster_end"] = last_end

    return out


# Positive if any cluster starts within [label_start_days .. label_end_days] after index_time.
def label_index_row(index_time, clusters, label_start_days: int, label_end_days: int) -> int:
    if not clusters:
        return 0
    horizon_start = (index_time + pd.Timedelta(days=label_start_days)).normalize()
    horizon_end = (index_time + pd.Timedelta(days=label_end_days)).normalize()
    for (start_time, _, _) in clusters:
        if (start_time >= horizon_start) and (start_time <= horizon_end):
            return 1
    return 0


# Choose threshold that maximises F1 from pooled PR curve (useful under class imbalance).
def best_f1_threshold(y_true, y_prob):
    precision, recall, thresholds = precision_recall_curve(y_true, y_prob)
    f1 = (2 * precision * recall) / (precision + recall + 1e-12)
    best_idx = int(np.nanargmax(f1))
    thr = float(thresholds[best_idx]) if best_idx < len(thresholds) else 1.0
    return thr, float(f1[best_idx]), float(precision[best_idx]), float(recall[best_idx])


def main():
    parser = argparse.ArgumentParser(description="Temporal cluster-start prediction (MLP) — V2 UPDATED")
    parser.add_argument("-f", default=None, help=argparse.SUPPRESS)

    # Data + temporal settings
    parser.add_argument("--csv", nargs="+", default=DEFAULT_CSV_PATHS, help="One or more CSV file paths.")
    parser.add_argument("--lookback-days", type=int, default=DEFAULT_LOOKBACK_DAYS)
    parser.add_argument("--sparse-gap-days", type=int, default=DEFAULT_SPARSE_GAP_DAYS)
    parser.add_argument("--dense-gap-days", type=int, default=DEFAULT_DENSE_GAP_DAYS)
    parser.add_argument("--min-cluster-days", type=int, default=DEFAULT_MIN_CLUSTER_DAYS)
    parser.add_argument("--cluster-from-table", default=DEFAULT_CLUSTER_FROM_TABLE)

    # UPDATED label horizon controls (7..30 default)
    parser.add_argument("--label-start-days", type=int, default=DEFAULT_LABEL_START_DAYS)
    parser.add_argument("--label-end-days", type=int, default=DEFAULT_LABEL_END_DAYS)

    # UPDATED: ambiguous-negative buffer (drop indices close to future cluster starts)
    parser.add_argument(
        "--negative-buffer-before-cluster-days",
        type=int,
        default=DEFAULT_NEGATIVE_BUFFER_BEFORE_CLUSTER_DAYS
    )

    # Optional imbalance control at dataset level (in addition to oversampling)
    parser.add_argument("--max-neg-per-subject", type=int, default=DEFAULT_MAX_NEG_PER_SUBJECT)

    # MLP hyperparams (optional overrides)
    parser.add_argument("--hidden-units", type=int, default=12)
    parser.add_argument("--max-iter", type=int, default=500)
    parser.add_argument("--alpha", type=float, default=1e-4)
    parser.add_argument("--lr", type=float, default=1e-3)

    args, unknown = parser.parse_known_args()
    if unknown:
        print(f"Ignoring unknown args: {unknown}")

    # Allow "--cluster-from-table none" to disable table restriction
    cluster_from_table = None if str(args.cluster_from_table).lower() == "none" else args.cluster_from_table

    dfs = []
    for path in args.csv:
        p = Path(path).expanduser()
        part = load_csv_simple(str(p))

        # Tag each row with the CSV filename (defines CV folds)
        part["source_file"] = p.name
        dfs.append(part)

    df = (
        pd.concat(dfs, ignore_index=True)
        .sort_values(["subject_id", "time"])
        .reset_index(drop=True)
    )

    print(f"Combined rows: {len(df)}  Files loaded: {len(dfs)}")

    # Diagnostic: how many files contribute to each subject (helps spot leakage)
    leakage_check = df.groupby("subject_id")["source_file"].nunique().value_counts().sort_index()
    print("\nNumber of CSV files per subject:")
    print(leakage_check)
    print("--------------------------------------------------\n")

    print(f"Unique subjects: {df['subject_id'].nunique(dropna=True)}")


    rows = []
    codes = [CODE_SBP, CODE_HR, CODE_GLUCOSE, CODE_HBA1C]

    for sid, patient in df.groupby("subject_id", sort=False):
        #Find clusters
        clusters = find_clusters_for_patient(
            patient,
            sparse_gap_days=args.sparse_gap_days,
            dense_gap_days=args.dense_gap_days,
            min_cluster_days=args.min_cluster_days,
            cluster_from_table=cluster_from_table,
        )

        # Compute patient baselines once per patient (median/IQR per code)
        meas_for_baseline = patient.copy()
        if "table" in meas_for_baseline.columns:
            meas_for_baseline = meas_for_baseline[meas_for_baseline["table"] == "measurement"].copy()
        patient_baselines = build_patient_baselines(meas_for_baseline, codes)

        #Build UPDATED indices (one per quiet spell, skip clusters, drop ambiguous negatives)
        index_days = build_index_days_one_per_quiet_spell(
            patient,
            clusters,
            sparse_gap_days=args.sparse_gap_days,
            negative_buffer_before_cluster_days=args.negative_buffer_before_cluster_days,
        )

        #For each index: build features + label
        for sid2, idx_time, idx_day, src_file in index_days:
            feats = build_features_for_index(
                patient,
                idx_time,
                lookback_days=args.lookback_days,
                patient_baselines=patient_baselines,
            )

            cadence_feats = build_cadence_features(patient, idx_time, clusters)

            label = label_index_row(
                idx_time,
                clusters,
                label_start_days=args.label_start_days,
                label_end_days=args.label_end_days,
            )

            row = {
                "subject_id": sid2,
                "index_time": idx_time,
                "index_day": idx_day,
                "label": int(label),
                "source_file": src_file,
            }
            row.update(feats)
            row.update(cadence_feats)
            rows.append(row)

    data = (
        pd.DataFrame(rows)
        .sort_values(["subject_id", "index_time"])
        .reset_index(drop=True)
    )

    if data.empty:
        print("No index rows produced. Check your cluster/index settings and data coverage.")
        return

    pos = int(data["label"].sum())
    neg = int(len(data) - pos)
    print(f"Dataset size (pre-cap): {len(data)}  Positives: {pos}  Negatives: {neg}")

    # Optional: cap negatives per subject (default off)
    data = cap_negatives_per_subject(data, args.max_neg_per_subject, seed=42)
    pos = int(data["label"].sum())
    neg = int(len(data) - pos)
    print(f"Dataset size (post-cap): {len(data)}  Positives: {pos}  Negatives: {neg}")


    meta_cols = ["subject_id", "index_time", "index_day", "label", "source_file"]
    feature_cols = [c for c in data.columns if c not in meta_cols]

    #Missingness indicators are added BEFORE imputation (captures informative missingness)
    for c in feature_cols:
        data[c + "_is_missing"] = data[c].isna().astype(int)

    # Refresh feature columns to include missingness flags
    feature_cols = [c for c in data.columns if c not in meta_cols]


    files = sorted(data["source_file"].dropna().unique().tolist())
    if len(files) < 2:
        print("Not enough distinct source_file values to run leave-one-CSV-out CV.")
        return

    y_true_all = []
    y_prob_all = []

    print(f"\nLabel horizon: {args.label_start_days}..{args.label_end_days} days")
    print(f"Negative buffer before cluster starts: {args.negative_buffer_before_cluster_days} days\n")

    # CV loop: each fold holds out one CSV file
    for test_file in files:
        test_df = data[data["source_file"] == test_file]
        train_df = data[data["source_file"] != test_file]

        if test_df.empty or train_df.empty:
            continue

        X_train = train_df[feature_cols].to_numpy(dtype=float)
        y_train = train_df["label"].astype(int).to_numpy()
        X_test = test_df[feature_cols].to_numpy(dtype=float)
        y_test = test_df["label"].astype(int).to_numpy()

        # Skip fold if training has only one class
        if len(np.unique(y_train)) < 2:
            print(f"Skipping fold {test_file}: only one class in training.")
            continue

        #Impute missing values (fit on TRAIN only to avoid leakage)
        imputer = SimpleImputer(strategy="median")
        X_train_imp = imputer.fit_transform(X_train)
        X_test_imp = imputer.transform(X_test)

        #Scale features (NN training is sensitive to feature scale)
        scaler = StandardScaler()
        X_train_scaled = scaler.fit_transform(X_train_imp)
        X_test_scaled = scaler.transform(X_test_imp)

        # Oversample minority class (TRAIN only)
        X_train_bal, y_train_bal = oversample_minority(X_train_scaled, y_train, random_state=42)

        # Fit MLP and predict probabilities
        model = MLPClassifier(
            hidden_layer_sizes=(args.hidden_units,),
            activation="relu",
            solver="adam",
            alpha=args.alpha,
            learning_rate_init=args.lr,
            max_iter=args.max_iter,
            early_stopping=True,
            validation_fraction=0.2,
            n_iter_no_change=20,
            random_state=42,
        )

        model.fit(X_train_bal, y_train_bal)
        y_prob = model.predict_proba(X_test_scaled)[:, 1]

        y_true_all.append(y_test)
        y_prob_all.append(y_prob)

        # Optional: per-fold reporting (kept brief; pooled results are primary)
        if len(np.unique(y_test)) > 1:
            fold_auroc = roc_auc_score(y_test, y_prob)
            fold_auprc = average_precision_score(y_test, y_prob)
            print(f"Fold {test_file}: AUROC={fold_auroc:.3f}  AUPRC={fold_auprc:.3f}  n={len(y_test)}  pos={int(y_test.sum())}")

    if not y_true_all:
        print("No CV folds produced predictions (check class balance per fold / source_file assignment).")
        return

    # Pool out-of-fold predictions
    y_true_all = np.concatenate(y_true_all)
    y_prob_all = np.concatenate(y_prob_all)


    #ONE POOLED EVALUATION


    print("\n================== POOLED OOF RESULTS ==================")
    print(f"Test rows: {len(y_true_all)}  Positives: {int(y_true_all.sum())}  Negatives: {int(len(y_true_all)-y_true_all.sum())}")
    print(f"Prevalence: {y_true_all.mean():.4f}")

    if len(np.unique(y_true_all)) > 1:
        print("AUROC:", round(roc_auc_score(y_true_all, y_prob_all), 3))
        print("AUPRC:", round(average_precision_score(y_true_all, y_prob_all), 3))
    print("Brier:", round(brier_score_loss(y_true_all, y_prob_all), 4))

    # Probability distribution sanity check
    qs = np.quantile(y_prob_all, [0.0, 0.5, 0.9, 0.99, 1.0])
    print(f"Predicted prob quantiles [min,50%,90%,99%,max]: {np.round(qs, 4)}")

    # Report @ 0.5 threshold
    y_pred_05 = (y_prob_all >= 0.5).astype(int)
    print("\n=== Classification report @ threshold = 0.5 ===")
    print(classification_report(y_true_all, y_pred_05, digits=3))

    # Best-F1 threshold (pooled PR curve)
    thr, best_f1, best_p, best_r = best_f1_threshold(y_true_all, y_prob_all)
    y_pred_best = (y_prob_all >= thr).astype(int)
    print("\n=== Best-F1 threshold from PR curve ===")
    print(f"Chosen threshold: {thr:.4f}  (F1={best_f1:.3f}, Precision={best_p:.3f}, Recall={best_r:.3f})")
    print(classification_report(y_true_all, y_pred_best, digits=3))


if __name__ == "__main__":
    main()


IT IS RUNNING!
Loading CSV from /Users/colleenohare/Desktop/Bioinformatics_MSC/RETFOUND/Chris_Sainsbury/Inspect_Dataset/inspect_data_csv/data_0.csv
Loading CSV from /Users/colleenohare/Desktop/Bioinformatics_MSC/RETFOUND/Chris_Sainsbury/Inspect_Dataset/inspect_data_csv/data_1.csv
Loading CSV from /Users/colleenohare/Desktop/Bioinformatics_MSC/RETFOUND/Chris_Sainsbury/Inspect_Dataset/inspect_data_csv/data_2.csv
Loading CSV from /Users/colleenohare/Desktop/Bioinformatics_MSC/RETFOUND/Chris_Sainsbury/Inspect_Dataset/inspect_data_csv/data_3.csv
Loading CSV from /Users/colleenohare/Desktop/Bioinformatics_MSC/RETFOUND/Chris_Sainsbury/Inspect_Dataset/inspect_data_csv/data_4.csv
Combined rows: 10502789  Files loaded: 5

Number of CSV files per subject:
source_file
1    946
Name: count, dtype: int64
--------------------------------------------------

Unique subjects: 946
Dataset size (pre-cap): 62730  Positives: 1464  Negatives: 61266
Dataset size (post-cap): 62730  Positives: 1464  Negatives: 