In [None]:
import scipy.io
import torch
from thoi.measures.gaussian_copula import nplets_measures
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from functools import partial
import h5py
from collections import defaultdict
import time
import itertools
from sklearn.metrics import roc_auc_score, average_precision_score
from sklearn.feature_selection import f_classif
import logging
from tqdm import tqdm
import ast

results_path = "C:/CAMILO/Brain_Entropy/RESULTS"
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.cuda.empty_cache()


def print_time(t_i, t_f):
    elapsed_time_seconds = t_f - t_i
    hours = int(elapsed_time_seconds // 3600)
    minutes = int((elapsed_time_seconds % 3600) // 60)
    seconds = int(elapsed_time_seconds % 60)
    print("Elapsed time: {:02d}:{:02d}:{:02d}".format(hours, minutes, seconds))


def load_covariance_dict(filepath, lazy=False):
    """
    Parameters
    ----------
    filepath : str
        Path to the .h5 file.
    lazy : bool
        If False (default)  → load every dataset into RAM (returns plain NumPy arrays).
        If True             → return h5py.Dataset handles (zero-copy); file must stay open.

    Returns
    -------
    covs : dict
        Two-level dict: covs[dataset][state] = ndarray | h5py.Dataset
    """
    covs = defaultdict(dict)

    if lazy:
        # keep file handle alive by attaching it to the dictionary itself
        h5f = h5py.File(filepath, "r")
        covs["_h5file"] = h5f  # so GC won't close it
        for dataset in h5f:
            for state in h5f[dataset]:
                covs[dataset][state] = h5f[dataset][state]  # h5py.Dataset
    else:
        with h5py.File(filepath, "r") as h5f:
            for dataset in h5f:
                for state in h5f[dataset]:
                    covs[dataset][state] = h5f[dataset][state][:]
    return covs


def evaluate_nplet_batched(
    idx,
    all_covs,
    conscious_states,
    nonresponsive_states,
    selected_dataset,
    optimal_nplet,
    state_c,  # tuple-like (ds, st) for discovery in conscious set
    state_nr,  # tuple-like (ds, st) for discovery in nonresponsive set
    subject_c,  # index to skip in conscious discovery pair
    subject_nr,  # index to skip in nonresponsive discovery pair
    device,
):

    # selected_dataset = result_name
    # optimal_nplet =ast.literal_eval(row["optimal_nplet"])
    # state_c = row["state_c"]  # tuple-like (ds, st) for discovery in conscious set
    # state_nr = row["state_nr"]  # tuple-like (ds, st) for discovery in nonresponsive set
    # subject_c = row["subject_c"]  # index to skip in conscious discovery pair
    # subject_nr = row["subject_nr"]  # index to skip in nonresponsive discovery pair
    """
    Batched evaluation of a single n-plet:
      - Computes TC, DTC, O, S (indices 0..3) in one pass.
      - Processes all subjects per condition in a single call to nplets_measures
        by passing a list of covariance matrices (covmat_precomputed=True).

    Returns
    -------
    pd.DataFrame with columns: measure, F_score, Cohen_d, AUROC, PR_AUC
    """
    measure_list = ["TC", "DTC", "O", "S", "norm_O"]

    # --- Gather covariance matrices per condition (skip discovery subject)
    covs_conscious = []
    for st in conscious_states[selected_dataset]:
        covs = all_covs[selected_dataset][st]  # shape: (n_subj, N, N), covariances
        for subj_idx in range(covs.shape[0]):
            if st == state_c and subj_idx == subject_c:
                continue
            covs_conscious.append(np.asarray(covs[subj_idx]))

    covs_nonresp = []
    for st in nonresponsive_states[selected_dataset]:
        covs = all_covs[selected_dataset][st]
        for subj_idx in range(covs.shape[0]):
            if st == state_nr and subj_idx == subject_nr:
                continue
            covs_nonresp.append(np.asarray(covs[subj_idx]))

    n_c = len(covs_conscious)
    n_nr = len(covs_nonresp)

    if n_c == 0 or n_nr == 0:
        raise ValueError(
            "Empty group after skipping discovery subject(s). Check inputs."
        )

    # --- Single batched call over ALL subjects (conscious first, then nonresponsive)
    X_list = covs_conscious + covs_nonresp  # list of (N,N)
    X_array = np.array(X_list)
    X_tensor = torch.tensor(X_array)
    measures = nplets_measures(
        X_tensor,
        nplets=[optimal_nplet],  # single n-plet
        covmat_precomputed=True,
        T=None,  # keep same behavior as your original code
        device=device,
        verbose=logging.WARNING,
    )
    # measures shape: (1, D, 4) where D = n_c + n_nr
    vals = measures[0, :, :4].detach().cpu().numpy()  # (D, 4)

    ratio = vals[:, 2] / vals[:, 3]  # shape (D,)
    vals = np.column_stack((vals, ratio))  # shape (D,5)
    # --- Labels: 1 for conscious, 0 for nonresponsive
    y = np.concatenate([np.ones(n_c, dtype=int), np.zeros(n_nr, dtype=int)])  # (D,)

    # ANOVA F-score across the 4 columns
    F_vals, _ = f_classif(vals, y)  # shape (4,)
    for j in range(4):
        if np.allclose(vals[:, j].var(), 0):
            F_vals[j] = 0.0
    # AUROC & PR AUC per column
    pr_aucs = []
    neg_pr_aucs = []
    y_neg = 1 - y
    for j in range(vals.shape[1]):
        xj = vals[:, j]
        try:
            pr_aucs.append(average_precision_score(y, xj))
            neg_pr_aucs.append(average_precision_score(y, -xj))
        except ValueError:
            pr_aucs.append(np.nan)
            neg_pr_aucs.append(np.nan)

    out_list = []
    for jdx, measure_ in enumerate(measure_list):
        out_list.append(
            {
                "row_idx": idx,
                "measure": measure_,
                "F_score": F_vals[jdx],
                "PR_AUC": pr_aucs[jdx],
                "PR_AUC_inv": neg_pr_aucs[jdx],
            }
        )
    return out_list


def evaluate_nplet_batched_both_datasets(
    idx,
    all_covs,
    conscious_states,
    nonresponsive_states,
    optimal_nplet,
    state_c,  # tuple-like (ds, st) for discovery in conscious set
    state_nr,  # tuple-like (ds, st) for discovery in nonresponsive set
    subject_c,  # index to skip in conscious discovery pair
    subject_nr,  # index to skip in nonresponsive discovery pair
    device,
):

    # selected_dataset = result_name
    # optimal_nplet =ast.literal_eval(row["optimal_nplet"])
    # state_c = row["state_c"]  # tuple-like (ds, st) for discovery in conscious set
    # state_nr = row["state_nr"]  # tuple-like (ds, st) for discovery in nonresponsive set
    # subject_c = row["subject_c"]  # index to skip in conscious discovery pair
    # subject_nr = row["subject_nr"]  # index to skip in nonresponsive discovery pair
    """
    Batched evaluation of a single n-plet:
      - Computes TC, DTC, O, S (indices 0..3) in one pass.
      - Processes all subjects per condition in a single call to nplets_measures
        by passing a list of covariance matrices (covmat_precomputed=True).

    Returns
    -------
    pd.DataFrame with columns: measure, F_score, Cohen_d, AUROC, PR_AUC
    """
    measure_list = ["TC", "DTC", "O", "S", "norm_O"]

    # --- Gather covariance matrices per condition (skip discovery subject)
    covs_conscious = []
    for ds, states in conscious_states.items():
        for st in states:
            covs = all_covs[ds][st]  # shape: (n_subj, N, N), covariances
            for subj_idx in range(covs.shape[0]):
                if st == state_c and subj_idx == subject_c:
                    continue
                covs_conscious.append(np.asarray(covs[subj_idx]))

    covs_nonresp = []
    for ds, states in nonresponsive_states.items():
        for st in states:
            covs = all_covs[ds][st]
            for subj_idx in range(covs.shape[0]):
                if st == state_nr and subj_idx == subject_nr:
                    continue
                covs_nonresp.append(np.asarray(covs[subj_idx]))

    n_c = len(covs_conscious)
    n_nr = len(covs_nonresp)

    if n_c == 0 or n_nr == 0:
        raise ValueError(
            "Empty group after skipping discovery subject(s). Check inputs."
        )

    # --- Single batched call over ALL subjects (conscious first, then nonresponsive)
    X_list = covs_conscious + covs_nonresp  # list of (N,N)
    X_array = np.array(X_list)
    X_tensor = torch.tensor(X_array)
    measures = nplets_measures(
        X_tensor,
        nplets=[optimal_nplet],  # single n-plet
        covmat_precomputed=True,
        T=None,  # keep same behavior as your original code
        device=device,
        verbose=logging.WARNING,
    )
    # measures shape: (1, D, 4) where D = n_c + n_nr
    vals = measures[0, :, :4].detach().cpu().numpy()  # (D, 4)
    ratio = vals[:, 2] / vals[:, 3]  # shape (D,)
    vals = np.column_stack((vals, ratio))  # shape (D,5)
    # --- Labels: 1 for conscious, 0 for nonresponsive
    y = np.concatenate([np.ones(n_c, dtype=int), np.zeros(n_nr, dtype=int)])  # (D,)

    # ANOVA F-score across the 4 columns
    F_vals, _ = f_classif(vals, y)  # shape (4,)
    for j in range(4):
        if np.allclose(vals[:, j].var(), 0):
            F_vals[j] = 0.0
    # AUROC & PR AUC per column
    pr_aucs = []
    neg_pr_aucs = []
    for j in range(vals.shape[1]):
        xj = vals[:, j]
        try:
            pr_aucs.append(average_precision_score(y, xj))
            neg_pr_aucs.append(average_precision_score(y, -xj))
        except ValueError:
            pr_aucs.append(np.nan)
            neg_pr_aucs.append(np.nan)

    out_list = []
    for jdx, measure_ in enumerate(measure_list):
        out_list.append(
            {
                "row_idx": idx,
                "measure": measure_,
                "F_score": F_vals[jdx],
                "PR_AUC": pr_aucs[jdx],
                "PR_AUC_inv": neg_pr_aucs[jdx],
            }
        )
    return out_list


def evaluate_nplet_batched_FC(
    idx,
    all_covs,
    conscious_states,
    nonresponsive_states,
    selected_dataset,
    optimal_nplet,
    state_c,  # tuple-like (ds, st) for discovery in conscious set
    state_nr,  # tuple-like (ds, st) for discovery in nonresponsive set
    subject_c,  # index to skip in conscious discovery pair
    subject_nr,  # index to skip in nonresponsive discovery pair
    device,
):
    """
    Batched evaluation of a single n-plet:
      - Computes TC, DTC, O, S, norm_O, FC_mean_z in one pass.
      - Processes all subjects per condition in a single call to nplets_measures
        by passing a list of covariance matrices (covmat_precomputed=True).

    Returns
    -------
    list of dicts, one per measure, with:
        row_idx, measure, F_score, PR_AUC, PR_AUC_inv
    """
    # HOI measures + classical FC
    measure_list = ["TC", "DTC", "O", "S", "norm_O", "FC_mean_z"]

    # --- Gather covariance matrices per condition (skip discovery subject)
    covs_conscious = []
    for st in conscious_states[selected_dataset]:
        covs = all_covs[selected_dataset][st]  # (n_subj, N, N)
        for subj_idx in range(covs.shape[0]):
            if st == state_c and subj_idx == subject_c:
                continue
            covs_conscious.append(np.asarray(covs[subj_idx]))

    covs_nonresp = []
    for st in nonresponsive_states[selected_dataset]:
        covs = all_covs[selected_dataset][st]
        for subj_idx in range(covs.shape[0]):
            if st == state_nr and subj_idx == subject_nr:
                continue
            covs_nonresp.append(np.asarray(covs[subj_idx]))

    n_c = len(covs_conscious)
    n_nr = len(covs_nonresp)

    if n_c == 0 or n_nr == 0:
        raise ValueError(
            "Empty group after skipping discovery subject(s). Check inputs."
        )

    # --- Single batched tensor over ALL subjects (conscious first, then nonresponsive)
    X_list = covs_conscious + covs_nonresp  # list of (N,N)
    X_array = np.array(X_list)  # (D, N, N)
    # keep dtype from numpy; move to device for GPU if requested
    X_tensor = torch.as_tensor(X_array, device=device)

    # --- HOI measures via THOI (TC, DTC, O, S)
    measures = nplets_measures(
        X_tensor,
        nplets=[optimal_nplet],  # single n-plet
        covmat_precomputed=True,
        T=None,  # keep same behavior as your original code
        device=device,
        verbose=logging.WARNING,
    )
    # measures shape: (1, D, 4)
    vals_hoi = measures[0, :, :4].detach().cpu().numpy()  # (D, 4)

    # norm_O = O / S
    ratio = vals_hoi[:, 2] / vals_hoi[:, 3]  # shape (D,)

    # --- Classical FC: mean Fisher-z correlation within n-plet (batched, on device)
    with torch.no_grad():
        idx_t = torch.as_tensor(optimal_nplet, dtype=torch.long, device=device)
        # Sub-covariance for n-plet: (D, k, k)
        cov_sub = X_tensor.index_select(1, idx_t).index_select(2, idx_t)

        # Convert covariance to correlation
        # var: (D, k)
        var = torch.diagonal(cov_sub, dim1=-2, dim2=-1)
        eps = 1e-12
        std = torch.sqrt(torch.clamp(var, min=eps))  # (D, k)
        denom = std.unsqueeze(-1) * std.unsqueeze(-2)  # (D, k, k)
        corr = cov_sub / torch.clamp(denom, min=eps)

        # numerical safety for Fisher z
        corr = torch.clamp(corr, -0.999999, 0.999999)

        k = idx_t.numel()
        # upper triangle mask, excluding diagonal
        triu_mask = torch.triu(
            torch.ones(k, k, dtype=torch.bool, device=device), diagonal=1
        )
        # (D, K) where K = k(k-1)/2
        corr_pairs = corr[:, triu_mask]
        z_pairs = torch.atanh(corr_pairs)
        fc_mean_z_tensor = z_pairs.mean(dim=-1)  # (D,)

    fc_mean_z = fc_mean_z_tensor.detach().cpu().numpy()  # (D,)

    # --- Stack all features: [TC, DTC, O, S, norm_O, FC_mean_z]
    vals = np.column_stack((vals_hoi, ratio, fc_mean_z))  # (D, 6)

    # --- Labels: 1 for conscious, 0 for nonresponsive
    y = np.concatenate([np.ones(n_c, dtype=int), np.zeros(n_nr, dtype=int)])  # (D,)

    # --- ANOVA F-score across all columns
    F_vals, _ = f_classif(vals, y)  # shape (6,)
    # Guard against zero variance
    for j in range(vals.shape[1]):
        if np.allclose(vals[:, j].var(), 0):
            F_vals[j] = 0.0

    # --- PR AUC & inverse-PR AUC per column
    pr_aucs = []
    neg_pr_aucs = []
    for j in range(vals.shape[1]):
        xj = vals[:, j]
        try:
            pr_aucs.append(average_precision_score(y, xj))
            neg_pr_aucs.append(average_precision_score(y, -xj))
        except ValueError:
            pr_aucs.append(np.nan)
            neg_pr_aucs.append(np.nan)

    out_list = []
    for jdx, measure_ in enumerate(measure_list):
        out_list.append(
            {
                "row_idx": idx,
                "measure": measure_,
                "F_score": F_vals[jdx],
                "PR_AUC": pr_aucs[jdx],
                "PR_AUC_inv": neg_pr_aucs[jdx],
            }
        )

    return out_list


all_covs = load_covariance_dict(f"{results_path}/covariance_matrices_gc.h5")

conscious_states = {
    "MA": ["MA_awake"],
    # "DBS": ["DBS_awake", "ts_on_5V"],
    "DBS": ["DBS_awake"],
    # "MA_DBS": ["MA_awake", "DBS_awake", "ts_on_5V"],
}
nonresponsive_states = {
    "MA": ["deep_propofol", "ketamine", "moderate_propofol", "ts_selv2", "ts_selv4"],
    "DBS": [
        "ts_off",
        "ts_on_3V_control",
        "ts_on_5V_control",
    ],
    # "MA_DBS": [
    #     "deep_propofol",
    #     "ketamine",
    #     "moderate_propofol",
    #     "ts_selv2",
    #     "ts_selv4",
    #     "ts_off",
    #     "ts_on_3V_control",
    #     "ts_on_5V_control",
    # ],
}

temp = []
for order in [2, 3, 4, 5, 6, 7, 8, 9]:
    temp.append(
        pd.read_csv(
            f"{results_path}/N_1_A_max_O_diff_MA_{order}.csv",
            encoding="utf-8-sig",
            sep=";",
            decimal=",",
        )
    )
results_df_MA = pd.concat(temp).reset_index(drop=True)

temp = []
for order in [2, 3, 4, 5, 6, 7, 8, 9]:
    temp.append(
        pd.read_csv(
            f"{results_path}/N_1_A_max_O_diff_DBS_{order}.csv",
            encoding="utf-8-sig",
            sep=";",
            decimal=",",
        )
    )
results_df_DBS = pd.concat(temp).reset_index(drop=True)

results_df_MA.value_counts(["task"])
results_df_DBS.value_counts(["task"])
results_df_DBS.query("order==5").value_counts(["state_c", "state_nr", "task"])

O_diff_df = pd.read_csv(
    f"{results_path}/N_1_A_max_O_differences_2_9.csv",
    encoding="utf-8-sig",
    sep=";",
    decimal=",",
)
minmax_df = pd.read_csv(
    f"{results_path}/N_1_A_minmax_differences_2_9.csv",
    encoding="utf-8-sig",
    sep=";",
    decimal=",",
)

################################################################################################################################################
# ANOVA, Cohen's D, AUC (ROC & PR)
################################################################################################################################################

results_dict = {
    # "diff_OA": O_diff_df,
    # "minmax": minmax_df,
    # "DBS": results_df_DBS,
    "DBS": results_df_DBS,
    # "MA": results_df_MA,
}

for result_name, results_df in results_dict.items():
    eval_results = []
    for idx, row in tqdm(
        results_df.iterrows(),
        total=len(results_df),
        desc="Evaluating n-plets",
    ):
        # metrics = evaluate_nplet_batched_both_datasets(
        #     idx,
        #     all_covs,
        #     conscious_states,
        #     nonresponsive_states,
        #     ast.literal_eval(row["optimal_nplet"]),
        #     row["state_c"],
        #     row["state_nr"],
        #     row["subject_c"],
        #     row["subject_nr"],
        #     device=device,
        # )
        metrics = evaluate_nplet_batched_FC(
            idx,
            all_covs,
            conscious_states,
            nonresponsive_states,
            result_name,
            ast.literal_eval(row["optimal_nplet"]),
            row["state_c"],
            row["state_nr"],
            row["subject_c"],
            row["subject_nr"],
            device=device,
        )
        eval_results.extend(metrics)
        if idx % 10000 == 0:
            metrics_df = pd.DataFrame(eval_results)
            results_eval_df = results_df.merge(
                metrics_df,
                left_index=True,  # results_df row index matches metrics_df row_idx groups
                right_on="row_idx",  # match on metrics_df row_idx
            )
            # Drop the helper column if not needed
            results_eval_df = results_eval_df.drop(columns=["row_idx"])
            results_eval_df.to_csv(
                f"{results_path}/A_2_C_nplet_eval_{result_name}_AWAKE_VS_off_5vctrl_3vctrl.csv",
                index=False,
                encoding="utf-8-sig",
                sep=";",
                decimal=",",
            )
    metrics_df = pd.DataFrame(eval_results)
    results_eval_df = results_df.merge(
        metrics_df,
        left_index=True,  # results_df row index matches metrics_df row_idx groups
        right_on="row_idx",  # match on metrics_df row_idx
    )

    # Drop the helper column if not needed
    results_eval_df = results_eval_df.drop(columns=["row_idx"])
    results_eval_df.to_csv(
        f"{results_path}/A_2_C_nplet_eval_{result_name}_AWAKE_VS_off_5vctrl_3vctrl.csv",
        index=False,
        encoding="utf-8-sig",
        sep=";",
        decimal=",",
    )
