In [None]:
pip install numpy scipy pandas scikit-learn cvxpy group-lasso pysindy
!pip install --upgrade importlib-metadata
!pip install --upgrade derivative
!pip install --upgrade --force-reinstall pysindy


In [None]:
from google.colab import drive # Import the drive module from google.colab

drive.mount('/content/drive', force_remount=True)  # Remount with force_remount=True
!find /content/drive/MyDrive/data


In [None]:
#!/usr/bin/env python3
# -----------------------------------------------------------
# Full Pipeline (per-window closed complexes + adaptive tau):
# 1. Data Loading & Per-channel Normalization
# 2. Robust Locality Pre-selection (PC-based MAD metrics)
# 3. Dynamic SINDy Fitting on Selected Local Points (Taylor Centering)
# 4. Per-window simplicial complex (unweighted) with closure + counts
# 5. Optional global structural/topology summary (unions across windows)
# 6. Save per-window counts + summary JSON
# -----------------------------------------------------------
import os
import math
import json
from collections import defaultdict
from itertools import combinations

import numpy as np
import pandas as pd
from scipy.signal import savgol_filter
import pysindy as ps
from joblib import Parallel, delayed
from collections import defaultdict
import heapq

# ─────────────── USER PARAMETERS ───────────────────────────
CSV_FILE        = '/content/drive/MyDrive/data/eeg/EEG_TD_31_EEGdata.csv'
FS              = 256.0          # [Hz]
WIN_SG, ORDER   = 13, 3          # SavGol window (odd) & polynomial order
D_MAX           = 2              # 2→ up to triangles; 3→ allow tetrahedra
THRESH_SINDY    = 0.2            # STLSQ sparsity threshold (model-fitting phase)

# Adaptive mapping threshold (per window) for |coef| → simplex
TAU_MODE        = 'percentile'   # 'percentile' or 'mad'
TAU_PCTL        = 97.5           # if TAU_MODE='percentile'
TAU_KMAD        = 6.0            # if TAU_MODE='mad' → tau = median + KMAD * MAD
ENFORCE_CLOSURE = True           # make each window's output a simplicial complex

WIN_LEN         = 1024           # samples per sliding window
STRIDE          = 512            # hop between windows
MAX_ROWS        = 3000           # cap rows per window during SINDy
N_JOBS          = -1             # joblib cores (-1 = all)

# Locality Analysis Parameters (PC-based)
R_TARGET_PC  = 1.0     # ~1 z-unit RMS per channel kept
R_DROP95_PC  = 3.5     # drop only if window is wildly nonlocal
BAD_KEPT_PC  = 2.0     # if kept subset is still too wide, skip
K_MIN        = 600     # minimum kept timestamps
K_MAX        = 1000    # optional cap

# Analysis Options
CENTER_METHOD   = 'median'       # 'mean' or 'median' for Taylor expansion center

OUTDIR          = 'dyn_graphs_full_pipeline_pc_norm_counts'
os.makedirs(OUTDIR, exist_ok=True)
# ───────────────────────────────────────────────────────────

# 1 ─── READ (T × n → n × T)
raw = pd.read_csv(
        CSV_FILE,
        header=None,
        nrows=None,
        usecols=range(31)      # 31 channels
      ).values.astype(np.float64)

# Drop constant/all-zero channels
tol = 1e-12
nz_mask = (np.ptp(raw, axis=0) > tol)
if (~nz_mask).any():
    dropped = np.where(~nz_mask)[0] + 1
    print(f'⚠️  Dropping constant channels (1-based): {dropped.tolist()}')
raw = raw[:, nz_mask]
raw = raw.T                       # rows = channels, cols = time
n, T_raw = raw.shape
dt = 1 / FS

print(f"Loaded data: {n} channels × {T_raw} samples ({T_raw/FS:.1f} seconds)")

# 2 ─── SAVITZKY–GOLAY smooth + derivative (aligned, same length)
half = (WIN_SG - 1) // 2
X_smooth_full = savgol_filter(raw, WIN_SG, ORDER, axis=1, mode='interp')
dXdt_full     = savgol_filter(raw, WIN_SG, ORDER, deriv=1, delta=dt, axis=1, mode='interp')

