In [None]:
pip install numpy scipy pandas scikit-learn cvxpy group-lasso pysindy
!pip install --upgrade importlib-metadata
!pip install --upgrade derivative
!pip install --upgrade --force-reinstall pysindy


In [None]:
from google.colab import drive # Import the drive module from google.colab

drive.mount('/content/drive', force_remount=True)  # Remount with force_remount=True
!find /content/drive/MyDrive/data


In [None]:
#!/usr/bin/env python3
# -----------------------------------------------------------
# Enhanced hypergraph reconstruction with:
# 1. Taylor expansion centering (dynamic per window)
# 2. Contribution ratios (ρ) for each interaction order
# 3. Better handling of higher-order interactions
# -----------------------------------------------------------
import numpy as np
import pandas as pd
from scipy.signal import savgol_filter
import pysindy as ps
from joblib import Parallel, delayed
import os
from collections import defaultdict
import matplotlib.pyplot as plt
import seaborn as sns
import json # Import json for saving summary

# ─────────────── USER PARAMETERS ───────────────────────────
CSV_FILE        = '/content/drive/MyDrive/data/eeg/EEG_TD_31_EEGdata.csv'
FS              = 256.0          # [Hz]
WIN_SG, ORDER   = 13, 3          # SavGol window (odd) & polynomial order
D_MAX           = 3              # highest polynomial degree in library (set to 4 for 4-way)
THRESH_SINDY    = 0.1            # sparsity threshold λ (STLSQ)
HEDGE_THRESH    = 0.05           # cut-off ε when mapping coefs → edges

WIN_LEN         = 1024           # samples per sliding window (≈ 4.0 s)
STRIDE          = 512            # hop between windows (≈ 2.0 s)
MAX_ROWS        = 3000           # cap rows per window during SINDy
N_JOBS          = -1             # joblib cores (-1 = all)

# New parameters for enhanced version
CENTER_METHOD   = 'median'       # 'mean' or 'median' for Taylor expansion center
USE_GLOBAL_NORM = True           # Whether to use global normalization
COMPUTE_CONTRIB = True           # Whether to compute contribution ratios

OUTDIR          = 'dyn_graphs_enhanced'
os.makedirs(OUTDIR, exist_ok=True)
# ───────────────────────────────────────────────────────────

# 1 ─── READ (T × n → n × T)
raw = pd.read_csv(
        CSV_FILE,
        header=None,
        nrows=None,            # Read ALL data (180,000 samples)
        usecols=range(31)      # 31 channels
      ).values.astype(np.float64)

# Drop constant/all-zero channels
tol = 1e-12
nz_mask = (np.ptp(raw, axis=0) > tol)
if (~nz_mask).any():
    dropped = np.where(~nz_mask)[0] + 1
    print(f'⚠️  Dropping constant channels (1-based): {dropped.tolist()}')
raw = raw[:, nz_mask]
raw = raw.T                       # rows = channels, cols = time
n, T_raw = raw.shape
dt = 1 / FS

print(f"Loaded data: {n} channels × {T_raw} samples ({T_raw/FS:.1f} seconds)")

# 2 ─── SAVITZKY–GOLAY smooth + derivative
half = (WIN_SG - 1) // 2
# 2. smooth & derivative (both from smoothed trace)
X_smooth = savgol_filter(raw, WIN_SG, ORDER, axis=1, mode='interp')
X_smooth = X_smooth[:, half:-half]

dXdt_raw = savgol_filter(X_smooth, WIN_SG, ORDER,
                         deriv=1, delta=dt, axis=1, mode='interp')

# 3. normalisation
if USE_GLOBAL_NORM:
    scale = np.mean(np.abs(X_smooth))
    X = X_smooth / scale
    Y = dXdt_raw / scale
else:
    X, Y = X_smooth, dXdt_raw




print(f"Data after preprocessing: {X.shape[0]} channels × {X.shape[1]} samples")

# ─────────── Helper functions ────────────

def indices_from_term(term_str):
    """
    Parse a term string to get node indices.
    '1'           -> []
    'x0'          -> [0]
    'x0^2'        -> [0, 0]
    'x1^3 x4'     -> [1, 1, 1, 4]
    """
    if term_str == '1':
        return []
    idxs = []
    for tok in term_str.split():
        base, *pow_part = tok.split('^')
        j = int(base[1:])
        power = int(pow_part[0]) if pow_part else 1
        idxs.extend([j] * power)
    return idxs

