<a href="https://colab.research.google.com/github/nehasharmn/Hybrid_CNN_Transformer_CHBMIT/blob/main/tiny_transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Install Required EEG and Deep Learning Libraries
This block installs MNE for EEG processing, Braindecode for EEG deep learning utilities, and PyTorch for model training.


In [116]:
!pip install mne braindecode torch torchvision torchaudio



###  Mount Google Drive and Set Data Directory
Mounts Google Drive so the notebook can load EDF files from the CHB-MIT dataset.


In [117]:
from google.colab import drive
drive.mount('/content/drive')

DATA_ROOT = "/content/drive/MyDrive/chbmit_data"
print("Using:", DATA_ROOT)


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Using: /content/drive/MyDrive/chbmit_data


### Required Summary Files for Seizure Labels
Defines which CHB-MIT summary text files must be present for seizure extraction


In [118]:
import os

required = [
    "RECORDS-WITH-SEIZURES",
    "chb01-summary.txt",
    "chb02-summary.txt",
    "chb03-summary.txt",
    "chb13-summary.txt",
    "chb20-summary.txt"
]

print("Checking summary files...")
for f in required:
    path = os.path.join(DATA_ROOT, f)
    print(f"{f}:", "OK" if os.path.exists(path) else "MISSING")


Checking summary files...
RECORDS-WITH-SEIZURES: OK
chb01-summary.txt: OK
chb02-summary.txt: OK
chb03-summary.txt: OK
chb13-summary.txt: OK
chb20-summary.txt: OK


###  Parse CHB-MIT Summary Files to Extract Seizure Intervals
This function reads .txt summary files and extracts seizure start and end timestamps for each EDF recording


In [119]:
def parse_summary_file(path):
    events = {}
    current_file = None
    start = None

    with open(path, "r") as f:
        for line in f:
            line = line.strip()

            if line.startswith("File Name:"):
                current_file = line.split(":")[1].strip().lower()
                events[current_file] = []
                continue

            if "Start Time" in line and "Seizure" in line:
                start = float(
                    line.split(":")[1].replace("seconds", "").strip()
                )
                continue

            if "End Time" in line and "Seizure" in line:
                end = float(
                    line.split(":")[1].replace("seconds", "").strip()
                )
                events[current_file].append((start, end))
                continue

    return events

SUMMARY = {
    "chb01": parse_summary_file(os.path.join(DATA_ROOT, "chb01-summary.txt")),
    "chb02": parse_summary_file(os.path.join(DATA_ROOT, "chb02-summary.txt")),
    "chb03": parse_summary_file(os.path.join(DATA_ROOT, "chb03-summary.txt")),
    "chb13": parse_summary_file(os.path.join(DATA_ROOT, "chb13-summary.txt")),
    "chb20": parse_summary_file(os.path.join(DATA_ROOT, "chb20-summary.txt")),
}

def load_seizure_intervals(patient, edf_filename):
    return SUMMARY.get(patient, {}).get(edf_filename.lower(), [])

### Load EDF Files and Normalize to a Fixed 23-Channels
Different EDF recordings have different channel sets. This block maps them into a consistent channel order or fills missing channels with zeros.


In [120]:
import mne
import numpy as np

CHB_CHANNELS = [
    "FP1-F7","F7-T7","T7-P7","P7-O1",
    "FP1-F3","F3-C3","C3-P3","P3-O1",
    "FP2-F4","F4-C4","C4-P4","P4-O2",
    "FP2-F8","F8-T8","T8-P8","P8-O2",
    "FZ-CZ","CZ-PZ","P7-T7",
    "T7-FT9","FT9-FT10","FT10-T8","T8-P8"
]

def load_raw_fixed_channels(edf_path):
    raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)

    current = [ch.lower() for ch in raw.ch_names]

    data_list = []
    for ch in CHB_CHANNELS:
        name = ch.lower()
        if name in current:
            idx = current.index(name)
            data_list.append(raw.get_data(picks=[idx]))
        else:
            data_list.append(np.zeros((1, raw.n_times)))

    return np.vstack(data_list), raw.info["sfreq"]


### Extract Overlapping Windows With Balanced Seizure/Non-Seizure Sampling
This block slices EEG into fixed length windows, associates each with seizure labels, and down samples non-seizure windows to reduce class imbalance.


