# EEG Spectrogram Analysis: Frontal and Central Electrodes

This notebook performs comprehensive time-frequency analysis of EEG data using spectrograms, **focusing on frontal and central electrodes** that are most relevant for decision-making processes.

**Data Details:**
- Sampling rate: 103 Hz
- Nyquist frequency: 51.5 Hz (maximum analyzable frequency)
- Trial length: 103 samples (1 second)
- Total electrodes: 58 analyzed
- Focusing on ~30 frontal + central electrodes (excluding centro-parietal)
- Data shape: (64 electrodes, 61415 trials, 103 time points)

**Focus Regions:**
- **Frontal electrodes** (Fp, AF, F): Executive function, decision-making, cognitive control
- **Central electrodes** (FC, C): Motor planning, motor execution, action preparation
- **Excluding** centro-parietal (CP) electrodes as they are posterior to the central region

**Note:** ~49% of the data contains NaN values (replaced with 0), which may affect analysis quality.

In [50]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
from scipy.signal import spectrogram, stft
import pandas as pd
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
from IPython.display import display

plt.style.use('default')
sns.set_style("darkgrid")
sns.set_palette("husl")

SAMPLING_RATE = 103
TRIAL_LENGTH = 103
NYQUIST_FREQ = SAMPLING_RATE / 2
DATA_PATH = Path("../data/raw/05_25")

## 1. Load Data

In [51]:
trials = np.load(DATA_PATH / "trials_dataset.npy")
electrodes = np.load(DATA_PATH / "electrodes_names.npy")
true_labels = np.load(DATA_PATH / "true_labels.npy")
cues = np.load(DATA_PATH / "cues.npy")
primes = np.load(DATA_PATH / "primes.npy")

print(f"Trials shape: {trials.shape}")
print(f"Electrodes: {len(electrodes)} | Trials: {trials.shape[1]}")
print(f"Labels: {np.unique(true_labels)} | Cues: {np.unique(cues)} | Primes: {np.unique(primes)}")

nan_count = np.isnan(trials).sum()
print(f"NaN values: {nan_count} ({100 * nan_count / trials.size:.2f}%)")

Trials shape: (64, 61415, 103)
Electrodes: 58 | Trials: 61415
Labels: [1. 2.] | Cues: [1. 2. 4.] | Primes: [1. 2. 3.]
NaN values: 197814656 (48.86%)


In [52]:
trials_clean = np.nan_to_num(trials, nan=0.0)

frontal_prefixes = ['Fp', 'AF', 'F']
central_prefixes = ['FC', 'C']

frontal_central_indices = []
frontal_central_names = []

for i, elec in enumerate(electrodes):
    elec_str = str(elec)
    if any(elec_str.startswith(prefix) for prefix in frontal_prefixes + central_prefixes):
        if not elec_str.startswith('CP'):
            frontal_central_indices.append(i)
            frontal_central_names.append(elec_str)

frontal_central_indices = np.array(frontal_central_indices)
frontal_central_names = np.array(frontal_central_names)

trials_clean_filtered = trials_clean[frontal_central_indices, :, :]

print(f"Filtered to {len(frontal_central_indices)} frontal/central electrodes (excluded CP)")
print(f"Electrodes: {frontal_central_names}")

Filtered to 30 frontal/central electrodes (excluded CP)
Electrodes: ['Fp1' 'Fpz' 'Fp2' 'AF3' 'AF4' 'F7' 'F5' 'F3' 'F1' 'Fz' 'F2' 'F4' 'F6'
 'F8' 'FT7' 'FC5' 'FC3' 'FC1' 'FCz' 'FC2' 'FC4' 'FC6' 'FT8' 'C5' 'C3' 'C1'
 'Cz' 'C2' 'C4' 'C6']


In [53]:
def normalize_data(data, method='none'):
    """Normalize EEG data using different methods."""
    data = data.copy()
    
    if method == 'none':
        return data
    elif method == 'zscore_trial':
        # Z-score per trial
        for elec_idx in range(data.shape[0]):
            for trial_idx in range(data.shape[1]):
                trial_data = data[elec_idx, trial_idx, :]
                mean = np.mean(trial_data)
                std = np.std(trial_data)
                if std > 0:
                    data[elec_idx, trial_idx, :] = (trial_data - mean) / std
        return data
    elif method == 'zscore_electrode':
        # Z-score per electrode
        for elec_idx in range(data.shape[0]):
            elec_data = data[elec_idx, :, :].flatten()
            mean = np.mean(elec_data)
            std = np.std(elec_data)
            if std > 0:
                data[elec_idx, :, :] = (data[elec_idx, :, :] - mean) / std
        return data
    elif method == 'baseline':
        # Baseline correction (subtract mean)
        for elec_idx in range(data.shape[0]):
            for trial_idx in range(data.shape[1]):
                trial_data = data[elec_idx, trial_idx, :]
                data[elec_idx, trial_idx, :] = trial_data - np.mean(trial_data)
        return data
    else:
        return data

# NORMALIZATION SELECTION
# Options: 'none', 'zscore_trial', 'zscore_electrode', 'baseline'
NORMALIZATION_METHOD = 'zscore_trial'  # Change this to apply normalization

# Apply normalization
trials_clean_filtered = normalize_data(trials_clean_filtered, method=NORMALIZATION_METHOD)
print(f"Normalization applied: {NORMALIZATION_METHOD}")

Normalization applied: zscore_trial


## 2. Define Frequency Bands

Standard EEG frequency bands (adjusted for Nyquist limit of 51.5 Hz):
- **Delta (δ)**: 0.5-4 Hz - Deep sleep, unconscious processes
- **Theta (θ)**: 4-8 Hz - Drowsiness, meditation, memory
- **Alpha (α)**: 8-13 Hz - Relaxed wakefulness, eyes closed
- **Beta (β)**: 13-30 Hz - Active thinking, concentration, anxiety
- **Low Gamma (γ)**: 30-50 Hz - Cognitive processing, attention (limited by Nyquist frequency)

