# EEG Emotion Classification with EEGNet

This notebook walks through the complete pipeline for classifying emotions from EEG data using deep learning.

## Overview

We use a **Muse 2** headband to record EEG signals from 4 channels (TP9, AF7, AF8, TP10) and train an **EEGNet** deep learning model to classify emotional states:
- Happy
- Calm
- Stressed
- Focused

The pipeline consists of:
1. **Recording** EEG data while experiencing different emotions
2. **Loading & Augmenting** the recorded data
3. **Training** an EEGNet model
4. **Live Prediction** with real-time visualization and MQTT publishing

## Step 1: Imports and Setup

First, we import all necessary libraries for EEG data acquisition, processing, and deep learning.

In [2]:
import time
import numpy as np
import pandas as pd
from pathlib import Path
import joblib

# BrainFlow for Muse 2 data acquisition
from brainflow.board_shim import BoardShim, BrainFlowInputParams, BoardIds
from brainflow.data_filter import DataFilter, FilterTypes

# For real-time visualization and state machine
from collections import deque
import matplotlib.pyplot as plt

# MQTT for communication with robot
import paho.mqtt.client as mqtt
import json

# Deep Learning with PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# Machine Learning utilities
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import confusion_matrix

## Step 2: Constants and Device Configuration

We define key constants for the Muse 2 board and configure the computation device (CPU/GPU/MPS).

In [3]:
# Muse 2 Board Configuration
BOARD_ID = BoardIds.MUSE_2_BOARD.value
SAMPLING_RATE = BoardShim.get_sampling_rate(BOARD_ID)  # 256 Hz
EEG_CHANNEL_NAMES = BoardShim.get_eeg_names(BOARD_ID)  # ['TP9', 'AF7', 'AF8', 'TP10']
EEG_CHANNELS_LIST = BoardShim.get_eeg_channels(BOARD_ID)

# Window size for prediction (2 seconds = 512 samples at 256 Hz)
WINDOW_SEC = 2
WINDOW_SAMPLES = int(WINDOW_SEC * SAMPLING_RATE)

# Data directories
BASE_DATA_DIR = Path("eeg_data")
BASE_MODELS_DIR = Path("models")

# Device configuration (use GPU if available, MPS for Apple Silicon, else CPU)
if torch.cuda.is_available():
    DEVICE = "cuda"
elif torch.backends.mps.is_available():
    DEVICE = "mps"
else:
    DEVICE = "cpu"

print(f"Using device: {DEVICE}")
print(f"Sampling rate: {SAMPLING_RATE} Hz")
print(f"Window size: {WINDOW_SAMPLES} samples ({WINDOW_SEC}s)")
print(f"EEG channels: {EEG_CHANNEL_NAMES}")

Using device: mps
Sampling rate: 256 Hz
Window size: 512 samples (2s)
EEG channels: ['TP9', 'AF7', 'AF8', 'TP10']


