In [3]:
print('IT IS RUNNING!')
"""
Neural network version (MLPClassifier: 1 hidden layer, 12 nodes, probabilistic output)

Pipeline:
1) Load CSVs + tag with source_file
2) Build index rows + features + labels
3) Leave-one-CSV-out CV
4) Impute + scale (NN needs scaling)
5) Oversample positives in the TRAIN fold only
6) Fit MLPClassifier(12 hidden units), predict_proba, pool metrics
"""

import argparse
from pathlib import Path
import math

import numpy as np
import pandas as pd

from sklearn.metrics import (
    roc_auc_score,
    average_precision_score,
    brier_score_loss,
    classification_report,
    precision_recall_curve,
)
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler
from sklearn.neural_network import MLPClassifier


# DEFAULT SETTINGS
DEFAULT_CSV_PATHS = ["/Users/colleenohare/Desktop/Bioinformatics_MSC/RETFOUND/Chris_Sainsbury/Inspect_Dataset/inspect_data_csv/data_0.csv",
                     "/Users/colleenohare/Desktop/Bioinformatics_MSC/RETFOUND/Chris_Sainsbury/Inspect_Dataset/inspect_data_csv/data_1.csv",
                     "/Users/colleenohare/Desktop/Bioinformatics_MSC/RETFOUND/Chris_Sainsbury/Inspect_Dataset/inspect_data_csv/data_2.csv",
                     "/Users/colleenohare/Desktop/Bioinformatics_MSC/RETFOUND/Chris_Sainsbury/Inspect_Dataset/inspect_data_csv/data_3.csv",
                     "/Users/colleenohare/Desktop/Bioinformatics_MSC/RETFOUND/Chris_Sainsbury/Inspect_Dataset/inspect_data_csv/data_4.csv"]

DEFAULT_TEST_SIZE = 0.2
DEFAULT_LOOKBACK_DAYS = 90
DEFAULT_SPARSE_GAP_DAYS = 6
DEFAULT_DENSE_GAP_DAYS = 2
DEFAULT_MIN_CLUSTER_DAYS = 4
DEFAULT_CLUSTER_FROM_TABLE = "measurement"
EXCLUSION_BEFORE_CLUSTER_DAYS = 7
PREDICTION_HORIZON_DAYS = 90

CODE_SBP = "LOINC/8480-6"
CODE_HR = "LOINC/8867-4"
CODE_GLUCOSE = "SNOMED/271649006"
CODE_HBA1C = "SNOMED/271650006"


# ---------- helper: oversample minority class in training fold
def oversample_minority(X, y, random_state=42):
    """
    Randomly oversample minority class to match majority count.
    X: numpy array (n_samples, n_features)
    y: numpy array (n_samples,)
    """
    rng = np.random.default_rng(random_state)
    y = y.astype(int)

    idx_pos = np.where(y == 1)[0]
    idx_neg = np.where(y == 0)[0]

    if len(idx_pos) == 0 or len(idx_neg) == 0:
        return X, y  # can't resample

    # Determine minority / majority
    if len(idx_pos) < len(idx_neg):
        idx_min, idx_maj = idx_pos, idx_neg
    else:
        idx_min, idx_maj = idx_neg, idx_pos

    n_to_add = len(idx_maj) - len(idx_min)
    if n_to_add <= 0:
        return X, y

    extra_idx = rng.choice(idx_min, size=n_to_add, replace=True)
    X_bal = np.vstack([X, X[extra_idx]])
    y_bal = np.concatenate([y, y[extra_idx]])

    # Shuffle
    perm = rng.permutation(len(y_bal))
    return X_bal[perm], y_bal[perm]


# ---------- rest of your functions (unchanged) ----------
def parse_mixed_datetime(series):
    series = series.astype("string")
    mask_slash = series.str.contains("/", na=False)
    parsed_slash = pd.to_datetime(series.where(mask_slash), format="%d/%m/%Y %H:%M", errors="coerce")
    parsed_other = pd.to_datetime(series.where(~mask_slash), errors="coerce", dayfirst=False)
    return parsed_slash.fillna(parsed_other)


def load_csv_simple(path):
    print("Loading CSV from", path)
    df = pd.read_csv(path, sep=",", encoding="utf-8-sig", low_memory=False)
    df["time"] = parse_mixed_datetime(df["time"])
    df["numeric_value"] = pd.to_numeric(df.get("numeric_value"), errors="coerce")
    df = df.sort_values(["subject_id", "time"], kind="mergesort").reset_index(drop=True)
    return df


