# EEG Memory Recognition Analysis - ERP Analysis

This notebook implements epoching, artifact rejection, and Event-Related Potential (ERP) analysis.

## Analysis Goals
Based on Delorme et al. (2018), we analyze:

1. **Familiarity Effect**: ERP differences between familiar vs. novel images
2. **Repetition Effect**: How recognition ERPs change across repeated exposures
3. **Category Effect**: Animal vs. non-animal stimulus processing differences

## ROI Analysis
- **Frontal ROI**: F3, FZ, F4 - Early recognition components (N200, P300)
- **Parietal ROI**: P3, PZ, P4 - Late recognition components (P600, LPC)

## Statistical Analysis
- T-tests with FDR correction for multiple comparisons
- Time-window analysis for significant effects
- Visualization matching the original study

## 1. Setup and Imports

In [4]:
import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import mne
import yaml
import logging
from tqdm import tqdm
from pathlib import Path
from scipy import stats
import warnings
warnings.filterwarnings('ignore')

notebook_dir = Path.cwd()
project_root = (notebook_dir / "..").resolve()
src_dir = project_root / "src"

if str(src_dir) not in sys.path:
    sys.path.insert(0, str(src_dir))

try:
    from utils.pathing import ensure_src_on_path, project_paths
    ensure_src_on_path()
    from utils.data_loader import EEGDataLoader
    from preprocessing.quality_assessment import EEGQualityAssessment
    print("‚úÖ Imports successful!")
except ImportError as e:
    print(f"‚ö†Ô∏è Import note: {e}")

print(f"Project root: {project_root}")

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 12
sns.set_style("whitegrid")

print("‚úì Setup complete")
print(f"MNE version: {mne.__version__}")

‚úÖ Imports successful!
Project root: /Users/leeyelim/Documents/EEG
‚úì Setup complete
MNE version: 1.8.0


## 2. Load Configuration

In [5]:
config_path = project_root / 'config' / 'analysis_config.yaml'
with open(config_path, 'r') as f:
    config = yaml.safe_load(f)

selected_subjects = config['subjects']['selected']
manual_ica_subject = config['subjects']['manual_ica_subject']

print("‚úÖ Configuration loaded!")
print(f"- Selected subjects: {len(selected_subjects)}")
print(f"- Epoching: {config['preprocessing']['epoching']['tmin']} to {config['preprocessing']['epoching']['tmax']} s")
print(f"- ROI channels: {config['erp_analysis']['roi']}")

data_loader = EEGDataLoader(config_path=str(config_path))
print("\n‚úì Data loader initialized")

2025-11-18 21:43:00,847 - INFO - EEGDataLoader initialized
  Project root: /Users/leeyelim/Documents/EEG
  Config: /Users/leeyelim/Documents/EEG/config/analysis_config.yaml
  Raw dir: /Users/leeyelim/Documents/EEG/ds002680 (exists=True)
  Preprocessed dir: /Users/leeyelim/Documents/EEG/data/preprocessed (exists=True)
  Derivatives dir: /Users/leeyelim/Documents/EEG/data/derivatives (exists=True)


‚úÖ Configuration loaded!
- Selected subjects: 10
- Epoching: -0.1 to 0.6 s
- ROI channels: {'frontal': ['FP1', 'FP2'], 'parieto_occipital': ['P3', 'P3"', 'P4', 'P4"', 'PZ', 'PZ"', 'CZ']}

‚úì Data loader initialized


## 3. Load ICA-Cleaned Data

Load the ICA-cleaned data for the manual subject to demonstrate ERP analysis.

In [6]:
# Load ICA-cleaned data
print(f"\nüî¨ Loading ICA-cleaned data for {manual_ica_subject}...")

ica_dir = project_root / 'data' / 'preprocessed' / 'after_ica'
subject_dir = ica_dir / manual_ica_subject

if subject_dir.exists():
    fif_files = list(subject_dir.rglob('*ica_cleaned.fif'))
    if fif_files:
        raw_file = fif_files[0]
        print(f"Loading: {raw_file}")
        raw = mne.io.read_raw_fif(str(raw_file), preload=True, verbose=False)
        print(f"‚úÖ Loaded: {raw.info['nchan']} channels, {raw.times[-1]:.1f}s")
        
        # Get session and run info from filename
        parts = raw_file.stem.split('_')
        session = parts[1]
        run = parts[3]
        print(f"Session: {session}, Run: {run}")
    else:
        print("‚ùå No ICA-cleaned files found")
        print("Please run 02_manual_ica_review.ipynb first")
        raw = None
