<a href="https://colab.research.google.com/github/nehasharmn/Hybrid_CNN_Transformer_CHBMIT/blob/main/CNN__Baseline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install mne braindecode torch torchvision torchaudio

In [None]:
from google.colab import drive
drive.mount('/content/drive')

DATA_ROOT = "/content/drive/MyDrive/chbmit_data"
print("Using:", DATA_ROOT)


In [None]:
import os

required = [
    "RECORDS-WITH-SEIZURES",
    "chb01-summary.txt",
    "chb02-summary.txt",
    "chb03-summary.txt",
    "chb13-summary.txt",
    "chb20-summary.txt"
]

print("Checking summary files...")
for f in required:
    path = os.path.join(DATA_ROOT, f)
    print(f"{f}:", "OK" if os.path.exists(path) else "MISSING")


In [None]:
def parse_summary_file(path):
    events = {}
    current_file = None
    start = None

    with open(path, "r") as f:
        for line in f:
            line = line.strip()

            if line.startswith("File Name:"):
                current_file = line.split(":")[1].strip().lower()
                events[current_file] = []
                continue

            if "Start Time" in line and "Seizure" in line:
                start = float(
                    line.split(":")[1].replace("seconds", "").strip()
                )
                continue

            if "End Time" in line and "Seizure" in line:
                end = float(
                    line.split(":")[1].replace("seconds", "").strip()
                )
                events[current_file].append((start, end))
                continue

    return events

SUMMARY = {
    "chb01": parse_summary_file(os.path.join(DATA_ROOT, "chb01-summary.txt")),
    "chb02": parse_summary_file(os.path.join(DATA_ROOT, "chb02-summary.txt")),
    "chb03": parse_summary_file(os.path.join(DATA_ROOT, "chb03-summary.txt")),
    "chb13": parse_summary_file(os.path.join(DATA_ROOT, "chb13-summary.txt")),
    "chb20": parse_summary_file(os.path.join(DATA_ROOT, "chb20-summary.txt")),
}

def load_seizure_intervals(patient, edf_filename):
    return SUMMARY.get(patient, {}).get(edf_filename.lower(), [])

In [None]:
import mne
import numpy as np

CHB_CHANNELS = [
    "FP1-F7","F7-T7","T7-P7","P7-O1",
    "FP1-F3","F3-C3","C3-P3","P3-O1",
    "FP2-F4","F4-C4","C4-P4","P4-O2",
    "FP2-F8","F8-T8","T8-P8","P8-O2",
    "FZ-CZ","CZ-PZ","P7-T7",
    "T7-FT9","FT9-FT10","FT10-T8","T8-P8"
]

def load_raw_fixed_channels(edf_path):
    raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)

    current = [ch.lower() for ch in raw.ch_names]

    data_list = []
    for ch in CHB_CHANNELS:
        name = ch.lower()
        if name in current:
            idx = current.index(name)
            data_list.append(raw.get_data(picks=[idx]))
        else:
            data_list.append(np.zeros((1, raw.n_times)))

    return np.vstack(data_list), raw.info["sfreq"]