def compute_contribution_ratios(coefs, feature_names, X_window, max_order=4):
    """Fixed version that skips repeated indices"""
    n_channels = coefs.shape[0]
    contributions = defaultdict(float)

    for target in range(n_channels):
        for coef, term in zip(coefs[target], feature_names):
            if abs(coef) < 1e-10 or term == '1':
                continue

            indices = indices_from_term(term)
            if not indices:
                continue

            all_nodes = [target] + indices
            unique_nodes = set(all_nodes)

            # CRITICAL FIX: Skip terms with repeated indices
            if len(unique_nodes) <= 1 or len(unique_nodes) != len(all_nodes):
                continue  # Skip x0^2, x1^3, etc.

            order = len(unique_nodes)
            if order > max_order:
                continue

            # Now compute magnitude (keeping multiplicity for computation)
            if len(indices) > 0:
                term_values = np.ones(X_window.shape[1])
                for idx in indices:
                    term_values *= X_window[idx]
                avg_magnitude = np.mean(np.abs(term_values))
            else:
                avg_magnitude = 1.0

            contribution = abs(coef) * avg_magnitude
            contributions[order] += contribution


    # Normalize to get ratios
    total = sum(contributions.values())
    rho = {}

    if total > 1e-10:
        for order in range(2, max_order + 1):
            rho[order] = contributions[order] / total
    else:
        for order in range(2, max_order + 1):
            rho[order] = 0.0

    # Compute higher-order ratio (3-way and above)
    rho['higher'] = sum(rho.get(o, 0) for o in range(3, max_order + 1))

    return rho

def coefs_to_simplices_enhanced(coefs, feature_names, thresh=1e-3):
    """Properly handle polynomial terms without double-counting"""
    edge_w, tri_w, quad_w = {}, {}, {}

    # Parse all features once
    parsed = []
    for term in feature_names:
        if term == '1':
            parsed.append(None)
        else:
            parsed.append(indices_from_term(term))

    n = coefs.shape[0]

    for target in range(n):
        for w, idxs in zip(coefs[target], parsed):
            if idxs is None or abs(w) < thresh:
                continue

            # Critical: Check for genuine multi-way interaction
            all_nodes = [target] + idxs
            unique_nodes = set(all_nodes)

            # Skip if:
            # 1. Self-loop (only one unique node)
            # 2. Repeated indices (not genuine multi-way)
            if len(unique_nodes) <= 1 or len(unique_nodes) != len(all_nodes):
                continue

            # Now we have a genuine k-way interaction
            simplex = frozenset(unique_nodes)
            k = len(simplex)

            if k == 2:
                edge_w[simplex] = max(edge_w.get(simplex, 0.0), abs(w))
            elif k == 3:
                tri_w[simplex] = max(tri_w.get(simplex, 0.0), abs(w))
            elif k == 4:
                quad_w[simplex] = max(quad_w.get(simplex, 0.0), abs(w))

    return edge_w, tri_w, quad_w, {}
GLOBAL_LIBRARY = ps.PolynomialLibrary(degree=D_MAX, include_bias=True)