In [54]:
freq_bands = {
    'Delta': (0.5, 4),
    'Theta': (4, 8),
    'Alpha': (8, 13),
    'Beta': (13, 30),
    'Gamma': (30, 50)
}

band_colors = {
    'Delta': '#1f77b4',
    'Theta': '#ff7f0e',
    'Alpha': '#2ca02c',
    'Beta': '#d62728',
    'Gamma': '#9467bd'
}

## 3. Single Trial Spectrogram Analysis

In [55]:
def compute_spectrogram(signal_data, fs=SAMPLING_RATE, nperseg=None, noverlap=None):
    if nperseg is None:
        nperseg = min(32, len(signal_data))
    if noverlap is None:
        noverlap = nperseg // 2
    f, t, Sxx = spectrogram(signal_data, fs=fs, nperseg=nperseg, 
                            noverlap=noverlap, window='hann')
    return f, t, Sxx


def compute_band_power(signal_data, fs, band_range):
    freqs, t, Zxx = stft(signal_data, fs=fs, nperseg=min(256, len(signal_data)))
    psd = np.abs(Zxx) ** 2
    psd_mean = np.mean(psd, axis=1)
    idx_band = np.logical_and(freqs >= band_range[0], freqs <= band_range[1])
    band_power = np.trapz(psd_mean[idx_band], freqs[idx_band])
    return band_power


def compute_average_spectrogram(trials_data, electrode_idx, trial_indices=None, max_trials=1000):
    if trial_indices is None:
        trial_indices = np.arange(trials_data.shape[1])
    if len(trial_indices) > max_trials:
        trial_indices = np.random.choice(trial_indices, max_trials, replace=False)
    
    spectrograms = []
    for trial_idx in trial_indices:
        signal_data = trials_data[electrode_idx, trial_idx, :]
        f, t, Sxx = compute_spectrogram(signal_data)
        spectrograms.append(Sxx)
    
    Sxx_avg = np.mean(spectrograms, axis=0)
    return f, t, Sxx_avg


def compute_all_band_powers(trials_data, electrode_idx, trial_indices=None, max_trials=1000):
    if trial_indices is None:
        trial_indices = np.arange(trials_data.shape[1])
    if len(trial_indices) > max_trials:
        trial_indices = np.random.choice(trial_indices, max_trials, replace=False)
    
    band_powers = {band: [] for band in freq_bands.keys()}
    
    for trial_idx in trial_indices:
        signal_data = trials_data[electrode_idx, trial_idx, :]
        for band, band_range in freq_bands.items():
            power = compute_band_power(signal_data, SAMPLING_RATE, band_range)
            band_powers[band].append(power)
    
    band_powers = {band: np.array(powers) for band, powers in band_powers.items()}
    return band_powers