In [None]:
def extract_windows_balanced(
    edf_path, patient,
    window_size=4.0,
    stride=0.5,
    non_seizure_fraction=0.01
):
    data, sfreq = load_raw_fixed_channels(edf_path)
    data = (data - data.mean(axis=1, keepdims=True)) / (data.std(axis=1, keepdims=True) + 1e-6)
    n_samples = data.shape[1]

    intervals = load_seizure_intervals(patient, os.path.basename(edf_path))
    seizure_segments = [(int(s * sfreq), int(e * sfreq)) for s, e in intervals]

    win_len = int(window_size * sfreq)
    step = int(stride * sfreq)

    seizure_X, seizure_y = [], []
    non_X, non_y = [], []

    for start in range(0, n_samples - win_len, step):
        end = start + win_len
        seg = data[:, start:end]

        is_seizure = any(not (end < s_start or start > s_end)
                         for s_start, s_end in seizure_segments)

        if is_seizure:
            seizure_X.append(seg)
            seizure_y.append(1)
        else:
            non_X.append(seg)
            non_y.append(0)

    if len(non_X) > 0:
        keep = int(len(non_X) * non_seizure_fraction)
        idx = np.random.choice(len(non_X), keep, replace=False)
        non_X = np.array(non_X)[idx]
        non_y = np.array(non_y)[idx]
    else:
        non_X = np.array([])
        non_y = np.array([])

    if len(seizure_X) > 0:
        seizure_X = np.array(seizure_X)
        seizure_y = np.array(seizure_y)

        X = np.concatenate([seizure_X, non_X], axis=0)
        y = np.concatenate([seizure_y, non_y], axis=0)
    else:
        X = non_X
        y = non_y

    return X, y


In [None]:
train_patients = {"chb01", "chb02", "chb03", "chb13"}
val_patient = "chb20"

train_X, train_y = [], []
val_X, val_y = [], []

for edf in sorted(os.listdir(DATA_ROOT)):
    if not edf.endswith(".edf"):
        continue

    patient = edf[:5].lower()
    path = os.path.join(DATA_ROOT, edf)

    X, y = extract_windows_balanced(path, patient, window_size= 2)

    if patient in train_patients:
        train_X.append(X)
        train_y.append(y)
    elif patient == val_patient:
        val_X.append(X)
        val_y.append(y)

train_X = np.vstack(train_X)
train_y = np.hstack(train_y)

val_X = np.vstack(val_X)
val_y = np.hstack(val_y)

print("Train:", train_X.shape, np.unique(train_y, return_counts=True))
print("Val:", val_X.shape, np.unique(val_y, return_counts=True))


In [None]:
import numpy as np

total = len(val_y)
seizures = np.sum(val_y == 1)
non_seizures = np.sum(val_y == 0)

print("Total val windows:", total)
print("Seizure windows:", seizures)
print("Non-seizure windows:", non_seizures)
print("Percent seizure: {:.4f}%".format(100 * seizures / total))


In [None]:
import numpy as np

total = len(train_y)
seizures = np.sum(train_y == 1)
non_seizures = np.sum(train_y == 0)

print("Total train windows:", total)
print("Seizure windows:", seizures)
print("Non-seizure windows:", non_seizures)
print("Percent seizure: {:.4f}%".format(100 * seizures / total))


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
import matplotlib.pyplot as plt
import time

class EEGConvNet(nn.Module):
    """Lightweight CNN for seizure detection."""
    def __init__(self, num_channels=23, num_classes=2, dropout=0.5, input_sequence_length=512):
        super().__init__()

        self.conv_layers = nn.Sequential(
            nn.Conv1d(num_channels, 64, kernel_size=10, stride=2, padding=4),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=4, stride=4),

            nn.Conv1d(64, 128, kernel_size=8, stride=2, padding=3),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=4, stride=4),

            nn.Conv1d(128, 256, kernel_size=6, stride=2, padding=2),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2, stride=2)
        )


        with torch.no_grad():
            dummy_input = torch.zeros(1, num_channels, input_sequence_length)
            flat_output = self.conv_layers(dummy_input).view(1, -1)
            self.flat_size = flat_output.size(1)

        self.dropout = nn.Dropout(dropout)
        self.fc1 = nn.Linear(self.flat_size, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.conv_layers(x)
        x = x.view(x.size(0), -1)

        x = self.dropout(x)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)

        return x

def train_epoch(model, loader, opt, criterion, device):
    model.train()
    loss_sum = 0.0
    correct = 0
    total = 0

    for batch_idx, (X, y) in enumerate(loader):
        X, y = X.to(device), y.to(device)

        logits = model(X)
        loss = criterion(logits, y)

        opt.zero_grad()
        loss.backward()
        opt.step()

        loss_sum += loss.item() * y.size(0)
        preds = logits.argmax(dim=1)
        correct += (preds == y).sum().item()
        total += y.size(0)

        if (batch_idx + 1) % 100 == 0:
            print(f"  Batch {batch_idx + 1}: loss={loss.item():.4f}")

    epoch_loss = loss_sum / total
    epoch_acc = correct / total

    return epoch_loss, epoch_acc


