In [None]:
from typing import Dict, List, Optional, Union

import mne
import numpy as np


def select_epochs_by_annotations(
    epochs: mne.Epochs,
    descriptors: List[str],
    reject_log: Optional[np.ndarray] = None,  # (n_epochs, n_channels) with {0,1,2} or NaN
    *,
    min_overlap: float = 0.5,
    fallback_duration_s: float = 1.0,
    allow_interpolated: bool = True,
    max_interp_frac: float = 0.3,
    max_bad_chns: int = 3,
    min_run: int = 1,
    case_sensitive: bool = False,
    return_indices: bool = False,
) -> Dict[str, Union[np.ndarray, mne.Epochs]]:
    """
    Epochs-only selection using epochs.annotations for interval overlap + a per-epoch
    quality filter from reject_log. Any NaN in reject_log is treated as GOOD (0).
    """
    n_epochs = len(epochs)
    if n_epochs == 0:
        empty = np.array([], dtype=int) if return_indices else epochs[:0]
        return {d: empty for d in descriptors}

    sf = float(epochs.info["sfreq"])

    # ---------- 1) Usable-epoch mask from reject_log (NaN -> good) ----------
    usable = np.ones(n_epochs, dtype=bool)
    if reject_log is not None:
        if reject_log.ndim != 2 or reject_log.shape[0] != n_epochs:
            raise ValueError("reject_log must be (n_epochs, n_channels) with values {0,1,2} or NaN.")
        # Treat NaN as good (0)
        rlog = np.nan_to_num(reject_log.astype(float), nan=0.0)
        n_ch = rlog.shape[1]
        bad_counts = (rlog == 1.0).sum(axis=1)
        interp_frac = (rlog == 2.0).sum(axis=1) / n_ch
        usable &= bad_counts <= int(max_bad_chns)
        if allow_interpolated:
            usable &= interp_frac <= float(max_interp_frac)
        else:
            usable &= (rlog == 2.0).sum(axis=1) == 0

    # ---------- 2) Build & merge annotation intervals per descriptor ----------
    queries = [d if case_sensitive else str(d).lower() for d in descriptors]
    buckets = {d: [] for d in descriptors}

    def _merge(iv):
        if not iv:
            return []
        iv = sorted(iv)
        out = [list(iv[0])]
        for s, e in iv[1:]:
            if s <= out[-1][1]:
                out[-1][1] = max(out[-1][1], e)
            else:
                out.append([s, e])
        return [(s, e) for s, e in out]

    if epochs.annotations is not None and len(epochs.annotations) > 0:
        for ann in epochs.annotations:
            desc = ann["description"] or ""
            if not case_sensitive:
                desc = desc.lower()
            onset_s = float(ann["onset"])
            dur_s = float(ann["duration"]) or 0.0
            if dur_s <= 0:
                dur_s = float(fallback_duration_s)
            s = int(round(onset_s * sf))
            e = int(round((onset_s + dur_s) * sf))
            if e <= s:
                e = s + 1
            for q, key in zip(queries, descriptors):
                if q in desc:
                    buckets[key].append((s, e))

    for k in buckets:
        buckets[k] = _merge(buckets[k])

    # ---------- 3) Epoch bounds (absolute samples) ----------
    ev = epochs.events[:, 0].astype(int)
    ep_s = ev + int(round(epochs.tmin * sf))
    ep_e = ev + int(round(epochs.tmax * sf))  # exclusive
    ep_len = (ep_e - ep_s).astype(int)
    need = np.maximum(0, (min_overlap * ep_len).astype(int))

    # ---------- 4) Overlap test ----------
    out_idxs = {k: [] for k in descriptors}
    for i, (s0, e0) in enumerate(zip(ep_s, ep_e)):
        if not usable[i] or e0 <= s0:
            continue
        req = int(need[i])
        for name, ivs in buckets.items():
            hit = any((min(e0, e1) - max(s0, s1)) >= req for (s1, e1) in ivs)
            if hit:
                out_idxs[name].append(i)

    # ---------- 5) Consecutive run filter ----------
    def _keep_runs(idxs: List[int], kmin: int) -> List[int]:
        if kmin <= 1 or not idxs:
            return idxs
        idxs = np.array(sorted(set(idxs)), dtype=int)
        cuts = np.where(np.diff(idxs) != 1)[0] + 1
        starts = np.r_[0, cuts]
        ends = np.r_[cuts, len(idxs)]
        keep = []
        for a, b in zip(starts, ends):
            run = idxs[a:b]
            if len(run) >= kmin:
                keep.extend(run.tolist())
        return keep

    if min_run > 1:
        for k in out_idxs:
            out_idxs[k] = _keep_runs(out_idxs[k], int(min_run))

    # ---------- 6) Return ----------
    if return_indices:
        return {k: np.asarray(v, dtype=int) for k, v in out_idxs.items()}
    else:
        return {k: epochs[v] for k, v in out_idxs.items()}

