# Transfer Entropy (Standalone, Colab-Ready)

- Standalone pipeline based on IDTxl and (spectral multivariate transfer entropy by Edoardo Pinzuti):
- Fast delay scan (Gaussian-TE/GC) to select optimal source→target lag per link.
- Band-wise TE on filter bank: delta, theta, alpha, beta, gamma, broadband.
- Phase-randomized surrogates for source-only and target-only p-values; FDR-corrected.
- Heatmaps (TE and significance) and an optional "winning band" map.
- Batch mode: analyze one or many EDFs; match annotation TXT by EDF prefix (before first underscore).
- Per-EDF subfolder: saves delay scan results, figures, winning band map, and a text summary.

## References
- IDTxl GitHub: https://github.com/pwollstadt/IDTxl
- Spectral TE branch: https://github.com/pwollstadt/IDTxl/tree/feature_spectral_te/
- Wollstadt, P., Lizier, J. T., Vicente, R., Finn, C., Martinez-Zarzuela, M., Mediano, P., Novelli, L., & Wibral, M. (2018). IDTxl: The Information Dynamics Toolkit xl: a Python package for the efficient analysis of multivariate information dynamics in networks. Journal of Open Source Software, 4(34), 1081. https://doi.org/10.21105/joss.01081

## Mount Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

## Install Dependencies

In [None]:
# Install dependencies (Colab)
import sys
if 'google.colab' in sys.modules:
    !pip -q install mne numpy scipy matplotlib seaborn joblib numba

## Imports and Global Configuration

In [None]:
# Imports and global config
import os, re, json, math, itertools, warnings
from dataclasses import dataclass
import numpy as np
import scipy as sp
from scipy import signal
import matplotlib.pyplot as plt
import seaborn as sns
from joblib import Parallel, delayed
from typing import List, Tuple, Dict
warnings.filterwarnings('ignore')
try:
    import mne
except Exception as e:
    print('mne not available; please install in Colab cell above')
import re
from typing import List, Tuple

np.set_printoptions(suppress=True, precision=5)
plt.rcParams.update({'figure.dpi': 120})

# Default bands (Hz)
DEFAULT_BANDS = {
    'delta': (0.5, 4.0),
    'theta': (4.0, 8.0),
    'alpha': (8.0, 13.0),
    'beta': (13.0, 30.0),
    'gamma': (30.0, 80.0),
    'broadband': (0.5, None),  # None => up to Nyquist
}

# Runtime config
N_JOBS = max(1, os.cpu_count() - 1)
MAX_DELAY_MS = 200  # delay scan up to 200 ms
EMBED_PAST = 1     # past embedding for Gaussian TE
N_SURR = 100        # surrogates per test (source-only and target-only)
ALPHA = 0.05        # significance level
DOWNSAMPLE_BAND = True  # decimate to 2x band high-cut (when possible)
# Limit analysis duration to first N seconds to avoid OOM on large recordings
MAX_DURATION_S = 300  # 300 seconds = 5 minutes (set to None to disable)

# Plot style
CMAP_TE = 'magma'
CMAP_BAND = 'tab10'

## Helper Functions (Filtering, Surrogates, Gaussian TE)

In [None]:
# Helpers: filtering, surrogates, Gaussian TE (equiv. to GC under Gaussianity)
def get_sos_band(fs: float, low: float, high: float, order=4):
    nyq = 0.5 * fs
    if high is None:
        # high is Nyquist minus margin
        high = 0.99 * nyq
    low = max(0.001, low)
    high = min(high, nyq * 0.99)
    if low <= 0 and high < nyq:
        wn = high / nyq
        btype = 'lowpass'
        sos = signal.butter(order, wn, btype=btype, output='sos')
    elif low > 0 and (high >= nyq * 0.99):
        wn = low / nyq
        btype = 'highpass'
        sos = signal.butter(order, wn, btype=btype, output='sos')
    else:
        wn = [low / nyq, high / nyq]
        sos = signal.butter(order, wn, btype='bandpass', output='sos')
    return sos

def bandpass(ts: np.ndarray, fs: float, band: Tuple[float, float], order=4, decimate=True):
    ts = np.asarray(ts, float)
    sos = get_sos_band(fs, band[0], band[1], order=order)
    y = signal.sosfiltfilt(sos, ts, axis=0)
    ds = 1
    if decimate and band[1] is not None and band[1] > 0:
        target_fs = min(fs, max(2.5 * band[1], 2 * band[1] + 1))
        ds = max(1, int(math.floor(fs / target_fs)))
    if ds > 1:
        y = signal.decimate(y, ds, axis=0, ftype='fir', zero_phase=True)
        fs_eff = fs / ds
    else:
        fs_eff = fs
    return y, fs_eff

def phase_randomize(ts: np.ndarray, rng: np.random.Generator):
    x = np.asarray(ts, float)
    n = len(x)
    Xf = np.fft.rfft(x)
    mags = np.abs(Xf)
    phases = np.angle(Xf)
    rand_ph = rng.uniform(-np.pi, np.pi, size=phases.shape)
    rand_ph[0] = 0.0
    if n % 2 == 0:
        rand_ph[-1] = 0.0
    Yf = mags * np.exp(1j * (phases + rand_ph))
    y = np.fft.irfft(Yf, n=n)
    return y

def build_reg_mats(xpast: np.ndarray, ypast: np.ndarray, yt: np.ndarray):
    # Assemble design matrices for Gaussian TE
    # Model 1: yt ~ ypast
    X1 = ypast
    # Model 2: yt ~ ypast + xpast
    X2 = np.hstack([ypast, xpast])
    return X1, X2, yt

def ls_res_var(X: np.ndarray, y: np.ndarray):
    beta, *_ = np.linalg.lstsq(X, y, rcond=None)
    resid = y - X @ beta
    return float(np.var(resid, ddof=min(X.shape[1], X.shape[0]) - 1 if X.shape[0] > X.shape[1] else 1))

def gaussian_te(x: np.ndarray, y: np.ndarray, delay_samples: int, k_past=1):
    # TE_{x->y} with Gaussian assumption: 0.5 * ln(var(e1)/var(e2))
    # Build past vectors
    d = int(delay_samples)
    if d <= 0:
        return np.nan
    # Align so Xpast leads Yt by d
    T = min(len(x) - d - k_past, len(y) - k_past)
    if T <= 5:
        return np.nan
    idx_y = np.arange(k_past, k_past + T)
    idx_x = np.arange(k_past + d, k_past + d + T)
    Yt = y[idx_y]  # current
    Ypast = np.column_stack([y[idx_y - i] for i in range(1, k_past + 1)])
    Xpast = np.column_stack([x[idx_x - i] for i in range(1, k_past + 1)])
    X1, X2, Y = build_reg_mats(Xpast, Ypast, Yt)
    v1 = ls_res_var(X1, Y)
    v2 = ls_res_var(X2, Y)
    if v1 <= 0 or v2 <= 0:
        return np.nan
    return 0.5 * np.log(v1 / v2)

def fdr_bh(pvals: np.ndarray, alpha=0.05):
    p = np.asarray(pvals).flatten()
    n = p.size
    idx = np.argsort(p)
    ranked = p[idx]
    thresh = alpha * (np.arange(1, n + 1) / n)
    passed = ranked <= thresh
    k = np.where(passed)[0].max() + 1 if np.any(passed) else 0
    crit = ranked[k - 1] if k > 0 else 0.0
    return (pvals <= crit), crit

## EDF/FIF Loader and Batch Runner

In [None]:
# EDF/FIF + annotations + batch runner
import csv


def find_annotations_for_edf_like(path: str, ann_folder: str):
    base = os.path.basename(path)
    prefix = base.split('_')[0]
    for fn in os.listdir(ann_folder):
        if fn.lower().endswith('.txt') and fn.startswith(prefix):
            return os.path.join(ann_folder, fn)
    return None



