In [1]:
#!/usr/bin/env python3
"""
Small EEG discrete-emotion demo (ready-to-paste)

- Default: generates a tiny synthetic EEG dataset and maps trials to four discrete
  emotion classes: 'Anxious', 'Depressed', 'Sad', 'Fear'.
- Computes band-power features (delta/theta/alpha/beta/gamma) per channel.
- Performs Leave-One-Subject-Out (LOSO) classification with RandomForest.
- Prints metrics and saves confusion matrix + feature importances into ./output_small_eeg_demo

Notes:
- This is a small demo (seconds to a few minutes depending on CPU).
- The "Anxious/Depressed/Sad/Fear" labels in the synthetic data are illustrative only.
"""

import os
import sys
import subprocess
import importlib
import math
from collections import OrderedDict

# -------------------- Auto-install missing packages (optional) --------------------
_REQ_PKGS = {
    'numpy': 'numpy',
    'scipy': 'scipy',
    'sklearn': 'scikit-learn',
    'matplotlib': 'matplotlib',
    'pandas': 'pandas',
    'tqdm': 'tqdm',
    'requests': 'requests',
    'seaborn': 'seaborn',
}

def ensure_packages(pack_map):
    missing = []
    for import_name, pip_name in pack_map.items():
        try:
            importlib.import_module(import_name)
        except Exception:
            missing.append(pip_name)
    if missing:
        print("Installing missing packages via pip:", missing)
        subprocess.check_call([sys.executable, "-m", "pip", "install"] + missing)
    else:
        print("All required packages present.")

# Uncomment the next line to auto-install if you want; otherwise install packages yourself.
try:
    ensure_packages(_REQ_PKGS)
except Exception as e:
    print("Package auto-install failed or skipped. Make sure required packages are installed.")
    print("Error:", e)

# -------------------- Imports --------------------
import numpy as np
from scipy.signal import welch
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import StandardScaler, label_binarize
from sklearn.model_selection import LeaveOneGroupOut
from sklearn.metrics import (
    accuracy_score, balanced_accuracy_score, confusion_matrix, classification_report,
    roc_auc_score
)
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm
import seaborn as sns

# -------------------- Configuration --------------------
RANDOM_STATE = 42
np.random.seed(RANDOM_STATE)

# Small dataset parameters (fast)
N_SUBJECTS = 8                  # small number of synthetic subjects
TRIALS_PER_SUBJECT = 20         # trials per subject => total 160 trials
N_CHANNELS = 8                  # small channel count (keeps feature vector small)
FS = 128                        # sampling rate (Hz)
DURATION_SEC = 4.0              # seconds per trial (short for speed)
N_SAMPLES = int(FS * DURATION_SEC)
EMOTION_CATS = ['Anxious', 'Depressed', 'Sad', 'Fear']  # discrete emotion labels
OUTPUT_DIR = os.path.abspath("./output_small_eeg_demo")
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Frequency bands used for features
BANDS = OrderedDict([
    ("delta", (1, 4)),
    ("theta", (4, 8)),
    ("alpha", (8, 13)),
    ("beta",  (13, 30)),
    ("gamma", (30, 45)),
])

# Welch parameters (ensure nperseg <= n_samples)
DEFAULT_NPERSEG = min(256, N_SAMPLES)

# -------------------- Utility: bandpower features --------------------
def compute_bandpower_features(eeg_epoch, fs=FS, bands=BANDS, nperseg=DEFAULT_NPERSEG):
    """
    eeg_epoch: ndarray (n_channels, n_samples)
    returns: 1D array: for each channel and each band -> log(relative band power)
    """
    freqs, psd = welch(eeg_epoch, fs=fs, nperseg=nperseg, axis=1)
    n_channels = eeg_epoch.shape[0]
    feats = []
    total_idx = np.logical_and(freqs >= 1, freqs <= 45)
    for ch in range(n_channels):
        ch_psd = psd[ch, :]
        total_power = np.trapz(ch_psd[total_idx], freqs[total_idx]) + 1e-12
        for (low, high) in bands.values():
            idx = np.logical_and(freqs >= low, freqs <= high)
            bp = np.trapz(ch_psd[idx], freqs[idx])
            feats.append(math.log((bp / total_power) + 1e-12))
    return np.array(feats, dtype=float)