In [63]:
def interactive_averaged_similar_trials(electrode_name, prime, cue, label, 
                                       freq_min, freq_max, show_bands, 
                                       nperseg, colormap, n_trials_to_avg):
    """
    Interactive visualization of averaged spectrograms for trials with the same prime, cue, and label.
    
    Parameters:
    -----------
    electrode_name : str
        Name of the electrode to analyze
    prime : float
        Prime condition value
    cue : float
        Cue condition value
    label : float
        Label value (1=Left, 2=Right)
    freq_min, freq_max : float
        Frequency range to display
    show_bands : bool
        Whether to show frequency band markers
    nperseg : int
        Window size for spectrogram
    colormap : str
        Colormap for the spectrogram
    n_trials_to_avg : int
        Maximum number of trials to average
    """
    # Get electrode index
    electrode_idx = np.where(frontal_central_names == electrode_name)[0][0]
    
    # Find trials with matching conditions
    matching_trials = np.where(
        (primes == prime) & 
        (cues == cue) & 
        (true_labels == label)
    )[0]
    
    n_matching = len(matching_trials)
    
    if n_matching == 0:
        print(f"No trials found with Prime={int(prime)}, Cue={int(cue)}, Label={int(label)}")
        return
    
    # Limit number of trials if needed
    if n_matching > n_trials_to_avg:
        selected_trials = np.random.choice(matching_trials, n_trials_to_avg, replace=False)
    else:
        selected_trials = matching_trials
    
    # Compute individual spectrograms for all selected trials
    spectrograms = []
    signals = []
    
    for trial_idx in selected_trials:
        signal_data = trials_clean_filtered[electrode_idx, trial_idx, :]
        signals.append(signal_data)
        f, t, Sxx = compute_spectrogram(signal_data, nperseg=nperseg)
        spectrograms.append(Sxx)
    
    # Average the spectrograms
    Sxx_avg = np.mean(spectrograms, axis=0)
    Sxx_std = np.std(spectrograms, axis=0)
    Sxx_avg_db = 10 * np.log10(Sxx_avg + 1e-10)
    Sxx_std_db = 10 * np.log10(Sxx_std + 1e-10)
    
    # Average the time series
    signals_avg = np.mean(signals, axis=0)
    signals_std = np.std(signals, axis=0)
    
    # Create the plot with 5 subplots
    fig, axes = plt.subplots(2, 3, figsize=(18, 10))
    
    # Plot 1: Averaged time series
    time_axis = np.arange(len(signals_avg)) / SAMPLING_RATE
    axes[0, 0].plot(time_axis, signals_avg, linewidth=1.5, color='steelblue', label='Mean')
    axes[0, 0].fill_between(time_axis, 
                           signals_avg - signals_std, 
                           signals_avg + signals_std,
                           alpha=0.3, color='steelblue', label='±1 SD')
    axes[0, 0].set_xlabel('Time (s)')
    axes[0, 0].set_ylabel('Amplitude (μV)')
    axes[0, 0].set_title(f'Averaged Signal: {electrode_name} (n={len(selected_trials)} trials)', fontweight='bold')
    axes[0, 0].grid(True, alpha=0.3)
    axes[0, 0].legend(loc='upper right')
    
    # Add condition information
    condition_text = f"Prime: {int(prime)} | Cue: {int(cue)} | Label: {int(label)}\nTotal matching: {n_matching} | Averaged: {len(selected_trials)}"
    axes[0, 0].text(0.02, 0.95, condition_text, transform=axes[0, 0].transAxes, 
                   verticalalignment='top', fontsize=9,
                   bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.7))
    
    # Plot 2: Averaged spectrogram
    im1 = axes[0, 1].pcolormesh(t, f, Sxx_avg_db, shading='gouraud', cmap=colormap)
    axes[0, 1].set_ylabel('Frequency (Hz)')
    axes[0, 1].set_xlabel('Time (s)')
    axes[0, 1].set_title('Averaged Spectrogram (dB)', fontweight='bold')
    axes[0, 1].set_ylim([freq_min, min(freq_max, NYQUIST_FREQ)])
    
    if show_bands:
        for band, (low, high) in freq_bands.items():
            if low >= freq_min and high <= freq_max:
                axes[0, 1].axhline(y=low, color='white', linestyle='--', alpha=0.5, linewidth=1)
                axes[0, 1].axhline(y=high, color='white', linestyle='--', alpha=0.5, linewidth=1)
                axes[0, 1].text(0.02, (low + high) / 2, band, transform=axes[0, 1].get_yaxis_transform(),
                              color='white', fontsize=9, fontweight='bold',
                              bbox=dict(boxstyle='round', facecolor='black', alpha=0.5))
    
    plt.colorbar(im1, ax=axes[0, 1], label='Power (dB)')
    
    # Plot 3: Standard deviation of spectrogram
    im2 = axes[0, 2].pcolormesh(t, f, Sxx_std_db, shading='gouraud', cmap='YlOrRd')
    axes[0, 2].set_ylabel('Frequency (Hz)')
    axes[0, 2].set_xlabel('Time (s)')
    axes[0, 2].set_title('Std Dev of Spectrogram (dB)', fontweight='bold')
    axes[0, 2].set_ylim([freq_min, min(freq_max, NYQUIST_FREQ)])
    plt.colorbar(im2, ax=axes[0, 2], label='Std Dev (dB)')
    
    # Plot 4: Averaged power spectral density
    avg_power = np.mean(Sxx_avg, axis=1)
    std_power = np.std(Sxx_avg, axis=1)
    
    axes[1, 0].plot(f, avg_power, linewidth=2, color='darkred', label='Mean Power')
    axes[1, 0].fill_between(f, avg_power - std_power, avg_power + std_power, 
                           alpha=0.3, color='red', label='±1 SD')
    axes[1, 0].set_xlabel('Frequency (Hz)')
    axes[1, 0].set_ylabel('Power')
    axes[1, 0].set_title('Averaged Power Spectral Density', fontweight='bold')
    axes[1, 0].set_xlim([freq_min, min(freq_max, NYQUIST_FREQ)])
    axes[1, 0].grid(True, alpha=0.3)
    
    if show_bands:
        for band, (low, high) in freq_bands.items():
            if low >= freq_min and high <= freq_max:
                axes[1, 0].axvspan(low, high, alpha=0.15, color=band_colors[band], label=band)
        axes[1, 0].legend(loc='upper right', fontsize=8)
    
    # Plot 5: Band power analysis
    band_powers_mean = []
    band_powers_std = []
    band_names = []
    
    for band, band_range in freq_bands.items():
        if band_range[0] >= freq_min and band_range[1] <= freq_max:
            # Compute band power for each trial
            powers = []
            for trial_idx in selected_trials:
                signal_data = trials_clean_filtered[electrode_idx, trial_idx, :]
                power = compute_band_power(signal_data, SAMPLING_RATE, band_range)
                powers.append(power)
            
            band_powers_mean.append(np.mean(powers))
            band_powers_std.append(np.std(powers))
            band_names.append(band)
    
    x_pos = np.arange(len(band_names))
    bars = axes[1, 1].bar(x_pos, band_powers_mean, yerr=band_powers_std, 
                         capsize=5, alpha=0.7, 
                         color=[band_colors[b] for b in band_names],
                         edgecolor='black', linewidth=1.5)
    axes[1, 1].set_xlabel('Frequency Band')
    axes[1, 1].set_ylabel('Power')
    axes[1, 1].set_title('Band Power Distribution', fontweight='bold')
    axes[1, 1].set_xticks(x_pos)
    axes[1, 1].set_xticklabels(band_names, rotation=45, ha='right')
    axes[1, 1].grid(True, alpha=0.3, axis='y')
    
    # Add value labels on bars
    for bar, mean, std in zip(bars, band_powers_mean, band_powers_std):
        height = bar.get_height()
        axes[1, 1].text(bar.get_x() + bar.get_width()/2., height + std,
                       f'{mean:.1f}',
                       ha='center', va='bottom', fontsize=8)
    
    # Remove the unused subplot
    fig.delaxes(axes[1, 2])
    
    plt.tight_layout()
    plt.show()

# Get unique values for dropdowns
unique_primes = sorted(np.unique(primes))
unique_cues = sorted(np.unique(cues))
unique_labels = sorted(np.unique(true_labels))

interact(
    interactive_averaged_similar_trials,
    electrode_name=widgets.Dropdown(
        options=list(frontal_central_names),
        value=frontal_central_names[0],
        description='Electrode:',
        style={'description_width': '150px'}
    ),
    prime=widgets.Dropdown(
        options=unique_primes,
        value=unique_primes[0],
        description='Prime:',
        style={'description_width': '150px'}
    ),
    cue=widgets.Dropdown(
        options=unique_cues,
        value=unique_cues[0],
        description='Cue:',
        style={'description_width': '150px'}
    ),
    label=widgets.Dropdown(
        options=unique_labels,
        value=unique_labels[0],
        description='Label (1=L, 2=R):',
        style={'description_width': '150px'}
    ),
    freq_min=widgets.FloatSlider(
        value=0,
        min=0,
        max=25,
        step=0.5,
        description='Min Freq (Hz):',
        style={'description_width': '150px'},
        continuous_update=False
    ),
    freq_max=widgets.FloatSlider(
        value=50,
        min=5,
        max=NYQUIST_FREQ,
        step=1,
        description='Max Freq (Hz):',
        style={'description_width': '150px'},
        continuous_update=False
    ),
    show_bands=widgets.Checkbox(
        value=True,
        description='Show Bands',
        style={'description_width': '150px'}
    ),
    nperseg=widgets.IntSlider(
        value=32,
        min=16,
        max=64,
        step=8,
        description='Window Size:',
        style={'description_width': '150px'},
        continuous_update=False
    ),
    colormap=widgets.Dropdown(
        options=['viridis', 'plasma', 'inferno', 'magma', 'cividis', 'jet', 'hot', 'cool'],
        value='viridis',
        description='Colormap:',
        style={'description_width': '150px'}
    ),
    n_trials_to_avg=widgets.IntSlider(
        value=100,
        min=10,
        max=1000,
        step=10,
        description='Max Trials:',
        style={'description_width': '150px'},
        continuous_update=False
    )
);