In [None]:
from pathlib import Path

import mne
import numpy as np

from mushroom_hyperscanning.data import load_eeg

SUBJECT = "01"
CEREMONY = "ceremony1"
BIDS_ROOT = "../data/004_autoreject"

raw = load_eeg(SUBJECT, CEREMONY, root=BIDS_ROOT, preload=True)
# raw.crop(tmax=60 * 20)  # use 20 minutes of data to test

# Apply a band-pass filter to EEG (e.g. 0.3–45 Hz)
raw.filter(l_freq=1, h_freq=60, picks="eeg", phase="zero")

# Optionally also remove line noise (60 Hz harmonics for Québec/Canada)
raw.notch_filter(freqs=[60], picks="eeg", phase="zero")

p = Path(raw.filenames[0]).resolve()
epochs_path = p.parent / str(p.name).replace("eeg", "epochs")
epochs = mne.read_epochs(epochs_path, preload=True)

rejectlog_path = p.parent / str(p.name).replace("eeg.fif", "rejectlog.npy")
rejectlog = np.load(rejectlog_path)

In [None]:
var1 = "eyes closed"
var2 = "eyes open"
# Descriptors to look for inside annotation descriptions (substring match)
desc_list = [var1, var2]

selected = select_epochs_by_annotations(
    epochs=epochs,
    descriptors=desc_list,
    reject_log=rejectlog,
)

epochs_1 = selected[var1]
epochs_2 = selected[var2]
selected

In [None]:
import matplotlib.pyplot as plt
import mne
import numpy as np

# ---------------------------------------------------------------------
# Inputs: epochs_1 (control) and epochs_2 (song) must already exist
# ---------------------------------------------------------------------

# Config
FMIN_ALL, FMAX_ALL = 1.0, 45.0
N_JOBS = 1  # set >1 if you want
BANDS = {
    "delta (1–4 Hz)": (1.0, 4.0),
    "theta (4–8 Hz)": (4.0, 8.0),
    "alpha (8–13 Hz)": (8.0, 13.0),
    "beta (13–30 Hz)": (13.0, 30.0),
    "gamma (30–45 Hz)": (30.0, 45.0),
}

# Ensure a montage (does nothing if already set)
montage = mne.channels.make_standard_montage("standard_1020")
for ep in (epochs_1, epochs_2):
    try:
        ep.set_montage(montage, match_case=False)
    except Exception:
        pass

# ---------------------------------------------------------------------
# Make both Epochs have the same EEG channels in the same order
# ---------------------------------------------------------------------
ep1_eeg = epochs_1.copy().pick(picks="eeg")
ep2_eeg = epochs_2.copy().pick(picks="eeg")

common_eeg = [ch for ch in ep1_eeg.ch_names if ch in set(ep2_eeg.ch_names)]
if not common_eeg:
    raise ValueError("No common EEG channels between the two Epochs.")
epochs_1 = ep1_eeg.copy().pick(picks=common_eeg)
epochs_2 = ep2_eeg.copy().pick(picks=common_eeg)

conds = {"cond1": epochs_1, "cond2": epochs_2}
cond_order = list(conds.keys())  # contrast will be cond_order[1] - cond_order[0] (relative %)