# Trim ends equally to reduce edge artifacts and keep X and dXdt aligned
X_smooth = X_smooth_full[:, half:-half]
dXdt_raw = dXdt_full[:,       half:-half]

# 3 ─── Per-channel Normalization (X to z-score; Y scaled by same sig)
print("Applying per-channel normalization...")
eps = 1e-8
mu  = X_smooth.mean(axis=1, keepdims=True)
sig = X_smooth.std(axis=1, keepdims=True) + eps

X = (X_smooth - mu) / sig
Y = dXdt_raw / sig  # scale only

norm_params = {"mean": mu.squeeze().tolist(), "std": sig.squeeze().tolist()}
print("Normalization complete.")
print(f"Data after preprocessing: {X.shape[0]} channels × {X.shape[1]} samples")
print("-" * 34)

# ─────────── Locality Analysis Helpers (PC-based) ───────────
def robust_window_distances(Xw_centered, eps=1e-8):
    """Per-channel robust z-distance for each timestamp in a window."""
    med   = np.median(Xw_centered, axis=1, keepdims=True)
    mad   = np.median(np.abs(Xw_centered - med), axis=1, keepdims=True)
    sigma = np.maximum(eps, 1.4826 * mad)
    Z = Xw_centered / sigma
    d_raw = np.sqrt((Z**2).sum(axis=0))     # L2 across channels
    d_pc  = d_raw / np.sqrt(Z.shape[0])     # per-channel RMS z-distance
    return d_pc, sigma

def select_local_indices(Xw):
    """Pick local timestamps based on robust per-channel distances."""
    x0 = np.median(Xw, axis=1, keepdims=True)
    Xc = Xw - x0
    d_pc, sigma = robust_window_distances(Xc)
    p95_pc = float(np.percentile(d_pc, 95))

    keep = np.where(d_pc <= R_TARGET_PC)[0]
    used = "radius_pc"
    if keep.size < K_MIN:
        keep = np.argsort(d_pc)[:K_MIN]; used = "kmin_pc"
    elif keep.size > K_MAX:
        keep = np.argsort(d_pc)[:K_MAX]; used = "kmax_pc"

    p95_kept_pc = float(np.percentile(d_pc[keep], 95)) if keep.size > 0 else np.nan

    if (p95_pc > R_DROP95_PC) and (p95_kept_pc > BAD_KEPT_PC):
        return {"skip": True, "x0": x0, "sigma": sigma,
                "p95_raw": p95_pc, "p95_kept": p95_kept_pc, "used": "drop_pc"}

    return {"skip": False, "x0": x0, "sigma": sigma, "keep": np.sort(keep),
            "p95_raw": p95_pc, "p95_kept": p95_kept_pc, "used": used}