def load_signal_any(path: str, max_duration_s=None):
    """Load .edf or .fif and optionally crop to first max_duration_s seconds to limit memory use.
    Returns: ch_names, data (n_channels x n_samples), fs
    """
    ext = os.path.splitext(path)[1].lower()
    if ext == '.edf':
        raw = mne.io.read_raw_edf(path, preload=True, verbose='ERROR')
    elif ext == '.fif':
        raw = mne.io.read_raw_fif(path, preload=True, verbose='ERROR')
    else:
        raise ValueError(f'Unsupported file type: {ext}. Expected .edf or .fif')

    # Optionally crop to limit duration (helps avoid GPU/CPU OOM)
    try:
        if max_duration_s is not None:
            # crop modifies Raw in-place; keep from 0 to max_duration_s
            raw.crop(tmin=0.0, tmax=float(max_duration_s))
    except Exception:
        # If cropping fails, continue with full data
        pass

    # Attempt unit normalization to Volts (MNE stores SI units typically already in Volts)
    try:
        units = raw.get_units()
        # If microvolts were used in EDF, ensure conversion to Volts
        for ch, u in units.items():
            if isinstance(u, str) and u.lower() in ('uv', 'µv', 'microvolts', 'microvolt'):
                raw.apply_function(lambda x: x * 1e-6)
                break
    except Exception:
        pass
    data = raw.get_data()  # n_channels x n_samples
    fs = float(raw.info['sfreq'])
    return raw.ch_names, data, fs


# Refined override: parse SOZ/NIZ annotations (3-column) with de-dup and diagnostics
# Format per line: <region>\t<channel_name>\t<label>
# label 1 => SOZ (targets), 0 => NIZ (sources)

def parse_links_from_txt(txt_path: str, ch_names: List[str]) -> List[Tuple[int, int]]:
    links: List[Tuple[int, int]] = []
    with open(txt_path, 'r') as f:
        lines = [ln.strip() for ln in f if ln.strip()]

    idx = {n: i for i, n in enumerate(ch_names)}

    # 1) sources:/targets: format
    srcs = None
    tgts = None
    for ln in lines:
        if ln.lower().startswith('sources:'):
            srcs = [s.strip() for s in re.split('[,;]+', ln.split(':', 1)[1]) if s.strip()]
        elif ln.lower().startswith('targets:'):
            tgts = [s.strip() for s in re.split('[,;]+', ln.split(':', 1)[1]) if s.strip()]
    if srcs is not None and tgts is not None:
        for s in srcs:
            for t in tgts:
                if s in idx and t in idx and idx[s] != idx[t]:
                    links.append((idx[s], idx[t]))
        if links:
            print(f"Parsed {len(links)} links from sources:/targets: format")
            return links

    # 2) SOZ/NIZ 3-column format
    labeled = {}  # channel -> label (last wins if duplicated)
    for ln in lines:
        parts = re.split(r'[\t, ]+', ln)
        if len(parts) >= 3:
            chan = parts[1].strip()
            try:
                lab = float(parts[2])
            except Exception:
                continue
            if chan in idx:
                labeled[chan] = lab
    if labeled:
        src_list = sorted({c for c, l in labeled.items() if l <= 0.5})  # NIZ
        tgt_list = sorted({c for c, l in labeled.items() if l > 0.5})   # SOZ
        pair_set = set()
        for s in src_list:
            for t in tgt_list:
                if s in idx and t in idx and idx[s] != idx[t]:
                    pair_set.add((idx[s], idx[t]))
        links = sorted(pair_set)
        print(f"Parsed {len(links)} links from SOZ/NIZ format ({len(src_list)} sources x {len(tgt_list)} targets)")
        missing = [c for c in list(labeled.keys()) if c not in idx]
        if missing:
            print(f"Warning: {len(missing)} channels not found in recording and were ignored: {missing[:10]}{'...' if len(missing)>10 else ''}")
        return links

    # 3) Fallback: pair-per-line
    pair_set = set()
    for ln in lines:
        parts = re.split(r'[\t, ]+', ln)
        if len(parts) >= 2 and parts[0] in idx and parts[1] in idx and parts[0] != parts[1]:
            pair_set.add((idx[parts[0]], idx[parts[1]]))
    links = sorted(pair_set)
    if links:
        print(f"Parsed {len(links)} links from pair-per-line format")
    return links

print('Annotation parser override (SOZ/NIZ) active.')


def save_delay_results(path, links, delays, te_traces, fs):
    with open(path, 'w') as f:
        f.write('# Delay scan results\n')
        f.write('# format: src\tgt\tbest_delay_samples\tbest_delay_ms\n')
        for (s, t), d, trace in zip(links, delays, te_traces):
            f.write(f"{s}\t{t}\t{d}\t{1000.0 * d / fs:.3f}\n")


def save_summary_txt(path, names, TE, pS, pT, sigS, sigT):
    with open(path, 'w') as f:
        f.write('Bands: ' + ','.join(names) + '\n')
        f.write('TE matrix (rows=src, cols=tgt):\n')
        np.savetxt(f, TE, fmt='%.6f')
        f.write('\nP_source:\n'); np.savetxt(f, pS, fmt='%.6f')
        f.write('\nP_target:\n'); np.savetxt(f, pT, fmt='%.6f')
        f.write('\nSig_source (FDR):\n'); np.savetxt(f, sigS.astype(int), fmt='%d')
        f.write('\nSig_target (FDR):\n'); np.savetxt(f, sigT.astype(int), fmt='%d')


