In [1]:
import os
import io

import numpy as np
import pandas as pd
import scipy.io as sio
from google.cloud import storage




In [2]:
# Configure which .mat object to load (defaults to S001)
mat_object = os.getenv('GCS_OBJECT_MAIN', 'Wearable SSVEP Dataset/S001.mat')
subject_id = os.path.splitext(os.path.basename(mat_object))[0]

client = storage.Client.from_service_account_json(
    os.getenv('GOOGLE_APPLICATION_CREDENTIALS')
)
bucket = client.bucket(os.getenv('GCP_BUCKET_NAME'))

print(f'Loading {mat_object} for subject {subject_id}')
blob = bucket.blob(mat_object)
data_bytes = blob.download_as_bytes()
mat = sio.loadmat(io.BytesIO(data_bytes))

data = mat['data']  # expected shape (channels, time, electrode, block, target)
print('Loaded data shape:', data.shape)

Loading Wearable SSVEP Dataset/S001.mat for subject S001
Loaded data shape: (8, 710, 2, 10, 12)


In [3]:
# Build epoch-level dataframe for a single subject (one row per electrode x block x target)
expected_shape = (8, 710, 2, 10, 12)
if data.shape != expected_shape:
    raise ValueError(f'Unexpected data shape {data.shape}, expected {expected_shape}')

rows = []
electrode_map = {0: 'wet', 1: 'dry'}

for e_idx in range(data.shape[2]):
    for b_idx in range(data.shape[3]):
        for t_idx in range(data.shape[4]):
            rows.append(
                {
                    'subject': subject_id,
                    'electrode': electrode_map.get(e_idx, str(e_idx)),
                    'block': b_idx + 1,   # 1-based
                    'target': t_idx + 1,  # 1-based
                    'signal': data[:, :, e_idx, b_idx, t_idx],  # shape (8, 710)
                }
            )

all_epochs = pd.DataFrame(rows)
print('DataFrame shape:', all_epochs.shape)
print(all_epochs.head())
print('Electrode counts:')
print(all_epochs['electrode'].value_counts())

DataFrame shape: (240, 5)
  subject electrode  block  target  \
0    S001       wet      1       1   
1    S001       wet      1       2   
2    S001       wet      1       3   
3    S001       wet      1       4   
4    S001       wet      1       5   

                                              signal  