def preselect_windows(X, WIN_LEN, STRIDE):
    metas = []
    starts = range(0, X.shape[1] - WIN_LEN + 1, STRIDE)
    nW = len(starts)
    print("\n--- Performing Robust Locality Pre-selection ---")
    print("="*40)
    print(f"Analyzing {nW} windows with WIN_LEN={WIN_LEN}, STRIDE={STRIDE}...")
    print(f"Locality Params: R_TARGET_PC={R_TARGET_PC}, K_MIN={K_MIN}, K_MAX={K_MAX}, "
          f"R_DROP95_PC={R_DROP95_PC}, BAD_KEPT_PC={BAD_KEPT_PC}\n")

    for wi, w_start in enumerate(starts, 1):
        w_end = w_start + WIN_LEN
        Xw = X[:, w_start:w_end]
        out = select_local_indices(Xw)
        t_mid = (w_start + w_end) / 2 * dt

        if out["skip"]:
            print(f"Window {wi}/{nW} (t={t_mid:.3f}s): SKIP ({out['used']}) "
                  f"p95_raw={out['p95_raw']:.2f}, p95_kept={out['p95_kept']:.2f}")
            metas.append({
                "w_start": w_start, "t_mid": t_mid, "skip": True, "keep": [],
                "p95_raw": out["p95_raw"], "p95_kept": out["p95_kept"],
                "kept_pct": 0.0, "used": out["used"]
            })
            continue

        keep = out["keep"]
        kept_pct = 100.0 * keep.size / WIN_LEN
        print(f"Window {wi}/{nW} (t={t_mid:.3f}s): kept={keep.size}/{WIN_LEN} ({kept_pct:.1f}%), "
              f"p95_raw={out['p95_raw']:.2f}, p95_kept={out['p95_kept']:.2f}, mode={out['used']}")

        metas.append({
            "w_start": w_start, "t_mid": t_mid, "skip": False, "keep": keep.tolist(),
            "x0": out["x0"].flatten().tolist(), "sigma": out["sigma"].flatten().tolist(),
            "p95_raw": out["p95_raw"], "p95_kept": out["p95_kept"],
            "kept_pct": float(kept_pct), "used": out["used"]
        })

    kept_counts = [len(m["keep"]) for m in metas if not m["skip"]]
    drops = sum(m["skip"] for m in metas)
    print("\n" + "="*40)
    print("--- Locality Pre-selection Summary ---")
    print("="*40)
    print(f"Total Windows: {nW}   Dropped (nonlocal): {drops}")
    if kept_counts:
        print(f"Points kept per window (min/mean/median/max): "
              f"{min(kept_counts)}/{np.mean(kept_counts):.1f}/{np.median(kept_counts):.1f}/{max(kept_counts)}")

    # Save per-window locality diagnostics
    df = pd.DataFrame([{
        "w_start": m["w_start"], "t_mid": m["t_mid"], "skip": m["skip"],
        "kept": len(m["keep"]), "kept_pct": m["kept_pct"],
        "p95_raw": m["p95_raw"], "p95_kept": m["p95_kept"], "mode": m["used"]
    } for m in metas])
    df.to_csv(f"{OUTDIR}/locality_stats_pc_perchannelnorm.csv", index=False)
    print(f"✓ Saved locality stats to {OUTDIR}/locality_stats_pc_perchannelnorm.csv")
    return metas

# ─────────── SINDy + hyperedge extraction ───────────
def indices_from_term(term_str):
    """
    Parse term string to a list of variable indices.
      '1'          -> []
      'x0'         -> [0]
      'x0^2'       -> [0, 0]
      'x1 x4'      -> [1, 4]
      'x1^3 x4'    -> [1, 1, 1, 4]
    """
    if term_str == '1':
        return []
    idxs = []
    for tok in term_str.split():
        base, *pow_part = tok.split('^')
        j = int(base[1:])
        power = int(pow_part[0]) if pow_part else 1
        idxs.extend([j] * power)
    return idxs

def coefs_to_simplex_sets(coefs, feature_names, *, thresh, enforce_closure=True, return_direct=False):
    """
    Build UNWEIGHTED simplices for ONE WINDOW from coefficients:
      • Use a per-window threshold 'thresh' on |coef|
      • Skip bias '1', repeated indices, and any term containing the target as a factor
      • If enforce_closure=True, include all faces of each simplex
    Returns closed sets; with return_direct=True also returns direct (pre-closure) sets.
    """
    parsed = [None if t == '1' else indices_from_term(t) for t in feature_names]
    n = coefs.shape[0]

    edges_dir, tris_dir, quads_dir = set(), set(), set()

    # Collect direct candidate simplices (unweighted)
    for target in range(n):
        row = coefs[target]
        for w, idxs in zip(row, parsed):
            if idxs is None or abs(w) < thresh:
                continue
            # genuine multi-node: no repeated indices
            if len(set(idxs)) != len(idxs):
                continue
            all_nodes = [target] + idxs
            # require all nodes distinct (prevents target participation)
            if len(set(all_nodes)) != len(all_nodes):
                continue
            k = len(all_nodes)
            s = frozenset(all_nodes)
            if   k == 2: edges_dir.add(s)
            elif k == 3: tris_dir.add(s)
            elif k == 4: quads_dir.add(s)

    # Enforce closure (faces)
    edges_cl, tris_cl, quads_cl = set(edges_dir), set(tris_dir), set(quads_dir)
    if enforce_closure:
        for q in list(quads_cl):
            for face3 in combinations(q, 3):
                tris_cl.add(frozenset(face3))
            for face2 in combinations(q, 2):
                edges_cl.add(frozenset(face2))
        for t in list(tris_cl):
            for face2 in combinations(t, 2):
                edges_cl.add(frozenset(face2))

    if return_direct:
        return edges_cl, tris_cl, quads_cl, edges_dir, tris_dir, quads_dir
    return edges_cl, tris_cl, quads_cl