def validate(model, loader, criterion, device):

    model.eval()
    loss_sum = 0.0
    all_preds = []
    all_labels = []
    all_probs = []

    with torch.no_grad():
        for X, y in loader:
            X, y = X.to(device), y.to(device)

            logits = model(X)
            loss = criterion(logits, y)

            loss_sum += loss.item() * y.size(0)

            probs = F.softmax(logits, dim=1)
            preds = logits.argmax(dim=1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(y.cpu().numpy())
            all_probs.extend(probs[:, 1].cpu().numpy())

    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    all_probs = np.array(all_probs)

    epoch_loss = loss_sum / len(all_labels)
    accuracy = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, zero_division=0)
    recall = recall_score(all_labels, all_preds, zero_division=0)
    f1 = f1_score(all_labels, all_preds, zero_division=0)
    auc = roc_auc_score(all_labels, all_probs)

    return {
        'loss': epoch_loss,
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'auc': auc
    }


def train_model(model, train_loader, val_loader, epochs=20, lr=1e-3, device='cuda', train_y=None):

    if train_y is not None:
        if isinstance(train_y, torch.Tensor):
            train_y_np = train_y.cpu().numpy()
        else:
            train_y_np = train_y

        unique, counts = np.unique(train_y_np, return_counts=True)
        class_counts = dict(zip(unique, counts))

        total = len(train_y_np)
        weights = torch.tensor([total / class_counts[i] if i in class_counts else 1.0
                               for i in range(2)], device=device, dtype=torch.float32)
        weights = weights / weights[0]

        print(f"\n{'='*60}")
        print(f"Class distribution: {class_counts}")
        print(f"Class weights: [negative={weights[0]:.4f}, positive={weights[1]:.4f}]")
        print(f"Positive class weighted {weights[1]:.2f}x more than negative")
        print(f"{'='*60}\n")

        criterion = nn.CrossEntropyLoss(weight=weights)
    else:
        print("WARNING: train_y not provided, using unweighted loss")
        criterion = nn.CrossEntropyLoss()

    opt = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        opt, mode='max', factor=0.5, patience=3
    )

    history = {
        'train_loss': [],
        'train_acc': [],
        'val_loss': [],
        'val_acc': [],
        'val_precision': [],
        'val_recall': [],
        'val_f1': [],
        'val_auc': []
    }

    best_val_f1 = 0
    best_model_state = None

    for epoch in range(epochs):
        print(f"\n--- Epoch {epoch+1}/{epochs} ---")

        train_loss, train_acc = train_epoch(model, train_loader, opt, criterion, device)
        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")

        val_metrics = validate(model, val_loader, criterion, device)
        print(f"Val Loss: {val_metrics['loss']:.4f}")
        print(f"Val Acc: {val_metrics['accuracy']:.4f}, Prec: {val_metrics['precision']:.4f}, "
              f"Rec: {val_metrics['recall']:.4f}, F1: {val_metrics['f1']:.4f}, AUC: {val_metrics['auc']:.4f}")

        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_metrics['loss'])
        history['val_acc'].append(val_metrics['accuracy'])
        history['val_precision'].append(val_metrics['precision'])
        history['val_recall'].append(val_metrics['recall'])
        history['val_f1'].append(val_metrics['f1'])
        history['val_auc'].append(val_metrics['auc'])

        if val_metrics['f1'] > best_val_f1:
            best_val_f1 = val_metrics['f1']
            best_model_state = model.state_dict().copy()
            print(f"✓ New best F1: {best_val_f1:.4f}")

        scheduler.step(val_metrics['f1'])

    if best_model_state is not None:
        model.load_state_dict(best_model_state)
        print(f"\nRestored best model with F1: {best_val_f1:.4f}")

    return model, history