In [121]:
def extract_windows_balanced(
    edf_path, patient,
    window_size=4.0,
    stride=0.5,
    non_seizure_fraction=0.01
):
    data, sfreq = load_raw_fixed_channels(edf_path)
    data = (data - data.mean(axis=1, keepdims=True)) / (data.std(axis=1, keepdims=True) + 1e-6)
    n_samples = data.shape[1]

    intervals = load_seizure_intervals(patient, os.path.basename(edf_path))
    seizure_segments = [(int(s * sfreq), int(e * sfreq)) for s, e in intervals]

    win_len = int(window_size * sfreq)
    step = int(stride * sfreq)

    seizure_X, seizure_y = [], []
    non_X, non_y = [], []

    for start in range(0, n_samples - win_len, step):
        end = start + win_len
        seg = data[:, start:end]

        is_seizure = any(not (end < s_start or start > s_end)
                         for s_start, s_end in seizure_segments)

        if is_seizure:
            seizure_X.append(seg)
            seizure_y.append(1)
        else:
            non_X.append(seg)
            non_y.append(0)

    if len(non_X) > 0:
        keep = int(len(non_X) * non_seizure_fraction)
        idx = np.random.choice(len(non_X), keep, replace=False)
        non_X = np.array(non_X)[idx]
        non_y = np.array(non_y)[idx]
    else:
        non_X = np.array([])
        non_y = np.array([])

    if len(seizure_X) > 0:
        seizure_X = np.array(seizure_X)
        seizure_y = np.array(seizure_y)

        X = np.concatenate([seizure_X, non_X], axis=0)
        y = np.concatenate([seizure_y, non_y], axis=0)
    else:
        X = non_X
        y = non_y

    return X, y


###  Build Training and Validation Sets From  EDF Files
Windows from selected patients are concatenated into training and validation datasets.


In [122]:
train_patients = {"chb01", "chb02", "chb03", "chb13"}
val_patient = "chb20"

train_X, train_y = [], []
val_X, val_y = [], []

for edf in sorted(os.listdir(DATA_ROOT)):
    if not edf.endswith(".edf"):
        continue

    patient = edf[:5].lower()
    path = os.path.join(DATA_ROOT, edf)

    X, y = extract_windows_balanced(path, patient, window_size= 2)

    if patient in train_patients:
        train_X.append(X)
        train_y.append(y)
    elif patient == val_patient:
        val_X.append(X)
        val_y.append(y)

train_X = np.vstack(train_X)
train_y = np.hstack(train_y)

val_X = np.vstack(val_X)
val_y = np.hstack(val_y)

print("Train:", train_X.shape, np.unique(train_y, return_counts=True))
print("Val:", val_X.shape, np.unique(val_y, return_counts=True))


  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=Fa

KeyboardInterrupt: 

### Analyze Class Imbalance in Validation Set
Shows how many windows contain seizures vs. non-seizures.


In [86]:
import numpy as np

total = len(val_y)
seizures = np.sum(val_y == 1)
non_seizures = np.sum(val_y == 0)

print("Total val windows:", total)
print("Seizure windows:", seizures)
print("Non-seizure windows:", non_seizures)
print("Percent seizure: {:.4f}%".format(100 * seizures / total))


Total val windows: 1020
Seizure windows: 628
Non-seizure windows: 392
Percent seizure: 61.5686%


### Analyze Class Imbalance in Training Set
Shows how many windows contain seizures vs. non-seizures.


In [87]:
import numpy as np

total = len(train_y)
seizures = np.sum(train_y == 1)
non_seizures = np.sum(train_y == 0)

print("Total train windows:", total)
print("Seizure windows:", seizures)
print("Non-seizure windows:", non_seizures)
print("Percent seizure: {:.4f}%".format(100 * seizures / total))


Total train windows: 9041
Seizure windows: 3247
Non-seizure windows: 5794
Percent seizure: 35.9142%


### Tiny EEG Transformer Architecture
A hybrid CNN + Transformer model that extracts spatial-temporal EEG features for seizure detection.


In [88]:
import torch
import torch.nn as nn
from einops import rearrange