def find_clusters_for_patient(patient_df, sparse_gap_days, dense_gap_days, min_cluster_days, cluster_from_table):
    if cluster_from_table is not None and "table" in patient_df.columns:
        chosen = patient_df[(patient_df["table"] == cluster_from_table) & (~patient_df["time"].isna())].copy()
        if chosen.empty:
            chosen = patient_df.dropna(subset=["time"]).copy()
    else:
        chosen = patient_df.dropna(subset=["time"]).copy()

    if chosen.empty:
        return []

    chosen["day"] = chosen["time"].dt.normalize()
    unique_days = sorted(chosen["day"].unique().tolist())
    clusters = []
    if len(unique_days) < min_cluster_days:
        return clusters

    day_diffs = []
    for i in range(1, len(unique_days)):
        gap_days = (unique_days[i] - unique_days[i - 1]).days
        day_diffs.append(gap_days)

    run_start_index = 0
    for i, gap in enumerate(day_diffs, start=1):
        if gap <= dense_gap_days:
            continue
        run_end_index = i - 1
        run_length = (run_end_index - run_start_index + 1)
        if run_length >= min_cluster_days:
            preceding_gap = math.inf if run_start_index == 0 else day_diffs[run_start_index - 1]
            if preceding_gap >= sparse_gap_days:
                start_day = unique_days[run_start_index]
                end_day = unique_days[run_end_index]
                clusters.append((start_day.normalize(), end_day.normalize(), run_length))
        run_start_index = i

    run_end_index = len(unique_days) - 1
    run_length = (run_end_index - run_start_index + 1)
    if run_length >= min_cluster_days:
        preceding_gap = math.inf if run_start_index == 0 else day_diffs[run_start_index - 1]
        if preceding_gap >= sparse_gap_days:
            start_day = unique_days[run_start_index]
            end_day = unique_days[run_end_index]
            clusters.append((start_day.normalize(), end_day.normalize(), run_length))

    return clusters


def build_index_days(patient_df, clusters, sparse_gap_days):
    df = patient_df.dropna(subset=["time"]).copy()
    if df.empty:
        return []
    df["time"] = pd.to_datetime(df["time"], errors="coerce")
    df = df.dropna(subset=["time"])
    df["day"] = df["time"].dt.normalize()

    cluster_day_set = {
        pd.Timestamp(d).normalize()
        for start_day, end_day, _ in clusters
        for d in pd.date_range(start_day, end_day, freq="D")
    }

    day_rows = df[["day"]].drop_duplicates().sort_values("day").reset_index(drop=True)
    index_list, prev_day = [], None
    sid = patient_df["subject_id"].iloc[0]

    for i in range(len(day_rows)):
        current_day = day_rows.loc[i, "day"]
        in_cluster = current_day in cluster_day_set
        gap_days = math.inf if prev_day is None else (current_day - prev_day).days
        if (not in_cluster) and gap_days >= sparse_gap_days:
            first_time = df.loc[df["day"] == current_day, "time"].min()
            if "source_file" in df.columns:
                src_candidates = df.loc[(df["day"] == current_day) & (df["time"] == first_time), "source_file"]
                src_file = src_candidates.iloc[0] if len(src_candidates) > 0 else "unknown"
            else:
                src_file = "unknown"
            index_list.append((sid, first_time, current_day, src_file))
        prev_day = current_day

    return index_list


def build_features_for_index(patient_df, index_time, lookback_days):
    result = {}
    meas_df = patient_df.copy()
    if "table" in meas_df.columns:
        meas_df = meas_df[meas_df["table"] == "measurement"].copy()

    lookback_start = index_time - pd.Timedelta(days=lookback_days)
    codes = [CODE_SBP, CODE_HR, CODE_GLUCOSE, CODE_HBA1C]

    for code_val in codes:
        col_prefix = code_val.replace("/", "_")
        code_df = meas_df[(meas_df["code"] == code_val) & (~meas_df["time"].isna())].copy()

        last_df = code_df[code_df["time"] <= index_time].sort_values("time")
        if last_df.empty:
            last_val = np.nan
            last_days_since = np.nan
        else:
            last_row = last_df.iloc[-1]
            last_val = last_row["numeric_value"]
            last_days_since = (index_time - last_row["time"]).days

        window_df = code_df[(code_df["time"] >= lookback_start) & (code_df["time"] <= index_time)]
        window_vals = window_df["numeric_value"].dropna()

        if window_vals.empty:
            window_median = np.nan
            window_iqr = np.nan
            window_count = 0
        else:
            window_median = float(window_vals.median())
            window_count = int(len(window_vals))
            if window_count >= 4:
                q1 = float(window_vals.quantile(0.25))
                q3 = float(window_vals.quantile(0.75))
                window_iqr = q3 - q1
            else:
                window_iqr = np.nan

        result[f"{col_prefix}_last"] = last_val
        result[f"{col_prefix}_days_since_last"] = last_days_since
        result[f"{col_prefix}_median_{lookback_days}d"] = window_median
        result[f"{col_prefix}_iqr_{lookback_days}d"] = window_iqr
        result[f"{col_prefix}_count_{lookback_days}d"] = window_count

    return result