# ---------------------------------------------------------------------
# PSD (Welch) once per condition on 1–45 Hz, in µV^2/Hz
# ---------------------------------------------------------------------
psd_data = {}  # cond -> (psds_uV2_perHz, freqs, info)
for name, ep in conds.items():
    psd = ep.compute_psd(method="welch", fmin=FMIN_ALL, fmax=FMAX_ALL, n_overlap=0, n_jobs=N_JOBS)
    P, f = psd.get_data(return_freqs=True)  # (n_epochs, n_ch, n_f), (n_f,)
    P = P * 1e12  # V^2/Hz -> µV^2/Hz
    psd_data[name] = (P, f, ep.info)

# ---------------------------------------------------------------------
# Relative band power (band / total 1–45 Hz) in percent
# ---------------------------------------------------------------------
rel_maps = {b: {} for b in BANDS}  # band -> cond -> (n_ch,)
rel_vlims = {}  # band -> (vmin, vmax)

for band_name, (fmin, fmax) in BANDS.items():
    all_vals = []
    for name, (P, f, info) in psd_data.items():
        fi = np.where((f >= fmin) & (f <= fmax))[0]
        if fi.size < 2:
            raise RuntimeError(f"Insufficient frequency bins for {band_name}")

        band_power = np.trapz(P[:, :, fi], f[fi], axis=2).mean(axis=0)  # µV^2, per ch
        total_power = np.trapz(P, f, axis=2).mean(axis=0) + np.finfo(float).eps
        rel_pct = 100.0 * (band_power / total_power)  # %
        rel_maps[band_name][name] = rel_pct
        all_vals.append(rel_pct)

    all_vals = np.concatenate(all_vals)
    vmin = float(np.percentile(all_vals, 5))
    vmax = float(np.percentile(all_vals, 95))
    if vmin == vmax:
        vmax = vmin + 1e-6
    rel_vlims[band_name] = (vmin, vmax)

# ---------------------------------------------------------------------
# Contrast (relative change): (cond2 - cond1) / cond1 * 100 (%)
# ---------------------------------------------------------------------
contrast_maps = {}  # band -> (n_ch,)
for band_name in BANDS.keys():
    b = rel_maps[band_name]
    if len(cond_order) >= 2:
        baseline = b[cond_order[0]]
        contrast_maps[band_name] = 100.0 * ((b[cond_order[1]] - baseline) / (baseline + np.finfo(float).eps))
    else:
        contrast_maps[band_name] = np.zeros_like(next(iter(b.values())))

# Symmetric limits per band for contrast (robust)
contrast_vlims = {}
for band_name, vals in contrast_maps.items():
    a5, a95 = np.percentile(vals, [5, 95])
    m = float(max(abs(a5), abs(a95)))
    contrast_vlims[band_name] = (-m, m) if m > 0 else (-1e-6, 1e-6)

# ---------------------------------------------------------------------
# Plot: rows = [control (rel %), song (rel %), contrast (relative % change)]
# ---------------------------------------------------------------------
n_rows = 3 if len(conds) >= 2 else len(conds)
fig, axes = plt.subplots(n_rows, len(BANDS), figsize=(4 * len(BANDS), 3.6 * n_rows), squeeze=False)

# Condition rows (relative %)
for r, cname in enumerate(cond_order):
    info = psd_data[cname][2]
    for c, (band_name, _) in enumerate(BANDS.items()):
        data = rel_maps[band_name][cname]
        vmin, vmax = rel_vlims[band_name]
        ax = axes[r, c]
        im, _ = mne.viz.plot_topomap(data, info, axes=ax, show=False, contours=0, cmap="viridis", vlim=(vmin, vmax))
        if r == 0:
            ax.set_title(band_name, fontsize=11)
        ax.set_xlabel(f"{cname} (rel. %)", fontsize=9)
        cb = plt.colorbar(im, ax=ax, shrink=0.7)
        cb.set_label("Relative power (%)", rotation=90, fontsize=9)

# Contrast row (relative % change)
if len(conds) >= 2:
    info = psd_data[cond_order[0]][2]  # same channel layout/order
    r = 2
    for c, (band_name, _) in enumerate(BANDS.items()):
        data = contrast_maps[band_name]
        vmin, vmax = contrast_vlims[band_name]
        ax = axes[r, c]
        im, _ = mne.viz.plot_topomap(data, info, axes=ax, show=False, contours=0, cmap="RdBu_r", vlim=(vmin, vmax))
        ax.set_xlabel(f"{cond_order[1]} vs {cond_order[0]} (rel. %)", fontsize=9)
        cb = plt.colorbar(im, ax=ax, shrink=0.7)
        cb.set_label("Δ relative power (%)", rotation=90, fontsize=9)