# ─────────── Per-window SINDy with Taylor centering ────────────
def fit_window_enhanced(w_start):
    """
    Enhanced window fitting with:
    1. Taylor expansion centering
    2. Contribution ratio computation
    """

    # ---- Slice window -------------------------------------------------
    w_end = w_start + WIN_LEN
    Xw, Yw = X[:, w_start:w_end], Y[:, w_start:w_end]

    # ---- Center data for Taylor expansion ----------------------------
    if CENTER_METHOD == 'mean':
        x0 = np.mean(Xw, axis=1, keepdims=True)
    elif CENTER_METHOD == 'median':
        x0 = np.median(Xw, axis=1, keepdims=True)
    else:
        x0 = np.zeros((n, 1))  # No centering

    # Center the data around x0
    Xw_centered = Xw - x0
    # Note: derivatives don't need centering

    # Optional row-subsample
    if Xw_centered.shape[1] > MAX_ROWS:
        idx = np.linspace(0, Xw_centered.shape[1]-1, MAX_ROWS, dtype=int)
        Xw_use, Yw_use = Xw_centered[:, idx], Yw[:, idx]
    else:
        Xw_use, Yw_use = Xw_centered, Yw

    # ---- SINDy fit ----------------------------------------------------
   # library = ps.PolynomialLibrary(degree=D_MAX, include_bias=True)
    optimizer = ps.STLSQ(alpha=1e-3, threshold=THRESH_SINDY)
    model = ps.SINDy(feature_library=GLOBAL_LIBRARY, optimizer=optimizer)

    # Fit on centered data
    model.fit(Xw_use.T, t=dt, x_dot=Yw_use.T, quiet=True)

    # ---- Extract results ----------------------------------------------
    coefs_matrix = model.coefficients()
    feature_names = model.get_feature_names()

    # Extract simplices
    edges_2, tris_3, quads_4, term_contribs = coefs_to_simplices_enhanced(
        coefs_matrix, feature_names, thresh=HEDGE_THRESH
    )

    # ---- Compute contribution ratios ----------------------------------
    rho = {}
    if COMPUTE_CONTRIB:
        rho = compute_contribution_ratios(
            coefs_matrix, feature_names, Xw_centered, max_order=D_MAX+1
        )

    # ---- Return results -----------------------------------------------
    t_mid = (w_start + w_end) / 2 * dt

    result = {
        't_mid': t_mid,
        'edges_2': edges_2,
        'tris_3': tris_3,
        'quads_4': quads_4,
        'rho': rho,
        'x0': x0,  # Store the centering point
        'n_samples': Xw_use.shape[1],
        'term_contributions': term_contribs
    }

    return result

# 4 ─── Run over sliding windows (parallel) ─────────
starts = range(0, X.shape[1] - WIN_LEN + 1, STRIDE)
n_windows = len(starts)
print(f'⏳  Fitting {n_windows} windows with Taylor centering ({CENTER_METHOD})...')

results = Parallel(n_jobs=N_JOBS, verbose=5)(
    delayed(fit_window_enhanced)(ws) for ws in starts
)
print('    done ✔')

# 5 ─── Analyze contribution ratios across windows ─────────
all_rho = defaultdict(list)

if COMPUTE_CONTRIB:
    print("\n" + "="*60)
    print("CONTRIBUTION RATIO ANALYSIS (ρ values)")
    print("="*60)

    # Collect all rho values
    # Initialize at the top:
    for res in results:
        if res['rho']:
            for order, value in res['rho'].items():
                all_rho[order].append(value)

    # Compute statistics
    print("\nAverage contribution by interaction order:")
    for order in [2, 3, 4, 'higher']:
        if order in all_rho and all_rho[order]:
            values = np.array(all_rho[order])
            print(f"  ρ_{order}: {np.mean(values):.3f} ± {np.std(values):.3f}")
            print(f"       median: {np.median(values):.3f}, "
                  f"range: [{np.min(values):.3f}, {np.max(values):.3f}]")

    # Create visualization
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))

    # Plot 1: Distribution of each ρ
    ax = axes[0, 0]
    data_for_plot = []
    labels = []
    for order in [2, 3, 4]:
        if order in all_rho and all_rho[order]:
            data_for_plot.append(all_rho[order])
            labels.append(f'ρ_{order}')

    if data_for_plot:
        bp = ax.boxplot(data_for_plot, labels=labels)
        ax.set_ylabel('Contribution Ratio')
        ax.set_title('Distribution of Contribution Ratios')
        ax.grid(True, alpha=0.3)

    # Plot 2: Time evolution of contributions
    ax = axes[0, 1]
    time_points = [res['t_mid'] for res in results]
    for order, color in [(2, 'blue'), (3, 'orange'), (4, 'green')]:
        if order in all_rho:
            values = [res['rho'].get(order, 0) for res in results]
            ax.plot(time_points, values, label=f'ρ_{order}', color=color, alpha=0.7)
    ax.set_xlabel('Time (s)')
    ax.set_ylabel('Contribution Ratio')
    ax.set_title('Temporal Evolution of Contributions')
    ax.legend()
    ax.grid(True, alpha=0.3)

    # Plot 3: Higher-order vs Pairwise
    ax = axes[1, 0]
    if 'higher' in all_rho:
        higher_values = all_rho['higher']
        ax.hist(higher_values, bins=30, edgecolor='black', alpha=0.7)
        ax.axvline(np.mean(higher_values), color='red', linestyle='--',
                  label=f'Mean = {np.mean(higher_values):.3f}')
        ax.set_xlabel('Higher-Order Contribution (ρ₃ + ρ₄)')
        ax.set_ylabel('Frequency')
        ax.set_title('Distribution of Higher-Order Contributions')
        ax.legend()
        ax.grid(True, alpha=0.3)

    # Plot 4: Scatter plot of ρ₂ vs ρ_higher
    ax = axes[1, 1]
    if 2 in all_rho and 'higher' in all_rho:
        ax.scatter(all_rho[2], all_rho['higher'], alpha=0.5)
        ax.set_xlabel('ρ₂ (Pairwise)')
        ax.set_ylabel('ρ_higher (3-way+)')
        ax.set_title('Pairwise vs Higher-Order Contributions')
        ax.plot([0, 1], [1, 0], 'k--', alpha=0.3)  # Reference line
        ax.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(f'{OUTDIR}/contribution_ratios.png', dpi=150)
    print(f"\n✓ Saved contribution ratio analysis to {OUTDIR}/contribution_ratios.png")

    # Statistical test: Are higher-order interactions significant?
    if 'higher' in all_rho:
        higher_vals = np.array(all_rho['higher'])
        prop_significant = np.mean(higher_vals > 0.3)  # Arbitrary threshold
        print(f"\n{prop_significant*100:.1f}% of windows have >30% higher-order contribution")

        # Compare to paper's finding (>60% from higher-order)
        prop_dominant = np.mean(higher_vals > 0.6)
        print(f"{prop_dominant*100:.1f}% of windows have >60% higher-order contribution")