def process_one_file(path: str, ann_folder: str, bands=DEFAULT_BANDS, out_root=None):
    """
    Process one EDF/FIF file and produce per-link outputs and one aggregated file-level summary.
    """
    # Load only up to MAX_DURATION_S seconds to limit memory/GPU usage
    ch_names, data, fs = load_signal_any(path, max_duration_s=MAX_DURATION_S)
    ann = find_annotations_for_edf_like(path, ann_folder)
    if ann is None:
        raise FileNotFoundError(f'No annotation TXT matching prefix for {path}')
    links = parse_links_from_txt(ann, ch_names)
    if not links:
        raise ValueError(f'No valid (src,tgt) links parsed from {ann}')

    base = os.path.basename(path)
    prefix = base.split('_')[0]
    out_dir = os.path.join(os.path.dirname(path) if out_root is None else out_root, prefix)
    # Ensure the output directory exists before saving files
    os.makedirs(out_dir, exist_ok=True)

    # Delay scan per link (broadband)
    bb_lo, bb_hi = bands['broadband']
    x_bb = {}
    for idx in set([s for s, _ in links] + [t for _, t in links]):
        x = data[idx] - np.mean(data[idx])
        xf, _ = bandpass(x, fs, (bb_lo, bb_hi), order=4, decimate=False)
        x_bb[idx] = xf
    delays = []
    te_traces = []
    for (s, t) in links:
        d, trace = delay_scan(x_bb[s], x_bb[t], fs, max_delay_ms=MAX_DELAY_MS, k_past=EMBED_PAST)
        delays.append(int(d))
        te_traces.append(trace)
    save_delay_results(os.path.join(out_dir, 'delay_scan_results.txt'), links, delays, te_traces, fs)

    # Band-wise TE per link + plots; collect results for aggregation
    aggregated_signifs = []
    for (s, t), d in zip(links, delays):
        names, TE, pS, pT, sigS, sigT, critS, critT = compute_te_grid_with_surrogates(
            data[s] - np.mean(data[s]), data[t] - np.mean(data[t]), fs, bands, delay_samples=d,
            k_past=EMBED_PAST, n_surr=N_SURR, alpha=ALPHA, downsample=DOWNSAMPLE_BAND, rng=np.random.default_rng(123))
        base_out = os.path.join(out_dir, f'link_{s}_{t}')
        plot_te_heatmap(names, TE, pS, pT, base_out)
        wins = plot_winning_band_map(names, TE, pS, pT, ALPHA, base_out + '_winning_band.png')
        np.savez_compressed(base_out + '.npz', names=np.array(names), TE=TE, p_source=pS, p_target=pT, sig_source=sigS, sig_target=sigT, delay_samples=int(d), fs=float(fs), wins=wins)
        save_summary_txt(base_out + '.txt', names, TE, pS, pT, sigS, sigT)
        aggregated_signifs.append({'pS': pS, 'pT': pT, 'names': names})

    # Aggregate across links to form file-level summary
    if not aggregated_signifs:
        print('No link results to aggregate.')
        return out_dir

    band_names = aggregated_signifs[0]['names']
    B = len(band_names)
    total_src_counts = np.zeros(B, dtype=int)
    total_tgt_counts = np.zeros(B, dtype=int)
    for res in aggregated_signifs:
        signif = (res['pS'] <= ALPHA) & (res['pT'] <= ALPHA)
        total_src_counts += np.sum(signif, axis=1)
        total_tgt_counts += np.sum(signif, axis=0)

    # Save aggregated counts CSV (use csv writer to avoid extra deps)
    base_out_file = os.path.join(out_dir, prefix)
    with open(base_out_file + '_band_counts.csv', 'w', newline='') as cf:
        writer = csv.writer(cf)
        writer.writerow(['band', 'total_src_count', 'total_tgt_count'])
        for nm, s, t in zip(band_names, total_src_counts, total_tgt_counts):
            writer.writerow([nm, int(s), int(t)])

    # Save nested pie summary
    plot_summary_pie(band_names, total_src_counts, total_tgt_counts, base_out_file + '_summary_pie.png')

    # Save text summary
    with open(base_out_file + '_summary.txt', 'w') as fh:
        fh.write(f'# Aggregated significant band counts for {os.path.basename(path)}\n')
        fh.write('# format: band, total_source_count, total_target_count\n')
        for nm, s, t in zip(band_names, total_src_counts, total_tgt_counts):
            fh.write(f"{nm},{s},{t}\n")

    print('Saved per-link results and file-wide summary to:', out_dir)
    return out_dir


def run_batch():
    mode = input('Analyze (1) one file or (2) multiple files in a folder? Enter 1 or 2: ').strip()
    if mode == '1':
        path = input('Enter full path to the EDF/FIF: ').strip()
        if not os.path.isfile(path):
            raise FileNotFoundError('File not found')
        ann_folder = input('Enter folder path containing annotation TXTs: ').strip()
        if not os.path.isdir(ann_folder):
            raise FileNotFoundError('Annotation folder not found')
        process_one_file(path, ann_folder)
    else:
        folder = input('Enter folder path containing EDF/FIF: ').strip()
        if not os.path.isdir(folder):
            raise FileNotFoundError('Folder not found')
        ann_folder = input('Enter folder path containing annotation TXTs: ').strip()
        if not os.path.isdir(ann_folder):
            raise FileNotFoundError('Annotation folder not found')
        files = [os.path.join(folder, f) for f in os.listdir(folder) if f.lower().endswith(('.edf', '.fif'))]
        for p in sorted(files):
            try:
                process_one_file(p, ann_folder, out_root=folder)
            except Exception as e:
                print('Failed for', p, '->', e)

print('Ready. Run run_batch() to start interactive processing for EDF/FIF.')

## Time Delay Scaner

In [None]:
# Delay scan helper: TE_{x->y} over delays to select best lag (coarse by default)
import numpy as np

def delay_scan(x, y, fs, max_delay_ms=200, k_past=1, step_ms=None, n_steps=10, min_delay_ms=1, detrend=True, zscore=True):
    """
    Scan TE_{x->y} over positive delays and pick the best delay (in samples).
    Efficient by default via coarse sampling (n_steps) across [min_delay, max_delay].

    Inputs:
      x, y: 1D arrays
      fs: sampling rate (Hz)
      max_delay_ms: maximum delay to test (ms)
      k_past: past embedding for gaussian_te
      step_ms: fixed step in ms (if provided, overrides n_steps)
      n_steps: number of evenly-spaced delays to test (default 10)
      min_delay_ms: minimum delay (ms)
      detrend/zscore: simple preprocessing
    Returns:
      best_delay_samples (int), te_trace (np.ndarray) aligned to the tested delays (internal)
    Notes:
      - Uses gaussian_te(x, y, delay_samples, k_past) defined earlier (optionally JIT-accelerated if you ran the GPU/Parallel cell).
    """
    x = np.asarray(x, dtype=float).ravel()
    y = np.asarray(y, dtype=float).ravel()
    n = min(x.size, y.size)
    if n < 10:
        return 1, np.array([np.nan])
    x = x[:n]; y = y[:n]
    if detrend:
        x = x - np.nanmean(x); y = y - np.nanmean(y)
    if zscore:
        sx = np.nanstd(x); sy = np.nanstd(y)
        if sx > 0: x = x / sx
        if sy > 0: y = y / sy

    d_min = max(1, int(round(min_delay_ms * fs / 1000.0)))
    d_max = max(d_min, int(round(max_delay_ms * fs / 1000.0)))

    if step_ms is not None:
        step_samp = max(1, int(round(step_ms * fs / 1000.0)))
        delays = np.arange(d_min, d_max + 1, step_samp, dtype=int)
    else:
        # Evenly-spaced integer sample delays
        n_steps = max(1, int(n_steps))
        delays = np.unique(np.linspace(d_min, d_max, num=n_steps, dtype=int))
        if delays.size == 0:
            delays = np.array([d_min], dtype=int)

    te_vals = np.full(delays.shape[0], np.nan, dtype=float)

    for i, d in enumerate(delays):
        try:
            te_vals[i] = gaussian_te(x, y, int(d), k_past=k_past)
        except Exception:
            te_vals[i] = np.nan

    if not np.any(np.isfinite(te_vals)):
        return int(delays[0]), te_vals
    best_idx = int(np.nanargmax(te_vals))
    best_delay = int(delays[best_idx])
    return best_delay, te_vals

## Visualization Helpers

In [None]:
# Visualization helpers
def plot_te_heatmap(names, TE, pS, pT, out_png_base):
    plt.figure(figsize=(8, 6))
    ax = sns.heatmap(TE, xticklabels=names, yticklabels=names, cmap=CMAP_TE, annot=False)
    ax.set_xlabel('Target band')
    ax.set_ylabel('Source band')
    ax.set_title('Cross-band TE (Gaussian)')
    plt.tight_layout()
    plt.savefig(out_png_base + '_TE.png', bbox_inches='tight')
    plt.close()
    # p-value heatmaps
    for lab, P in [('p_source', pS), ('p_target', pT)]:
        plt.figure(figsize=(8, 6))
        ax = sns.heatmap(P, xticklabels=names, yticklabels=names, cmap='viridis_r', vmin=0, vmax=1)
        ax.set_xlabel('Target band')
        ax.set_ylabel('Source band')
        ax.set_title(f'{lab} (fraction of surrogates ≥ obs)')
        plt.tight_layout()
        plt.savefig(out_png_base + f'_{lab}.png', bbox_inches='tight')
        plt.close()