plt.tight_layout()
plt.show()

## Connectivity

In [None]:
import mne
from mne_bids import BIDSPath

BIDS_ROOT = "../data/004_autoreject_15min"
CEREMONY = "ceremony1"


def get_epochs(subject, ceremony, condition, root):
    paths = BIDSPath(subject=subject, session=ceremony, task="psilo", datatype="eeg", root=root).match()
    paths = [p.fpath for p in paths if "eeg" in p.fpath.name]
    assert len(paths) == 1, f"Didn't get one path, got {len(paths)}"

    p = paths[0].resolve()
    epochs_path = p.parent / str(p.name).replace("eeg", "epochs")
    epochs = mne.read_epochs(epochs_path, preload=True)

    rejectlog_path = p.parent / str(p.name).replace("eeg.fif", "rejectlog.npy")
    rejectlog = np.load(rejectlog_path)

    selected = select_epochs_by_annotations(
        epochs=epochs,
        descriptors=[condition],
        reject_log=rejectlog,
    )
    return selected[condition]


# subject 1
epo1_cond1 = get_epochs("01", CEREMONY, "eyes open", BIDS_ROOT)
epo1_cond2 = get_epochs("01", CEREMONY, "eyes closed", BIDS_ROOT)

# subject 4
epo2_cond1 = get_epochs("03", CEREMONY, "eyes open", BIDS_ROOT)
epo2_cond2 = get_epochs("03", CEREMONY, "eyes closed", BIDS_ROOT)

epo1_cond1, epo1_cond2, epo2_cond1, epo2_cond2

In [None]:
def match_epochs_by_index(epo1_cond1, epo1_cond2, epo2_cond1, epo2_cond2):
    """
    Find common epoch indices across all conditions and participants
    """
    # Get event IDs for each condition/participant
    events1_cond1 = set(epo1_cond1.events[:, 2])
    events1_cond2 = set(epo1_cond2.events[:, 2])
    events2_cond1 = set(epo2_cond1.events[:, 2])
    events2_cond2 = set(epo2_cond2.events[:, 2])

    # Find common indices across all
    common_indices = events1_cond1.intersection(events1_cond2, events2_cond1, events2_cond2)

    return list(common_indices)


# Step 2: Extract epochs with matching indices
common_indices = match_epochs_by_index(epo1_cond1, epo1_cond2, epo2_cond1, epo2_cond2)

# Select epochs with matching indices for each condition
epo1_cond1_matched = epo1_cond1[np.isin(epo1_cond1.events[:, 2], common_indices)]
epo1_cond2_matched = epo1_cond2[np.isin(epo1_cond2.events[:, 2], common_indices)]
epo2_cond1_matched = epo2_cond1[np.isin(epo2_cond1.events[:, 2], common_indices)]
epo2_cond2_matched = epo2_cond2[np.isin(epo2_cond2.events[:, 2], common_indices)]

# Step 3: Equalize epoch counts (as shown in tutorial)
mne.epochs.equalize_epoch_counts([epo1_cond1_matched, epo2_cond1_matched])
mne.epochs.equalize_epoch_counts([epo1_cond2_matched, epo2_cond2_matched])

In [None]:
import mne
import numpy as np
from hypyp import analyses, stats, viz

# --- 1) Helpers ---------------------------------------------------------------


def align_epochs_for_hypyp(epo_a: mne.Epochs, epo_b: mne.Epochs, drop_bads=True):
    """Make two Epochs objects shape-compatible: same channels (order), same n_epochs."""
    # 1) keep only common EEG channels and enforce identical ordering
    common = [ch for ch in epo_a.ch_names if ch in epo_b.ch_names]
    if not common:
        raise RuntimeError("No common channels between the two Epochs.")
    a = epo_a.copy().pick_channels(common, ordered=True)
    b = epo_b.copy().pick_channels(common, ordered=True)

    # 2) drop union of bads (optional)
    if drop_bads:
        bads = list(set(a.info.get("bads", [])) | set(b.info.get("bads", [])))
        if bads:
            a.drop_channels(bads, on_missing="ignore")
            b.drop_channels(bads, on_missing="ignore")

    # 3) sanity: same sfreq & n_times
    if a.info["sfreq"] != b.info["sfreq"]:
        raise RuntimeError(f"Different sfreq: {a.info['sfreq']} vs {b.info['sfreq']}")
    if a.get_data().shape[-1] != b.get_data().shape[-1]:
        raise RuntimeError("Different n_times between subjects. Check epoching/cropping.")

    # 4) equalize epoch counts IN PLACE
    mne.epochs.equalize_epoch_counts([a, b], method="truncate")
    return a, b