# Global polynomial library (linear + interactions up to D_MAX)
GLOBAL_LIBRARY = ps.PolynomialLibrary(
    degree=D_MAX,
    include_bias=False,
    interaction_only=False
)

def _tau_from_abs(absA):
    """Compute per-window tau from |coef| matrix according to TAU_MODE."""
    v = absA.ravel()
    if TAU_MODE.lower() == 'mad':
        med = np.median(v)
        mad = 1.4826 * np.median(np.abs(v - med))
        return float(med + TAU_KMAD * mad)
    # default: percentile
    return float(np.percentile(v, TAU_PCTL))

def fit_window_from_meta(meta):
    if meta["skip"]:
        return {"t_mid": meta["t_mid"], "skip": True, "reason": "nonlocal window"}

    w_start = meta["w_start"]; w_end = w_start + WIN_LEN
    Xw_orig, Yw_orig = X[:, w_start:w_end], Y[:, w_start:w_end]

    x0   = np.array(meta["x0"]).reshape(-1, 1)
    keep = np.array(meta["keep"])
    if keep.size == 0:
        return {"t_mid": meta["t_mid"], "skip": True, "reason": "empty keep list"}

    Xw_c = Xw_orig - x0
    Xw_use, Yw_use = Xw_c[:, keep], Yw_orig[:, keep]

    if Xw_use.shape[1] > MAX_ROWS:
        idx = np.linspace(0, Xw_use.shape[1]-1, MAX_ROWS, dtype=int)
        Xw_use, Yw_use = Xw_use[:, idx], Yw_use[:, idx]

    optimizer = ps.STLSQ(alpha=1e-3, threshold=THRESH_SINDY)
    model = ps.SINDy(feature_library=GLOBAL_LIBRARY, optimizer=optimizer)
    model.fit(Xw_use.T, t=dt, x_dot=Yw_use.T, quiet=True)

    A = model.coefficients()
    names = model.get_feature_names()

    # NEW: extract raw scores (no threshold here)
    S2, S3 = extract_edge_triangle_scores(A, names, n)

    return {
        't_mid': meta["t_mid"],
        'S2': S2,                         # (n x n) symmetric edge scores
        'S3': S3,                         # dict of triangle scores
        'x0': x0.flatten().tolist(),
        'n_samples_fit': int(Xw_use.shape[1]),
        'n_samples_kept': len(meta["keep"]),
        'kept_pct': meta["kept_pct"],
        'locality_mode': meta["used"],
        'p95_raw': meta["p95_raw"],
        'p95_kept': meta["p95_kept"],
        'skip': False
    }



def _parse_feature_types_via_indices(feature_names):
    kinds, lin_var, cross_pair = [], [], []
    for name in feature_names:
        idxs = [] if name == '1' else indices_from_term(name)
        if len(idxs) == 1:
            kinds.append('lin');  lin_var.append(idxs[0]); cross_pair.append(None)
        elif len(idxs) == 2 and len(set(idxs)) == 2:
            i, j = sorted(idxs)
            kinds.append('cross'); lin_var.append(None);    cross_pair.append((i, j))
        else:
            kinds.append('other'); lin_var.append(None);    cross_pair.append(None)
    return kinds, lin_var, cross_pair