if __name__ == "__main__":
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Using device: {device}")


    train_X_torch = torch.from_numpy(train_X).float() if isinstance(train_X, np.ndarray) else train_X
    train_y_torch = torch.from_numpy(train_y).long() if isinstance(train_y, np.ndarray) else train_y
    val_X_torch = torch.from_numpy(val_X).float() if isinstance(val_X, np.ndarray) else val_X
    val_y_torch = torch.from_numpy(val_y).long() if isinstance(val_y, np.ndarray) else val_y

    print(f"Train data shape: {train_X_torch.shape}, labels shape: {train_y_torch.shape}")
    print(f"Val data shape: {val_X_torch.shape}, labels shape: {val_y_torch.shape}")
    print(f"Train class distribution: {np.unique(train_y_torch.cpu().numpy(), return_counts=True)}")
    print(f"Val class distribution: {np.unique(val_y_torch.cpu().numpy(), return_counts=True)}")

    train_dataset = TensorDataset(train_X_torch, train_y_torch)
    val_dataset = TensorDataset(val_X_torch, val_y_torch)

    train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=256, shuffle=False, num_workers=0)

    model = EEGConvNet(num_channels=23, num_classes=2, dropout=0.5)
    model = model.to(device)

    model, history = train_model(
        model, train_loader, val_loader,
        epochs=20, lr=1e-3, device=device,
        train_y=train_y_torch
    )


    torch.save(model.state_dict(), 'seizure_model.pth')
    print("✓ Model saved as 'seizure_model.pth'")

In [None]:
def evaluate_seizure_detection(model, loader):
    model.eval()
    preds_all = []
    labels_all = []

    with torch.no_grad():
        for X, y in loader:
            X = X.to(device)
            logits = model(X)
            probs = F.softmax(logits, dim=1)[:, 1].cpu().numpy()
            preds = (probs > 0.3).astype(int)
            preds_all.extend(preds)
            labels_all.extend(y.cpu().numpy())

    preds_all = np.array(preds_all)
    labels_all = np.array(labels_all)

    tp = np.sum((preds_all == 1) & (labels_all == 1))
    fp = np.sum((preds_all == 1) & (labels_all == 0))
    fn = np.sum((preds_all == 0) & (labels_all == 1))

    sensitivity = tp / (tp + fn + 1e-8)
    precision   = tp / (tp + fp + 1e-8)

    print("TP:", tp)
    print("FP:", fp)
    print("FN:", fn)
    print("Sensitivity (Recall):", sensitivity)
    print("Precision:", precision)

In [None]:
print("\n=== SEIZURE DETECTION METRICS ===")
evaluate_seizure_detection(model, val_loader)

In [None]:
def benchmark_model(model, val_loader):
    import time
    model.eval()

    latencies = []
    with torch.no_grad():
        for X, _ in val_loader:
            X = X.to(device)
            start = time.time()
            _ = model(X)
            end = time.time()
            latencies.append(end - start)
    avg_latency = np.mean(latencies)

    throughput = len(val_loader.dataset) / sum(latencies)

    correct = 0
    total = 0
    with torch.no_grad():
        for X, y in val_loader:
            X, y = X.to(device), y.to(device)
            logits = model(X)
            preds = logits.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)
    accuracy = correct / total

    return {
        "latency": avg_latency,
        "throughput": throughput,
        "accuracy": accuracy
    }

In [None]:
window_sizes = [1, 2, 4]
results = {}

