**Note**: This notebook should be run from the `high-order-anesthesia` folder to ensure the correct imports and file paths are used.

In [1]:
from pathlib import Path
import os
def ensure_project_root(target_name: str = "high-order-anesthesia") -> Path:
    cwd = Path.cwd().resolve()
    if cwd.name == target_name:
        return cwd
    for parent in cwd.parents:
        if parent.name == target_name:
            os.chdir(parent)
            return parent
    raise RuntimeError(
        f"Could not find '{target_name}' in current path or parents. "
        f"Please run the notebook from inside the project."
    )
ROOT = ensure_project_root("high-order-anesthesia")
print(f"Now in: {ROOT.name}")


Now in: high-order-anesthesia


In [3]:
import torch
import math

#### Custom libraries

In [None]:
from src.hoi_anesthesia.io import load_covariance_dict, save_results
from src.hoi_anesthesia.utils import O_PR_AUC, generate_X,analyze_order


In [None]:
results_path = "results"
data_path = "results"
all_covs = load_covariance_dict(f"{data_path}/covariance_matrices_gc.h5")
N = 82
torch.cuda.empty_cache()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Np_max = 500_000
M = 50_000_000
R = 82
batch_size = 1000
p = 0.05  # 5% of the total number of nplets
eval_func = O_PR_AUC

MA: order 4 (synergy brain), order 7 (redundancy brain)

In [None]:
conscious_states = {"MA": ["MA_awake"]}
nonresponsive_states = {
    "MA": ["deep_propofol", "ketamine", "moderate_propofol", "ts_selv2", "ts_selv4"],
}
X_tensor, n_c, n_nr = generate_X(conscious_states, nonresponsive_states, all_covs)

torch.cuda.empty_cache()
for k in [4,7]:
    results = {"MA": {}}
    tail_size = int(min(math.comb(R, k) * p, Np_max))
    print(
        f"MA, order {k} | tail size: {tail_size:.0f} (of {math.comb(R, k):.0f} combinations)"
    )
    top_c, top_nr = analyze_order(
        X_tensor=X_tensor,
        n_c=n_c,
        k=k,
        M=M,
        Np=tail_size,
        batch_size=batch_size,
        R=R,
        device=device,
        eval_fc=eval_func,
    )
    # top_c: high O -> C (best PR with C positive)
    # top_nr: high O -> NR (best PR with NR positive)
    results["MA"][k] = {
        "C_positive": top_c,
        "NR_positive": top_nr,
    }
    save_results(results, f"{results_path}/nplet_tails_PRAUC_MA_{k}.pkl.gz")
torch.cuda.empty_cache()

DBS: order 3 (synergy brain), order 9 (redundancy brain)

In [None]:
torch.cuda.empty_cache()

conscious_states = {"DBS": ["DBS_awake", "ts_on_5V"]}
nonresponsive_states = {
    "DBS": ["ts_off", "ts_on_3V_control", "ts_on_5V_control"],
}
X_tensor, n_c, n_nr = generate_X(conscious_states, nonresponsive_states, all_covs)

for k in [3,9]:
    results = {"DBS": {}}
    tail_size = int(min(math.comb(R, k) * p, Np_max))
    print(
        f"DBS, order {k} | tail size: {tail_size:.0f} (of {math.comb(R, k):.0f} combinations)"
    )
    top_c, top_nr = analyze_order(
        X_tensor=X_tensor,
        n_c=n_c,
        k=k,
        M=M,
        Np=tail_size,
        batch_size=batch_size,
        R=R,
        device=device,
        eval_fc=eval_func,
    )
    results["DBS"][k] = {
        "C_positive": top_c,
        "NR_positive": top_nr,
    }
    save_results(results, f"{results_path}/nplet_tails_PRAUC_DBS_{k}.pkl.gz")
torch.cuda.empty_cache()


In [None]:

@torch.no_grad()
def attach_delta_O_to_tail(
    X_tensor: torch.Tensor,
    n_c: int,
    tail: List[Tuple[float, Tuple[int, ...]]],
    k: int,
    device: torch.device,
    batch_size: int = 2048,
) -> List[Tuple[float, float, Tuple[int, ...]]]:
    """
    Given a tail of (PR_AUC, nplet) and subject covariances X_tensor,
    compute O for those n-plets, then ΔO = mean(O_C) - mean(O_NR),
    and return (PR_AUC, ΔO, nplet) for each entry, in the same order.
    """
    if len(tail) == 0:
        return []

    # Unpack
    pr_scores, nplets = zip(
        *tail
    )  # pr_scores: tuple[float], nplets: tuple[tuple[int,...]]
    pr_scores = np.array(pr_scores, dtype=float)
    nplets_arr = torch.tensor(nplets, dtype=torch.long)  # (B, k)

    X_tensor = X_tensor.to(device=device, dtype=torch.float64)

    B = nplets_arr.shape[0]
    delta_O_all = np.empty(B, dtype=float)

    for start in tqdm(
        range(0, B, batch_size),
        desc=f"Computing ΔO for tail (k={k}, B={B})",
    ):
        end = min(start + batch_size, B)
        nplets_batch = nplets_arr[start:end].to(device=device)

        # Compute measures
        measures = nplets_measures(
            X=X_tensor,
            covmat_precomputed=True,
            nplets=nplets_batch,
            device=device,
            verbose=logging.WARNING,
        )
        # measures[..., 2] -> O-information
        O_vals = (
            measures[..., 2].detach().cpu().numpy()
        )  # shape either (b, N) or (N, b)

        # Ensure O_vals is (N_subjects, b)
        if O_vals.shape[0] != X_tensor.shape[0]:
            O_vals = O_vals.T  # now (N_subjects, b)

        # Use your existing delta_O function
        delta_O_batch, _ = delta_O(O_vals, n_c)  # (b,)
        delta_O_all[start:end] = delta_O_batch

    # Build new tail with ΔO attached
    tail_with_delta = [
        (float(pr_scores[i]), float(delta_O_all[i]), nplets[i]) for i in range(B)
    ]
    return tail_with_delta


# ---------------------------------------------------------------------
# Step 2 main: load all tails, merge, compute ΔO, save merged dict
# ---------------------------------------------------------------------

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 1) Load step-1 tails and merge into a single dict
merged_results = {"MA": {}, "DBS": {}}

# MA: orders 4 (redundant NRpos) and 7 (redundant Cpos)
for k in [4, 7]:
    path = f"{results_path}/nplet_tails_PRAUC_MA_{k}.pkl.gz"
    print(f"Loading MA order {k} tails from: {path}")
    with gzip.open(path, "rb") as f:
        res = pickle.load(f)  # {"MA": {k: {"C_positive": [...], "NR_positive": [...]}}}
    merged_results["MA"][k] = res["MA"][k]

# DBS: orders 3 (synergy NRpos) and 9 (redundant Cpos)
for k in [3, 9]:
    path = f"{results_path}/nplet_tails_PRAUC_DBS_{k}.pkl.gz"
    print(f"Loading DBS order {k} tails from: {path}")
    with gzip.open(path, "rb") as f:
        res = pickle.load(
            f
        )  # {"DBS": {k: {"C_positive": [...], "NR_positive": [...]}}}
    merged_results["DBS"][k] = res["DBS"][k]

# 2) Rebuild X_tensor and n_c for each dataset (same as step 1)


# 3) For each dataset/order/tail, attach ΔO

for dataset in ["MA", "DBS"]:
    print(f"\n=== Attaching ΔO for dataset: {dataset} ===")
    X_tensor, n_c = build_X_for_dataset(dataset)
    X_tensor = X_tensor.to(device=device, dtype=torch.float64)

    for k, tails in merged_results[dataset].items():
        print(f"  Order k={k}")

        for key in ["C_positive", "NR_positive"]:
            tail = tails.get(key, [])
            if not tail:
                merged_results[dataset][k][key] = []
                continue

            print(f"    Tail: {key}, n={len(tail)}")
            tail_with_delta = attach_delta_O_to_tail(
                X_tensor=X_tensor,
                n_c=n_c,
                tail=tail,
                k=k,
                device=device,
                batch_size=2048,
            )
            # Now each element: (PR_AUC, ΔO, nplet_tuple)
            merged_results[dataset][k][key] = tail_with_delta

merged_results[dataset][k][key][0]

# 4) Save merged dict with ΔO attached for all datasets/orders/tails
out_path = f"{results_path}/B_2_nplet_tails_PRAUC_with_deltaO_ALL.pkl.gz"
print(f"\nSaving merged results with ΔO to: {out_path}")
save_results(merged_results, out_path)


In [None]:
import gzip
import pickle
import numpy as np



merged_path = f"{results_path}/B_2_nplet_tails_PRAUC_with_deltaO_ALL.pkl.gz"
print(f"Loading merged tails (with ΔO) from: {merged_path}")
with gzip.open(merged_path, "rb") as f:
    merged_results = pickle.load(f)
    # structure:
    # merged_results[dataset][k]["C_positive"] = [(pr_auc, delta_O, nplet), ...]
    # merged_results[dataset][k]["NR_positive"] = [(pr_auc, delta_O, nplet), ...]