interactive(children=(Dropdown(description='Electrode:', options=('Fp1', 'Fpz', 'Fp2', 'AF3', 'AF4', 'F7', 'F5…

## 4. Multi-Electrode Spectrogram Comparison

### 4.1 Interactive Multi-Electrode Comparison

Compare spectrograms across multiple electrodes simultaneously!

In [57]:
def interactive_multi_electrode_comparison(electrode1, electrode2, electrode3, 
                                          trial_idx, freq_max, colormap):
    electrode_names_list = [electrode1, electrode2, electrode3]
    electrode_indices = [np.where(frontal_central_names == e)[0][0] for e in electrode_names_list]
    
    fig, axes = plt.subplots(2, 3, figsize=(18, 10))
    
    all_Sxx = []
    for idx in electrode_indices:
        signal_data = trials_clean_filtered[idx, trial_idx, :]
        f, t, Sxx = compute_spectrogram(signal_data)
        all_Sxx.append(Sxx)
    
    all_Sxx_db = [10 * np.log10(Sxx + 1e-10) for Sxx in all_Sxx]
    vmin = np.min([Sxx_db.min() for Sxx_db in all_Sxx_db])
    vmax = np.max([Sxx_db.max() for Sxx_db in all_Sxx_db])
    
    for i, (idx, elec_name) in enumerate(zip(electrode_indices, electrode_names_list)):
        signal_data = trials_clean_filtered[idx, trial_idx, :]
        Sxx_db = all_Sxx_db[i]
        
        time_axis = np.arange(len(signal_data)) / SAMPLING_RATE
        axes[0, i].plot(time_axis, signal_data, linewidth=0.8, color=f'C{i}')
        axes[0, i].set_title(f'{elec_name}', fontweight='bold')
        axes[0, i].set_xlabel('Time (s)')
        if i == 0:
            axes[0, i].set_ylabel('Amplitude (μV)')
        axes[0, i].grid(True, alpha=0.3)
        
        im = axes[1, i].pcolormesh(t, f, Sxx_db, shading='gouraud', cmap=colormap,
                                   vmin=vmin, vmax=vmax)
        axes[1, i].set_xlabel('Time (s)')
        if i == 0:
            axes[1, i].set_ylabel('Frequency (Hz)')
        axes[1, i].set_ylim([0, min(freq_max, NYQUIST_FREQ)])
    
    fig.colorbar(im, ax=axes[1, :], label='Power (dB)', pad=0.02)
    
    fig.suptitle(f'Trial {trial_idx} | Label: {int(true_labels[trial_idx])} | Cue: {int(cues[trial_idx])} | Prime: {int(primes[trial_idx])}', 
                 fontweight='bold', y=0.995)
    plt.tight_layout()
    plt.show()

interact(
    interactive_multi_electrode_comparison,
    electrode1=widgets.Dropdown(
        options=list(frontal_central_names),
        value='Fz' if 'Fz' in frontal_central_names else frontal_central_names[0],
        description='Electrode 1:',
        style={'description_width': '100px'}
    ),
    electrode2=widgets.Dropdown(
        options=list(frontal_central_names),
        value='FCz' if 'FCz' in frontal_central_names else frontal_central_names[1],
        description='Electrode 2:',
        style={'description_width': '100px'}
    ),
    electrode3=widgets.Dropdown(
        options=list(frontal_central_names),
        value='Cz' if 'Cz' in frontal_central_names else frontal_central_names[2],
        description='Electrode 3:',
        style={'description_width': '100px'}
    ),
    trial_idx=widgets.IntSlider(
        value=1000,
        min=0,
        max=trials_clean_filtered.shape[1] - 1,
        step=1,
        description='Trial:',
        style={'description_width': '100px'},
        continuous_update=False
    ),
    freq_max=widgets.FloatSlider(
        value=50,
        min=10,
        max=NYQUIST_FREQ,
        step=5,
        description='Max Freq (Hz):',
        style={'description_width': '100px'},
        continuous_update=False
    ),
    colormap=widgets.Dropdown(
        options=['viridis', 'plasma', 'inferno', 'magma', 'cividis'],
        value='viridis',
        description='Colormap:',
        style={'description_width': '100px'}
    )
);

interactive(children=(Dropdown(description='Electrode 1:', index=9, options=('Fp1', 'Fpz', 'Fp2', 'AF3', 'AF4'…

## 6. Condition-Dependent Spectrogram Analysis

### 6.1 Interactive Condition Comparison

Compare spectrograms between different trial conditions interactively!

In [58]:
def interactive_condition_comparison(electrode_name, condition_type, condition_val1, 
                                    condition_val2, n_trials, freq_max, colormap):
    electrode_idx = np.where(frontal_central_names == electrode_name)[0][0]
    
    if condition_type == 'Label (Left/Right)':
        condition_labels = true_labels
        cond_name1, cond_name2 = f"Label {condition_val1}", f"Label {condition_val2}"
    elif condition_type == 'Cue':
        condition_labels = cues
        cond_name1, cond_name2 = f"Cue {condition_val1}", f"Cue {condition_val2}"
    else:
        condition_labels = primes
        cond_name1, cond_name2 = f"Prime {condition_val1}", f"Prime {condition_val2}"
    
    cond1_trials = np.where(condition_labels == condition_val1)[0]
    cond2_trials = np.where(condition_labels == condition_val2)[0]
    
    if len(cond1_trials) == 0 or len(cond2_trials) == 0:
        print(f"Error: No trials for condition {condition_val1} or {condition_val2}")
        return
    
    f, t, Sxx_avg1 = compute_average_spectrogram(trials_clean_filtered, electrode_idx, 
                                                  cond1_trials, n_trials)
    _, _, Sxx_avg2 = compute_average_spectrogram(trials_clean_filtered, electrode_idx, 
                                                  cond2_trials, n_trials)
    
    Sxx_avg1_db = 10 * np.log10(Sxx_avg1 + 1e-10)
    Sxx_avg2_db = 10 * np.log10(Sxx_avg2 + 1e-10)
    
    fig, axes = plt.subplots(2, 3, figsize=(18, 10))
    
    im1 = axes[0, 0].pcolormesh(t, f, Sxx_avg1_db, shading='gouraud', cmap=colormap)
    axes[0, 0].set_title(f'{cond_name1}', fontweight='bold')
    axes[0, 0].set_ylabel('Frequency (Hz)')
    axes[0, 0].set_xlabel('Time (s)')
    axes[0, 0].set_ylim([0, min(freq_max, NYQUIST_FREQ)])
    plt.colorbar(im1, ax=axes[0, 0], label='Power (dB)')
    
    im2 = axes[0, 1].pcolormesh(t, f, Sxx_avg2_db, shading='gouraud', cmap=colormap)
    axes[0, 1].set_title(f'{cond_name2}', fontweight='bold')
    axes[0, 1].set_xlabel('Time (s)')
    axes[0, 1].set_ylim([0, min(freq_max, NYQUIST_FREQ)])
    plt.colorbar(im2, ax=axes[0, 1], label='Power (dB)')
    
    diff = Sxx_avg2 - Sxx_avg1
    diff_db = 10 * np.log10(np.abs(diff) + 1e-10)
    im3 = axes[0, 2].pcolormesh(t, f, diff_db, shading='gouraud', cmap='RdBu_r')
    axes[0, 2].set_title('Difference', fontweight='bold')
    axes[0, 2].set_xlabel('Time (s)')
    axes[0, 2].set_ylim([0, min(freq_max, NYQUIST_FREQ)])
    plt.colorbar(im3, ax=axes[0, 2], label='Diff (dB)')
    
    avg_power1 = np.mean(Sxx_avg1, axis=1)
    avg_power2 = np.mean(Sxx_avg2, axis=1)
    
    axes[1, 0].plot(f, avg_power1, linewidth=2, label=cond_name1, color='blue')
    axes[1, 0].fill_between(f, avg_power1, alpha=0.3, color='blue')
    axes[1, 0].set_xlabel('Frequency (Hz)')
    axes[1, 0].set_ylabel('Power')
    axes[1, 0].set_xlim([0, min(freq_max, NYQUIST_FREQ)])
    axes[1, 0].grid(True, alpha=0.3)
    axes[1, 0].legend()
    
    axes[1, 1].plot(f, avg_power2, linewidth=2, label=cond_name2, color='red')
    axes[1, 1].fill_between(f, avg_power2, alpha=0.3, color='red')
    axes[1, 1].set_xlabel('Frequency (Hz)')
    axes[1, 1].set_xlim([0, min(freq_max, NYQUIST_FREQ)])
    axes[1, 1].grid(True, alpha=0.3)
    axes[1, 1].legend()
    
    axes[1, 2].plot(f, avg_power1, linewidth=2, label=cond_name1, color='blue', alpha=0.7)
    axes[1, 2].plot(f, avg_power2, linewidth=2, label=cond_name2, color='red', alpha=0.7)
    axes[1, 2].fill_between(f, avg_power1, alpha=0.2, color='blue')
    axes[1, 2].fill_between(f, avg_power2, alpha=0.2, color='red')
    axes[1, 2].set_xlabel('Frequency (Hz)')
    axes[1, 2].set_xlim([0, min(freq_max, NYQUIST_FREQ)])
    axes[1, 2].grid(True, alpha=0.3)
    axes[1, 2].legend()
    axes[1, 2].set_title('Overlay', fontweight='bold')
    
    for band, (low, high) in freq_bands.items():
        if low < freq_max:
            axes[1, 2].axvspan(low, high, alpha=0.1, color=band_colors[band])
    
    fig.suptitle(f'{electrode_name}', fontweight='bold', y=0.995)
    plt.tight_layout()
    plt.show()

interact(
    interactive_condition_comparison,
    electrode_name=widgets.Dropdown(
        options=list(frontal_central_names),
        value='FCz' if 'FCz' in frontal_central_names else frontal_central_names[0],
        description='Electrode:',
        style={'description_width': '130px'}
    ),
    condition_type=widgets.Dropdown(
        options=['Label (Left/Right)', 'Cue', 'Prime'],
        value='Label (Left/Right)',
        description='Type:',
        style={'description_width': '130px'}
    ),
    condition_val1=widgets.Dropdown(
        options=[1.0, 2.0],
        value=1.0,
        description='Condition 1:',
        style={'description_width': '130px'}
    ),
    condition_val2=widgets.Dropdown(
        options=[1.0, 2.0],
        value=2.0,
        description='Condition 2:',
        style={'description_width': '130px'}
    ),
    n_trials=widgets.IntSlider(
        value=500,
        min=50,
        max=2000,
        step=50,
        description='Trials:',
        style={'description_width': '130px'},
        continuous_update=False
    ),
    freq_max=widgets.FloatSlider(
        value=50,
        min=10,
        max=NYQUIST_FREQ,
        step=5,
        description='Max Freq (Hz):',
        style={'description_width': '130px'},
        continuous_update=False
    ),
    colormap=widgets.Dropdown(
        options=['viridis', 'plasma', 'inferno', 'magma', 'cividis'],
        value='viridis',
        description='Colormap:',
        style={'description_width': '130px'}
    )
);

interactive(children=(Dropdown(description='Electrode:', index=18, options=('Fp1', 'Fpz', 'Fp2', 'AF3', 'AF4',…

## 7. Frequency Band Power Analysis

### 7.1 Interactive Band Power Analyzer

Analyze power in specific frequency bands interactively!

In [None]:
def interactive_band_power_analyzer(electrode_name, freq_min, freq_max, n_trials, 
                                   condition_type, show_distribution):
    electrode_idx = np.where(frontal_central_names == electrode_name)[0][0]
    
    if condition_type == 'Label (Left/Right)':
        condition_labels = true_labels
        condition_values = sorted(np.unique(true_labels))
        condition_names = [f"Label {int(v)}" for v in condition_values]
    elif condition_type == 'Cue':
        condition_labels = cues
        condition_values = sorted(np.unique(cues))
        condition_names = [f"Cue {int(v)}" for v in condition_values]
    else:
        condition_labels = primes
        condition_values = sorted(np.unique(primes))
        condition_names = [f"Prime {int(v)}" for v in condition_values]
    
    custom_band = (freq_min, freq_max)
    
    all_powers = {}
    for cond_val, cond_name in zip(condition_values, condition_names):
        cond_trials = np.where(condition_labels == cond_val)[0]
        cond_trials_sample = cond_trials[:min(n_trials, len(cond_trials))]
        
        powers = []
        for trial_idx in cond_trials_sample:
            signal_data = trials_clean_filtered[electrode_idx, trial_idx, :]
            power = compute_band_power(signal_data, SAMPLING_RATE, custom_band)
            powers.append(power)
        
        all_powers[cond_name] = np.array(powers)
    
    if show_distribution:
        fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    else:
        fig, axes = plt.subplots(1, 2, figsize=(14, 6))
        axes = list(axes) + [None]
    
    means = [np.mean(all_powers[name]) for name in condition_names]
    stds = [np.std(all_powers[name]) for name in condition_names]
    colors = plt.cm.Set2(np.linspace(0, 1, len(condition_names)))
    
    x = np.arange(len(condition_names))
    bars = axes[0].bar(x, means, yerr=stds, capsize=10, alpha=0.7, 
                       color=colors, edgecolor='black', linewidth=1.5)
    axes[0].set_xlabel('Condition', fontweight='bold')
    axes[0].set_ylabel('Power', fontweight='bold')
    axes[0].set_title(f'{freq_min}-{freq_max} Hz | {electrode_name}', fontweight='bold')
    axes[0].set_xticks(x)
    axes[0].set_xticklabels(condition_names)
    axes[0].grid(True, alpha=0.3, axis='y')
    
    for bar, mean, std in zip(bars, means, stds):
        height = bar.get_height()
        axes[0].text(bar.get_x() + bar.get_width()/2., height + std,
                    f'{mean:.2f}±{std:.2f}',
                    ha='center', va='bottom', fontsize=9, fontweight='bold')
    
    positions = np.arange(len(condition_names))
    parts = axes[1].violinplot([all_powers[name] for name in condition_names], 
                               positions=positions, showmeans=True, showmedians=True)
    
    for i, pc in enumerate(parts['bodies']):
        pc.set_facecolor(colors[i])
        pc.set_alpha(0.7)
        pc.set_edgecolor('black')
        pc.set_linewidth(1.5)
    
    axes[1].set_xlabel('Condition', fontweight='bold')
    axes[1].set_ylabel('Distribution', fontweight='bold')
    axes[1].set_title('Power Distribution', fontweight='bold')
    axes[1].set_xticks(positions)
    axes[1].set_xticklabels(condition_names)
    axes[1].grid(True, alpha=0.3, axis='y')
    
    if show_distribution and len(condition_names) >= 2:
        from scipy import stats as sp_stats
        
        results = []
        for i in range(len(condition_names)):
            for j in range(i + 1, len(condition_names)):
                name1, name2 = condition_names[i], condition_names[j]
                powers1, powers2 = all_powers[name1], all_powers[name2]
                
                t_stat, p_value = sp_stats.ttest_ind(powers1, powers2)
                pooled_std = np.sqrt((np.var(powers1) + np.var(powers2)) / 2)
                cohens_d = (np.mean(powers1) - np.mean(powers2)) / pooled_std
                
                results.append({
                    'Comparison': f'{name1} vs {name2}',
                    'Mean Diff': np.mean(powers1) - np.mean(powers2),
                    't-stat': t_stat,
                    'p-value': p_value,
                    "Cohen's d": cohens_d,
                    'Sig': '***' if p_value < 0.001 else '**' if p_value < 0.01 else '*' if p_value < 0.05 else 'ns'
                })
        
        df_results = pd.DataFrame(results)
        axes[2].axis('tight')
        axes[2].axis('off')
        
        table_data = []
        for _, row in df_results.iterrows():
            table_data.append([
                row['Comparison'],
                f"{row['Mean Diff']:.3f}",
                f"{row['t-stat']:.3f}",
                f"{row['p-value']:.4f}",
                f"{row['Cohen\'s d']:.3f}",
                row['Sig']
            ])
        
        table = axes[2].table(cellText=table_data,
                            colLabels=['Comparison', 'Diff', 't-stat', 'p-value', "Cohen's d", 'Sig'],
                            cellLoc='center',
                            loc='center',
                            bbox=[0, 0, 1, 1])
        
        table.auto_set_font_size(False)
        table.set_fontsize(9)
        table.scale(1, 2)
        
        for i, row in df_results.iterrows():
            if row['p-value'] < 0.05:
                table[(i+1, 5)].set_facecolor('#90EE90')
            else:
                table[(i+1, 5)].set_facecolor('#FFB6C1')
        
        axes[2].set_title('Statistics', fontweight='bold')
        
        print(df_results.to_string(index=False))
        print("\nSig: *** p<0.001, ** p<0.01, * p<0.05, ns = not significant")
    
    plt.tight_layout()
    plt.show()

interact(
    interactive_band_power_analyzer,
    electrode_name=widgets.Dropdown(
        options=list(frontal_central_names),
        value='Fz' if 'Fz' in frontal_central_names else frontal_central_names[0],
        description='Electrode:',
        style={'description_width': '140px'}
    ),
    freq_min=widgets.FloatSlider(
        value=13,
        min=0.5,
        max=45,
        step=0.5,
        description='Min Freq (Hz):',
        style={'description_width': '140px'},
        continuous_update=False
    ),
    freq_max=widgets.FloatSlider(
        value=30,
        min=5,
        max=NYQUIST_FREQ,
        step=1,
        description='Max Freq (Hz):',
        style={'description_width': '140px'},
        continuous_update=False
    ),
    n_trials=widgets.IntSlider(
        value=500,
        min=100,
        max=2000,
        step=100,
        description='Trials:',
        style={'description_width': '140px'},
        continuous_update=False
    ),
    condition_type=widgets.Dropdown(
        options=['Label (Left/Right)', 'Cue', 'Prime'],
        value='Label (Left/Right)',
        description='Type:',
        style={'description_width': '140px'}
    ),
    show_distribution=widgets.Checkbox(
        value=True,
        description='Show Stats',
        style={'description_width': '140px'}
    )
);

interactive(children=(Dropdown(description='Electrode:', index=9, options=('Fp1', 'Fpz', 'Fp2', 'AF3', 'AF4', …

## 8. Time-Frequency Analysis Across All Electrodes

In [60]:
def plot_topographic_spectrogram(trials_data, electrode_names, trial_idx, 
                                 time_window=None, freq_range=(8, 13),
                                 figsize=(16, 12)):
    """
    Create a topographic view of power in frequency bands across FRONTAL AND CENTRAL electrodes.
    Shows spatial distribution of decision-making related activity.
    
    Parameters:
    -----------
    time_window : tuple, optional
        (start_time, end_time) in seconds. If None, uses entire trial.
    freq_range : tuple
        (low_freq, high_freq) in Hz
    """
    n_electrodes = trials_data.shape[0]
    
    # Compute power for each electrode
    powers = []
    
    print(f"Computing power in {freq_range[0]}-{freq_range[1]} Hz band for frontal/central electrodes...")
    for elec_idx in range(n_electrodes):
        signal_data = trials_data[elec_idx, trial_idx, :]
        
        # Extract time window if specified
        if time_window is not None:
            start_idx = int(time_window[0] * SAMPLING_RATE)
            end_idx = int(time_window[1] * SAMPLING_RATE)
            signal_data = signal_data[start_idx:end_idx]
        
        power = compute_band_power(signal_data, SAMPLING_RATE, freq_range)
        powers.append(power)
    
    powers = np.array(powers)
    
    # Create figure
    fig, axes = plt.subplots(2, 3, figsize=figsize)
    axes = axes.flatten()
    
    # Plot power for each frequency band
    for idx, (band, band_range) in enumerate(freq_bands.items()):
        band_powers = []
        
        for elec_idx in range(n_electrodes):
            signal_data = trials_data[elec_idx, trial_idx, :]
            if time_window is not None:
                start_idx = int(time_window[0] * SAMPLING_RATE)
                end_idx = int(time_window[1] * SAMPLING_RATE)
                signal_data = signal_data[start_idx:end_idx]
            
            power = compute_band_power(signal_data, SAMPLING_RATE, band_range)
            band_powers.append(power)
        
        band_powers = np.array(band_powers)
        
        # Sort by power and plot top electrodes
        sorted_indices = np.argsort(band_powers)[::-1]
        top_n = 15
        
        axes[idx].barh(np.arange(top_n), band_powers[sorted_indices[:top_n]], 
                      color=band_colors[band], alpha=0.7)
        axes[idx].set_yticks(np.arange(top_n))
        axes[idx].set_yticklabels(electrode_names[sorted_indices[:top_n]])
        axes[idx].set_xlabel('Power')
        axes[idx].set_title(f'{band} ({band_range[0]}-{band_range[1]} Hz)')
        axes[idx].grid(True, alpha=0.3, axis='x')
        axes[idx].invert_yaxis()
    
    # Remove extra subplot
    fig.delaxes(axes[5])
    
    time_str = f" (t={time_window[0]}-{time_window[1]}s)" if time_window else ""
    fig.suptitle(f'Power Distribution Across Frontal/Central Electrodes - Trial {trial_idx}{time_str}', 
                 fontsize=14, y=0.995)
    plt.tight_layout()
    return fig


# Example: Topographic view for a single trial (focusing on decision-making regions)
example_trial = np.random.randint(0, trials_clean_filtered.shape[1])
print(f"\nTopographic analysis for trial {example_trial}")
print(f"Analyzing power distribution across {len(frontal_central_names)} frontal/central electrodes")
print(f"Label: {true_labels[example_trial]}, Cue: {cues[example_trial]}, Prime: {primes[example_trial]}")


Topographic analysis for trial 28857
Analyzing power distribution across 30 frontal/central electrodes
Label: 2.0, Cue: 2.0, Prime: 3.0


## 9. Statistical Comparison of Spectrograms

In [61]:
def statistical_band_comparison(electrode_idx, trials_data, electrode_names, condition_labels,
                                condition_values, condition_names, max_trials=50000):
    """
    Perform statistical tests comparing band powers between conditions.
    """
    print(f"\nStatistical Comparison - Electrode {electrode_names[electrode_idx]}")
    print("="*70)
    
    # Compute band powers for each condition
    all_band_powers = {}
    for cond_val, cond_name in zip(condition_values, condition_names):
        cond_trials = np.where(condition_labels == cond_val)[0]
        band_powers = compute_all_band_powers(trials_data, electrode_idx, 
                                             cond_trials, max_trials)
        all_band_powers[cond_name] = band_powers
    
    # Perform t-tests for each band
    results = []
    
    for band in freq_bands.keys():
        powers_cond1 = all_band_powers[condition_names[0]][band]
        powers_cond2 = all_band_powers[condition_names[1]][band]
        
        # T-test
        t_stat, p_value = stats.ttest_ind(powers_cond1, powers_cond2)
        
        # Effect size (Cohen's d)
        pooled_std = np.sqrt((np.var(powers_cond1) + np.var(powers_cond2)) / 2)
        cohens_d = (np.mean(powers_cond1) - np.mean(powers_cond2)) / pooled_std
        
        results.append({
            'Band': band,
            'Freq Range': f"{freq_bands[band][0]}-{freq_bands[band][1]} Hz",
            f'{condition_names[0]} Mean': np.mean(powers_cond1),
            f'{condition_names[1]} Mean': np.mean(powers_cond2),
            't-statistic': t_stat,
            'p-value': p_value,
            "Cohen's d": cohens_d,
            'Significant': 'Yes' if p_value < 0.05 else 'No'
        })
    
    # Create DataFrame
    df_results = pd.DataFrame(results)
    
    # Display results
    print("\nBand Power Comparison:")
    print(df_results.to_string(index=False))
    
    # Highlight significant differences
    significant_bands = df_results[df_results['Significant'] == 'Yes']['Band'].tolist()
    if significant_bands:
        print(f"\n*** Significant differences found in: {', '.join(significant_bands)} ***")
    else:
        print("\n*** No significant differences found ***")
    
    return df_results


# Example: Statistical comparison for left vs. right responses on Fz
fz_idx = np.where(frontal_central_names == 'Fz')[0][0] if 'Fz' in frontal_central_names else 0
print(f"\nStatistical analysis for electrode {frontal_central_names[fz_idx]}")
print("This frontal electrode is critical for decision-making and executive control")
df_stats = statistical_band_comparison(fz_idx, trials_clean_filtered, frontal_central_names,
                                      true_labels, [1.0, 2.0], ['Left', 'Right'],
                                      max_trials=50000)


Statistical analysis for electrode Fz
This frontal electrode is critical for decision-making and executive control

Statistical Comparison - Electrode Fz

Band Power Comparison:
 Band Freq Range  Left Mean  Right Mean  t-statistic      p-value  Cohen's d Significant
Delta   0.5-4 Hz   0.155984    0.161735   -11.055199 2.200481e-28  -0.089670         Yes
Theta     4-8 Hz   0.073581    0.071461     6.490559 8.616736e-11   0.052723         Yes
Alpha    8-13 Hz   0.049941    0.047585     7.203117 5.952464e-13   0.058542         Yes
 Beta   13-30 Hz   0.052959    0.052091     3.016098 2.561556e-03   0.024482         Yes
Gamma   30-50 Hz   0.023400    0.023602    -1.013802 3.106811e-01  -0.008211          No

*** Significant differences found in: Delta, Theta, Alpha, Beta ***


## 10. Summary and Export Results

In [62]:
print("\n" + "="*70)
print("SPECTROGRAM ANALYSIS SUMMARY")
print("="*70)
print(f"\nDataset: {DATA_PATH}")
print(f"Total electrodes in dataset: {len(electrodes)}")
print(f"ANALYZED electrodes (frontal + central): {len(frontal_central_names)}")
print(f"  - Frontal electrodes: {sum(1 for e in frontal_central_names if any(str(e).startswith(p) for p in ['Fp', 'AF', 'F']))}")
print(f"  - Central electrodes: {sum(1 for e in frontal_central_names if any(str(e).startswith(p) for p in ['FC', 'C']) and not str(e).startswith('CP'))}")
print(f"Number of trials: {trials_clean_filtered.shape[1]}")
print(f"Trial length: {TRIAL_LENGTH} samples ({TRIAL_LENGTH/SAMPLING_RATE} seconds)")
print(f"Sampling rate: {SAMPLING_RATE} Hz")
print(f"Nyquist frequency: {NYQUIST_FREQ} Hz")

print("\nData Quality:")
print(f"  - NaN values: {nan_count} ({100 * nan_count / trials.size:.2f}%)")
print(f"  - WARNING: High NaN percentage may affect analysis quality")

print("\nFrequency Bands Analyzed:")
for band, (low, high) in freq_bands.items():
    print(f"  - {band}: {low}-{high} Hz")
print(f"  Note: All bands limited by Nyquist frequency of {NYQUIST_FREQ} Hz")

print("\nConditions:")
print(f"  - Labels: {np.unique(true_labels)} ")
print(f"  - Cues: {np.unique(cues)}")
print(f"  - Primes: {np.unique(primes)}")

print("\nMethodology:")
print(f"  - Spectrogram window size: Adaptive (default 32 samples)")
print(f"  - Time resolution: ~{32/SAMPLING_RATE:.3f}s per window")
print(f"  - Frequency resolution: ~{SAMPLING_RATE/32:.2f} Hz")
print(f"  - NaN values replaced with 0 (may affect power estimates)")

print("\n" + "="*70)


SPECTROGRAM ANALYSIS SUMMARY

Dataset: ../data/raw/05_25
Total electrodes in dataset: 58
ANALYZED electrodes (frontal + central): 30
  - Frontal electrodes: 23
  - Central electrodes: 14
Number of trials: 61415
Trial length: 103 samples (1.0 seconds)
Sampling rate: 103 Hz
Nyquist frequency: 51.5 Hz

Data Quality:
  - NaN values: 197814656 (48.86%)

Frequency Bands Analyzed:
  - Delta: 0.5-4 Hz
  - Theta: 4-8 Hz
  - Alpha: 8-13 Hz
  - Beta: 13-30 Hz
  - Gamma: 30-50 Hz
  Note: All bands limited by Nyquist frequency of 51.5 Hz

Conditions:
  - Labels: [1. 2.] 
  - Cues: [1. 2. 4.]
  - Primes: [1. 2. 3.]

Methodology:
  - Spectrogram window size: Adaptive (default 32 samples)
  - Time resolution: ~0.311s per window
  - Frequency resolution: ~3.22 Hz
  - NaN values replaced with 0 (may affect power estimates)