def extract_edge_triangle_scores(A, feature_names, n):
    """
    Return:
      S2: (n x n) symmetric edge scores (only linear features)
      S3: dict {frozenset({i,j,k}): score} of triangle scores
    """
    kinds, lin_of, cross_pair = _parse_feature_types_via_indices(feature_names)
    lin_cols   = [c for c,k in enumerate(kinds) if k == 'lin']
    cross_cols = [c for c,k in enumerate(kinds) if k == 'cross']
    absA = np.abs(A)

    # edges: symmetric from linear terms only
    S2_dir = np.zeros((n, n), dtype=float)
    for i in range(n):                # target
        for c in lin_cols:
            j = lin_of[c]
            if j is None or j == i:
                continue
            S2_dir[i, j] = max(S2_dir[i, j], absA[i, c])
    S2 = np.maximum(S2_dir, S2_dir.T) # undirected

    # triangles: genuine cross terms (target i, monomial x_j x_k with j<k, j!=i, k!=i)
    S3 = defaultdict(float)
    for i in range(n):   # target
        for c in cross_cols:
            j, k = cross_pair[c]
            if i in (j, k):  # skip target participation in the monomial
                continue
            key = frozenset((i, j, k))
            if absA[i, c] > S3[key]:
                S3[key] = absA[i, c]
    return S2, S3

# --- maximum bottleneck spanning tree bottleneck value (per window) ---
def mst_bottleneck_value(S2):
    """
    Given an undirected weight matrix S2 (n x n), compute the bottleneck value
    of a maximum spanning tree: min edge weight on the MST (descending weights).
    If graph has n<2, returns +inf. If weights are all 0, returns 0.
    """
    n = S2.shape[0]
    if n < 2:
        return float('inf')

    # Prim's algorithm for MAX spanning tree
    visited = [False]*n
    visited[0] = True
    heap = []
    for v in range(1, n):
        w = S2[0, v]
        heapq.heappush(heap, (-w, 0, v))  # max-heap via negative weight

    chosen = 0
    mins = []  # collect MST edge weights; we'll take min at the end
    while heap and chosen < n - 1:
        negw, u, v = heapq.heappop(heap)
        if visited[v]:
            continue
        visited[v] = True
        w = -negw
        mins.append(w)
        chosen += 1
        for x in range(n):
            if not visited[x] and x != v:
                heapq.heappush(heap, (-S2[v, x], v, x))

    if chosen != n - 1:
        # graph had zero/very small weights that don't connect -> bottleneck 0
        return 0.0
    return min(mins) if mins else 0.0

def build_complex_from_scores(S2, S3, tau2, tau3, enforce_closure=True):
    """
    Build closed complex from fixed thresholds tau2 (edges) and tau3 (triangles).
    """
    n = S2.shape[0]
    edges = {frozenset((i, j))
             for i in range(n) for j in range(i+1, n)
             if S2[i, j] >= tau2}

    tris  = {t for t, s in S3.items() if s >= tau3}

    if enforce_closure:
        # add triangle faces
        for t in tris:
            i, j, k = sorted(t)
            edges.add(frozenset((i, j)))
            edges.add(frozenset((i, k)))
            edges.add(frozenset((j, k)))

    return edges, tris


# 1. Preselect windows
window_meta = preselect_windows(X, WIN_LEN, STRIDE)

# 2. Fit SINDy over selected windows (parallel) → per-window complexes
fit_metas = [m for m in window_meta if not m["skip"]]
skipped_count = len(window_meta) - len(fit_metas)
print(f"\n--- Proceeding to SINDy Fitting ---")
print(f"Fitting {len(fit_metas)} windows (Skipped {skipped_count} nonlocal windows)")
print("="*40)

results = Parallel(n_jobs=N_JOBS, verbose=5)(
    delayed(fit_window_from_meta)(m) for m in fit_metas
)
print('    done ✔')

# Keep only successfully fitted windows
successful_results = [res for res in results if not res.get("skip", False)]
fit_skipped_count = len(results) - len(successful_results)
print(f"Successfully fitted {len(successful_results)} windows (Skipped {fit_skipped_count} during fit)")
# ---- GLOBAL THRESHOLDS ----
# Edge threshold ensuring connectivity in *every* window:
bottlenecks = []
for res in successful_results:
    b = mst_bottleneck_value(res['S2'])
    bottlenecks.append(b)