merged_results["MA"][4]["NR_positive"][0]
merged_results["MA"][7]["C_positive"][0]
merged_results["DBS"][3]["NR_positive"][0]
merged_results["DBS"][9]["C_positive"][0]


def build_region_maps_for_tail(
    tail_with_delta,
    mode: str,
    R: int = 82,
    delta_eps: float = 0.0,
):
    """
    tail_with_delta: list of (pr_auc, delta_O, nplet_tuple)
    mode: "C_positive" -> keep ΔO >  delta_eps
          "NR_positive" -> keep ΔO < -delta_eps (or < 0 if delta_eps == 0)

    Returns:
        region_counts       : (R,) integer count of how many n-plets include each region
        region_counts_prop  : (R,) counts divided by total # of kept n-plets
        region_counts_z     : (R,) z-scored counts across regions (for plotting)
    """
    if len(tail_with_delta) == 0:
        return (
            np.zeros(R, dtype=int),
            np.zeros(R, dtype=float),
            np.zeros(R, dtype=float),
        )

    # 1) Filter n-plets by ΔO sign according to mode
    nplet_list = []

    for pr_auc, delta_O_val, nplet in tail_with_delta:
        if mode == "C_positive":
            if delta_O_val > delta_eps:
                nplet_list.append(nplet)
        elif mode == "NR_positive":
            if delta_O_val < -delta_eps if delta_eps > 0 else delta_O_val < 0.0:
                nplet_list.append(nplet)
        else:
            raise ValueError(f"Unknown mode: {mode}")

    N_tail_filtered = len(nplet_list)
    print(f"      After ΔO filtering ({mode}): {N_tail_filtered} n-plets")

    if N_tail_filtered == 0:
        return (
            np.zeros(R, dtype=int),
            np.zeros(R, dtype=float),
            np.zeros(R, dtype=float),
        )

    # 2) Compute region participation (raw counts)
    region_counts = np.zeros(R, dtype=int)

    for nplet in nplet_list:
        for r in nplet:
            region_counts[r] += 1

    # 3) Counts as proportion of all kept n-plets
    region_counts_prop = region_counts.astype(float) / float(N_tail_filtered)
    region_counts_percent = (region_counts.astype(float) / float(N_tail_filtered)) * 100

    # 4) z-score counts across regions (for plotting)
    mu = region_counts.mean()
    sigma = region_counts.std()
    if sigma > 0:
        region_counts_z = (region_counts - mu) / sigma
    else:
        region_counts_z = np.zeros_like(region_counts, dtype=float)

    return region_counts, region_counts_prop, region_counts_z, region_counts_percent


# ---------------------------------------------------------------------
# Build region maps for all dataset / order / tail
# ---------------------------------------------------------------------

region_maps = {
    "MA": {},
    "DBS": {},
}

for dataset in ["MA", "DBS"]:
    print(f"\n=== Building region maps for dataset: {dataset} ===")

    for k, tails in merged_results[dataset].items():
        print(f"  Order k={k}")
        region_maps[dataset][k] = {}

        for key in ["C_positive", "NR_positive"]:
            tail_with_delta = tails.get(key, [])
            # tail_with_delta[0]
            print(f"    Tail: {key}, n_raw={len(tail_with_delta)}")

            (
                region_counts,
                region_counts_prop,
                region_counts_z,
                region_counts_percent,
            ) = build_region_maps_for_tail(
                tail_with_delta=tail_with_delta,
                mode=key,  # "C_positive" or "NR_positive"
                R=R,
                delta_eps=0.0,  # you can set >0 to be stricter on ΔO
            )

            region_maps[dataset][k][key] = {
                "region_counts": region_counts,
                "region_counts_prop": region_counts_prop,
                "region_counts_z": region_counts_z,
                "region_counts_percent": region_counts_percent,
            }


# ---------------------------------------------------------------------
# Save region_maps to disk for plotting
# ---------------------------------------------------------------------

out_maps_path = f"{results_path}/B_3_region_maps_PRAUC_deltaO.pkl.gz"
print(f"\nSaving region maps to: {out_maps_path}")
with gzip.open(out_maps_path, "wb") as f:
    pickle.dump(region_maps, f, protocol=pickle.HIGHEST_PROTOCOL)


import matplotlib.pyplot as plt

X = 10  # number of max indices you want
arr = region_maps["MA"][4]["NR_positive"]["region_counts"]
idx = np.argpartition(arr, -X)[-X:]  # gets the indices of the top X values (unordered)
idx_sorted = idx[
    np.argsort(arr[idx])[::-1]
]  # optional: sort them by value (descending)

print(idx_sorted)

# plt.hist(region_maps["MA"][4]['NR_positive']['region_counts'],bins=100)
# plt.show()