def compute_interbrain_connectivity(epo_a, epo_b, freq_bands, mode="ccorr", epochs_average=True):
    """
    Returns connectivity per band with shape:
      result: (n_bands, 2*n_ch, 2*n_ch)
      inter-brain slice for band bi: result[bi, 0:n_ch, n_ch:2*n_ch]
    """
    # Ensure aligned
    a, b = align_epochs_for_hypyp(epo_a, epo_b, drop_bads=True)
    sfreq = a.info["sfreq"]
    n_ch = len(a.ch_names)

    # HyPyP expects data shaped (2, n_epochs, n_channels, n_times)
    data_inter = [a.get_data(), b.get_data()]  # list is fine

    # High-level wrapper does TF + connectivity
    con = analyses.pair_connectivity(
        data_inter, sampling_rate=sfreq, frequencies=freq_bands, mode=mode, epochs_average=epochs_average
    )
    # con shape: (n_freqs_or_bands, 2*n_ch, 2*n_ch)
    # Extract only inter-brain block for each band
    C_inter = con[:, 0:n_ch, n_ch : 2 * n_ch]
    return C_inter, a, b  # return aligned epochs too for viz


def compute_interbrain_connectivity_directional(epo_a, epo_b, freq_bands, measure="pdc"):
    """
    Compute directional inter-brain connectivity using MVAR modeling
    """
    # Ensure aligned
    a, b = align_epochs_for_hypyp(epo_a, epo_b, drop_bads=True)
    sfreq = a.info["sfreq"]
    n_ch = len(a.ch_names)

    # Prepare data for MVAR analysis
    data_inter = [a.get_data(), b.get_data()]

    # Compute analytic signal in frequency bands
    complex_signal = analyses.compute_freq_bands(
        data_inter, sfreq, freq_bands, filter_length="auto", l_trans_bandwidth="auto", h_trans_bandwidth="auto"
    )

    # MVAR parameters
    mvar_params = {"mvar_order": 5, "fitting_method": "default", "delta": 0}

    # ICA parameters
    ica_params = {"method": "infomax_extended", "random_state": 42}

    # Connectivity measure parameters
    measure_params = {"name": measure, "n_fft": 512}  # 'pdc' or 'dtf'

    # Compute directional connectivity
    conn = analyses.compute_conn_mvar(complex_signal, mvar_params, ica_params, measure_params, check_stability=True)

    return conn, a, b


# --- 2) Prepare your matched epochs (use your own matched objects) ------------
# You already built:
#   epo1_cond1_matched, epo1_cond2_matched, epo2_cond1_matched, epo2_cond2_matched
# IMPORTANT: your earlier match-by-event-id doesn’t work with fixed-length events
# (all event IDs are identical). We rely on equalize_epoch_counts inside aligner.

In [None]:
freq_bands = {
    "Alpha-Low": [7.5, 11],
    "Alpha-High": [11.5, 13],
    "Theta": [4, 7.5],
    "Delta": [1, 4],
    "Low-Beta": [13, 20],
    "High-Beta": [20, 30],
    "Gamma1": [30, 45],
}
mode = "imaginary_coh"
# Condition 1 (e.g., control)
C_cond1, epo1_c1_aligned, epo2_c1_aligned = compute_interbrain_connectivity(
    epo1_cond1_matched, epo2_cond1_matched, freq_bands, mode=mode, epochs_average=True
)
# Condition 2 (e.g., limpia)
C_cond2, epo1_c2_aligned, epo2_c2_aligned = compute_interbrain_connectivity(
    epo1_cond2_matched, epo2_cond2_matched, freq_bands, mode=mode, epochs_average=True
)