# -------------------- Synthetic EEG data generator (small) --------------------
def generate_synthetic_eeg(n_subjects=N_SUBJECTS,
                           trials_per_subject=TRIALS_PER_SUBJECT,
                           n_channels=N_CHANNELS,
                           fs=FS, duration=DURATION_SEC,
                           classes=EMOTION_CATS, random_state=RANDOM_STATE):
    """
    Produce a small synthetic EEG dataset (time-series) and corresponding labels/groups.
    Returns:
      signals: ndarray (n_trials, n_channels, n_samples)
      labels: list[str] length n_trials
      groups: ndarray subject indices length n_trials
    The synthetic signals are constructed by summing band-limited sine components with
    class-specific band amplitude multipliers and simple channel topographies.
    """
    rng = np.random.RandomState(random_state)
    n_samples = int(fs * duration)
    n_trials = n_subjects * trials_per_subject
    signals = np.zeros((n_trials, n_channels, n_samples), dtype=float)
    labels = []
    groups = np.zeros(n_trials, dtype=int)

    # Balanced label assignment
    labels_cycle = []
    n_classes = len(classes)
    # create a balanced shuffled label list
    base = classes * ( (n_trials // n_classes) + 1 )
    labels_cycle = base[:n_trials]
    rng.shuffle(labels_cycle)

    # class-specific band multipliers (heuristic / illustrative)
    class_band_multipliers = {
        'Anxious':  {'delta':0.9,'theta':1.0,'alpha':0.9,'beta':1.6,'gamma':1.5},
        'Depressed':{'delta':1.0,'theta':1.4,'alpha':0.6,'beta':0.8,'gamma':0.7},
        'Sad':      {'delta':1.1,'theta':1.2,'alpha':0.8,'beta':0.9,'gamma':0.8},
        'Fear':     {'delta':0.9,'theta':1.1,'alpha':0.8,'beta':1.7,'gamma':1.6},
    }

    # small channel topography: frontal channels (0..2) get boosted for high-arousal classes
    channel_topography = np.ones(n_channels)
    for ch in range(n_channels):
        if ch <= 2:
            channel_topography[ch] = 1.15
        elif 3 <= ch <= 4:
            channel_topography[ch] = 1.05
        else:
            channel_topography[ch] = 0.95

    t = np.arange(n_samples) / fs

    for i in range(n_trials):
        cls = labels_cycle[i]
        subj = i // trials_per_subject
        labels.append(cls)
        groups[i] = subj

        # For each channel, sum a few sinusoids inside each band with amplitudes per class
        for ch in range(n_channels):
            sig = np.zeros(n_samples, dtype=float)
            topo = channel_topography[ch]
            for band, (low, high) in BANDS.items():
                # choose 1-2 components per band
                n_comps = rng.choice([1, 2])
                for _c in range(n_comps):
                    freq = rng.uniform(low, high)
                    phase = rng.uniform(0, 2*np.pi)
                    # base amplitude small; class multipliers shape the spectrum
                    base_amp = 1.0
                    class_mul = class_band_multipliers.get(cls, {}).get(band, 1.0)
                    amp = base_amp * class_mul * topo * (1.0 + 0.15 * rng.randn())
                    sig += amp * np.sin(2 * np.pi * freq * t + phase)
            # add Gaussian noise (subject + trial variability)
            noise_level = 0.5 * (1.0 + 0.2 * rng.randn())
            sig += noise_level * rng.randn(n_samples)
            signals[i, ch, :] = sig

    return signals, labels, groups

# -------------------- LOSO evaluation (multiclass) --------------------
def loso_multiclass_evaluate(X, y, groups, class_names, classifier_type='rf', random_state=RANDOM_STATE):
    """
    Perform Leave-One-Subject-Out multiclass evaluation.
    Returns a result dict with predictions, probs, metrics, confusion matrix, and last classifier.
    """
    logo = LeaveOneGroupOut()
    n_samples = X.shape[0]
    n_classes = len(class_names)
    preds = np.zeros(n_samples, dtype=int)
    probs = np.zeros((n_samples, n_classes), dtype=float)

    last_clf = None
    for fold_i, (train_idx, test_idx) in enumerate(logo.split(X, y, groups)):
        # scale based only on training data
        scaler = StandardScaler().fit(X[train_idx])
        Xtr = scaler.transform(X[train_idx])
        Xte = scaler.transform(X[test_idx])

        if classifier_type == 'rf':
            clf = RandomForestClassifier(n_estimators=200, n_jobs=-1, random_state=random_state)
        else:
            raise ValueError("Unsupported classifier_type, only 'rf' implemented in demo.")

        clf.fit(Xtr, y[train_idx])
        preds[test_idx] = clf.predict(Xte)
        if hasattr(clf, "predict_proba"):
            probs[test_idx] = clf.predict_proba(Xte)
        else:
            # fallback: one-hot from decision_function (not expected here)
            probs[test_idx] = 0.0
            probs[test_idx, preds[test_idx]] = 1.0

        last_clf = clf

    # Metrics
    acc = accuracy_score(y, preds)
    bacc = balanced_accuracy_score(y, preds)
    cls_report = classification_report(y, preds, target_names=class_names, zero_division=0)
    cm = confusion_matrix(y, preds)

    # multiclass ROC AUC (one-vs-rest macro)
    try:
        y_bin = label_binarize(y, classes=list(range(n_classes)))
        auc_macro = roc_auc_score(y_bin, probs, average='macro', multi_class='ovr')
    except Exception:
        auc_macro = float('nan')

    return {
        'preds': preds,
        'probs': probs,
        'acc': acc,
        'bacc': bacc,
        'auc_macro': auc_macro,
        'class_report': cls_report,
        'cm': cm,
        'clf': last_clf
    }

# -------------------- Small plotting helpers --------------------
def plot_confusion_matrix(cm, class_names, outpath):
    plt.figure(figsize=(6,5))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.title('Confusion matrix (LOSO)')
    plt.tight_layout()
    plt.savefig(outpath)
    plt.close()

def plot_feature_importances(importances, feature_names, outpath, topk=20):
    idx_sorted = np.argsort(importances)[::-1]
    top_idx = idx_sorted[:topk]
    top_names = [feature_names[i] for i in top_idx][::-1]
    top_vals = importances[top_idx][::-1]
    plt.figure(figsize=(7, max(3, 0.25*len(top_names))))
    y_pos = np.arange(len(top_names))
    plt.barh(y_pos, top_vals, align='center')
    plt.yticks(y_pos, top_names)
    plt.xlabel("Importance")
    plt.title("Top feature importances (RF)")
    plt.tight_layout()
    plt.savefig(outpath)
    plt.close()

# -------------------- Main pipeline --------------------
def main(use_real_dataset=False, real_dataset_path=None):
    """
    If use_real_dataset=True and real_dataset_path is provided, the script will try to load it.
    Otherwise the script uses the synthetic small dataset.
    """
    # 1) Load or generate dataset
    if use_real_dataset and real_dataset_path is not None:
        # Placeholder: user-supplied loader could be implemented here.
        raise NotImplementedError("Real dataset loading is not implemented in this demo. "
                                  "Please supply a preprocessed feature matrix or request a custom loader.")
    else:
        print("Generating small synthetic EEG dataset (fast demo).")
        signals, labels_str, groups = generate_synthetic_eeg()
        n_trials = signals.shape[0]
        print(f"Generated {n_trials} trials, {signals.shape[1]} channels, {signals.shape[2]} samples per trial.")

        # Save raw synthetic dataset to CSV/Excel for review (wide format: one trial per row)
        try:
            n_trials, n_channels, n_samples = signals.shape
            # Column names: trial, subject, label, then ch1_s0 ... chN_sM
            col_names = ['trial', 'subject', 'label']
            for ch in range(n_channels):
                for s in range(n_samples):
                    col_names.append(f"ch{ch+1}_s{s}")
            data_rows = []
            for i in range(n_trials):
                row = [i, int(groups[i]), labels_str[i]]
                # Flatten channel-major: ch0 samples then ch1 samples ...
                row.extend(signals[i].reshape(-1).tolist())
                data_rows.append(row)
            df_signals = pd.DataFrame(data_rows, columns=col_names)
            csv_signals_path = os.path.join(OUTPUT_DIR, "synthetic_signals_wide.csv")
            df_signals.to_csv(csv_signals_path, index=False)
            print("Saved synthetic dataset (wide CSV) to:", csv_signals_path)

            excel_signals_path = os.path.join(OUTPUT_DIR, "synthetic_signals_wide.xlsx")
            try:
                df_signals.to_excel(excel_signals_path, index=False)
                print("Saved synthetic dataset (Excel) to:", excel_signals_path)
            except Exception as e:
                print("Failed to save Excel file. Install 'openpyxl' or 'xlsxwriter' to enable Excel export. Error:", e)
        except Exception as e:
            print("Failed to save synthetic dataset to CSV/Excel:", e)

        # Compute features
        X_list = []
        print("Computing band-power features (per trial) ...")
        for i in tqdm(range(n_trials), desc="Feature extraction"):
            eeg_epoch = signals[i]  # shape (n_channels, n_samples)
            feats = compute_bandpower_features(eeg_epoch)
            X_list.append(feats)
        X = np.vstack(X_list)
        # label encoding (int)
        class_to_idx = {c:i for i,c in enumerate(EMOTION_CATS)}
        y = np.array([class_to_idx[c] for c in labels_str], dtype=int)
        groups = np.array(groups, dtype=int)

    print("Feature matrix X shape:", X.shape)
    print("Label distribution:", {EMOTION_CATS[i]: int(np.sum(y==i)) for i in range(len(EMOTION_CATS))})
    # Feature names
    feature_names = []
    for ch in range(signals.shape[1]):
        for band in BANDS.keys():
            feature_names.append(f"{band}_ch{ch+1}")

    # 2) LOSO evaluation
    print("\nRunning LOSO multiclass evaluation (RandomForest)...")
    results = loso_multiclass_evaluate(X, y, groups, EMOTION_CATS, classifier_type='rf')
    print(f"Accuracy: {results['acc']:.4f}")
    print(f"Balanced Accuracy: {results['bacc']:.4f}")
    print(f"Macro ROC AUC (OVR): {results['auc_macro']:.4f}")
    print("\nClassification report:\n")
    print(results['class_report'])
    print("Confusion matrix:\n", results['cm'])

    # Save confusion matrix plot
    cm_path = os.path.join(OUTPUT_DIR, "confusion_matrix_loso.png")
    plot_confusion_matrix(results['cm'], EMOTION_CATS, cm_path)
    print("Saved confusion matrix to:", cm_path)

    # 3) Train RF on full dataset to get feature importances (approximate)
    print("\nTraining RandomForest on the full feature set to obtain feature importances (approx).")
    scaler_full = StandardScaler().fit(X)
    X_full = scaler_full.transform(X)
    rf_full = RandomForestClassifier(n_estimators=500, n_jobs=-1, random_state=RANDOM_STATE)
    rf_full.fit(X_full, y)
    importances = rf_full.feature_importances_
    topk = min(20, len(importances))
    top_idx = np.argsort(importances)[::-1][:topk]
    print(f"Top {topk} features (name, importance):")
    for i in top_idx:
        print(f"  {feature_names[i]:20s}  {importances[i]:.6f}")

    # Save feature importances plot
    fi_path = os.path.join(OUTPUT_DIR, "feature_importances.png")
    plot_feature_importances(importances, feature_names, fi_path, topk=topk)
    print("Saved feature importances to:", fi_path)

    # 4) Save simple CSV summary
    df = pd.DataFrame({
        'subject': groups,
        'label': [EMOTION_CATS[int(lbl)] for lbl in y],
        'pred': [EMOTION_CATS[int(p)] for p in results['preds']]
    })
    csv_path = os.path.join(OUTPUT_DIR, "trial_predictions.csv")
    df.to_csv(csv_path, index=False)
    print("Saved trial-level predictions to:", csv_path)

    print("\nDone. Outputs are in:", OUTPUT_DIR)
    print("Reminder: synthetic labels are illustrative. For real data, provide a loader or dataset path.")

if __name__ == "__main__":
    # By default, run the synthetic demo. If you want me to wire a small real dataset
    # (e.g., DREAMER) into this pipeline, say so and I'll add the downloader/loader.
    main(use_real_dataset=False)

All required packages present.
Generating small synthetic EEG dataset (fast demo).
Generated 160 trials, 8 channels, 512 samples per trial.
Saved synthetic dataset (wide CSV) to: /Users/stageacomeback/Desktop/output_small_eeg_demo/synthetic_signals_wide.csv
Failed to save Excel file. Install 'openpyxl' or 'xlsxwriter' to enable Excel export. Error: No module named 'openpyxl'
Computing band-power features (per trial) ...


  total_power = np.trapz(ch_psd[total_idx], freqs[total_idx]) + 1e-12
  bp = np.trapz(ch_psd[idx], freqs[idx])
Feature extraction: 100%|██████████| 160/160 [00:00<00:00, 1668.32it/s]

Feature matrix X shape: (160, 40)
Label distribution: {'Anxious': 40, 'Depressed': 40, 'Sad': 40, 'Fear': 40}

Running LOSO multiclass evaluation (RandomForest)...





Accuracy: 0.8250
Balanced Accuracy: 0.8250
Macro ROC AUC (OVR): 0.9660

Classification report:

              precision    recall  f1-score   support

     Anxious       0.74      0.70      0.72        40
   Depressed       0.93      0.93      0.93        40
         Sad       0.93      0.93      0.93        40
        Fear       0.71      0.75      0.73        40

    accuracy                           0.82       160
   macro avg       0.83      0.82      0.82       160
weighted avg       0.83      0.82      0.82       160

Confusion matrix:
 [[28  0  0 12]
 [ 0 37  3  0]
 [ 0  3 37  0]
 [10  0  0 30]]
Saved confusion matrix to: /Users/stageacomeback/Desktop/output_small_eeg_demo/confusion_matrix_loso.png

Training RandomForest on the full feature set to obtain feature importances (approx).
Top 20 features (name, importance):
  theta_ch2             0.063483
  theta_ch6             0.057592
  theta_ch7             0.043826
  alpha_ch5             0.040927
  alpha_ch8             0.040