else:
    print(f"‚ùå Directory not found: {subject_dir}")
    print("Please run 01 and 02 notebooks first")    raw = None



üî¨ Loading ICA-cleaned data for sub-003...
Loading: /Users/leeyelim/Documents/EEG/data/preprocessed/after_ica/sub-003/ses-02/sub-003_ses-02_run-3_preprocessed_ica_cleaned.fif
‚úÖ Loaded: 31 channels, 209.3s
Session: ses-02, Run: preprocessed


## 4. Load Events

In [9]:
if raw is not None:
    # Load events for this session/run
    print("\nüìÖ Loading events...")
    try:
        events, event_id = data_loader.load_events(
            manual_ica_subject, 
            session=session, 
            run=run, 
            task='gonogo'
        )
        
        if events is not None:
            print(f"‚úÖ Loaded {len(events)} events")
            print(f"Event types: {event_id}")
        else:
            print("‚ö†Ô∏è No events found")
    except Exception as e:
        print(f"‚ùå Error loading events: {e}")
        events = None
        event_id = None
else:
    print("‚ö†Ô∏è Skipping: No data loaded")
    events = None
    event_id = None


üìÖ Loading events...
‚ùå Error loading events: Could not find EEG file for sub-003 with session=ses-02 run=preprocessed. Found 25 candidate .set files under /Users/leeyelim/Documents/EEG/ds002680/sub-003. Examples: ['ses-02/eeg/sub-003_ses-02_task-gonogo_run-5_eeg.set', 'ses-02/eeg/sub-003_ses-02_task-gonogo_run-11_eeg.set', 'ses-02/eeg/sub-003_ses-02_task-gonogo_run-4_eeg.set', 'ses-02/eeg/sub-003_ses-02_task-gonogo_run-10_eeg.set', 'ses-02/eeg/sub-003_ses-02_task-gonogo_run-12_eeg.set']


## 5. Create Epochs

Epoch the data from -100 to 600 ms around stimulus onset.

In [None]:
if raw is not None and events is not None:
    # Create epochs
    print("\nüîÑ Creating epochs...")
    
    tmin = config['preprocessing']['epoching']['tmin']
    tmax = config['preprocessing']['epoching']['tmax']
    baseline = tuple(config['preprocessing']['epoching']['baseline'])
    
    epochs = mne.Epochs(
        raw, events, event_id=event_id,
        tmin=tmin, tmax=tmax,
        baseline=baseline,
        preload=True,
        reject=dict(eeg=100e-6),  # Reject epochs with amplitude > 100 ¬µV
        verbose=False
    )
    
    print(f"‚úÖ Created {len(epochs)} epochs")
    print(f"  Time range: {tmin}s to {tmax}s")
    print(f"  Baseline: {baseline}")
    
    # Display event counts
    print("\nüìä Epochs per condition:")
    for condition in epochs.event_id.keys():
        n_epochs = len(epochs[condition])
        print(f"  {condition}: {n_epochs} epochs")
else:
    print("‚ö†Ô∏è Skipping: No data or events loaded")
    epochs = None

## 6. Compute ERPs

In [None]:
if epochs is not None:
    # Compute ERPs (evoked responses) for each condition
    print("\nüîÑ Computing ERPs...")
    
    erps = {}
    for condition in epochs.event_id.keys():
        if condition in epochs:
            erp = epochs[condition].average()
            erps[condition] = erp
            print(f"  ‚úÖ {condition}: {len(epochs[condition])} trials averaged")
    
    print(f"\n‚úÖ Computed ERPs for {len(erps)} conditions")
else:
    print("‚ö†Ô∏è Skipping: No epochs available")
    erps = {}

## 7. Visualize ERPs