0  [[-52325.52005800724, -53157.22841961553, -554...  
1  [[-56806.91044371786, -58819.819022406286, -60...  
2  [[-58424.81901793594, -57290.15513003797, -551...  
3  [[-59719.47668310861, -61462.01857698074, -609...  
4  [[-57667.36314698238, -60079.16093344456, -608...  
Electrode counts:
electrode
wet    120
dry    120
Name: count, dtype: int64


In [4]:
# Build full multi-subject epoch DataFrame (102 subjects x 240 epochs)

import re


def load_mat_from_gcs(filepath):
    """Download a .mat from GCS and return the 'data' array."""
    client = storage.Client.from_service_account_json(
        os.getenv('GOOGLE_APPLICATION_CREDENTIALS')
    )
    bucket = client.bucket(os.getenv('GCP_BUCKET_NAME'))
    blob = bucket.blob(filepath)
    bytes_data = blob.download_as_bytes()
    mat = sio.loadmat(io.BytesIO(bytes_data))
    if 'data' not in mat:
        raise KeyError(f"'data' variable missing in {filepath}")
    arr = mat['data']
    if arr.shape != (8, 710, 2, 10, 12):
        raise ValueError(f"Unexpected shape {arr.shape} in {filepath}")
    return arr


def build_subject_df(subject_id, data):
    """Convert one subject's 5-D array into a 240-row DataFrame."""
    expected = (8, 710, 2, 10, 12)
    if data.shape != expected:
        raise ValueError(f"Unexpected shape {data.shape}, expected {expected}")

    rows = []
    for e_idx in range(2):
        for b_idx in range(10):
            for t_idx in range(12):
                rows.append(
                    {
                        'subject': subject_id,
                        'electrode': 'wet' if e_idx == 0 else 'dry',
                        'block': b_idx + 1,
                        'target': t_idx + 1,
                        'signal': data[:, :, e_idx, b_idx, t_idx],
                    }
                )
    return pd.DataFrame(rows)


# List all subject .mat files (skip non-subject files like Impedance.mat)
client = storage.Client.from_service_account_json(
    os.getenv('GOOGLE_APPLICATION_CREDENTIALS')
)
bucket = client.bucket(os.getenv('GCP_BUCKET_NAME'))
prefix = 'Wearable SSVEP Dataset/'
mat_files = []
for b in bucket.list_blobs(prefix=prefix):
    name = b.name
    base = os.path.basename(name)
    if not base.endswith('.mat'):
        continue
    if not re.match(r'S\d+\.mat$', base):
        continue  # skip non-subject files (e.g., Impedance.mat)
    mat_files.append(name)
mat_files.sort()
print(f'Found {len(mat_files)} subject .mat files')

# Build combined DataFrame
all_subject_dfs = []
for fp in mat_files:
    subject_id = os.path.splitext(os.path.basename(fp))[0]
    try:
        data_arr = load_mat_from_gcs(fp)
    except (KeyError, ValueError) as exc:
        print(f'Skipping {fp}: {exc}')
        continue
    df_subj = build_subject_df(subject_id, data_arr)
    all_subject_dfs.append(df_subj)

full_df = pd.concat(all_subject_dfs, ignore_index=True)

# Summaries
print('full_df shape:', full_df.shape)
print(full_df.head())
print('Subject counts:')
print(full_df['subject'].value_counts())
print('Electrode counts:')
print(full_df['electrode'].value_counts())


Found 102 subject .mat files
full_df shape: (24480, 5)
  subject electrode  block  target  \
0    S001       wet      1       1   
1    S001       wet      1       2   
2    S001       wet      1       3   
3    S001       wet      1       4   
4    S001       wet      1       5   

                                              signal  
0  [[-52325.52005800724, -53157.22841961553, -554...  
1  [[-56806.91044371786, -58819.819022406286, -60...  
2  [[-58424.81901793594, -57290.15513003797, -551...  
3  [[-59719.47668310861, -61462.01857698074, -609...  
4  [[-57667.36314698238, -60079.16093344456, -608...  
Subject counts:
subject
S001    240
S065    240
S075    240
S074    240
S073    240
       ... 
S032    240
S031    240
S030    240
S029    240
S102    240
Name: count, Length: 102, dtype: int64
Electrode counts:
electrode
wet    12240
dry    12240
Name: count, dtype: int64


In [5]:
import numpy as np
from scipy.signal import butter, filtfilt, iirnotch

FS = 250  # sampling rate in Hz (given by the dataset paper)


def bandpass(epoch, low=8, high=90, fs=FS, order=4):
    """
    Band-pass filter: keep only frequencies between `low` and `high` Hz.
    epoch: array of shape (n_channels, n_samples)
    """
    nyq = fs / 2
    b, a = butter(order, [low / nyq, high / nyq], btype="band")
    return filtfilt(b, a, epoch, axis=-1)


def notch(epoch, freq=50, fs=FS, q=30):
    """
    Notch filter: remove a narrow band around `freq` Hz (e.g. 50 Hz mains noise).
    """
    nyq = fs / 2
    w0 = freq / nyq
    b, a = iirnotch(w0, Q=q)
    return filtfilt(b, a, epoch, axis=-1)


def preprocess_epoch(epoch_raw):
    """
    Full preprocessing for a single trial.
    Input:  epoch_raw shape (8, 710)
    Output: preprocessed epoch shape (8, 500)
    """
    # Always work in float64 to avoid numerical weirdness
    epoch = epoch_raw.astype(np.float64)

    # 1) Band-pass 8–90 Hz: keep SSVEP fundamentals + harmonics, remove drifts & ultra-high noise
    epoch = bandpass(epoch, low=8, high=90, fs=FS, order=4)

    # 2) Notch 50 & 100 Hz: remove mains electrical hum and its first harmonic
    epoch = notch(epoch, freq=50, fs=FS, q=30)
    epoch = notch(epoch, freq=100, fs=FS, q=30)

    # 3) Baseline correction using -0.5s to 0s (first 125 samples at 250 Hz)
    baseline = epoch[:, :125].mean(axis=-1, keepdims=True)
    epoch = epoch - baseline

    # 4) Trim to main SSVEP window.
    #    0.5s pre + ~0.14s visual delay => start ~0.64s after trial onset: sample 160
    #    2s of stimulation => 500 samples => 160:660
    epoch = epoch[:, 160:660]  # shape (8, 500)

    # 5) Per-epoch, per-channel z-score:
    #    make each channel have mean 0 and std 1 in this epoch, so scales are comparable.
    mean = epoch.mean(axis=-1, keepdims=True)
    std = epoch.std(axis=-1, keepdims=True) + 1e-8
    epoch = (epoch - mean) / std

    return epoch


In [6]:
full_df['signal_pp'] = full_df['signal'].apply(preprocess_epoch)
pp_full_df = full_df
pp_full_df

Unnamed: 0,subject,electrode,block,target,signal,signal_pp
0,S001,wet,1,1,"[[-52325.52005800724, -53157.22841961553, -554...","[[1.70329534663277, 1.005246832872975, -0.6968..."
1,S001,wet,1,2,"[[-56806.91044371786, -58819.819022406286, -60...","[[2.1173829193890454, 0.23705557026530585, -1...."
2,S001,wet,1,3,"[[-58424.81901793594, -57290.15513003797, -551...","[[-3.4062449388976734, -0.6271521744047096, 3...."
3,S001,wet,1,4,"[[-59719.47668310861, -61462.01857698074, -609...","[[0.2641990410447794, -0.39355795229029994, -0..."
4,S001,wet,1,5,"[[-57667.36314698238, -60079.16093344456, -608...","[[0.633041792724005, -0.27745217316847137, -0...."
...,...,...,...,...,...,...
24475,S102,dry,10,8,"[[10971.978960751234, 11017.442406263495, 1104...","[[-1.371439566731729, -1.0397680194580703, -2...."
24476,S102,dry,10,9,"[[11072.293583887433, 11040.330591221487, 1100...","[[0.19956229808600792, 0.5625620121231513, -1...."
24477,S102,dry,10,10,"[[11305.064636770763, 11329.919775123582, 1130...","[[-0.25180437619401635, -0.1536912581035672, 0..."
24478,S102,dry,10,11,"[[11089.77264700965, 11050.031247736886, 11049...","[[1.9501752759802595, 1.570491017693854, 1.201..."


In [7]:
# Use only dry-electrode trials
mask = (pp_full_df['electrode'] == 'dry')
df_dry = pp_full_df.loc[mask].reset_index(drop=True)

# Stack preprocessed signals into a big 3D array: (n_trials, n_channels, n_samples)
X = np.stack(df_dry['signal_pp'].values)   # shape (N_trials, 8, 500)

# Targets: 1..12 in the dataset; convert to 0..11 for indexing convenience
y = df_dry['target'].values - 1

# Keep subject labels for optional per-subject analysis
subjects = df_dry['subject'].values

print("X shape:", X.shape)   # (N_dry_trials, 8, 500)
print("y shape:", y.shape, "unique targets:", np.unique(y))

X shape: (12240, 8, 500)
y shape: (12240,) unique targets: [ 0  1  2  3  4  5  6  7  8  9 10 11]


In [8]:

FS = 250          # sampling rate (Hz)
N_HARMONICS = 5   # number of harmonics to use in references

# Target-wise frequency & phase from stimulation_information.pdf -----

FREQ_PER_TARGET = np.array([
    9.25,   # target 1  → '1'
    11.25,  # target 2  → '2'
    13.25,  # target 3  → '3'
    9.75,   # target 4  → '4'
    11.75,  # target 5  → '5'
    13.75,  # target 6  → '6'
    10.25,  # target 7  → '7'
    12.25,  # target 8  → '8'
    14.25,  # target 9  → '9'
    10.75,  # target 10 → '0'
    12.75,  # target 11 → '*'
    14.75   # target 12 → '#'
], dtype=float)

PHASE_PER_TARGET = np.array([
    0.0 * np.pi,  # target 1
    0.0 * np.pi,  # target 2
    0.0 * np.pi,  # target 3
    0.5 * np.pi,  # target 4
    0.5 * np.pi,  # target 5
    0.5 * np.pi,  # target 6
    1.0 * np.pi,  # target 7
    1.0 * np.pi,  # target 8
    1.0 * np.pi,  # target 9
    1.5 * np.pi,  # target 10
    1.5 * np.pi,  # target 11
    1.5 * np.pi   # target 12
], dtype=float)

In [9]:
def make_ref_signals_targetwise(freqs, phases, n_samples, fs=FS, n_harmonics=N_HARMONICS):
    """
    Build one reference matrix per target using that target's frequency AND phase.

    Returns: list of length n_targets
        refs[k] has shape (2 * n_harmonics, n_samples)
    """
    t = np.arange(n_samples) / fs
    refs = []

    for f, phi in zip(freqs, phases):
        components = []
        for h in range(1, n_harmonics + 1):
            # sin and cos with correct phase offset
            components.append(np.sin(2 * np.pi * h * f * t + phi))
            components.append(np.cos(2 * np.pi * h * f * t + phi))
        Y = np.stack(components, axis=0)  # (2 * n_harmonics, n_samples)
        refs.append(Y)

    return refs

In [10]:
from scipy.linalg import svd
from sklearn.model_selection import StratifiedKFold

# Filter-bank definition: five contiguous 8 Hz bands from 8–50 Hz
FILTER_BANKS = [
    (8, 16),
    (16, 24),
    (24, 32),
    (32, 40),
    (40, 50),
]


def cca_on_trial(X_trial, Y_ref, reg=1e-6):
    """Return the top canonical correlation between one trial and one reference."""
    # Center each channel/component
    Xc = X_trial - X_trial.mean(axis=1, keepdims=True)
    Yc = Y_ref - Y_ref.mean(axis=1, keepdims=True)
    n = Xc.shape[1]

    # Covariance + small ridge for stability
    Sxx = (Xc @ Xc.T) / (n - 1) + reg * np.eye(Xc.shape[0])
    Syy = (Yc @ Yc.T) / (n - 1) + reg * np.eye(Yc.shape[0])
    Sxy = (Xc @ Yc.T) / (n - 1)

    # Whitening transforms via eigen decomposition
    def inv_sqrt(mat):
        vals, vecs = np.linalg.eigh(mat)
        vals = np.clip(vals, reg, None)
        return vecs @ np.diag(1.0 / np.sqrt(vals)) @ vecs.T

    Wx = inv_sqrt(Sxx)
    Wy = inv_sqrt(Syy)

    # Solve CCA via SVD of the whitened cross-covariance
    K = Wx @ Sxy @ Wy
    _, s, _ = svd(K, full_matrices=False)
    return float(np.clip(s[0], 0.0, 1.0))


def fbcca_classify(X_trial, target_refs, filter_bank_params=FILTER_BANKS, fs=FS):
    """Classify one preprocessed trial using FBCCA and return the predicted target index."""
    rhos = []
    for band_idx, (f_low, f_high) in enumerate(filter_bank_params, start=1):
        # Band-specific filtering only (input is already preprocessed)
        trial_filt = bandpass(X_trial, low=f_low, high=f_high, fs=fs, order=4)
        trial_filt = notch(trial_filt, freq=50, fs=fs, q=30)
        trial_filt = notch(trial_filt, freq=100, fs=fs, q=30)

        # CCA against each target reference
        rho_per_target = [cca_on_trial(trial_filt, ref) for ref in target_refs]
        rhos.append(rho_per_target)

    rhos = np.asarray(rhos)  # shape (n_bands, n_targets)
    weights = np.array([(i ** -1.25) + 0.25 for i in range(1, len(filter_bank_params) + 1)])
    scores = (weights[:, None] * (rhos ** 2)).sum(axis=0)
    return int(np.argmax(scores))


In [11]:
# Reference signals for the full window (500 samples)
ref_signals = make_ref_signals_targetwise(
    FREQ_PER_TARGET,
    PHASE_PER_TARGET,
    n_samples=500,
    fs=FS
)

In [12]:
# Split into train+val and test sets (80/20 stratified) used by BOTH models
from sklearn.model_selection import StratifiedShuffleSplit

sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=2025)
trainval_idx, test_idx = next(sss.split(X, y))

X_trainval, y_trainval = X[trainval_idx], y[trainval_idx]
X_test, y_test = X[test_idx], y[test_idx]

print("Train+Val:", X_trainval.shape)
print("Test:", X_test.shape)


Train+Val: (9792, 8, 500)
Test: (2448, 8, 500)


In [13]:
fbcca_test_preds = [fbcca_classify(trial, ref_signals) for trial in X_test]
fbcca_test_acc = (np.array(fbcca_test_preds) == y_test).mean()

print(f"FBCCA test accuracy: {fbcca_test_acc:.4f}")

FBCCA test accuracy: 0.7594


In [14]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt

# -----------------------------
# Compact EEGNet-like CNN model
# -----------------------------
class EEGNet(nn.Module):
    def __init__(self, n_chans=8, n_samples=500, n_classes=12, dropout=0.5):
        super().__init__()
        self.n_chans = n_chans
        self.n_samples = n_samples

        self.conv_time = nn.Conv2d(1, 8, kernel_size=(1, 64), padding=(0, 32), bias=False)
        self.bn1 = nn.BatchNorm2d(8)

        # Depthwise convolution across channels
        self.depthwise = nn.Conv2d(8, 16, kernel_size=(n_chans, 1), groups=8, bias=False)
        self.bn2 = nn.BatchNorm2d(16)
        self.pool1 = nn.AvgPool2d(kernel_size=(1, 4))
        self.dropout1 = nn.Dropout(dropout)

        # Separable convolution (depthwise temporal + pointwise)
        self.separable = nn.Sequential(
            nn.Conv2d(16, 16, kernel_size=(1, 16), padding=(0, 8), groups=16, bias=False),
            nn.Conv2d(16, 16, kernel_size=(1, 1), bias=False),
        )
        self.bn3 = nn.BatchNorm2d(16)
        self.pool2 = nn.AvgPool2d(kernel_size=(1, 8))
        self.dropout2 = nn.Dropout(dropout)

        # Compute final feature size after convolutions/pooling
        dummy = torch.zeros(1, 1, n_chans, n_samples)
        with torch.no_grad():
            feat = self._forward_features(dummy)
        self.classifier = nn.Linear(feat.shape[1], n_classes)

    def _forward_features(self, x):
        x = self.conv_time(x)
        x = self.bn1(x)
        x = F.elu(x)

        x = self.depthwise(x)
        x = self.bn2(x)
        x = F.elu(x)
        x = self.pool1(x)
        x = self.dropout1(x)

        x = self.separable(x)
        x = self.bn3(x)
        x = F.elu(x)
        x = self.pool2(x)
        x = self.dropout2(x)

        return x.flatten(start_dim=1)

    def forward(self, x):
        x = self._forward_features(x)
        return self.classifier(x)


In [15]:
# train_eegnet.py (or one notebook cell)

import torch
import numpy as np
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import StratifiedShuffleSplit

# -------------------
# Hyperparameters
# -------------------
EPOCHS = 40
BATCH_SIZE = 32
LR = 1e-3
WEIGHT_DECAY = 1e-3
DROPOUT = 0.3
PATIENCE = 6
CKPT_PATH = "eegnet_tuned.pth"

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# -------------------
# Inner train/val split (ONLY on X_trainval)
# -------------------
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=2024)
train_idx, val_idx = next(sss.split(X_trainval, y_trainval))

X_tr = np.expand_dims(X_trainval[train_idx], 1)
y_tr = y_trainval[train_idx]
X_va = np.expand_dims(X_trainval[val_idx], 1)
y_va = y_trainval[val_idx]

train_loader = DataLoader(
    TensorDataset(torch.from_numpy(X_tr).float(),
                  torch.from_numpy(y_tr).long()),
    batch_size=BATCH_SIZE,
    shuffle=True
)

val_loader = DataLoader(
    TensorDataset(torch.from_numpy(X_va).float(),
                  torch.from_numpy(y_va).long()),
    batch_size=BATCH_SIZE
)

# -------------------
# Model + optimiser
# -------------------
model = EEGNet(
    n_chans=8,
    n_samples=500,
    n_classes=12,
    dropout=DROPOUT
).to(DEVICE)

optimizer = torch.optim.Adam(
    model.parameters(),
    lr=LR,
    weight_decay=WEIGHT_DECAY
)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=EPOCHS
)

criterion = torch.nn.CrossEntropyLoss()

# -------------------
# Training loop
# -------------------
best_val = 0.0
wait = 0

for epoch in range(1, EPOCHS + 1):
    # ---- Train ----
    model.train()
    for xb, yb in train_loader:
        xb, yb = xb.to(DEVICE), yb.to(DEVICE)
        optimizer.zero_grad()
        loss = criterion(model(xb), yb)
        loss.backward()
        optimizer.step()

    scheduler.step()

    # ---- Validate ----
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for xb, yb in val_loader:
            xb, yb = xb.to(DEVICE), yb.to(DEVICE)
            preds = model(xb).argmax(1)
            correct += (preds == yb).sum().item()
            total += yb.numel()

    val_acc = correct / total
    print(f"Epoch {epoch:02d} | val acc {val_acc:.4f}")

    # ---- Early stopping ----
    if val_acc > best_val:
        best_val = val_acc
        wait = 0
        torch.save(model.state_dict(), CKPT_PATH)
    else:
        wait += 1
        if wait >= PATIENCE:
            break

print(f"✅ Training complete. Best val acc: {best_val:.4f}")

Epoch 01 | val acc 0.5472
Epoch 02 | val acc 0.6830
Epoch 03 | val acc 0.7289
Epoch 04 | val acc 0.7412
Epoch 05 | val acc 0.7432
Epoch 06 | val acc 0.7611
Epoch 07 | val acc 0.7534
Epoch 08 | val acc 0.7611
Epoch 09 | val acc 0.7703
Epoch 10 | val acc 0.7693
Epoch 11 | val acc 0.7698
Epoch 12 | val acc 0.7754
Epoch 13 | val acc 0.7728
Epoch 14 | val acc 0.7769
Epoch 15 | val acc 0.7739
Epoch 16 | val acc 0.7774
Epoch 17 | val acc 0.7779
Epoch 18 | val acc 0.7759
Epoch 19 | val acc 0.7769
Epoch 20 | val acc 0.7800
Epoch 21 | val acc 0.7800
Epoch 22 | val acc 0.7805
Epoch 23 | val acc 0.7785
Epoch 24 | val acc 0.7820
Epoch 25 | val acc 0.7769
Epoch 26 | val acc 0.7820
Epoch 27 | val acc 0.7876
Epoch 28 | val acc 0.7861
Epoch 29 | val acc 0.7876
Epoch 30 | val acc 0.7846
Epoch 31 | val acc 0.7856
Epoch 32 | val acc 0.7846
Epoch 33 | val acc 0.7887
Epoch 34 | val acc 0.7836
Epoch 35 | val acc 0.7831
Epoch 36 | val acc 0.7841
Epoch 37 | val acc 0.7851
Epoch 38 | val acc 0.7846
Epoch 39 | v

In [16]:
# evaluate_eegnet.py

import torch
import numpy as np
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# -------------------
# Load trained model
# -------------------
model = EEGNet(n_chans=8, n_samples=500, n_classes=12)
model.load_state_dict(torch.load("eegnet_tuned.pth", map_location=DEVICE))
model.to(DEVICE).eval()

# -------------------
# Predict on test set
# -------------------
X_test_t = torch.from_numpy(np.expand_dims(X_test, 1)).float().to(DEVICE)

preds = []
with torch.no_grad():
    for i in range(0, len(X_test_t), 64):
        logits = model(X_test_t[i:i+64])
        preds.append(logits.argmax(1).cpu().numpy())

y_pred = np.concatenate(preds)

# -------------------
# Metrics
# -------------------
acc = accuracy_score(y_test, y_pred)
cm = confusion_matrix(y_test, y_pred)

print(f"✅ Test accuracy: {acc:.4f}")
print(classification_report(y_test, y_pred, digits=3))
print("Confusion matrix:\n", cm)

✅ Test accuracy: 0.7831
              precision    recall  f1-score   support

           0      0.729     0.779     0.754       204
           1      0.779     0.794     0.786       204
           2      0.850     0.779     0.813       204
           3      0.741     0.770     0.755       204
           4      0.731     0.799     0.763       204
           5      0.820     0.828     0.824       204
           6      0.810     0.838     0.824       204
           7      0.805     0.770     0.787       204
           8      0.765     0.686     0.724       204
           9      0.829     0.833     0.831       204
          10      0.816     0.848     0.832       204
          11      0.729     0.672     0.699       204

    accuracy                          0.783      2448
   macro avg      0.784     0.783     0.783      2448
weighted avg      0.784     0.783     0.783      2448

Confusion matrix:
 [[159   5   2   8  10   4   4   5   0   4   1   2]
 [  6 162   2   7   4   6   4   4   2  

In [17]:
# inference.py

import torch
import numpy as np

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = EEGNet(n_chans=8, n_samples=500, n_classes=12)
model.load_state_dict(torch.load("eegnet_tuned.pth", map_location=DEVICE))
model.to(DEVICE).eval()

def predict_epoch(epoch):
    """
    epoch: numpy array (8, 500), preprocessed exactly like training
    returns: predicted class index (0..11)
    """
    with torch.no_grad():
        x = torch.from_numpy(epoch[None, None, ...]).float().to(DEVICE)
        return int(model(x).argmax(1).cpu().item())