tau2_global = float(min(bottlenecks)) if bottlenecks else 0.0
print(f"\nGlobal edge threshold tau2 (connectivity-guaranteeing): {tau2_global:.6g}")

# Triangle threshold: pick once across all windows (quantile or your own criterion)
# Example: 97.5th percentile of all triangle scores pooled
all_tri_scores = []
for res in successful_results:
    all_tri_scores.extend(list(res['S3'].values()))
if all_tri_scores:
    tau3_global = float(np.quantile(all_tri_scores, 0.975))
else:
    tau3_global = float('inf')  # no triangles possible
print(f"Global triangle threshold tau3: {tau3_global:.6g}")

# ---- REBUILD PER-WINDOW COMPLEXES with (tau2_global, tau3_global) ----
for res in successful_results:
    E2, T3 = build_complex_from_scores(res['S2'], res['S3'], tau2_global, tau3_global, enforce_closure=True)
    res['edges_2'] = E2
    res['tris_3']  = T3
    res['quads_4'] = set()  # D_MAX=2 here
    res['n_edges2'] = len(E2)
    res['n_edges3'] = len(T3)
    res['n_edges4'] = 0
    res['tau2_global'] = tau2_global
    res['tau3_global'] = tau3_global

# 3. Locality summary across fitted windows
all_locality_stats = defaultdict(list)
for res in successful_results:
    for k in ('p95_raw','p95_kept','n_samples_fit','n_samples_kept','kept_pct'):
        all_locality_stats[k].append(res[k])

# (Optional) print the fixed global thresholds once:
print(f"\nFixed thresholds: tau2_global={successful_results[0]['tau2_global']:.6g}, "
      f"tau3_global={successful_results[0]['tau3_global']:.6g}")


print("\n" + "="*60)
print("LOCALITY ANALYSIS SUMMARY (Fitted Windows)")
print("="*60)
for key, vals in all_locality_stats.items():
    if vals:
        print(f"  {key}: mean={np.mean(vals):.3f}, std={np.std(vals):.3f}, median={np.median(vals):.3f}")

# 4. Save per-window edge lists + counts (closed complexes)
def pretty_edge(e):
    return '{' + ','.join(map(str, sorted(e))) + '}'

print(f"\n📁 Saving {len(successful_results)} fitted time slices to '{OUTDIR}/'...")
successful_results.sort(key=lambda r: r["t_mid"])

for res in successful_results:
    stamp = f'{res["t_mid"]:010.3f}'

    with open(f'{OUTDIR}/edges2_{stamp}.txt', 'w') as f2:
        f2.write(f'# Taylor center ({CENTER_METHOD}): {np.array(res["x0"])[:3]}...\n')
        f2.write(f'# counts (closed): edges2={res["n_edges2"]}, edges3={res["n_edges3"]}, edges4={res["n_edges4"]}\n')
        f2.write(f'# Locality (PC Norm): P95_raw={res["p95_raw"]:.3f}, P95_kept={res["p95_kept"]:.3f}, Kept_pct={res["kept_pct"]:.1f}%\n')
        f2.write(f'# global thresholds: tau2_global={tau2_global:.6g}, tau3_global={tau3_global:.6g}\n')
        for e in sorted(res['edges_2'], key=lambda s: tuple(sorted(s))):
            f2.write(f'{pretty_edge(e)}\n')

    with open(f'{OUTDIR}/edges3_{stamp}.txt', 'w') as f3:
        f3.write(f'# Taylor center ({CENTER_METHOD}): {np.array(res["x0"])[:3]}...\n')
        f3.write(f'# counts (closed): edges2={res["n_edges2"]}, edges3={res["n_edges3"]}, edges4={res["n_edges4"]}\n')
        f3.write(f'# Locality (PC Norm): P95_raw={res["p95_raw"]:.3f}, P95_kept={res["p95_kept"]:.3f}, Kept_pct={res["kept_pct"]:.1f}%\n')
        f3.write(f'# global thresholds: tau2_global={tau2_global:.6g}, tau3_global={tau3_global:.6g}\n')

        for e in sorted(res['tris_3'], key=lambda s: tuple(sorted(s))):
            f3.write(f'{pretty_edge(e)}\n')

    if res['quads_4']:
        with open(f'{OUTDIR}/edges4_{stamp}.txt', 'w') as f4:
            f4.write(f'# Taylor center ({CENTER_METHOD}): {np.array(res["x0"])[:3]}...\n')
            f4.write(f'# counts (closed): edges2={res["n_edges2"]}, edges3={res["n_edges3"]}, edges4={res["n_edges4"]}\n')
            f4.write(f'# Locality (PC Norm): P95_raw={res["p95_raw"]:.3f}, P95_kept={res["p95_kept"]:.3f}, Kept_pct={res["kept_pct"]:.1f}%\n')
            f4.write(f'# global thresholds: tau2_global={tau2_global:.6g}, tau3_global={tau3_global:.6g}\n')
            for e in sorted(res['quads_4'], key=lambda s: tuple(sorted(s))):
                f4.write(f'{pretty_edge(e)}\n')