class EEGTransformer(nn.Module):
    def __init__(self, embed_dim=64, num_heads=4, num_layers=2):
        super().__init__()

        self.conv = nn.Sequential(
            nn.Conv1d(23, embed_dim, kernel_size=7, stride=4, padding=3),
            nn.ReLU(),
            nn.Conv1d(embed_dim, embed_dim, kernel_size=7, stride=2, padding=3),
            nn.ReLU(),
        )

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            batch_first=True,
            dim_feedforward=embed_dim * 4,
            dropout=0.1
        )

        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.ln = nn.LayerNorm(embed_dim)

        self.cls = nn.Linear(embed_dim, 1)

    def forward(self, x):
        x = self.conv(x)
        x = rearrange(x, "b c t -> b t c")
        x = self.ln(x)
        x = self.transformer(x)
        self.dropout = nn.Dropout(0.3)
        x = self.dropout(x[:, -1, :])
        logit = self.cls(x).squeeze(1)
        return logit


### Define Training and Eval Functions
Implements the model training loop, validation loop, and uses BCEWithLogitsLoss with class weighting to handle imbalance.


In [89]:
from torch.utils.data import Dataset, DataLoader

class EEGDataset(Dataset):
    def __init__(self, X, y):
      self.X = torch.tensor(X, dtype=torch.float32)
      self.y = torch.tensor(y, dtype=torch.float32)

      self.X = (self.X - self.X.mean(dim=-1, keepdim=True)) / \
         (self.X.std(dim=-1, keepdim=True) + 1e-6)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]




train_loader = DataLoader(
    EEGDataset(train_X, train_y),
    batch_size=64,
    shuffle=True,
    drop_last=True
)

val_loader = DataLoader(
    EEGDataset(val_X, val_y),
    batch_size=64,
    shuffle=False
)


### Train EEG Transformer on Seizure Windows
Runs full training for 30 epochs and reports training and validation accuracy

In [97]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


device = "cuda" if torch.cuda.is_available() else "cpu"
model = EEGTransformer().to(device)

criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([20.0]).to(device))
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)


def train_epoch(model, loader, optimizer, criterion, epoch):
    model.train()
    running_loss = 0
    correct = 0
    total = 0

    for X, y in loader:
        X, y = X.to(device), y.to(device)

        optimizer.zero_grad()
        logits = model(X)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * X.size(0)
        probs = torch.sigmoid(logits)
        preds = (probs > 0.5).long()

        correct += (preds == y).sum().item()
        total += y.size(0)

    return running_loss / total, correct / total


def eval_epoch(model, loader, criterion):
    model.eval()
    running_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for X, y in loader:
            X, y = X.to(device), y.to(device)
            logits = model(X)
            loss = criterion(logits, y)
            running_loss += loss.item() * X.size(0)

            probs = torch.sigmoid(logits)
            preds = (probs > 0.5).long()

            correct += (preds == y).sum().item()
            total += y.size(0)

    return running_loss / total, correct / total


EPOCHS = 30
for epoch in range(1, EPOCHS + 1):
    train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, epoch)
    val_loss, val_acc = eval_epoch(model, val_loader, criterion)

    print(f"Epoch {epoch:02d} | "
          f"Train Loss {train_loss:.4f} Acc {train_acc:.4f} | "
          f"Val Loss {val_loss:.4f} Acc {val_acc:.4f}")


Epoch 01 | Train Loss 2.0940 Acc 0.4108 | Val Loss 2.2000 Acc 0.6127
Epoch 02 | Train Loss 1.0640 Acc 0.7367 | Val Loss 8.1392 Acc 0.6255
Epoch 03 | Train Loss 0.8363 Acc 0.8116 | Val Loss 6.4531 Acc 0.6480
Epoch 04 | Train Loss 0.5546 Acc 0.8795 | Val Loss 12.6001 Acc 0.5676
Epoch 05 | Train Loss 0.4938 Acc 0.8824 | Val Loss 25.1341 Acc 0.5784
Epoch 06 | Train Loss 0.6015 Acc 0.8460 | Val Loss 5.7044 Acc 0.6255
Epoch 07 | Train Loss 0.4590 Acc 0.8911 | Val Loss 24.5319 Acc 0.5618
Epoch 08 | Train Loss 0.3401 Acc 0.9239 | Val Loss 37.5178 Acc 0.5020
Epoch 09 | Train Loss 0.3236 Acc 0.9244 | Val Loss 25.4264 Acc 0.5902
Epoch 10 | Train Loss 0.2742 Acc 0.9352 | Val Loss 60.8807 Acc 0.4804
Epoch 11 | Train Loss 0.3854 Acc 0.9113 | Val Loss 17.3183 Acc 0.6000
Epoch 12 | Train Loss 0.2176 Acc 0.9506 | Val Loss 62.6671 Acc 0.4804
Epoch 13 | Train Loss 0.3313 Acc 0.9254 | Val Loss 36.3385 Acc 0.5490
Epoch 14 | Train Loss 0.2360 Acc 0.9545 | Val Loss 3.2025 Acc 0.6157
Epoch 15 | Train Loss 0.5