def label_index_row(index_time, clusters):
    if not clusters:
        return 0
    horizon_start = (index_time + pd.Timedelta(days=EXCLUSION_BEFORE_CLUSTER_DAYS + 1)).normalize()
    horizon_end = (index_time + pd.Timedelta(days=PREDICTION_HORIZON_DAYS)).normalize()
    for (start_time, _, _) in clusters:
        if (start_time >= horizon_start) and (start_time <= horizon_end):
            return 1
    return 0


def best_f1_threshold(y_true, y_prob):
    precision, recall, thresholds = precision_recall_curve(y_true, y_prob)
    f1 = (2 * precision * recall) / (precision + recall + 1e-12)
    best_idx = int(np.nanargmax(f1))
    thr = float(thresholds[best_idx]) if best_idx < len(thresholds) else 1.0
    return thr, float(f1[best_idx]), float(precision[best_idx]), float(recall[best_idx])


def main():
    parser = argparse.ArgumentParser(description="Beginner temporal hospitalization model (multi-CSV)")
    parser.add_argument("-f", default=None, help=argparse.SUPPRESS)
    parser.add_argument("--csv", nargs="+", default=DEFAULT_CSV_PATHS, help="One or more CSV file paths.")
    parser.add_argument("--test-size", type=float, default=DEFAULT_TEST_SIZE)
    parser.add_argument("--lookback-days", type=int, default=DEFAULT_LOOKBACK_DAYS)
    parser.add_argument("--sparse-gap-days", type=int, default=DEFAULT_SPARSE_GAP_DAYS)
    parser.add_argument("--dense-gap-days", type=int, default=DEFAULT_DENSE_GAP_DAYS)
    parser.add_argument("--min-cluster-days", type=int, default=DEFAULT_MIN_CLUSTER_DAYS)
    parser.add_argument("--cluster-from-table", default=DEFAULT_CLUSTER_FROM_TABLE)
    args, unknown = parser.parse_known_args()
    if unknown:
        print(f"Ignoring unknown args: {unknown}")

    cluster_from_table = None if str(args.cluster_from_table).lower() == "none" else args.cluster_from_table

    dfs = []
    for path in args.csv:
        p = Path(path).expanduser()
        part = load_csv_simple(str(p))
        part["source_file"] = p.name
        dfs.append(part)

    df = (pd.concat(dfs, ignore_index=True)
          .sort_values(["subject_id", "time"])
          .reset_index(drop=True))
    print(f"Combined rows: {len(df)}  Files loaded: {len(dfs)}")

    subjects = df["subject_id"].dropna().unique()
    print(f"Unique subjects: {len(subjects)}")

    rows = []
    for sid, patient in df.groupby("subject_id", sort=False):
        clusters = find_clusters_for_patient(
            patient,
            sparse_gap_days=args.sparse_gap_days,
            dense_gap_days=args.dense_gap_days,
            min_cluster_days=args.min_cluster_days,
            cluster_from_table=cluster_from_table
        )
        index_days = build_index_days(patient, clusters, sparse_gap_days=args.sparse_gap_days)

        for sid2, idx_time, idx_day, src_file in index_days:
            feats = build_features_for_index(patient, idx_time, lookback_days=args.lookback_days)
            label = label_index_row(idx_time, clusters)
            row = {
                "subject_id": sid2,
                "index_time": idx_time,
                "index_day": idx_day,
                "label": label,
                "source_file": src_file
            }
            row.update(feats)
            rows.append(row)

    data = (pd.DataFrame(rows)
            .sort_values(["subject_id", "index_time"])
            .reset_index(drop=True))

    print(f"Dataset size: {len(data)}  Positives: {data['label'].sum()}  Negatives: {len(data)-data['label'].sum()}")

    meta_cols = ["subject_id", "index_time", "index_day", "label", "source_file"]
    feature_cols = [c for c in data.columns if c not in meta_cols]

    files = sorted(data["source_file"].dropna().unique().tolist())
    if len(files) < 2:
        print("Not enough distinct source_file values to run leave-one-CSV-out CV.")
        return

    y_true_all = []
    y_prob_all = []

    for test_file in files:
        test_df = data[data["source_file"] == test_file]
        train_df = data[data["source_file"] != test_file]

        if test_df.empty or train_df.empty:
            continue

        X_train = train_df[feature_cols].to_numpy(dtype=float)
        y_train = train_df["label"].astype(int).to_numpy()
        X_test = test_df[feature_cols].to_numpy(dtype=float)
        y_test = test_df["label"].astype(int).to_numpy()

        if len(np.unique(y_train)) < 2:
            continue

        # Impute + scale
        imputer = SimpleImputer(strategy="median")
        scaler = StandardScaler()

        X_train_imp = imputer.fit_transform(X_train)
        X_test_imp = imputer.transform(X_test)

        X_train_scaled = scaler.fit_transform(X_train_imp)
        X_test_scaled = scaler.transform(X_test_imp)

        # Oversample minority class in TRAIN fold only (fix imbalance without sample_weight)
        X_train_bal, y_train_bal = oversample_minority(X_train_scaled, y_train, random_state=42)

        model = MLPClassifier(
            hidden_layer_sizes=(12,),
            activation="relu",
            solver="adam",
            alpha=1e-4,
            learning_rate_init=1e-3,
            max_iter=500,
            early_stopping=True,
            validation_fraction=0.2,
            n_iter_no_change=20,
            random_state=42
        )

        model.fit(X_train_bal, y_train_bal)
        y_prob = model.predict_proba(X_test_scaled)[:, 1]

        y_true_all.append(y_test)
        y_prob_all.append(y_prob)

    if not y_true_all:
        print("No CV folds produced predictions (check class balance per fold / source_file assignment).")
        return

    y_true_all = np.concatenate(y_true_all)
    y_prob_all = np.concatenate(y_prob_all)

    print(f"Test rows: {len(y_true_all)}  Positives: {int(y_true_all.sum())}  Negatives: {int(len(y_true_all)-y_true_all.sum())}")
    print(f"Prevalence: {y_true_all.mean():.4f}")

    if len(np.unique(y_true_all)) > 1:
        print("AUROC:", round(roc_auc_score(y_true_all, y_prob_all), 3))
        print("AUPRC:", round(average_precision_score(y_true_all, y_prob_all), 3))

    print("Brier:", round(brier_score_loss(y_true_all, y_prob_all), 4))

    qs = np.quantile(y_prob_all, [0.0, 0.5, 0.9, 0.99, 1.0])
    print(f"Predicted prob quantiles [min,50%,90%,99%,max]: {np.round(qs, 4)}")

    y_pred_05 = (y_prob_all >= 0.5).astype(int)
    print("\n=== Classification report @ threshold = 0.5 ===")
    print(classification_report(y_true_all, y_pred_05, digits=3))

    thr, best_f1, best_p, best_r = best_f1_threshold(y_true_all, y_prob_all)
    y_pred_best = (y_prob_all >= thr).astype(int)
    print(f"\n=== Best-F1 threshold from PR curve ===")
    print(f"Chosen threshold: {thr:.4f}  (F1={best_f1:.3f}, Precision={best_p:.3f}, Recall={best_r:.3f})")
    print(classification_report(y_true_all, y_pred_best, digits=3))