# Per-window counts CSV
rows = []
tau2_global = successful_results[0]["tau2_global"] if successful_results else np.nan
tau3_global = successful_results[0]["tau3_global"] if successful_results else np.nan

for r in successful_results:
    rows.append({
        "t_mid": r["t_mid"],
        "n_edges2": r["n_edges2"],
        "n_edges3": r["n_edges3"],
        "n_edges4": r["n_edges4"],
        "tau2_global": tau2_global,
        "tau3_global": tau3_global,
        "n_samples_fit": r["n_samples_fit"],
        "kept_pct": r["kept_pct"],
        "p95_raw": r["p95_raw"],
        "p95_kept": r["p95_kept"],
        "mode": r["locality_mode"],
    })

pd.DataFrame(rows).to_csv(f"{OUTDIR}/hyperedge_counts_per_window.csv", index=False)
print(f"✓ Saved per-window counts → {OUTDIR}/hyperedge_counts_per_window.csv")

# --- Optional: Hypergraph Structural/Topology Summary (across windows) ---
def count_hyperedge_statistics(results_list):
    """Totals per window (averaged) and unique sets across all windows."""
    total_counts = {'edges_2': 0, 'edges_3': 0, 'edges_4': 0}
    unique_hyperedges = {'edges_2': set(), 'edges_3': set(), 'edges_4': set()}

    for res in results_list:
        total_counts['edges_2'] += len(res.get('edges_2', set()))
        total_counts['edges_3'] += len(res.get('tris_3', set()))
        total_counts['edges_4'] += len(res.get('quads_4', set()))

        unique_hyperedges['edges_2'].update(res.get('edges_2', set()))
        unique_hyperedges['edges_3'].update(res.get('tris_3', set()))
        unique_hyperedges['edges_4'].update(res.get('quads_4', set()))

    print("\n" + "="*60)
    print("HYPERGRAPH STRUCTURAL STATISTICS (Across Windows)")
    print("="*60)
    nW = len(results_list)

    print("\nAverage per window:")
    for key, count in total_counts.items():
        print(f"  {key}: {count/nW:.1f}" if nW > 0 else f"  {key}: 0.0")

    print("\nStructural proportions (unique across all windows):")
    tot_unique = sum(len(s) for s in unique_hyperedges.values())
    for key, edges in unique_hyperedges.items():
        prop = (len(edges) / tot_unique * 100) if tot_unique > 0 else 0.0
        print(f"  {key}: {prop:.1f}% ({len(edges)} unique total)")
    return total_counts, unique_hyperedges