In [None]:
freq_bands = {
    "Alpha-Low": [7.5, 11],
    "Alpha-High": [11.5, 13],
    "Theta": [4, 7.5],
    "Delta": [1, 4],
    "Low-Beta": [13, 20],
    "High-Beta": [20, 30],
    "Gamma1": [30, 45],
}
# Use directional connectivity instead
measure = "pdc"  # or 'dtf'

# Condition 1 (e.g., control)
C_cond1, epo1_c1_aligned, epo2_c1_aligned = compute_interbrain_connectivity_directional(
    epo1_cond1_matched, epo2_cond1_matched, freq_bands, measure=measure
)

# Condition 2 (e.g., limpia)
C_cond2, epo1_c2_aligned, epo2_c2_aligned = compute_interbrain_connectivity_directional(
    epo1_cond2_matched, epo2_cond2_matched, freq_bands, measure=measure
)

In [None]:
print(f"C1 range: {np.min(C1):.3f} to {np.max(C1):.3f}")
print(f"C1 mean: {np.mean(C1):.3f}, std: {np.std(C1):.3f}")

In [None]:
# ---- Select a band for visualization ----

# C_condX has shape (2 bands, n_ch, n_ch). Pick a band for visualization:
band_names = list(freq_bands.keys())
band_idx = band_names.index("Low-Beta")  # or 'Alpha-High'
C1 = C_cond1[band_idx]
C2 = C_cond2[band_idx]

# --- 3) Visualize -------------------------------------------------------------
# Use the *aligned* epochs returned above
thresh = "auto"
print("Condition 1 – 2D:")
viz.viz_2D_topomap_inter(epo1_c1_aligned, epo2_c1_aligned, C1, threshold=thresh, steps=10, lab=True)
print("Condition 1 – 3D:")
viz.viz_3D_inter(epo1_c1_aligned, epo2_c1_aligned, C1, threshold=thresh, steps=10, lab=False)

print("Condition 2 – 2D:")
viz.viz_2D_topomap_inter(epo1_c2_aligned, epo2_c2_aligned, C2, threshold=thresh, steps=10, lab=True)
print("Condition 2 – 3D:")
viz.viz_3D_inter(epo1_c2_aligned, epo2_c2_aligned, C2, threshold=thresh, steps=10, lab=False)

# --- 4) (Optional) cluster stats scaffold ------------------------------------
# Cluster stats need multiple observations (e.g., multiple dyads or sessions).
# With a single dyad, permutation tests are not meaningful.
# The snippet below guards against that and shows the expected shapes.

do_stats = False  # set True only if you have >= ~10 observations per condition

if do_stats:
    # Build channel–frequency connectivity (alpha-low only here)
    # Flatten n_ch x n_ch to vector per observation if you keep epochs_average=False upstream.
    # With epochs_average=True, you already have one matrix per dyad/condition.
    # Example uses spatial+frequency connectivity helper:
    from hypyp.stats import (
        con_matrix,
        indices_connectivity_interbrain,
        metaconn_matrix_2brains,
    )

    # Create a merged Epochs to get inter-brain indices for metaconn (any aligned pair works)
    merged = mne.concatenate_epochs([epo1_c1_aligned.copy(), epo2_c1_aligned.copy()])
    inter_pairs = analyses.indices_connectivity_interbrain(merged)

    alpha_freqs = np.arange(8, 13)
    conn = stats.con_matrix(epo1_c1_aligned, freqs_mean=alpha_freqs)
    meta = stats.metaconn_matrix_2brains(inter_pairs, conn.ch_con, freqs_mean=alpha_freqs)

    # data_for_stats must be [array(shape=(n_obs, n_features)), array(...)]
    # You would stack *many* dyads/session entries here.
    # stats.statscondCluster(data, freqs_mean=alpha_freqs, ch_con_freq=meta.metaconn_freq,
    #                        tail=0, n_permutations=5000, alpha=0.05)
    pass  # You would stack *many* dyads/session entries here.
    # stats.statscondCluster(data, freqs_mean=alpha_freqs, ch_con_freq=meta.metaconn_freq,
    #                        tail=0, n_permutations=5000, alpha=0.05)
    pass

## Directional