if __name__ == "__main__":
    main()

IT IS RUNNING!
Loading CSV from /Users/colleenohare/Desktop/Bioinformatics_MSC/RETFOUND/Chris_Sainsbury/Inspect_Dataset/inspect_data_csv/data_0.csv
Loading CSV from /Users/colleenohare/Desktop/Bioinformatics_MSC/RETFOUND/Chris_Sainsbury/Inspect_Dataset/inspect_data_csv/data_1.csv
Loading CSV from /Users/colleenohare/Desktop/Bioinformatics_MSC/RETFOUND/Chris_Sainsbury/Inspect_Dataset/inspect_data_csv/data_2.csv
Loading CSV from /Users/colleenohare/Desktop/Bioinformatics_MSC/RETFOUND/Chris_Sainsbury/Inspect_Dataset/inspect_data_csv/data_3.csv
Loading CSV from /Users/colleenohare/Desktop/Bioinformatics_MSC/RETFOUND/Chris_Sainsbury/Inspect_Dataset/inspect_data_csv/data_4.csv
Combined rows: 10502789  Files loaded: 5
Unique subjects: 946
Dataset size: 62171  Positives: 4058  Negatives: 58113
Test rows: 62171  Positives: 4058  Negatives: 58113
Prevalence: 0.0653
AUROC: 0.641
AUPRC: 0.118
Brier: 0.2176
Predicted prob quantiles [min,50%,90%,99%,max]: [7.000e-04 4.039e-01 6.743e-01 8.738e-01 9.9