In [4]:
class EEGNet(nn.Module):
    def __init__(self, nb_classes, Chans=4, Samples=WINDOW_SAMPLES,
                 dropoutRate=0.5, kernLength=64, F1=8,
                 D=2, F2=16, norm_rate=0.25):
        super(EEGNet, self).__init__()

        # Block 1
        self.block1 = nn.Sequential(
            nn.Conv2d(1, F1, (1, kernLength), padding=(0, kernLength // 2), bias=False),
            nn.BatchNorm2d(F1),
            nn.Conv2d(F1, F1 * D, (Chans, 1), groups=F1, bias=False),
            nn.BatchNorm2d(F1 * D),
            nn.ELU(),
            nn.AvgPool2d((1, 4)),
            nn.Dropout(dropoutRate)
        )

        # Block 2
        self.block2 = nn.Sequential(
            nn.Conv2d(F1 * D, F1 * D, (1, 16), padding=(0, 16 // 2), groups=F1 * D, bias=False),
            nn.Conv2d(F1 * D, F2, (1, 1), bias=False),
            nn.BatchNorm2d(F2),
            nn.ELU(),
            nn.AvgPool2d((1, 8)),
            nn.Dropout(dropoutRate)
        )

        # Classifier
        demo_input = torch.randn(1, 1, Chans, Samples)
        with torch.no_grad():
            flattened_size = self._get_flattened_size(demo_input)
        
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(flattened_size, nb_classes)
        )

    def _get_flattened_size(self, x):
        x = self.block1(x)
        x = self.block2(x)
        return x.view(1, -1).size(1)

    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.classifier(x)
        return x

## Step 3: Recording EEG Data

The `record()` function connects to the Muse 2, records EEG data, and performs quality checks.

**Process:**
1. Connect to Muse 2 headband via BrainFlow
2. Give user 5-second countdown to focus on the target emotion
3. Record EEG data for specified duration (default 30s)
4. Check data quality by measuring standard deviation (noise threshold: 80 µV)
5. Save clean recordings to `eeg_data/<user>/` or noisy ones to `eeg_data/<user>/rejected/`

In [5]:
def record(user_id, label, duration=30):
    """Record EEG data for a specific emotion."""
    params = BrainFlowInputParams()
    board = BoardShim(BOARD_ID, params)

    # Create directories
    data_dir = BASE_DATA_DIR / user_id
    rejected_dir = data_dir / "rejected"
    data_dir.mkdir(parents=True, exist_ok=True)
    rejected_dir.mkdir(exist_ok=True)
    
    try:
        # Start Muse 2 session
        board.prepare_session()
        board.start_stream(45000 * 2)
        
        # Countdown before recording
        print(f"\nRecording '{label.upper()}' in 5 seconds...")
        for i in range(5, 0, -1):
            print(f"{i}...")
            time.sleep(1)

        # Record for specified duration
        print(f"RECORDING '{label.upper()}' for {duration}s...")
        for i in range(duration):
            time.sleep(1)
            board.get_current_board_data(1)  # Keep connection alive
            print(f"{i+1}/{duration}s", end='\r')
        
        # Get all recorded data
        data = board.get_board_data()
        
    finally:
        if board.is_prepared():
            print("\nRecording complete")
            board.stop_stream()
            board.release_session()

    # Quality check: measure noise on each channel
    eeg_data = data[EEG_CHANNELS_LIST]
    ARTIFACT_THRESHOLD_STD = 80.0  # microvolts
    
    is_noisy = False
    for i in range(eeg_data.shape[0]):
        channel_std = np.std(eeg_data[i])
        if channel_std > ARTIFACT_THRESHOLD_STD:
            print(f"WARNING: Channel {EEG_CHANNEL_NAMES[i]} exceeds threshold ({channel_std:.1f} > {ARTIFACT_THRESHOLD_STD})")
            is_noisy = True
    
    # Save to appropriate directory
    timestamp = int(time.time())
    if is_noisy:
        print("Recording NOISY - saved to 'rejected' folder")
        file_path = rejected_dir / f"{label}_{timestamp}.csv"
    else:
        print("Recording CLEAN")
        file_path = data_dir / f"{label}_{timestamp}.csv"
    
    # Save as CSV with channel names as columns
    df_to_save = pd.DataFrame(np.transpose(eeg_data))
    df_to_save.columns = EEG_CHANNEL_NAMES
    df_to_save.to_csv(file_path, index=False)
    print(f"Saved {eeg_data.shape[1]} samples to {file_path}")

### Example: Record Calm State

Uncomment and run to record 30 seconds of EEG while in a calm state:

In [7]:
# record(user_id="yk", label="calm", duration=30)

## Step 4: Data Loading and Augmentation

The `load_and_augment_data()` function prepares training data by:

**Windowing:**
- Splits continuous EEG into 2-second windows (512 samples)
- Uses overlapping windows (stride = 0.25s) to increase dataset size

**Data Augmentation** (randomly applied):
- **Noise injection**: Add Gaussian noise (5% of signal std)
- **Polarity reversal**: Flip signal upside down
- **Amplitude scaling**: Scale by 0.9-1.1x

**Filtering:**
- Bandpass filter: 1-45 Hz (removes DC drift and high-frequency noise)
- Bandstop filter: 48-52 Hz (removes powerline interference)

**Output:**
- X: Shape (samples, 1, 4 channels, 512 timepoints) - ready for EEGNet
- y: Integer labels (0, 1, 2, 3 for different emotions)
- le: LabelEncoder to convert between labels and emotion names

In [8]:
def load_and_augment_data(user_id):
    """Load all recordings and create augmented training set."""
    data_dir = BASE_DATA_DIR / user_id
    if not data_dir.exists() or not any(data_dir.glob('*.csv')):
        print(f"Error: Data directory '{data_dir}' is empty or does not exist.")
        return None, None, None

    features, labels = [], []
    stride = int(0.25 * SAMPLING_RATE)  # 64 samples = 0.25s overlap
    noise_factor = 0.05

    print("Loading and preprocessing data...")
    for file_path in data_dir.rglob("*.csv"):
        # Skip noisy recordings
        if 'rejected' in file_path.parts: 
            continue
        
        # Extract emotion label from filename (e.g., "calm_1234567890.csv" -> "calm")
        label = file_path.stem.split('_')[0]
        data_df = pd.read_csv(file_path)
        data = np.transpose(data_df[EEG_CHANNEL_NAMES].values).copy()

        # Create overlapping windows
        for i in range(0, data.shape[1] - WINDOW_SAMPLES, stride):
            window = data[:, i:i + WINDOW_SAMPLES].copy()
            
            # Apply random augmentation (50% chance)
            if np.random.rand() > 0.5:
                aug_type = np.random.randint(0, 3)
                if aug_type == 0:  # Noise injection
                    window += np.random.normal(0, np.std(window) * noise_factor, window.shape)
                elif aug_type == 1:  # Polarity reversal
                    window *= -1
                elif aug_type == 2:  # Amplitude scaling
                    window *= np.random.uniform(0.9, 1.1)

            # Apply bandpass and notch filters
            for ch in range(window.shape[0]):
                DataFilter.perform_bandpass(window[ch], SAMPLING_RATE, 1.0, 45.0, 4, FilterTypes.BUTTERWORTH_ZERO_PHASE, 0)
                DataFilter.perform_bandstop(window[ch], SAMPLING_RATE, 48.0, 52.0, 4, FilterTypes.BUTTERWORTH_ZERO_PHASE, 0)
            
            features.append(window)
            labels.append(label)

    # Convert to numpy arrays
    X = np.array(features)
    le = LabelEncoder()
    y = le.fit_transform(labels)
    
    # Reshape for EEGNet: (samples, 1, channels, timepoints)
    X = X.reshape(X.shape[0], 1, X.shape[1], X.shape[2])

    print(f"Total samples: {X.shape[0]}")
    print(f"Classes: {le.classes_}")
    return X, y, le

### Example: Load Data

Load all recordings for user "yk":

In [9]:
X, y, le = load_and_augment_data("yk")
print(f"Data shape: {X.shape}")
print(f"Labels shape: {y.shape}")
print(f"Emotion classes: {le.classes_}")

Loading and preprocessing data...
Total samples: 13480
Classes: ['calm' 'focused' 'happy' 'stressed']
Data shape: (13480, 1, 4, 512)
Labels shape: (13480,)
Emotion classes: ['calm' 'focused' 'happy' 'stressed']


## Step 5: Training the EEGNet Model

EEGNet is a compact convolutional neural network designed specifically for EEG classification.

**Training Process:**
1. **Data Split**: 60% train, 16% validation, 20% test (with stratification)
2. **Model**: EEGNet with 0.75 dropout rate to prevent overfitting
3. **Optimizer**: Adam with learning rate 0.001 and L2 regularization
4. **Early Stopping**: Stops if validation accuracy doesn't improve for 10 epochs
5. **Evaluation**: Prints confusion matrix and per-class accuracy on test set

**Model Architecture:**
- Temporal convolution to learn time-domain features
- Depthwise convolution for spatial filtering (channel relationships)
- Separable convolution for feature extraction
- Dense layer for classification

In [10]:
def train(user_id):
    """Train EEGNet model on user's recordings."""
    user_model_dir = BASE_MODELS_DIR / user_id
    user_model_dir.mkdir(parents=True, exist_ok=True)
    
    # Load and prepare data
    X, y, le = load_and_augment_data(user_id)
    if X is None: 
        return

    # Save label encoder for later use in prediction
    joblib.dump(le, user_model_dir / "label_encoder_eegnet_4.pkl")
    
    # Split into train/validation/test sets
    X_train_val, X_test, y_train_val, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42, stratify=y
    )
    X_train, X_val, y_train, y_val = train_test_split(
        X_train_val, y_train_val, test_size=0.2, random_state=42, stratify=y_train_val
    )
    
    # Create PyTorch data loaders
    train_loader = DataLoader(
        TensorDataset(torch.from_numpy(X_train).float(), torch.from_numpy(y_train).long()), 
        batch_size=32, shuffle=True
    )
    val_loader = DataLoader(
        TensorDataset(torch.from_numpy(X_val).float(), torch.from_numpy(y_val).long()), 
        batch_size=32
    )
    test_loader = DataLoader(
        TensorDataset(torch.from_numpy(X_test).float(), torch.from_numpy(y_test).long()), 
        batch_size=32
    )
    
    # Initialize EEGNet model
    model = EEGNet(
        nb_classes=len(le.classes_), 
        Chans=X_train.shape[2],  # 4 channels
        Samples=X_train.shape[3],  # 512 samples
        dropoutRate=0.75
    ).to(DEVICE)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)

    # Training loop with early stopping
    best_val_accuracy = 0
    patience_counter = 0
    patience = 10
    
    print(f"\nTraining EEGNet on device: {DEVICE}")
    for epoch in range(100):
        # Training phase
        model.train()
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        
        # Validation phase
        model.eval()
        val_correct, val_total = 0, 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
                outputs = model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()
        
        val_accuracy = 100 * val_correct / val_total
        print(f'Epoch {epoch+1:02d}.. Val Accuracy: {val_accuracy:.2f}%')

        # Save best model and check for early stopping
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            patience_counter = 0
            torch.save(model.state_dict(), user_model_dir / "emotion_model_eegnet_4.pth")
        else:
            patience_counter += 1
        
        if patience_counter >= patience:
            print(f"Early stopping at epoch {epoch+1}")
            break
    
    # Test on held-out test set
    print("\nTesting on final data...")
    model.load_state_dict(torch.load(user_model_dir / "emotion_model_eegnet_4.pth"))
    model.eval()
    
    all_predictions = []
    all_labels = []
    
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            all_predictions.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    # Calculate and display metrics
    all_predictions = np.array(all_predictions)
    all_labels = np.array(all_labels)
    final_accuracy = 100 * np.sum(all_predictions == all_labels) / len(all_labels)
    
    print(f"\nFINAL ACCURACY: {final_accuracy:.2f}%")
    
    # Print confusion matrix
    cm = confusion_matrix(all_labels, all_predictions)
    class_names = le.classes_
    
    print(f"\n{'':>12}", end="")
    for name in class_names:
        print(f"{name:>12}", end="")
    print()
    
    for i, name in enumerate(class_names):
        print(f"{name:>12}", end="")
        for j in range(len(class_names)):
            print(f"{cm[i,j]:>12}", end="")
        print()
    
    # Per-class accuracy
    print()
    for i, class_name in enumerate(class_names):
        class_mask = all_labels == i
        class_correct = np.sum((all_predictions == all_labels) & class_mask)
        class_total = np.sum(class_mask)
        class_accuracy = 100 * class_correct / class_total if class_total > 0 else 0
        print(f"{class_name.upper():>10}: {class_accuracy:>6.2f}%")
    
    print(f"\nModel saved to {user_model_dir / 'emotion_model_eegnet_4.pth'}")

### Example: Train Model

Train the model on user "yk"'s recordings:

In [11]:
train("yk")

Loading and preprocessing data...
Total samples: 13480
Classes: ['calm' 'focused' 'happy' 'stressed']

Training EEGNet on device: mps
Epoch 01.. Val Accuracy: 40.52%
Epoch 02.. Val Accuracy: 46.92%
Epoch 03.. Val Accuracy: 51.27%
Epoch 04.. Val Accuracy: 76.68%
Epoch 05.. Val Accuracy: 84.89%
Epoch 06.. Val Accuracy: 86.00%
Epoch 07.. Val Accuracy: 86.32%
Epoch 08.. Val Accuracy: 85.81%
Epoch 09.. Val Accuracy: 87.34%
Epoch 10.. Val Accuracy: 90.03%
Epoch 11.. Val Accuracy: 89.66%
Epoch 12.. Val Accuracy: 90.77%
Epoch 13.. Val Accuracy: 90.26%
Epoch 14.. Val Accuracy: 91.19%
Epoch 15.. Val Accuracy: 90.59%
Epoch 16.. Val Accuracy: 90.54%
Epoch 17.. Val Accuracy: 90.77%
Epoch 18.. Val Accuracy: 90.91%
Epoch 19.. Val Accuracy: 90.77%
Epoch 20.. Val Accuracy: 91.33%
Epoch 21.. Val Accuracy: 91.52%
Epoch 22.. Val Accuracy: 91.10%
Epoch 23.. Val Accuracy: 91.24%
Epoch 24.. Val Accuracy: 92.07%
Epoch 25.. Val Accuracy: 91.24%
Epoch 26.. Val Accuracy: 91.79%
Epoch 27.. Val Accuracy: 92.63%
Ep

## Step 6: Live Prediction with Visualization

The `predict()` function performs real-time emotion classification with:

**Real-time Pipeline:**
1. **Data Acquisition**: Continuously read EEG from Muse 2
2. **Visualization**: Plot 4-channel EEG in real-time (5-second rolling window)
3. **Prediction**: Every 2 seconds, classify the last 2-second window
4. **State Machine**: Use a 10-prediction buffer with 70% majority voting to prevent jitter
5. **MQTT Publishing**: Send state changes to robot via MQTT

**Outputs:**
- Live matplotlib plot showing EEG signals
- Console output showing predicted states
- MQTT messages to `/set_state` topic with format: `{"state": "calm", "confidence": 0.85}`
- Filtered EEG data to `/eeg` topic for visualization

In [12]:
def predict(user_id, mqtt_ip="192.168.1.13"):
    """Run live emotion classification with real-time visualization."""
    print(f"Using device: {DEVICE}")
    
    # Setup MQTT connection (optional)
    mqtt_connected = False
    client = None
    
    if mqtt_ip:
        MQTT_TOPIC = "/set_state"
        client = mqtt.Client()
        
        try:
            client.connect(mqtt_ip, 1883, 60)
            client.loop_start() 
            print("MQTT Connected")
            mqtt_connected = True
        except Exception as e:
            print(f"MQTT Failed: {e}")
    else:
        print("Running without MQTT")

    # Load trained model and label encoder
    user_model_dir = BASE_MODELS_DIR / user_id
    le_path = user_model_dir / "label_encoder_eegnet_4.pkl"
    model_path = user_model_dir / "emotion_model_eegnet_4.pth"
    
    if not le_path.exists() or not model_path.exists():
        print("Error: Model or label encoder not found")
        return
    
    le = joblib.load(le_path)
    model = EEGNet(nb_classes=len(le.classes_), Chans=4, Samples=WINDOW_SAMPLES).to(DEVICE)
    model.load_state_dict(torch.load(model_path, map_location=DEVICE))
    model.eval()
    print("Model loaded")

    # Setup BrainFlow board
    params = BrainFlowInputParams()
    board = BoardShim(BOARD_ID, params)

    # State machine for stable predictions
    prediction_buffer = deque(maxlen=10)  # Last 10 predictions
    state_threshold = 0.7  # 70% majority needed to change state
    current_stable_state = "neutral"
    
    # Setup visualization (5-second rolling window)
    seconds_to_show = 5
    buffer_size = seconds_to_show * SAMPLING_RATE
    data_buffers = np.zeros((len(EEG_CHANNELS_LIST), buffer_size))
    
    plt.ion()
    fig, axs = plt.subplots(len(EEG_CHANNELS_LIST), 1, figsize=(12, 8), sharex=True)
    lines = [ax.plot(np.zeros(buffer_size))[0] for ax in axs]
    
    for i, ax in enumerate(axs):
        ax.set_title(EEG_CHANNEL_NAMES[i])
        ax.set_ylim(-200, 200)  # Typical EEG range in µV
        ax.grid(True)
    axs[-1].set_xlabel('Time (s)')
    fig.text(0.06, 0.5, 'Voltage (uV)', va='center', rotation='vertical')
    
    time_axis = np.arange(-seconds_to_show, 0, 1.0/SAMPLING_RATE)

    # Color map for different states
    color_map = {
        'calm': '#3498db',
        'focused': '#f39c12',
        'happy': '#2ecc71',
        'stressed': '#e74c3c',
        'neutral': '#95a5a6'
    }

    # Prediction timing
    last_prediction_time = time.time()
    prediction_interval = WINDOW_SEC  # Predict every 2 seconds
    prediction_label = "INIT"
    confidence = 0.0

    try:
        board.prepare_session()
        board.start_stream(45000)
        time.sleep(2)  # Let buffer fill
        print("Starting live prediction...")
        
        while True:
            # Get new data from Muse
            new_data = board.get_board_data()
            
            if new_data.shape[1] > 0:
                new_eeg_data = new_data[EEG_CHANNELS_LIST]
                num_new_samples = new_eeg_data.shape[1]
                
                # Update rolling buffer
                data_buffers = np.roll(data_buffers, -num_new_samples, axis=1)
                data_buffers[:, -num_new_samples:] = new_eeg_data
                
                # Filter for visualization
                viz_buffers = data_buffers.copy()
                for ch in range(viz_buffers.shape[0]):
                    DataFilter.perform_bandpass(viz_buffers[ch], SAMPLING_RATE, 1.0, 45.0, 4, FilterTypes.BUTTERWORTH_ZERO_PHASE, 0)
                    DataFilter.perform_bandstop(viz_buffers[ch], SAMPLING_RATE, 48.0, 52.0, 4, FilterTypes.BUTTERWORTH_ZERO_PHASE, 0)

                # Update plots
                for i, line in enumerate(lines):
                    line.set_ydata(viz_buffers[i])
                    if line.get_xdata().shape[0] != time_axis.shape[0]:
                        line.set_xdata(time_axis)
                
                for ax in axs:
                    ax.relim()
                    ax.autoscale_view(scalex=False, scaley=True)

                # Publish filtered EEG via MQTT
                if mqtt_connected:
                    new_filtered_data = viz_buffers[:, -num_new_samples:]
                    eeg_payload = {
                        name: new_filtered_data[i].tolist() 
                        for i, name in enumerate(EEG_CHANNEL_NAMES)
                    }
                    client.publish("/eeg", json.dumps(eeg_payload))

                # Run prediction every 2 seconds
                if (time.time() - last_prediction_time) > prediction_interval:
                    last_prediction_time = time.time()
                    
                    # Extract 2-second window for model
                    analysis_window = np.ascontiguousarray(data_buffers[:, -WINDOW_SAMPLES:])
                    
                    if analysis_window.shape[1] == WINDOW_SAMPLES:
                        # Filter the analysis window
                        for ch in range(analysis_window.shape[0]):
                            DataFilter.perform_bandpass(analysis_window[ch], SAMPLING_RATE, 1.0, 45.0, 4, FilterTypes.BUTTERWORTH_ZERO_PHASE, 0)
                            DataFilter.perform_bandstop(analysis_window[ch], SAMPLING_RATE, 48.0, 52.0, 4, FilterTypes.BUTTERWORTH_ZERO_PHASE, 0)

                        # Run EEGNet inference
                        input_tensor = torch.tensor(analysis_window, dtype=torch.float32).reshape(1, 1, 4, WINDOW_SAMPLES).to(DEVICE)
                        with torch.no_grad():
                            outputs = model(input_tensor)
                            probs = torch.softmax(outputs, dim=1).cpu().numpy()[0]
                            pred_idx = np.argmax(probs)
                            prediction_label = le.inverse_transform([pred_idx])[0]
                            confidence = probs[pred_idx]

                        # State machine with majority voting
                        prediction_buffer.append(prediction_label)
                        if len(prediction_buffer) == prediction_buffer.maxlen:
                            unique_labels, counts = np.unique(list(prediction_buffer), return_counts=True)
                            majority_label = unique_labels[np.argmax(counts)]
                            majority_percent = np.max(counts) / prediction_buffer.maxlen

                            # Change state only if 70%+ of buffer agrees
                            if majority_label != current_stable_state and majority_percent >= state_threshold:
                                current_stable_state = majority_label
                                
                                # Publish state change via MQTT
                                if mqtt_connected:
                                    payload = json.dumps({
                                        "state": current_stable_state,
                                        "confidence": float(confidence)
                                    })
                                    client.publish(MQTT_TOPIC, payload)
                                    print(f"\n>>> STATE: {current_stable_state.upper()} ({confidence*100:.0f}%)")
                                else:
                                    print(f"\n>>> STATE: {current_stable_state.upper()} ({confidence*100:.0f}%)")

                # Update plot title with current state
                buffer_viz = " ".join([l[:3].upper() for l in prediction_buffer])
                title_str = (
                    f"ROBOT STATE: {current_stable_state.upper()}\n"
                    f"Buffer: [{buffer_viz}] | Raw: {prediction_label.upper()} ({confidence*100:.0f}%)"
                )
                fig.suptitle(
                    title_str,
                    fontsize=16,
                    fontweight='bold',
                    color=color_map.get(current_stable_state, '#000000')
                )

                plt.pause(0.001)
            else:
                time.sleep(0.01)

    except KeyboardInterrupt:
        print("\nStopping...")
    finally:
        if board.is_prepared():
            board.stop_stream()
            board.release_session()
        if mqtt_connected and client:
            client.loop_stop()
            client.disconnect()
        plt.ioff()
        plt.show()

### Example: Run Live Prediction

Start real-time emotion classification (requires Muse 2 to be connected):

In [14]:
# predict(user_id="yk", mqtt_ip="192.168.1.13")

## Summary

This notebook demonstrated the complete EEG emotion classification pipeline:

1. **Recording**: Captured EEG data for different emotional states with quality checking
2. **Preprocessing**: Loaded data, applied augmentation, and filtered signals
3. **Training**: Trained EEGNet model with early stopping and evaluated performance
4. **Prediction**: Performed real-time classification with state machine filtering and MQTT integration

### Key Design Decisions:

- **2-second windows**: Balance between temporal resolution and enough data for classification
- **Overlapping windows**: Increases training data without collecting more recordings
- **Data augmentation**: Helps model generalize to different signal variations
- **State machine**: Prevents jitter by requiring consistent predictions before state change
- **Bandpass filtering**: Removes artifacts while preserving relevant EEG frequencies (1-45 Hz)

### Next Steps:

- Record more data for improved accuracy
- Experiment with different emotions
- Try different window sizes or models
- Integrate with your robot for emotion-driven behaviors!