def analyze_hypergraph_topology(results_list):
    """Unique hyperedges and node participation across all windows."""
    all_edges = {'2way': [], '3way': [], '4way': []}
    for res in results_list:
        all_edges['2way'].extend(list(res.get('edges_2', set())))
        all_edges['3way'].extend(list(res.get('tris_3', set())))
        all_edges['4way'].extend(list(res.get('quads_4', set())))
    unique_edges = {k: set(map(frozenset, v)) for k, v in all_edges.items()}

    n2, n3, n4 = len(unique_edges['2way']), len(unique_edges['3way']), len(unique_edges['4way'])
    tot = n2 + n3 + n4

    print(f"\nHYPERGRAPH STRUCTURE (Unique edges):")
    print(f"  2-way: {n2} ({(n2/tot*100 if tot>0 else 0):.1f}%)")
    print(f"  3-way: {n3} ({(n3/tot*100 if tot>0 else 0):.1f}%)")
    print(f"  4-way: {n4} ({(n4/tot*100 if tot>0 else 0):.1f}%)")

    node_degrees = {2: defaultdict(int), 3: defaultdict(int), 4: defaultdict(int)}
    for edge in unique_edges['2way']:
        for node in edge: node_degrees[2][node] += 1
    for edge in unique_edges['3way']:
        for node in edge: node_degrees[3][node] += 1
    for edge in unique_edges['4way']:
        for node in edge: node_degrees[4][node] += 1

    print(f"\nNODE PARTICIPATION (Unique edge counts per node):")
    for order in [2, 3, 4]:
        vals = list(node_degrees[order].values())
        if vals:
            print(f"  {order}-way: avg degree = {np.mean(vals):.1f}, max = {np.max(vals)}")
        else:
            print(f"  {order}-way: No unique edges found.")
    return unique_edges, node_degrees

structural_counts, unique_edges_counts = count_hyperedge_statistics(successful_results)
unique_edges_topology, node_degrees   = analyze_hypergraph_topology(successful_results)

# Final summary + JSON
print(f"\nFINAL HYPERGRAPH SUMMARY:")
print(f"  Nodes: {n} channels")
print(f"  Unique 2-way edges: {len(unique_edges_topology['2way'])}")
print(f"  Unique 3-way hyperedges: {len(unique_edges_topology['3way'])}")
print(f"  Unique 4-way hyperedges: {len(unique_edges_topology['4way'])}")

tot_unique_higher = len(unique_edges_topology['3way']) + len(unique_edges_topology['4way'])
tot_unique_edges  = len(unique_edges_topology['2way']) + tot_unique_higher
higher_order_fraction = (tot_unique_higher / tot_unique_edges) if tot_unique_edges > 0 else 0.0
print(f"  Higher-order unique edge fraction: {higher_order_fraction*100:.1f}%")

summary = {
    'n_windows_total': len(window_meta),
    'n_windows_skipped_preselection': skipped_count,
    'n_windows_attempted_fit': len(fit_metas),
    'n_windows_skipped_during_fit': fit_skipped_count,
    'n_windows_successful_fit': len(successful_results),
    'window_length_s': WIN_LEN / FS,
    'stride_s': STRIDE / FS,
    'center_method': CENTER_METHOD,
    'degree_max': D_MAX,
    'n_channels': n,
    'thresholding': {
    'strategy': 'fixed_global',
    'edge_tau2_global': float(tau2_global),
    'triangle_tau3_global': float(tau3_global),
    'connectivity_rule': 'tau2_global = min_t bottleneck(MST(S2_t))',
    'enforce_closure': ENFORCE_CLOSURE,
    'sindy_threshold': THRESH_SINDY
    },

    'locality_stats_summary_fitted': {
        k: {
            'mean': float(np.mean(v)),
            'std' : float(np.std(v)),
            'median': float(np.median(v)),
            'min': float(np.min(v)),
            'max': float(np.max(v))
        } for k, v in all_locality_stats.items() if v
    },
    'unique_counts': {
        'edges2': len(unique_edges_topology['2way']),
        'edges3': len(unique_edges_topology['3way']),
        'edges4': len(unique_edges_topology['4way']),
    },
    'higher_order_unique_fraction': float(higher_order_fraction)
}

with open(f'{OUTDIR}/summary.json', 'w') as f:
    json.dump(summary, f, indent=2)

print(f"\n✓ Wrote {len(successful_results)} closed complexes to '{OUTDIR}/'")
print(f"✓ Saved per-window counts CSV and summary JSON to '{OUTDIR}/'")