def plot_winning_band_map(names, TE, pS, pT, alpha, out_png):
    """Draw a vertical column where each row = source band and color = winning target band.
    Legend shows gray = no significant winner and a colored swatch for each target band (by name).
    """
    TE = np.asarray(TE)
    pS = np.asarray(pS); pT = np.asarray(pT)
    signif = (pS <= alpha) & (pT <= alpha)
    # significant-only TE, with NaN where not significant
    sig_TE = np.where(signif, TE, np.nan)
    B = len(names)
    # Prepare wins array: -1 means no significant winning band for that source
    wins = np.full(B, -1, dtype=int)
    # Find rows that have at least one finite value
    has_finite = np.any(np.isfinite(sig_TE), axis=1)
    for i in np.where(has_finite)[0]:
        # np.nanargmax is safe now because row has at least one finite entry
        wins[i] = int(np.nanargmax(sig_TE[i, :]))

    # color by target band index; index 0 reserved for 'no winner' (gray)
    colors = plt.get_cmap(CMAP_BAND)(np.linspace(0, 1, B))
    cmap_list = [(0.7, 0.7, 0.7, 1.0)] + [tuple(c) for c in colors]  # 0=none, 1..B bands
    from matplotlib.colors import ListedColormap
    from matplotlib.patches import Patch

    cmap = ListedColormap(cmap_list)
    img = (wins + 1).reshape(-1, 1)

    fig, ax = plt.subplots(figsize=(5, 4))
    im = ax.imshow(img, aspect='auto', cmap=cmap, vmin=0, vmax=B)
    ax.set_yticks(range(B))
    ax.set_yticklabels(names)
    ax.set_xticks([0])
    ax.set_xticklabels(['win tgt'])
    ax.set_title('Winning band (significant-only argmax)')

    # Annotation text clarifying mapping
    ax.text(1.05, 0.5, 'Rows = source bands\nColor = winning target band\nGray = no significant target', transform=ax.transAxes,
            fontsize=8, va='center', ha='left', bbox=dict(facecolor='white', alpha=0.7, edgecolor='none'))

    # Build legend patches: first gray = none, then each target band color
    patches = [Patch(facecolor=cmap_list[0], edgecolor='black', label='no significant target (gray)')]
    for j, nm in enumerate(names):
        # map target j -> cmap_list[j+1]
        patches.append(Patch(facecolor=cmap_list[j + 1], edgecolor='black', label=f'target: {nm}'))

    # Place legend to the right
    ax.legend(handles=patches, bbox_to_anchor=(1.1, 1.0), loc='upper left', fontsize=8, frameon=False)

    plt.tight_layout()
    plt.savefig(out_png, bbox_inches='tight')
    plt.close()

    # Return wins for possible saving/inspection
    return wins


def plot_summary_pie(names, src_counts, tgt_counts, out_png):
    """Nested pie: outer ring = source counts (lighter shades), inner ring = target counts (darker shades).
    Both rings use the same base color per band; outer slices are lighter to indicate source.
    """
    import numpy as _np
    from matplotlib.patches import Patch as _Patch

    names = list(names)
    B = len(names)
    src = _np.asarray(src_counts, dtype=float)
    tgt = _np.asarray(tgt_counts, dtype=float)

    # If no significant pairs, save a simple informational figure
    if src.sum() == 0 and tgt.sum() == 0:
        fig, ax = plt.subplots(figsize=(5, 4))
        ax.text(0.5, 0.5, 'No significant source-target pairs found', ha='center', va='center')
        ax.axis('off')
        plt.tight_layout()
        plt.savefig(out_png, bbox_inches='tight')
        plt.close()
        return

    base_colors = plt.get_cmap(CMAP_BAND)(_np.linspace(0, 1, B))
    # lighter for source (blend with white), darker for target
    light_colors = [tuple(0.65 * c[:3] + 0.35 * _np.array([1.0, 1.0, 1.0])) + (1.0,) for c in base_colors]
    dark_colors = [tuple(0.92 * c[:3]) + (1.0,) for c in base_colors]

    # Normalize zero-sums to small positive so pie draws uniformly
    src_sizes = src.copy()
    tgt_sizes = tgt.copy()
    if src_sizes.sum() == 0:
        src_sizes = _np.ones(B)
    if tgt_sizes.sum() == 0:
        tgt_sizes = _np.ones(B)

    fig, ax = plt.subplots(figsize=(6, 6))
    ax.axis('equal')

    # Outer ring = source (lighter shades)
    wedges_outer, _ = ax.pie(src_sizes, radius=1.3, colors=light_colors, startangle=90,
                             wedgeprops=dict(width=0.4, edgecolor='white'))
    # Inner ring = target (darker shades)
    wedges_inner, _ = ax.pie(tgt_sizes, radius=0.9, colors=dark_colors, startangle=90,
                             wedgeprops=dict(width=0.4, edgecolor='white'))

    # Legend: colors map to band names (use dark color swatches), annotation explains outer/inner
    legend_patches = [_Patch(facecolor=dark_colors[i], edgecolor='black', label=names[i]) for i in range(B)]
    ax.legend(handles=legend_patches, bbox_to_anchor=(1.1, 0.8), title='Bands (dark=target)', fontsize=8)

    ax.text(1.05, 0.4, 'Outer = source (lighter)\nInner = target (darker)', transform=ax.transAxes,
            fontsize=8, va='center', ha='left', bbox=dict(facecolor='white', alpha=0.7, edgecolor='none'))

    plt.title('Summary counts: source (outer) and target (inner)')
    plt.tight_layout()
    plt.savefig(out_png, bbox_inches='tight')
    plt.close()

## Notes
- This pipeline uses Gaussian TE (equivalent to Granger causality for Gaussian data) for speed and robustness.
- Surrogates are FFT phase-randomized within each band to break directional structure while preserving power spectra.
- Delay scan is done on broadband; band-wise TE then uses the chosen per-link delay.
- Outputs per EDF: delay_scan_results.txt; per-link TE heatmaps, p-value maps, winning-band map, .npz and .txt summaries.
- Customize bands via DEFAULT_BANDS at the top.

## GPU and Parallel Enhancements

- Optional GPU acceleration via CuPy/cuSignal for FFT-based surrogates (falls back to CPU automatically).
- CPU speed-ups via numba JIT for Gaussian TE and joblib parallelization over links and band-pairs.
- These overrides shadow earlier helpers; run this section after the main setup cells.


In [None]:
# Try GPU (CuPy/cuSignal) and enable parallel + JIT
import os, sys
USE_GPU = True
try:
    import cupy as cp
    import cupyx.scipy.signal as cpsignal
    USE_GPU = True
    print('CuPy/cuSignal available: GPU acceleration ON')
except Exception:
    print('CuPy/cuSignal not available: using CPU')

try:
    from numba import njit
    NUMBA_ON = True
except Exception:
    NUMBA_ON = False

from joblib import Parallel, delayed

# Pure-Python fallback for TE (mirrors earlier gaussian_te)
def _gaussian_te_py(x, y, delay_samples, k_past=1):
    d = int(delay_samples)
    if d <= 0:
        return np.nan
    T = min(len(x) - d - k_past, len(y) - k_past)
    if T <= 5:
        return np.nan
    idx_y = np.arange(k_past, k_past + T)
    idx_x = np.arange(k_past + d, k_past + d + T)
    Yt = y[idx_y]
    Ypast = np.column_stack([y[idx_y - i] for i in range(1, k_past + 1)])
    Xpast = np.column_stack([x[idx_x - i] for i in range(1, k_past + 1)])
    # LS residual variances
    beta1, *_ = np.linalg.lstsq(Ypast, Yt, rcond=None)
    r1 = Yt - Ypast @ beta1
    n1, p1 = Ypast.shape
    v1 = float(np.var(r1, ddof=min(p1, n1) - 1 if n1 > p1 else 1))
    X2 = np.hstack([Ypast, Xpast])
    beta2, *_ = np.linalg.lstsq(X2, Yt, rcond=None)
    r2 = Yt - X2 @ beta2
    n2, p2 = X2.shape
    v2 = float(np.var(r2, ddof=min(p2, n2) - 1 if n2 > p2 else 1))
    if v1 <= 0 or v2 <= 0:
        return np.nan
    return 0.5 * np.log(v1 / v2)