In [20]:
save_path = "/content/drive/MyDrive/chbmit_data/eeg_transformer_model.pth"
torch.save({
    "model_state_dict": model.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),
}, save_path)

print("Model saved to:", save_path)


Model saved to: /content/drive/MyDrive/chbmit_data/eeg_transformer_model.pth


### Compute Seizure Detection Metrics
Evaluates sensitivity, precision, and confusion statistics using probability thresholds


In [126]:
def evaluate_seizure_detection(model, loader):
    model.eval()
    preds_all = []
    labels_all = []

    with torch.no_grad():
        for X, y in loader:
            X = X.to(device)
            logits = model(X)
            probs = torch.sigmoid(logits).cpu().numpy()
            preds = (probs > 0.3).astype(int)
            preds_all.extend(preds)
            labels_all.extend(y.cpu().numpy())

    preds_all = np.array(preds_all)
    labels_all = np.array(labels_all)


    tp = np.sum((preds_all == 1) & (labels_all == 1))
    fp = np.sum((preds_all == 1) & (labels_all == 0))
    fn = np.sum((preds_all == 0) & (labels_all == 1))

    sensitivity = tp / (tp + fn + 1e-8)
    precision   = tp / (tp + fp + 1e-8)

    print("TP:", tp)
    print("FP:", fp)
    print("FN:", fn)
    print("Sensitivity (Recall):", sensitivity)
    print("Precision:", precision)


In [101]:
print("\n=== SEIZURE DETECTION METRICS ===")
evaluate_seizure_detection(model, val_loader)



=== SEIZURE DETECTION METRICS ===
TP: 103
FP: 21
FN: 525
Sensitivity (Recall): 0.16401273885089152
Precision: 0.8306451612233351


### Benchmark Inference Latency, Throughput, and Accuracy
Measures how fast the model runs on GPU/CPU and how many samples per second it can process.


In [127]:
def benchmark_model(model, val_loader):
    import time
    model.eval()

    latencies = []
    with torch.no_grad():
        for X, _ in val_loader:
            X = X.to(device)
            start = time.time()
            _ = model(X)
            end = time.time()
            latencies.append(end - start)
    avg_latency = np.mean(latencies)

    throughput = len(val_loader.dataset) / sum(latencies)

    correct = 0
    total = 0
    with torch.no_grad():
        for X, y in val_loader:
            X, y = X.to(device), y.to(device)
            logits = model(X)
            preds = (torch.sigmoid(logits) > 0.5).long()
            correct += (preds == y).sum().item()
            total += y.size(0)
    accuracy = correct / total

    return {
        "latency": avg_latency,
        "throughput": throughput,
        "accuracy": accuracy
    }


### Compare Window Sizes (1s, 2s, 4s) on Performance
Retrains the model using various window lengths and evaluates accuracy, latency, and performance.


In [128]:


window_sizes = [1, 2, 4]

results = {}

for win in window_sizes:

    print(f"Training model with window size = {win} sec")


    train_X, train_y = [], []
    val_X, val_y = [], []

    for edf in sorted(os.listdir(DATA_ROOT)):
        if not edf.endswith(".edf"):
            continue

        patient = edf[:5].lower()
        path = os.path.join(DATA_ROOT, edf)

        X, y = extract_windows_balanced(path, patient, window_size=win)

        if patient in train_patients:
            train_X.append(X)
            train_y.append(y)
        elif patient == val_patient:
            val_X.append(X)
            val_y.append(y)


    train_X = np.vstack(train_X)
    train_y = np.hstack(train_y)
    val_X = np.vstack(val_X)
    val_y = np.hstack(val_y)


    train_loader = DataLoader(EEGDataset(train_X, train_y), batch_size=64, shuffle=True)
    val_loader   = DataLoader(EEGDataset(val_X, val_y), batch_size=64, shuffle=False)


    model = EEGTransformer().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([5.0]).to(device))


    for epoch in range(10):
        train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, epoch)
        val_loss, val_acc = eval_epoch(model, val_loader, criterion)
        print(f"Epoch {epoch+1} | Train Acc {train_acc:.3f} | Val Acc {val_acc:.3f}")



    print("\nEvaluating seizure performance...")
    evaluate_seizure_detection(model, val_loader)


    bench = benchmark_model(model, val_loader)


    results[win] = {
        "model": model,
        "train_acc": train_acc,
        "val_acc": val_acc,
        "latency": bench["latency"],
        "throughput": bench["throughput"],
        "benchmark_accuracy": bench["accuracy"]
    }

    print(f"\n--- Benchmark for {win}-sec model ---")
    print(f"Latency: {bench['latency']:.6f}s")
    print(f"Throughput: {bench['throughput']:.2f} samples/sec")
    print(f"Accuracy: {bench['accuracy']:.3f}")