# 6 ─── Save enhanced edge lists with metadata ───────────
def pretty_edge(e):
    return '{' + ','.join(map(str, sorted(e))) + '}'

print(f"\n📁 Saving {len(results)} time slices to '{OUTDIR}/'...")

# Also save summary statistics
summary = {
    'n_windows': len(results),
    'window_length_s': WIN_LEN / FS,
    'stride_s': STRIDE / FS,
    'center_method': CENTER_METHOD,
    'degree_max': D_MAX,
    'n_channels': n,
    'contribution_stats': {}
}

for order in [2, 3, 4, 'higher']:
    if order in all_rho and all_rho[order]:
        summary['contribution_stats'][f'rho_{order}'] = {
            'mean': float(np.mean(all_rho[order])),
            'std': float(np.std(all_rho[order])),
            'median': float(np.median(all_rho[order])),
            'min': float(np.min(all_rho[order])),
            'max': float(np.max(all_rho[order]))
        }

# Save individual windows
for res in results:
    stamp = f'{res["t_mid"]:010.3f}'

    # Save edges
    with open(f'{OUTDIR}/edges2_{stamp}.txt', 'w') as f2:
        f2.write(f'# Taylor center ({CENTER_METHOD}): {res["x0"].flatten()[:3]}...\n')
        f2.write(f'# ρ₂={res["rho"].get(2, 0):.3f}, ρ₃={res["rho"].get(3, 0):.3f}\n')
        for e, w in sorted(res['edges_2'].items(), key=lambda x: -x[1]):
            f2.write(f'{pretty_edge(e)} {w:.4f}\n')

    with open(f'{OUTDIR}/edges3_{stamp}.txt', 'w') as f3:
        for e, w in sorted(res['tris_3'].items(), key=lambda x: -x[1]):
            f3.write(f'{pretty_edge(e)} {w:.4f}\n')

    if res['quads_4']:
        with open(f'{OUTDIR}/edges4_{stamp}.txt', 'w') as f4:
            for e, w in sorted(res['quads_4'].items(), key=lambda x: -x[1]):
                f4.write(f'{pretty_edge(e)} {w:.4f}\n')

# Save summary
with open(f'{OUTDIR}/summary.json', 'w') as f:
    json.dump(summary, f, indent=2)

print(f"✓ Wrote {len(results)} dynamic slices to '{OUTDIR}/'")
print(f"✓ Saved summary statistics to '{OUTDIR}/summary.json'")

# --- Run the analysis functions ---