if NUMBA_ON:
    @njit(fastmath=True)
    def _ls_res_var_jit(X, y):
        # Ridge-regularized normal equations (Numba-friendly)
        XT = X.T
        XTX = XT @ X
        eps = 1e-8
        for i in range(XTX.shape[0]):
            XTX[i, i] = XTX[i, i] + eps
        XTy = XT @ y
        beta = np.linalg.solve(XTX, XTy)
        # residuals
        resid = y - X @ beta
        n = X.shape[0]
        p = X.shape[1]
        denom = n - p
        if denom < 1:
            denom = n
        rss = 0.0
        for i in range(resid.shape[0]):
            rss += resid[i] * resid[i]
        return rss / denom

    @njit(fastmath=True)
    def _gaussian_te_jit(x, y, delay_samples, k_past):
        d = int(delay_samples)
        if d <= 0:
            return np.nan
        T = min(len(x) - d - k_past, len(y) - k_past)
        if T <= 5:
            return np.nan
        Yt = np.empty(T, dtype=np.float64)
        Ypast = np.empty((T, k_past), dtype=np.float64)
        Xpast = np.empty((T, k_past), dtype=np.float64)
        for t in range(T):
            it_y = t + k_past
            it_x = t + k_past + d
            Yt[t] = y[it_y]
            for i in range(k_past):
                Ypast[t, i] = y[it_y - (i+1)]
                Xpast[t, i] = x[it_x - (i+1)]
        v1 = _ls_res_var_jit(Ypast, Yt)
        X2 = np.empty((T, 2*k_past), dtype=np.float64)
        for t in range(T):
            for i in range(k_past):
                X2[t, i] = Ypast[t, i]
                X2[t, k_past + i] = Xpast[t, i]
        v2 = _ls_res_var_jit(X2, Yt)
        if v1 <= 0 or v2 <= 0:
            return np.nan
        return 0.5 * np.log(v1 / v2)

    def gaussian_te(x, y, delay_samples, k_past=1):
        # Try JIT; fall back to Python on any failure (e.g., worker compile issues)
        try:
            return _gaussian_te_jit(np.asarray(x, np.float64), np.asarray(y, np.float64), int(delay_samples), int(k_past))
        except Exception:
            return _gaussian_te_py(np.asarray(x, np.float64), np.asarray(y, np.float64), int(delay_samples), int(k_past))

# GPU-aware FFT phase randomization (batched)
def phase_randomize_batch(X, rng, use_gpu=USE_GPU):
    X = np.asarray(X, np.float64)
    if X.size == 0:
        return np.empty_like(X)  # Return empty array if input is empty
    N, T = X.shape
    if use_gpu:
        xg = cp.asarray(X)
        Xf = cp.fft.rfft(xg, axis=1)
        mags = cp.abs(Xf)
        phs = cp.angle(Xf)
        rand = rng.uniform(-np.pi, np.pi, size=phs.shape)
        R = cp.asarray(rand)
        R[:, 0] = 0.0
        if T % 2 == 0:
            R[:, -1] = 0.0
        Yf = mags * cp.exp(1j * (phs + R))
        Y = cp.fft.irfft(Yf, n=T, axis=1).get()
        return Y
    else:
        Xf = np.fft.rfft(X, axis=1)
        mags = np.abs(Xf)
        phs = np.angle(Xf)
        R = rng.uniform(-np.pi, np.pi, size=phs.shape)
        R[:, 0] = 0.0
        if T % 2 == 0:
            R[:, -1] = 0.0
        Yf = mags * np.exp(1j * (phs + R))
        Y = np.fft.irfft(Yf, n=T, axis=1)
        return Y