Training model with window size = 1 sec


  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=Fa

Epoch 1 | Train Acc 0.541 | Val Acc 0.551
Epoch 2 | Train Acc 0.844 | Val Acc 0.568
Epoch 3 | Train Acc 0.888 | Val Acc 0.575
Epoch 4 | Train Acc 0.916 | Val Acc 0.496
Epoch 5 | Train Acc 0.917 | Val Acc 0.502
Epoch 6 | Train Acc 0.941 | Val Acc 0.602
Epoch 7 | Train Acc 0.945 | Val Acc 0.498
Epoch 8 | Train Acc 0.952 | Val Acc 0.619
Epoch 9 | Train Acc 0.959 | Val Acc 0.514
Epoch 10 | Train Acc 0.955 | Val Acc 0.614

Evaluating seizure performance...
TP: 304
FP: 58
FN: 308
Sensitivity (Recall): 0.4967320261356743
Precision: 0.8397790055016635

--- Benchmark for 1-sec model ---
Latency: 0.001140s
Throughput: 55046.09 samples/sec
Accuracy: 0.611
Training model with window size = 2 sec


  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=Fa

Epoch 1 | Train Acc 0.598 | Val Acc 0.630
Epoch 2 | Train Acc 0.881 | Val Acc 0.475
Epoch 3 | Train Acc 0.923 | Val Acc 0.519
Epoch 4 | Train Acc 0.922 | Val Acc 0.519
Epoch 5 | Train Acc 0.952 | Val Acc 0.645
Epoch 6 | Train Acc 0.948 | Val Acc 0.484
Epoch 7 | Train Acc 0.961 | Val Acc 0.485
Epoch 8 | Train Acc 0.967 | Val Acc 0.459
Epoch 9 | Train Acc 0.975 | Val Acc 0.494
Epoch 10 | Train Acc 0.981 | Val Acc 0.438

Evaluating seizure performance...
TP: 89
FP: 16
FN: 539
Sensitivity (Recall): 0.14171974522067324
Precision: 0.8476190475383221

--- Benchmark for 2-sec model ---
Latency: 0.001144s
Throughput: 55707.05 samples/sec
Accuracy: 0.442
Training model with window size = 4 sec


  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
  raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=Fa

Epoch 1 | Train Acc 0.673 | Val Acc 0.472
Epoch 2 | Train Acc 0.911 | Val Acc 0.608
Epoch 3 | Train Acc 0.956 | Val Acc 0.533
Epoch 4 | Train Acc 0.968 | Val Acc 0.424
Epoch 5 | Train Acc 0.966 | Val Acc 0.426
Epoch 6 | Train Acc 0.973 | Val Acc 0.608
Epoch 7 | Train Acc 0.976 | Val Acc 0.476
Epoch 8 | Train Acc 0.984 | Val Acc 0.546
Epoch 9 | Train Acc 0.990 | Val Acc 0.451
Epoch 10 | Train Acc 0.988 | Val Acc 0.492

Evaluating seizure performance...
TP: 161
FP: 19
FN: 499
Sensitivity (Recall): 0.24393939393569788
Precision: 0.894444444394753

--- Benchmark for 4-sec model ---
Latency: 0.001131s
Throughput: 54702.43 samples/sec
Accuracy: 0.491


In [129]:
import pandas as pd

df = pd.DataFrame(results).T
df


Unnamed: 0,model,train_acc,val_acc,latency,throughput,benchmark_accuracy
1,EEGTransformer(\n (conv): Sequential(\n (0...,0.955259,0.613546,0.00114,55046.093724,0.610558
2,EEGTransformer(\n (conv): Sequential(\n (0...,0.980754,0.438235,0.001144,55707.050704,0.442157
4,EEGTransformer(\n (conv): Sequential(\n (0...,0.987764,0.492395,0.001131,54702.434951,0.491445