def count_hyperedge_statistics(results):
    """
    Count the actual hypergraph structure, not dynamical contributions
    """
    total_counts = {
        'edges_2': 0,  # 1-simplices (edges)
        'edges_3': 0,  # 2-simplices (triangles)
        'edges_4': 0   # 3-simplices (tetrahedra)
    }

    unique_hyperedges = {
        'edges_2': set(),
        'edges_3': set(),
        'edges_4': set()
    }

    for res in results:
        # Count per window
        total_counts['edges_2'] += len(res['edges_2'])
        total_counts['edges_3'] += len(res['tris_3'])
        total_counts['edges_4'] += len(res['quads_4'])

        # Track unique across all windows
        for edge in res['edges_2']:
            unique_hyperedges['edges_2'].add(frozenset(edge))
        for tri in res['tris_3']:
            unique_hyperedges['edges_3'].add(frozenset(tri))
        for quad in res['quads_4']:
            unique_hyperedges['edges_4'].add(frozenset(quad))

    # Compute structural statistics
    total_edges_count = sum(total_counts.values())

    print("\n" + "="*60)
    print("HYPERGRAPH STRUCTURAL STATISTICS")
    print("="*60)

    print("\nAverage per window:")
    n_windows = len(results)
    for key, count in total_counts.items():
         print(f"  {key}: {count/n_windows:.1f}")


    print("\nStructural proportions (total unique edges):")
    total_unique_edges_count = sum(len(s) for s in unique_hyperedges.values())
    for key, edges in unique_hyperedges.items():
        proportion = len(edges) / total_unique_edges_count if total_unique_edges_count > 0 else 0
        print(f"  {key}: {proportion*100:.1f}% ({len(edges)} unique total)")


    return total_counts, unique_hyperedges

def analyze_hypergraph_topology(results):
    """
    Analyze the actual hypergraph structure
    """
    # Collect all hyperedges
    all_edges = {'2way': [], '3way': [], '4way': []}

    for res in results:
        all_edges['2way'].extend(list(res['edges_2'].keys()))
        all_edges['3way'].extend(list(res['tris_3'].keys()))
        all_edges['4way'].extend(list(res['quads_4'].keys()))

    # Convert to sets to get unique
    unique_edges = {
        k: set(map(frozenset, v)) for k, v in all_edges.items()
    }

    # Compute statistics
    n2 = len(unique_edges['2way'])
    n3 = len(unique_edges['3way'])
    n4 = len(unique_edges['4way'])
    total = n2 + n3 + n4

    print(f"\nHYPERGRAPH STRUCTURE (Unique edges):")
    print(f"  2-way edges: {n2} ({n2/total*100:.1f}%)" if total > 0 else f"  2-way edges: {n2} (0.0%)")
    print(f"  3-way edges: {n3} ({n3/total*100:.1f}%)" if total > 0 else f"  3-way edges: {n3} (0.0%)")
    print(f"  4-way edges: {n4} ({n4/total*100:.1f}%)" if total > 0 else f"  4-way edges: {n4} (0.0%)")


    # Degree distribution (based on unique edges)
    node_degrees = {2: defaultdict(int), 3: defaultdict(int), 4: defaultdict(int)}

    for edge in unique_edges['2way']:
        for node in edge:
            node_degrees[2][node] += 1

    for edge in unique_edges['3way']:
        for node in edge:
            node_degrees[3][node] += 1

    for edge in unique_edges['4way']:
        for node in edge:
            node_degrees[4][node] += 1

    print(f"\nNODE PARTICIPATION (Unique edge counts per node):")
    for order in [2, 3, 4]:
        if node_degrees[order]:
            avg_degree = np.mean(list(node_degrees[order].values()))
            max_degree = max(node_degrees[order].values())
            print(f"  {order}-way: avg degree = {avg_degree:.1f}, max = {max_degree}")
        else:
             print(f"  {order}-way: No unique edges found.")

    return unique_edges, node_degrees

# --- Run the analysis functions ---

structural_counts, unique_edges_counts = count_hyperedge_statistics(results)
unique_edges_topology, node_degrees = analyze_hypergraph_topology(results)

# This tells you the actual hypergraph structure based on unique edges
print(f"\nFINAL HYPERGRAPH SUMMARY:")
print(f"  Nodes: {n} channels")
print(f"  Unique 2-way edges: {len(unique_edges_topology['2way'])}")
print(f"  Unique 3-way hyperedges: {len(unique_edges_topology['3way'])}")
print(f"  Unique 4-way hyperedges: {len(unique_edges_topology['4way'])}")

total_unique_higher_order = len(unique_edges_topology['3way']) + len(unique_edges_topology['4way'])
total_unique_edges = len(unique_edges_topology['2way']) + total_unique_higher_order

# This is your key metric!
higher_order_fraction = total_unique_higher_order / total_unique_edges if total_unique_edges > 0 else 0
print(f"  Higher-order unique edge fraction: {higher_order_fraction*100:.1f}%")