for win in window_sizes:

    print(f"Training CNN baseline with window size = {win} sec")

    train_X, train_y = [], []
    val_X, val_y = [], []

    sfreq = 256
    current_sequence_length = int(win * sfreq)

    for edf in sorted(os.listdir(DATA_ROOT)):
        if not edf.endswith(".edf"):
            continue

        patient = edf[:5].lower()
        path = os.path.join(DATA_ROOT, edf)

        X, y = extract_windows_balanced(path, patient, window_size=win)

        if patient in train_patients:
            train_X.append(X)
            train_y.append(y)
        elif patient == val_patient:
            val_X.append(X)
            val_y.append(y)

    train_X = np.vstack(train_X)
    train_y = np.hstack(train_y)
    val_X   = np.vstack(val_X)
    val_y   = np.hstack(val_y)

    train_X_torch = torch.from_numpy(train_X).float()
    train_y_torch = torch.from_numpy(train_y).long()
    val_X_torch = torch.from_numpy(val_X).float()
    val_y_torch = torch.from_numpy(val_y).long()

    train_loader = DataLoader(TensorDataset(train_X_torch, train_y_torch), batch_size=64, shuffle=True)
    val_loader   = DataLoader(TensorDataset(val_X_torch, val_y_torch), batch_size=64, shuffle=False)

    model = EEGConvNet(input_sequence_length=current_sequence_length).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    unique, counts = np.unique(train_y, return_counts=True)
    class_counts = dict(zip(unique, counts))
    total = len(train_y)
    weights = torch.tensor([total / class_counts[i] if i in class_counts else 1.0
                           for i in range(2)], device=device, dtype=torch.float32)
    weights = weights / weights[0]
    criterion = nn.CrossEntropyLoss(weight=weights)

    print(f"\nClass distribution: {class_counts}")
    print(f"Class weights: [negative={weights[0]:.4f}, positive={weights[1]:.4f}]")
    print(f"Positive class weighted {weights[1]:.2f}x more than negative\n")

    history = {
        'train_loss': [], 'train_acc': [],
        'val_loss': [], 'val_acc': [],
        'val_precision': [], 'val_recall': [],
        'val_f1': [], 'val_auc': []
    }
    best_val_f1 = 0
    best_model_state = None

    for epoch in range(10):
        print(f"--- Epoch {epoch+1}/10 ---")
        train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
        val_metrics = validate(model, val_loader, criterion, device)

        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_metrics['loss'])
        history['val_acc'].append(val_metrics['accuracy'])
        history['val_precision'].append(val_metrics['precision'])
        history['val_recall'].append(val_metrics['recall'])
        history['val_f1'].append(val_metrics['f1'])
        history['val_auc'].append(val_metrics['auc'])

        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
        print(f"Val Loss: {val_metrics['loss']:.4f}, Val Acc: {val_metrics['accuracy']:.4f}, Prec: {val_metrics['precision']:.4f}, Rec: {val_metrics['recall']:.4f}, F1: {val_metrics['f1']:.4f}, AUC: {val_metrics['auc']:.4f}")

        if val_metrics['f1'] > best_val_f1:
            best_val_f1 = val_metrics['f1']
            best_model_state = model.state_dict().copy()
            print(f"✓ New best F1: {best_val_f1:.4f}")

    if best_model_state is not None:
        model.load_state_dict(best_model_state)
        print(f"\nRestored best model with F1: {best_val_f1:.4f}")

    print("\nSeizure detection metrics:")
    evaluate_seizure_detection(model, val_loader)

    bench = benchmark_model(model, val_loader)

    results[f"{win}sec_window"] = {
        "model": model,
        "train_acc": history['train_acc'][-1],
        "val_acc": history['val_acc'][-1],
        "latency": bench["latency"],
        "throughput": bench["throughput"],
        "benchmark_accuracy": bench["accuracy"]
    }

    print(f"\n--- Benchmark for {win}-sec CNN model ---")
    print(f"Latency: {bench['latency']:.6f}s")
    print(f"Throughput: {bench['throughput']:.2f} samples/sec")
    print(f"Accuracy: {bench['accuracy']:.3f}")

In [None]:
import pandas as pd

df = pd.DataFrame(results).T
df