# Override compute_te_grid_with_surrogates with parallel + batched surrogates
def compute_te_grid_with_surrogates(src, tgt, fs, bands, delay_samples, k_past=1, n_surr=100, alpha=0.05, downsample=True, rng=None):
    if rng is None:
        rng = np.random.default_rng(123)
    names = list(bands.keys())
    B = len(names)
    TE = np.full((B, B), np.nan, float)
    pS = np.full((B, B), np.nan, float)
    pT = np.full((B, B), np.nan, float)

    # Pre-filter cache
    filt_src, fs_src = {}, {}
    filt_tgt, fs_tgt = {}, {}
    for b in names:
        filt_src[b], fs_src[b] = bandpass(src, fs, bands[b], order=4, decimate=downsample)
        filt_tgt[b], fs_tgt[b] = bandpass(tgt, fs, bands[b], order=4, decimate=downsample)

    def compute_cell(i, j):
        xs = filt_src[names[i]]; fs_x = fs_src[names[i]]
        yt = filt_tgt[names[j]]; fs_y = fs_tgt[names[j]]
        fs_c = min(fs_x, fs_y)
        if abs(fs_x - fs_c) > 1e-6:
            xs = sp.signal.resample_poly(xs, up=int(fs_c), down=int(fs_x))
        if abs(fs_y - fs_c) > 1e-6:
            yt = sp.signal.resample_poly(yt, up=int(fs_c), down=int(fs_y))
        d_samp = max(1, int(round(delay_samples * (fs_c / fs))))
        v_obs = gaussian_te(xs, yt, d_samp, k_past=k_past)
        if not np.isfinite(v_obs):
            return (i, j, np.nan, np.nan, np.nan)
        # Batched surrogates (source-only and target-only)
        batch = max(1, n_surr // 2)
        XS = np.tile(xs, (batch, 1))
        YT = np.tile(yt, (batch, 1))
        XS_s = phase_randomize_batch(XS, rng, use_gpu=USE_GPU)
        PS = np.empty(batch, float)
        for k in range(batch):
            PS[k] = gaussian_te(XS_s[k], yt, d_samp, k_past=k_past)
        p_src = float(np.nanmean(PS >= v_obs))
        YT_s = phase_randomize_batch(YT, rng, use_gpu=USE_GPU)
        PT = np.empty(batch, float)
        for k in range(batch):
            PT[k] = gaussian_te(xs, YT_s[k], d_samp, k_past=k_past)
        p_tgt = float(np.nanmean(PT >= v_obs))
        return (i, j, v_obs, p_src, p_tgt)

    # Fill grid in parallel (use threads to avoid Numba compile issues in subprocesses)
    results = Parallel(n_jobs=N_JOBS, prefer="threads")(delayed(compute_cell)(i, j) for i in range(B) for j in range(B))
    for (i, j, v, ps, pt) in results:
        TE[i, j] = v
        pS[i, j] = ps
        pT[i, j] = pt

    # FDR per grid
    sigS, critS = fdr_bh(pS, alpha=alpha)
    sigT, critT = fdr_bh(pT, alpha=alpha)
    return names, TE, pS, pT, sigS, sigT, float(critS), float(critT)

print('GPU/Parallel overrides active. Set N_JOBS at top to control CPU parallelism.')

## Main Script Launcher

In [None]:
# Main entry: simple launcher for interactive processing
import sys, traceback
print("Transfer Entropy – Main Menu")
print("1) Run (you'll select single-file or folder next)\nQ) Quit")
choice = input("Choose an option [1/Q]: ").strip().lower()
if choice in ("1", "", "r", "run"):
    rb = globals().get("run_batch")
    if rb is None or not callable(rb):
        print("run_batch is not defined. Please run the EDF/FIF runner cell above (the one that defines process_one_file and run_batch).")
    else:
        try:
            print(f"Using run_batch defined at: {rb.__code__.co_filename}, line {rb.__code__.co_firstlineno}")
        except Exception:
            pass
        rb()
else:
    print("Exit.")

## Summary Visualisations Launcher

In [None]:
# OPTIONAL: Build 5% and 1% (alpha=0.05 and 0.01) summary visualizations from existing per-link .npz files
# Modes: 1) Single folder, 2) Batch over subfolders, 3) Meta-analysis across all subfolders

def process_single_folder(out_dir):
    # [Existing code for single folder processing]
    import glob, csv
    if not os.path.isdir(out_dir):
        print('Folder not found:', out_dir)
        return
    npz_files = sorted(glob.glob(os.path.join(out_dir, 'link_*.npz')))
    if len(npz_files) == 0:
        npz_files = sorted(glob.glob(os.path.join(out_dir, '*_results.npz')))
    if len(npz_files) == 0:
        print('No per-link .npz files found in', out_dir)
        return
    base = os.path.basename(out_dir.rstrip(os.sep))
    patient_name = base.split('_')[0]
    alphas = [0.05, 0.01]
    summaries = {}
    for alpha in alphas:
        band_names = None
        total_src_counts = None
        total_tgt_counts = None
        all_pS = []
        all_pT = []
        for fpath in npz_files:
            try:
                arr = np.load(fpath, allow_pickle=True)
            except Exception as e:
                print('Skipping', fpath, '->', e); continue
            try:
                names_raw = arr['names']
            except Exception:
                print('Skipping', fpath, ': no names field'); continue
            names = [n.decode('utf-8') if isinstance(n, (bytes, bytearray)) else str(n) for n in np.asarray(names_raw).ravel()]
            try:
                pS = arr['p_source']
                pT = arr['p_target']
            except Exception:
                print('Skipping', fpath, ': missing p_source/p_target'); continue
            if band_names is None:
                band_names = names
                B = len(band_names)
                total_src_counts = np.zeros(B, dtype=int)
                total_tgt_counts = np.zeros(B, dtype=int)
            if pS.shape != (B, B) or pT.shape != (B, B):
                print('Skipping', fpath, ': shape mismatch pS', getattr(pS, 'shape', None))
                continue
            signif = (pS <= alpha) & (pT <= alpha)
            total_src_counts += np.sum(signif, axis=1)
            total_tgt_counts += np.sum(signif, axis=0)
            all_pS.append(pS)
            all_pT.append(pT)
        if band_names is None:
            print(f'No usable .npz files found for alpha={alpha}.')
            continue
        summaries[alpha] = {
            'band_names': band_names,
            'total_src_counts': total_src_counts,
            'total_tgt_counts': total_tgt_counts,
            'all_pS': all_pS,
            'all_pT': all_pT
        }
    if not summaries:
        print('No summaries could be built.')
        return
    # Create figure with 2x2 subplots: top row pies, bottom row stacked bars for top3 sources
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    for idx, alpha in enumerate(alphas):
        if alpha not in summaries:
            continue
        data = summaries[alpha]
        band_names = data['band_names']
        total_src_counts = data['total_src_counts']
        total_tgt_counts = data['total_tgt_counts']
        all_pS = data['all_pS']
        all_pT = data['all_pT']
        # Pie subplot (top row)
        pie_ax = axes[0, idx]
        pie_ax.axis('equal')
        B = len(band_names)
        src = np.asarray(total_src_counts, dtype=float)
        tgt = np.asarray(total_tgt_counts, dtype=float)
        if src.sum() == 0 and tgt.sum() == 0:
            pie_ax.text(0.5, 0.5, 'No significant source-target pairs found', ha='center', va='center', transform=pie_ax.transAxes)
            continue
        base_colors = plt.get_cmap(CMAP_BAND)(np.linspace(0, 1, B))
        light_colors = [tuple(0.65 * c[:3] + 0.35 * np.array([1.0, 1.0, 1.0])) + (1.0,) for c in base_colors]
        dark_colors = [tuple(0.92 * c[:3]) + (1.0,) for c in base_colors]
        src_sizes = src.copy()
        tgt_sizes = tgt.copy()
        if src_sizes.sum() == 0:
            src_sizes = np.ones(B)
        if tgt_sizes.sum() == 0:
            tgt_sizes = np.ones(B)
        # Two rings: inner (targets, dark), outer (sources, light)
        wedges_inner = pie_ax.pie(tgt_sizes, radius=0.9, colors=dark_colors, startangle=90, wedgeprops=dict(width=0.4, edgecolor='white'), autopct='')[0]
        wedges_outer = pie_ax.pie(src_sizes, radius=1.3, colors=light_colors, startangle=90, wedgeprops=dict(width=0.4, edgecolor='white'), autopct='')[0]
        # Add percentage labels with lines for outer, inside for inner
        for wedge in wedges_inner:
            theta = (wedge.theta1 + wedge.theta2) / 2
            r = 0.45  # inside inner ring
            x = r * np.cos(np.deg2rad(theta))
            y = r * np.sin(np.deg2rad(theta))
            pct = (wedge.theta2 - wedge.theta1) / 360 * 100
            if pct > 1:  # only label if >1%
                pie_ax.text(x, y, f'{pct:.1f}%', ha='center', va='center', fontsize=8, color='black')
        for wedge in wedges_outer:
            theta = (wedge.theta1 + wedge.theta2) / 2
            r_mid = 1.1  # midpoint of outer ring
            x_mid = r_mid * np.cos(np.deg2rad(theta))
            y_mid = r_mid * np.sin(np.deg2rad(theta))
            r_label = 1.5  # outside
            x_label = r_label * np.cos(np.deg2rad(theta))
            y_label = r_label * np.sin(np.deg2rad(theta))
            pct = (wedge.theta2 - wedge.theta1) / 360 * 100
            if pct > 1:  # only label if >1%
                pie_ax.annotate(f'{pct:.1f}%', xy=(x_mid, y_mid), xytext=(x_label, y_label), 
                                arrowprops=dict(arrowstyle='-', color='black', lw=0.5, shrinkA=0, shrinkB=0), 
                                ha='center', va='center', fontsize=8)
        # Legend for pie
        from matplotlib.patches import Patch
        legend_patches = [Patch(facecolor=dark_colors[i], edgecolor='black', label=band_names[i]) for i in range(B)]
        pie_ax.legend(handles=legend_patches, bbox_to_anchor=(1.05, 0.8), title='Bands', fontsize=8)
        pie_ax.set_title(f'Alpha={alpha*100:.0f}%: Inner=Targets (dark), Outer=Sources (light)')
        # Bar subplot (bottom row) - horizontal stacked bars for top3 sources' target distributions
        bar_ax = axes[1, idx]
        top3_indices = np.argsort(total_src_counts)[-3:][::-1]  # descending
        y_positions = np.arange(len(top3_indices))
        for bar_idx, i in enumerate(top3_indices):
            tgt_dist = np.zeros(B)
            for link_idx, (ps, pt) in enumerate(zip(all_pS, all_pT)):
                signif = (ps[i, :] <= alpha) & (pt[i, :] <= alpha)
                tgt_dist += signif.astype(int)
            if tgt_dist.sum() == 0:
                tgt_dist = np.ones(B)  # uniform if no data
            # Normalize to sum to 1 for bar
            tgt_dist = tgt_dist / tgt_dist.sum()
            left = 0
            for j in range(B):
                bar_ax.barh(y_positions[bar_idx], tgt_dist[j], left=left, color=dark_colors[j], edgecolor='white', height=0.8)
                left += tgt_dist[j]
        bar_ax.set_yticks(y_positions)
        bar_ax.set_yticklabels([band_names[i] for i in top3_indices])
        bar_ax.set_xlabel('Proportion of Target Bands')
        bar_ax.set_title(f'Top 3 Sources\' Target Distributions (Alpha={alpha*100:.0f}%)')
        bar_ax.legend(handles=legend_patches, bbox_to_anchor=(1.05, 0.8), title='Target Bands', fontsize=8)
    fig.suptitle(f'Drivers of communication between NIZ and SOZ for {patient_name}', fontsize=14)
    plt.tight_layout()
    png_path = os.path.join(out_dir, f'{base}_enhanced_summary.png')
    plt.savefig(png_path, bbox_inches='tight')
    plt.close()
    print('Saved enhanced summary figure:', png_path)
    # Save TXT with percentages table
    txt_summary_path = os.path.join(out_dir, f'{base}_enhanced_summary.txt')
    with open(txt_summary_path, 'w') as f:
        f.write(f'# Enhanced Summary for {patient_name}\n')
        f.write('# Percentages from Pie and Bar Charts\n\n')
        for idx, alpha in enumerate(alphas):
            if alpha not in summaries:
                continue
            data = summaries[alpha]
            band_names = data['band_names']
            total_src_counts = data['total_src_counts']
            total_tgt_counts = data['total_tgt_counts']
            all_pS = data['all_pS']
            all_pT = data['all_pT']
            f.write(f'# Alpha {alpha*100:.0f}%\n')
            f.write('# Pie Chart Percentages\n')
            f.write('Band\tInner (Targets %)\tOuter (Sources %)\n')
            src_total = total_src_counts.sum()
            tgt_total = total_tgt_counts.sum()
            for nm, s, t in zip(band_names, total_src_counts, total_tgt_counts):
                src_pct = (s / src_total * 100) if src_total > 0 else 0
                tgt_pct = (t / tgt_total * 100) if tgt_total > 0 else 0
                f.write(f'{nm}\t{src_pct:.1f}\t{tgt_pct:.1f}\n')
            f.write('\n')
            f.write('# Bar Chart: Top 3 Sources\' Target Distributions\n')
            f.write('Source Band\tTarget Band\tProportion\n')
            top3_indices = np.argsort(total_src_counts)[-3:][::-1]
            for i in top3_indices:
                tgt_dist = np.zeros(len(band_names))
                for link_idx, (ps, pt) in enumerate(zip(all_pS, all_pT)):
                    signif = (ps[i, :] <= alpha) & (pt[i, :] <= alpha)
                    tgt_dist += signif.astype(int)
                if tgt_dist.sum() == 0:
                    tgt_dist = np.ones(len(band_names))
                tgt_dist = tgt_dist / tgt_dist.sum()
                for j, nm_tgt in enumerate(band_names):
                    f.write(f'{band_names[i]}\t{nm_tgt}\t{tgt_dist[j]:.3f}\n')
            f.write('\n')
    print('Saved enhanced summary TXT:', txt_summary_path)
    # Also save CSVs and TXTs for each alpha
    for alpha in alphas:
        if alpha in summaries:
            data = summaries[alpha]
            base_out_file = os.path.join(out_dir, f'{base}_alpha{int(alpha*100):02d}')
            csv_path = base_out_file + '_band_counts.csv'
            with open(csv_path, 'w', newline='') as cf:
                writer = csv.writer(cf)
                writer.writerow(['band', 'total_src_count', 'total_tgt_count'])
                for nm, s, t in zip(data['band_names'], data['total_src_counts'], data['total_tgt_counts']):
                    writer.writerow([nm, int(s), int(t)])
            txt_path = base_out_file + '_summary.txt'
            with open(txt_path, 'w') as fh:
                fh.write(f'# Aggregated significant band counts at alpha={alpha}\n')
                fh.write('# format: band,total_source_count,total_target_count\n')
                for nm, s, t in zip(data['band_names'], data['total_src_counts'], data['total_tgt_counts']):
                    fh.write(f"{nm},{s},{t}\n")
            print(f'Saved alpha={alpha*100:.0f}% files: CSV={csv_path}, TXT={txt_path}')


def meta_analyze(parent_dir):
    import glob
    txt_files = []
    for sub in os.listdir(parent_dir):
        sub_path = os.path.join(parent_dir, sub)
        if os.path.isdir(sub_path):
            txt_file = os.path.join(sub_path, f'{sub}_enhanced_summary.txt')
            if os.path.isfile(txt_file):
                txt_files.append((sub, txt_file))
    if not txt_files:
        print('No enhanced summary TXT files found.')
        return
    # Aggregate data
    all_src_pcts = {}
    all_tgt_pcts = {}
    all_src_dists = {}
    band_names = None
    for patient, txt_file in txt_files:
        with open(txt_file, 'r') as f:
            lines = f.readlines()
        current_alpha = None
        section = None
        pie_header_cols = None
        for line in lines:
            line = line.strip()
            if not line:
                continue
            if line.startswith('# Alpha'):
                try:
                    tok = line.split()[2]
                    current_alpha = float(tok.rstrip('%')) / 100
                except Exception:
                    current_alpha = None
                section = None
                pie_header_cols = None
                continue
            if line.startswith('# Pie Chart Percentages'):
                section = 'pie'
                continue
            if line.startswith('Band\t') and section == 'pie':
                pie_header_cols = line.split('\t')
                continue
            if line.startswith('# Bar Chart'):
                section = 'bar'
                continue
            if line.startswith('Source Band\t') and section == 'bar':
                continue
            # Parse pie rows for alpha=0.01
            if '\t' in line and current_alpha == 0.01 and section == 'pie':
                parts = line.split('\t')
                if len(parts) == 3:
                    band = parts[0]
                    # determine order using header when possible
                    if pie_header_cols and len(pie_header_cols) == 3:
                        col1 = pie_header_cols[1].lower()
                        col2 = pie_header_cols[2].lower()
                        if 'target' in col1 and 'source' in col2:
                            tgt_pct_s = parts[1]
                            src_pct_s = parts[2]
                        elif 'source' in col1 and 'target' in col2:
                            src_pct_s = parts[1]
                            tgt_pct_s = parts[2]
                        else:
                            # fallback to writer-convention: second=source, third=target
                            src_pct_s = parts[1]
                            tgt_pct_s = parts[2]
                    else:
                        src_pct_s = parts[1]
                        tgt_pct_s = parts[2]
                    try:
                        src_pct = float(str(src_pct_s).rstrip('%'))
                        tgt_pct = float(str(tgt_pct_s).rstrip('%'))
                    except Exception:
                        continue
                    if band_names is None:
                        band_names = []
                    if band not in band_names:
                        band_names.append(band)
                    all_src_pcts.setdefault(band, []).append(float(src_pct))
                    all_tgt_pcts.setdefault(band, []).append(float(tgt_pct))
                continue
            # Parse bar rows for alpha=0.01
            if '\t' in line and len(line.split('\t')) == 3 and current_alpha == 0.01 and section == 'bar':
                src, tgt, prop_s = line.split('\t')
                try:
                    prop = float(prop_s)
                except Exception:
                    continue
                if src not in all_src_dists:
                    all_src_dists[src] = {}
                if tgt not in all_src_dists[src]:
                    all_src_dists[src][tgt] = []
                all_src_dists[src][tgt].append(prop)
    if not band_names:
        print('No data to aggregate.')
        return
    # Ensure numeric lists (coerce) and compute medians
    for b in band_names:
        all_src_pcts[b] = [float(x) for x in all_src_pcts.get(b, [])]
        all_tgt_pcts[b] = [float(x) for x in all_tgt_pcts.get(b, [])]
    src_medians = {b: float(np.median(all_src_pcts.get(b, [0.0]))) for b in band_names}
    tgt_medians = {b: float(np.median(all_tgt_pcts.get(b, [0.0]))) for b in band_names}
    # Top 3 sources
    top3_src = sorted(src_medians, key=src_medians.get, reverse=True)[:3]
    # For each top source, top target
    top_targets = {}
    for src in top3_src:
        if src in all_src_dists:
            tgt_props = {tgt: float(np.median(all_src_dists[src].get(tgt, [0.0]))) for tgt in band_names}
            top_targets[src] = max(tgt_props, key=tgt_props.get)
    # Statistical test: Kruskal-Wallis for differences in src pcts across all bands
    from scipy.stats import kruskal
    groups = [all_src_pcts.get(b, []) for b in band_names if all_src_pcts.get(b)]
    if len(groups) > 1:
        try:
            stat, p = kruskal(*groups)
            test_name = f'Kruskal-Wallis H={stat:.2f}, p={p:.3f}'
        except Exception:
            test_name = 'Kruskal-Wallis failed'
            p = 1.0
    else:
        test_name = 'No test (insufficient data)'
        p = 1.0
    # Traditional visualizations: nested pie and horizontal stacked bars (horizontal layout)
    fig, (pie_ax, bar_ax) = plt.subplots(1, 2, figsize=(16, 8))
    B = len(band_names)
    # Nested pie chart (left)
    pie_ax.axis('equal')
    src_vals = np.array([src_medians[b] for b in band_names])
    tgt_vals = np.array([tgt_medians[b] for b in band_names])
    if src_vals.sum() == 0:
        src_vals = np.ones(B)
    if tgt_vals.sum() == 0:
        tgt_vals = np.ones(B)
    # Use vibrant colormap
    cmap = plt.get_cmap('viridis')
    base_colors = cmap(np.linspace(0, 1, B))
    light_colors = [tuple(0.6 * c[:3] + 0.4 * np.array([1.0, 1.0, 1.0])) + (1.0,) for c in base_colors]
    dark_colors = [tuple(0.8 * c[:3]) + (1.0,) for c in base_colors]
    # Inner pie (targets, dark)
    wedges_inner = pie_ax.pie(tgt_vals, radius=0.7, colors=dark_colors, startangle=90, wedgeprops=dict(width=0.3, edgecolor='black', linewidth=1), autopct=lambda pct: f'{pct:.1f}%' if pct > 1 else '', pctdistance=0.8, textprops={'fontsize': 10, 'color': 'white'})
    # Outer pie (sources, light)
    wedges_outer = pie_ax.pie(src_vals, radius=1.0, colors=light_colors, startangle=90, wedgeprops=dict(width=0.3, edgecolor='black', linewidth=1), autopct=lambda pct: f'{pct:.1f}%' if pct > 1 else '', pctdistance=0.85, textprops={'fontsize': 10, 'color': 'black'})
    # Legend
    from matplotlib.patches import Patch
    legend_patches = [Patch(facecolor=dark_colors[i], edgecolor='black', label=f'{band_names[i]} (Target)') for i in range(B)] + [Patch(facecolor=light_colors[i], edgecolor='black', label=f'{band_names[i]} (Source)') for i in range(B)]
    pie_ax.legend(handles=legend_patches, bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=9, title='Bands')
    pie_ax.set_title(f'Median Source & Target % (Nested Pie)\np = {p:.3f}', fontsize=14, fontweight='bold')
    # Horizontal stacked bars (right)
    y_positions = np.arange(len(top3_src))
    bar_ax.set_yticks(y_positions)
    bar_ax.set_yticklabels(top3_src)
    bar_ax.set_xlabel('Proportion of Target Bands', fontsize=12)
    bar_ax.set_title('Top 3 Sources\' Target Distributions', fontsize=14, fontweight='bold')
    for i, src in enumerate(top3_src):
        if src in all_src_dists:
            tgt_props = {tgt: float(np.median(all_src_dists[src].get(tgt, [0.0]))) for tgt in band_names}
            props = [tgt_props.get(b, 0.0) for b in band_names]
            if sum(props) == 0:
                props = [1.0 / B] * B
            left = 0
            for j, prop in enumerate(props):
                bar_ax.barh(i, prop, left=left, color=dark_colors[j], edgecolor='black', height=0.6)
                left += prop
    # Add asterisks for top targets
    for i, src in enumerate(top3_src):
        top_tgt = top_targets.get(src)
        if top_tgt:
            tgt_idx = band_names.index(top_tgt)
            bar_ax.text(left + 0.02, i, '*', ha='left', va='center', fontsize=16, color='red', fontweight='bold')
    bar_ax.legend(handles=[Patch(facecolor=dark_colors[j], edgecolor='black', label=band_names[j]) for j in range(B)] + [plt.Line2D([0], [0], marker='*', color='red', markersize=10, linestyle='None', label='Top Target (p < 0.01)')], bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=9, title='Target Bands')
    # Overall title and test info
    fig.suptitle(f'Meta-Analysis: Transfer Entropy Across Patients\n{test_name}', fontsize=16, fontweight='bold')
    plt.tight_layout()
    meta_png = os.path.join(parent_dir, 'meta_analysis_traditional.png')
    plt.savefig(meta_png, bbox_inches='tight', dpi=300)
    plt.close()
    print('Saved meta-analysis traditional figure:', meta_png)
    # Second figure: Summary table
    fig2, ax2 = plt.subplots(figsize=(12, 6))
    ax2.axis('off')
    # Calculate summary data
    total_links = sum(len(all_src_pcts.get(b, [])) for b in band_names)
    signif_links = sum(1 for b in band_names for pct in all_src_pcts.get(b, []) if pct > 0)
    table_data = [
        ['Metric', 'Value'],
        ['Total Patients', str(len(txt_files))],
        ['Total Links', str(total_links)],
        ['Significant Links', str(signif_links)],
        ['Top Sources', ', '.join(top3_src)],
        ['Top Targets', ', '.join([top_targets.get(s, 'N/A') for s in top3_src])],
        ['Kruskal-Wallis p-value', f'{p:.3f}'],
        ['Significant? (p < 0.01)', 'Yes' if p < 0.01 else 'No']
    ]
    table = ax2.table(cellText=table_data, colLabels=None, cellLoc='center', loc='center', bbox=[0, 0, 1, 1])
    table.auto_set_font_size(False)
    table.set_fontsize(10)
    table.scale(1, 1.5)
    ax2.set_title('Meta-Analysis Summary Table', fontsize=14, fontweight='bold')
    meta_table_png = os.path.join(parent_dir, 'meta_analysis_summary_table.png')
    plt.savefig(meta_table_png, bbox_inches='tight', dpi=300)
    plt.close()
    print('Saved meta-analysis summary table:', meta_table_png)


In [None]:
# Visible Summary Visualisations Launcher (compact)
print('Summary Visualisations Launcher')
print('Choose mode: (1) Single folder, (2) Batch subfolders, (3) Meta-analysis, Q to quit')
mode = input('Mode [1/2/3/Q]: ').strip().lower()
if mode in ('q', 'quit'):
    print('Cancelled.')
else:
    try:
        if mode == '1' or mode == 'single':
            out_dir = input('Enter per-EDF output folder (full path): ').strip()
            if not out_dir:
                print('No folder provided.')
            else:
                process_single_folder(out_dir)
        elif mode == '2' or mode == 'batch':
            parent_dir = input('Enter parent directory containing output folders: ').strip()
            if not parent_dir or not os.path.isdir(parent_dir):
                print('Invalid directory.')
            else:
                for sub in sorted(os.listdir(parent_dir)):
                    sub_path = os.path.join(parent_dir, sub)
                    if os.path.isdir(sub_path):
                        print(f'Processing {sub}...')
                        try:
                            process_single_folder(sub_path)
                        except Exception as e:
                            print(f'Error processing {sub}: {e}')
        elif mode == '3' or mode == 'meta':
            parent_dir = input('Enter parent directory for meta-analysis: ').strip()
            if not parent_dir or not os.path.isdir(parent_dir):
                print('Invalid directory.')
            else:
                meta_analyze(parent_dir)
        else:
            print('Invalid mode.')
    except NameError as ne:
        print('Required function not found in the notebook. Make sure you ran the cell that defines process_single_folder and meta_analyze.')
    except Exception as e:
        import traceback
        traceback.print_exc()
        print('Launcher failed with:', e)