In [None]:
if erps:
    # Plot ERPs for all conditions
    print("\nüìä Visualizing ERPs...")
    
    # Get ROI channels
    roi_channels = config['erp_analysis']['roi']
    frontal_chs = roi_channels['frontal']
    parietal_chs = roi_channels['parieto_occipital']
    
    # Create figure
    fig, axes = plt.subplots(2, 2, figsize=(16, 10))
    fig.suptitle(f'Event-Related Potentials - {manual_ica_subject}', 
                 fontsize=16, fontweight='bold')
    
    # Plot 1: All conditions, frontal ROI
    ax = axes[0, 0]
    times = erps[list(erps.keys())[0]].times
    colors = ['blue', 'red', 'green', 'orange']
    
    for i, (condition, erp) in enumerate(erps.items()):
        # Pick frontal channels
        try:
            frontal_data = erp.copy().pick_channels(frontal_chs).get_data()
            frontal_mean = np.mean(frontal_data, axis=0) * 1e6  # Convert to ¬µV
            ax.plot(times, frontal_mean, color=colors[i % len(colors)], 
                   linewidth=2, label=condition)
        except:
            print(f"  ‚ö†Ô∏è Could not plot {condition} for frontal ROI")
    
    ax.set_xlabel('Time (s)')
    ax.set_ylabel('Amplitude (¬µV)')
    ax.set_title('Frontal ROI')
    ax.legend()
    ax.grid(True, alpha=0.3)
    ax.axhline(y=0, color='black', linestyle='-', alpha=0.3)
    ax.axvline(x=0, color='black', linestyle='--', alpha=0.5, label='Stimulus')
    
    # Plot 2: All conditions, parietal ROI
    ax = axes[0, 1]
    for i, (condition, erp) in enumerate(erps.items()):
        try:
            parietal_data = erp.copy().pick_channels(parietal_chs).get_data()
            parietal_mean = np.mean(parietal_data, axis=0) * 1e6
            ax.plot(times, parietal_mean, color=colors[i % len(colors)], 
                   linewidth=2, label=condition)
        except:
            print(f"  ‚ö†Ô∏è Could not plot {condition} for parietal ROI")
    
    ax.set_xlabel('Time (s)')
    ax.set_ylabel('Amplitude (¬µV)')
    ax.set_title('Parietal ROI')
    ax.legend()
    ax.grid(True, alpha=0.3)
    ax.axhline(y=0, color='black', linestyle='-', alpha=0.3)
    ax.axvline(x=0, color='black', linestyle='--', alpha=0.5)
    
    # Plot 3: Topographic map at peak (example)
    ax = axes[1, 0]
    if len(erps) > 0:
        first_erp = list(erps.values())[0]
        peak_time = 0.3  # Example: 300ms
        try:
            first_erp.plot_topomap(times=[peak_time], axes=ax, show=False, 
                                  colorbar=True, time_format='%0.3f s')
            ax.set_title(f'Topography at {peak_time}s')
        except:
            ax.text(0.5, 0.5, 'Topography visualization\nrequires standard montage', 
                   ha='center', va='center', transform=ax.transAxes)
    
    # Plot 4: Difference wave (if applicable)
    ax = axes[1, 1]
    # Example: Familiar - New difference
    familiar_conds = [k for k in erps.keys() if 'familiar' in k.lower()]
    new_conds = [k for k in erps.keys() if 'new' in k.lower()]
    
    if familiar_conds and new_conds:
        try:
            fam_data = erps[familiar_conds[0]].copy().pick_channels(parietal_chs).get_data()
            new_data = erps[new_conds[0]].copy().pick_channels(parietal_chs).get_data()
            diff = (np.mean(fam_data, axis=0) - np.mean(new_data, axis=0)) * 1e6
            ax.plot(times, diff, 'purple', linewidth=2, label='Familiar - New')
            ax.fill_between(times, 0, diff, alpha=0.3, color='purple')
        except:
            pass
    
    ax.set_xlabel('Time (s)')
    ax.set_ylabel('Amplitude Difference (¬µV)')
    ax.set_title('Difference Wave (Parietal ROI)')
    ax.legend()
    ax.grid(True, alpha=0.3)
    ax.axhline(y=0, color='black', linestyle='-', alpha=0.3)
    ax.axvline(x=0, color='black', linestyle='--', alpha=0.5)
    
    plt.tight_layout()
    
    # Save figure
    fig_path = project_root / 'results' / 'figures'
    fig_path.mkdir(parents=True, exist_ok=True)
    plt.savefig(fig_path / f'erp_analysis_{manual_ica_subject}.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    print(f"\nüíæ ERP plot saved to: {fig_path / f'erp_analysis_{manual_ica_subject}.png'}")
else:
    print("‚ö†Ô∏è Skipping: No ERPs available")

## 8. Statistical Analysis (Example)

Perform t-tests to compare conditions.

In [None]:
if epochs is not None and len(erps) >= 2:
    print("\nüìä Statistical Analysis...")
    
    # Example: Compare familiar vs new
    familiar_conds = [k for k in epochs.event_id.keys() if 'familiar' in k.lower()]
    new_conds = [k for k in epochs.event_id.keys() if 'new' in k.lower()]
    
    if familiar_conds and new_conds:
        try:
            # Get data
            fam_epochs = epochs[familiar_conds]
            new_epochs = epochs[new_conds]
            
            # Pick parietal channels
            fam_data = fam_epochs.copy().pick_channels(parietal_chs).get_data()
            new_data = new_epochs.copy().pick_channels(parietal_chs).get_data()
            
            # Average across channels: (n_epochs, n_times)
            fam_mean = np.mean(fam_data, axis=1)
            new_mean = np.mean(new_data, axis=1)
            
            # T-test at each time point
            t_stats = []
            p_values = []
            for t_idx in range(fam_mean.shape[1]):
                t, p = stats.ttest_ind(fam_mean[:, t_idx], new_mean[:, t_idx])
                t_stats.append(t)
                p_values.append(p)
            
            # FDR correction
            from statsmodels.stats.multitest import multipletests
            _, p_corrected, _, _ = multipletests(p_values, method='fdr_bh', alpha=0.05)
            
            # Find significant time points
            sig_times = np.where(p_corrected < 0.05)[0]
            
            print(f"\n‚úÖ Statistical comparison: {familiar_conds[0]} vs {new_conds[0]}")
            print(f"  Significant timepoints: {len(sig_times)} / {len(p_values)}")
            if len(sig_times) > 0:
                times_ms = epochs.times[sig_times] * 1000
                print(f"  Time range: {times_ms.min():.0f} - {times_ms.max():.0f} ms")
            else:
                print("  No significant differences found (FDR corrected)")
                
        except Exception as e:
            print(f"‚ö†Ô∏è Statistical analysis note: {e}")
    else:
        print("‚ö†Ô∏è Need familiar and new conditions for comparison")
else:
    print("‚ö†Ô∏è Skipping: Insufficient data for statistics")

## 9. Save Epochs

In [None]:
if epochs is not None:
    # Save epochs for later analysis
    epochs_dir = project_root / 'data' / 'preprocessed' / 'after_epochs'
    subject_epochs_dir = epochs_dir / manual_ica_subject / session
    subject_epochs_dir.mkdir(parents=True, exist_ok=True)
    
    epochs_filename = f"{manual_ica_subject}_{session}_{run}_epo.fif"
    epochs_path = subject_epochs_dir / epochs_filename
    
    epochs.save(str(epochs_path), overwrite=True, verbose=False)
    print(f"\nüíæ Epochs saved to: {epochs_path}")
else:
    print("‚ö†Ô∏è No epochs to save")

## 10. Next Steps

This demonstrates ERP analysis for one subject. To complete the analysis:

1. **Loop through all subjects** to compute group-level ERPs
2. **Implement full statistical pipeline** with:
   - Repetition effects (1st, 2nd, 3rd presentation)
   - Category effects (animal vs non-animal)
   - Multiple comparison corrections
3. **Create publication-quality figures** matching Delorme et al. (2018)
4. **Create main_analysis.ipynb** for comprehensive results

For production use, consider using the `src/analysis/erp_analysis.py` module.

In [None]:
print("\nüéØ ERP ANALYSIS DEMONSTRATION COMPLETED")
print("=" * 60)
print(f"‚úÖ ERP analysis completed for {manual_ica_subject}")
print("\nüìã NEXT STEPS:")
print("1. Extend to all subjects for group analysis")
print("2. Implement full statistical pipeline")
print("3. Create main_analysis.ipynb for final results")
print("\nüìä Analysis Progress: 4/4 Complete (Demo)")
print("   ‚úì Setup and data exploration")
print("   ‚úì Preprocessing pipeline")
print("   ‚úì Manual ICA review")
print("   ‚úì ERP analysis (demonstration)")
print("   ‚Üí Next: Full group analysis + GitHub presentation")

## 6. Group-level familiarity effect (familiar vs new) with FDR correction

Uses aggregated ERP timecourses to run paired t-tests across subjects at each timepoint for each ROI, correcting p-values (BH-FDR). Saves CSVs under `results/statistical_outputs/` and plots group difference waves with significant intervals highlighted.


## 7. Repetition-wise familiarity difference (1st/2nd/3rd) per ROI

Computes group-level Familiar ‚àí New difference ERPs separately for the 1st, 2nd, and 3rd repetition of the familiar set, and plots mean ¬± SEM for each repetition, matching the reference figure layout (Frontal ROI on the left, Parieto-occipital on the right).


### 7a. Repetition-wise repeated-measures ANOVA (per timepoint, per ROI)

Computes one-way repeated-measures ANOVA across repetitions (1/2/3) at each timepoint for each ROI. Saves CSVs under `results/statistical_outputs/erp_rep_anova_<roi>.csv` with columns: `time_ms, F, p, p_fdr, n_subjects`. Uses BH-FDR within each ROI.


### 7b. Save repetition-wise familiarity difference timecourses (mean ¬± SEM)

Exports per-ROI repetition curves to CSV with columns: `time_ms, repetition, mean_uV, sem_uV, n`. Files are written to `results/statistical_outputs/erp_repetition_diff_<roi>.csv`.


## 8. Category analysis: Animal vs Non-animal

This section reproduces category effects:
- Channel √ó time p-value heatmap for Animal vs Non-animal (paired across subjects)
- ROI panel of Familiar‚àíNew difference per category (Animal vs Non-animal), with mean ¬± SEM and timepoints where categories differ (BH-FDR).


In [None]:
# Build per-subject category ERPs and channel√ótime stats
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

figdir = project_root / 'results' / 'figures'
statdir = project_root / 'results' / 'statistical_outputs'
figdir.mkdir(parents=True, exist_ok=True)
statdir.mkdir(parents=True, exist_ok=True)

roi_cfg = config['erp_analysis']['roi']
FRONTAL = roi_cfg.get('frontal', [])
PAROCC  = roi_cfg.get('parieto_occipital', [])

familiar_labels = {'animal_target','nonanimal_target','easy_target','difficult_target'}
new_labels      = {'animal_distractor','nonanimal_distractor','easy_distractor','difficult_distractor'}

# helper to decide category label
def _is_animal(v: str) -> bool:
    return isinstance(v, str) and ('animal_' in v) and (not v.startswith('nonanimal'))

def _is_nonanimal(v: str) -> bool:
    return isinstance(v, str) and (v.startswith('nonanimal'))

# Collect per-subject per-channel time series of category difference (fam‚àínew) and ROI series
subj_chan_series = {}  # subj -> DataFrame (channels x times) of (animal‚àínonanimal) difference of (fam‚àínew)
roi_series = {'Frontal ROI': {'animal': {}, 'nonanimal': {}},
              'Parieto-occipital ROI': {'animal': {}, 'nonanimal': {}}}

for subj in selected_subjects:
    # aggregate across sessions
    chan_acc_animal = []
    chan_acc_nonanimal = []
    times_ref = None
    ch_names_ref = None
    for ses in list_sessions(subj):
        raw_after = merge_stage(subj, ses, after_ica_root, 'ica_cleaned')
        if raw_after is None:
            continue
        evdf = build_session_events_with_meta(subj, ses)
        if evdf is None or evdf.empty:
            continue
        tmin = config['preprocessing']['epoching']['tmin']
        tmax = config['preprocessing']['epoching']['tmax']
        baseline = tuple(config['preprocessing']['epoching']['baseline'])
        # create events array for familiar/new with categories
        ev_animal_fam = evdf[evdf['value'].apply(_is_animal) & evdf['value'].isin(familiar_labels)]['sample'].astype(int).values
        ev_animal_new = evdf[evdf['value'].apply(_is_animal) & evdf['value'].isin(new_labels)]['sample'].astype(int).values
        ev_non_fam = evdf[evdf['value'].apply(_is_nonanimal) & evdf['value'].isin(familiar_labels)]['sample'].astype(int).values
        ev_non_new = evdf[evdf['value'].apply(_is_nonanimal) & evdf['value'].isin(new_labels)]['sample'].astype(int).values
        def _epochs_from_samples(samples):
            if samples.size == 0:
                return None
            arr = np.column_stack([samples, np.zeros(len(samples), dtype=int), np.ones(len(samples), dtype=int)])
            try:
                ep = mne.Epochs(raw_after, arr, event_id=None, tmin=tmin, tmax=tmax,
                                baseline=baseline, preload=True, verbose='ERROR')
                return ep
            except Exception:
                return None
        ep_af = _epochs_from_samples(ev_animal_fam)
        ep_an = _epochs_from_samples(ev_animal_new)
        ep_nf = _epochs_from_samples(ev_non_fam)
        ep_nn = _epochs_from_samples(ev_non_new)
        if any(e is None or len(e)==0 for e in [ep_af, ep_an, ep_nf, ep_nn]):
            continue
        # channel-level evokeds
        af = ep_af.average(); an = ep_an.average(); nf = ep_nf.average(); nn = ep_nn.average()
        # fam‚àínew per category
        animal_diff = (af.data - an.data) * 1e6  # channels x time (¬µV)
        non_diff   = (nf.data - nn.data) * 1e6
        chan_acc_animal.append(animal_diff)
        chan_acc_nonanimal.append(non_diff)
        times_ref = af.times * 1000.0
        ch_names_ref = af.ch_names
        # ROI series per session
        for roi_name, picks in [('Frontal ROI', FRONTAL), ('Parieto-occipital ROI', PAROCC)]:
            sel = [ch for ch in picks if ch in af.ch_names]
            if not sel:
                continue
            idx = [af.ch_names.index(ch) for ch in sel]
            roi_animal = animal_diff[idx, :].mean(axis=0)
            roi_non    = non_diff[idx, :].mean(axis=0)
            roi_series[roi_name]['animal'].setdefault(subj, []).append(pd.Series(roi_animal, index=times_ref))
            roi_series[roi_name]['nonanimal'].setdefault(subj, []).append(pd.Series(roi_non, index=times_ref))
    # average across sessions for subject at channel level
    if times_ref is not None and ch_names_ref is not None and chan_acc_animal and chan_acc_nonanimal:
        A = np.stack(chan_acc_animal, axis=0).mean(axis=0)
        N = np.stack(chan_acc_nonanimal, axis=0).mean(axis=0)
        subj_chan_series[subj] = {'animal': A, 'nonanimal': N, 'times_ms': times_ref, 'ch_names': ch_names_ref}

# 1) Channel √ó time p-values heatmap (Animal vs Non-animal difference)
if subj_chan_series:
    ch_names = list(next(iter(subj_chan_series.values()))['ch_names'])
    times_ms = next(iter(subj_chan_series.values()))['times_ms']
    # stack subjects
    animal_stack = []
    non_stack = []
    for subj, obj in subj_chan_series.items():
        animal_stack.append(obj['animal'])
        non_stack.append(obj['nonanimal'])
    animal_stack = np.stack(animal_stack, axis=0)  # n x ch x t
    non_stack    = np.stack(non_stack, axis=0)
    # paired t-test across subjects for each ch,t
    from scipy import stats as _stats
    n_subj, n_ch, n_t = animal_stack.shape
    p_mat = np.full((n_ch, n_t), np.nan)
    for ci in range(n_ch):
        x = animal_stack[:, ci, :]
        y = non_stack[:, ci, :]
        for ti in range(n_t):
            xv = x[:, ti]; yv = y[:, ti]
            mask = np.isfinite(xv) & np.isfinite(yv)
            if mask.sum() < 2:
                continue
            t, p = _stats.ttest_rel(xv[mask], yv[mask])
            p_mat[ci, ti] = p
    # save CSV and figure
    heat_df = pd.DataFrame(p_mat, index=ch_names, columns=np.round(times_ms,1))
    heat_csv = statdir / 'erp_category_channel_time_pvals.csv'
    heat_df.to_csv(heat_csv)
    print(f"üíæ Saved heatmap p-values ‚Üí {heat_csv}")
    plt.figure(figsize=(8,6))
    # map p-values to -log10 for better dynamic range
    with np.errstate(divide='ignore'):
        z = -np.log10(p_mat)
    sns.heatmap(z, cmap='inferno', yticklabels=ch_names, xticklabels=200,
                cbar_kws={'label': '-log10(p)'} )
    plt.title('Animal vs Non-animal (paired t-test across subjects): -log10(p)')
    plt.xlabel('Time (ms)')
    plt.ylabel('Channels')
    heat_png = figdir / 'group_category_channel_time_pvals.png'
    plt.tight_layout()
    plt.savefig(heat_png, dpi=200, bbox_inches='tight')
    print(f"üíæ Saved: {heat_png}")
    plt.show()

# 2) ROI panel: category difference curves (Animal vs Non-animal) and FDR significance
try:
    _bh_fdr
except NameError:
    def _bh_fdr(pvals):
        p = np.asarray(pvals, dtype=float); m = p.size
        order = np.argsort(p); ranked = p[order]
        adj = np.empty_like(ranked); prev=1.0
        for i in range(m-1, -1, -1):
            rank = i+1; val = ranked[i]*m/float(rank)
            prev = min(prev, val) if np.isfinite(val) else prev
            adj[i] = prev if np.isfinite(val) else np.nan
        out = np.minimum(1.0, adj); res = np.empty_like(p); res[order]=out; return res

fig, axes = plt.subplots(1, 2, figsize=(14,5), sharey=True)
for idx, (roi_name, ax) in enumerate([('Frontal ROI', axes[0]), ('Parieto-occipital ROI', axes[1])]):
    # subject-averaged per category
    subj_series_an = []
    subj_series_non = []
    common_index = None
    for subj in sorted(set(roi_series[roi_name]['animal'].keys()) | set(roi_series[roi_name]['nonanimal'].keys())):
        s_an = roi_series[roi_name]['animal'].get(subj)
        s_non = roi_series[roi_name]['nonanimal'].get(subj)
        if not s_an or not s_non:
            continue
        # average across sessions within subject
        idx_union = sorted(set().union(*[s.index for s in (s_an+s_non)]))
        an = np.nanmean(np.vstack([s.reindex(idx_union).values for s in s_an]), axis=0)
        nn = np.nanmean(np.vstack([s.reindex(idx_union).values for s in s_non]), axis=0)
        subj_series_an.append(pd.Series(an, index=idx_union))
        subj_series_non.append(pd.Series(nn, index=idx_union))
        common_index = idx_union
    if not subj_series_an or not subj_series_non:
        continue
    A = np.vstack([s.reindex(common_index).values for s in subj_series_an])
    N = np.vstack([s.reindex(common_index).values for s in subj_series_non])
    meanA = np.nanmean(A, axis=0); semA = np.nanstd(A, axis=0, ddof=1)/np.sqrt(np.sum(np.isfinite(A), axis=0).clip(min=1))
    meanN = np.nanmean(N, axis=0); semN = np.nanstd(N, axis=0, ddof=1)/np.sqrt(np.sum(np.isfinite(N), axis=0).clip(min=1))
    # plot curves
    ax.plot(common_index, meanA, color='magenta', lw=2, label='Animal')
    ax.fill_between(common_index, meanA-semA, meanA+semA, color='magenta', alpha=0.25)
    ax.plot(common_index, meanN, color='teal', lw=2, label='Non-animal')
    ax.fill_between(common_index, meanN-semN, meanN+semN, color='teal', alpha=0.25)
    # significance Animal vs Non-animal
    from scipy import stats as _stats
    pvals = []
    for i in range(len(common_index)):
        x = A[:, i]; y = N[:, i]
        mask = np.isfinite(x) & np.isfinite(y)
        if mask.sum() < 2:
            pvals.append(np.nan)
        else:
            t,p = _stats.ttest_rel(x[mask], y[mask])
            pvals.append(p)
    pvals = np.array(pvals)
    valid = np.isfinite(pvals); p_fdr = np.full_like(pvals, np.nan)
    if valid.any():
        p_fdr[valid] = _bh_fdr(pvals[valid])
    sig = (p_fdr < 0.05)
    ymin, ymax = ax.get_ylim(); yspan = ymax - ymin; y_bar = ymin + 0.08*yspan
    if np.any(sig):
        t_ms = np.array(common_index)
        step = np.median(np.diff(t_ms)) if len(t_ms)>1 else 1.0
        start=None; prev=None
        for tt in t_ms[sig]:
            if start is None:
                start=tt; prev=tt; continue
            if abs(tt-prev-step)<1e-6:
                prev=tt; continue
            ax.hlines(y_bar, start, prev, colors='black', linewidth=4)
            start=tt; prev=tt
        if start is not None:
            ax.hlines(y_bar, start, prev, colors='black', linewidth=4)
    ax.axhline(0, color='k', lw=0.8, alpha=0.5)
    ax.axvline(0, color='k', lw=1.0, ls='--', alpha=0.6)
    ax.set_title(roi_name)
    ax.set_xlabel('Time (ms)')
axes[0].set_ylabel('Amplitude difference (¬µV)')
axes[0].legend(loc='upper left')
plt.tight_layout()
roi_png = figdir / 'group_category_roi_diff.png'
plt.savefig(roi_png, dpi=200, bbox_inches='tight')
print(f"üíæ Saved: {roi_png}")
plt.show()

# Save ROI curves to CSV per ROI
for roi_name, A_label, N_label in [('Frontal ROI','frontal_roi','frontal_roi'), ('Parieto-occipital ROI','parieto-occipital_roi','parieto-occipital_roi')]:
    # reuse A,N from loop by re-computing quickly
    # collect again to ensure variables exist here
    subj_series_an = []
    subj_series_non = []
    for subj in sorted(set(roi_series[roi_name]['animal'].keys()) | set(roi_series[roi_name]['nonanimal'].keys())):
        s_an = roi_series[roi_name]['animal'].get(subj)
        s_non = roi_series[roi_name]['nonanimal'].get(subj)
        if not s_an or not s_non:
            continue
        idx_union = sorted(set().union(*[s.index for s in (s_an+s_non)]))
        an = np.nanmean(np.vstack([s.reindex(idx_union).values for s in s_an]), axis=0)
        nn = np.nanmean(np.vstack([s.reindex(idx_union).values for s in s_non]), axis=0)
        subj_series_an.append(pd.Series(an, index=idx_union))
        subj_series_non.append(pd.Series(nn, index=idx_union))
    if not subj_series_an or not subj_series_non:
        continue
    common_index = sorted(set().union(*[s.index for s in (subj_series_an+subj_series_non)]))
    A = np.vstack([s.reindex(common_index).values for s in subj_series_an])
    N = np.vstack([s.reindex(common_index).values for s in subj_series_non])
    meanA = np.nanmean(A, axis=0); semA = np.nanstd(A, axis=0, ddof=1)/np.sqrt(np.sum(np.isfinite(A), axis=0).clip(min=1))
    meanN = np.nanmean(N, axis=0); semN = np.nanstd(N, axis=0, ddof=1)/np.sqrt(np.sum(np.isfinite(N), axis=0).clip(min=1))
    out_rows = []
    for t, mA, sA, mN, sN in zip(common_index, meanA, semA, meanN, semN):
        out_rows.append({'roi': roi_name, 'time_ms': float(t), 'category': 'animal', 'mean_uV': float(mA), 'sem_uV': float(sA)})
        out_rows.append({'roi': roi_name, 'time_ms': float(t), 'category': 'nonanimal', 'mean_uV': float(mN), 'sem_uV': float(sN)})
    out_df = pd.DataFrame(out_rows)
    out_csv = statdir / f'erp_category_roi_diff_{roi_name.replace(" ", "_").lower()}.csv'
    out_df.to_csv(out_csv, index=False)
    print(f"üíæ Saved ROI curves ‚Üí {out_csv}")


## 9. Subject diagnostics ‚Äì sub-003

Computes diagnostics to explain reduced SME/Cohen‚Äôs d:
- Trial counts per condition and repetition
- ROI channel availability and fallbacks
- SME and Cohen‚Äôs d on trial-level ROI amplitudes (config post_window) before vs after cleaning
- Approximate latency jitter (SD of single-trial peak latency in window)

Saves:
- `results/diagnostics/sub-003_diagnostics_summary.csv`
- `results/diagnostics/sub-003_trial_metrics_<stage>_<roi>.csv`
