In [41]:
from dotenv import load_dotenv
import os
from glob import glob
import mne
import numpy as np
import torch
import torch.nn as nn
import gc 
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from torch.utils.tensorboard import SummaryWriter
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.metrics import accuracy_score, classification_report
from sklearn.preprocessing import StandardScaler, RobustScaler
from sklearn.metrics import accuracy_score
import joblib
from scipy.signal import welch, butter, lfilter

from preprocess_eeg_signal import butter_bandpass, bandpass_filter, band_power_envelope, multiband_features 

from sklearn.preprocessing import LabelEncoder

load_dotenv()
root_dir = os.getenv("ROOT_DIR")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [42]:
# Model definitions
class InnerSpeechDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.long)
    
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx].unsqueeze(0)

In [43]:
# Model definitions
class EEGInnerSpeechDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.long)
    
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

In [44]:
class EEGInnerSpeechClassifier(nn.Module):
    def __init__(self, num_classes=4, num_bands=4, num_channels=128):
        super(EEGInnerSpeechClassifier, self).__init__()
        # Conv1: Process channels across bands
        self.conv1 = nn.Conv1d(in_channels=num_bands, out_channels=32, kernel_size=5, stride=1, padding=2)
        self.bn1 = nn.BatchNorm1d(32)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool1d(kernel_size=2, stride=2)
        # Conv2: Separable convolution (depthwise + pointwise)
        self.conv2_depth = nn.Conv1d(32, 32, kernel_size=5, stride=1, padding=2, groups=32)
        self.conv2_point = nn.Conv1d(32, 64, kernel_size=1, stride=1)
        self.bn2 = nn.BatchNorm1d(64)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool1d(kernel_size=2, stride=2)
        # Conv3: Standard convolution
        self.conv3 = nn.Conv1d(64, 128, kernel_size=3, stride=1, padding=1)
        self.bn3 = nn.BatchNorm1d(128)
        self.relu3 = nn.ReLU()
        self.pool3 = nn.MaxPool1d(kernel_size=2, stride=2)
        # Linear layers
        self.fc1 = nn.Linear(128 * 16, 256)  # 128 channels * 16 (after pooling)
        self.relu4 = nn.ReLU()
        self.dropout = nn.Dropout(0.4)
        self.fc2 = nn.Linear(256, num_classes)

    def forward(self, x):
        # Input: (batch, 4, 128)
        x = self.conv1(x)  # (batch, 32, 128)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.pool1(x)  # (batch, 32, 64)
        x = self.conv2_depth(x)  # (batch, 32, 64)
        x = self.conv2_point(x)  # (batch, 64, 64)
        x = self.bn2(x)
        x = self.relu2(x)
        x = self.pool2(x)  # (batch, 64, 32)
        x = self.conv3(x)  # (batch, 128, 32)
        x = self.bn3(x)
        x = self.relu3(x)
        x = self.pool3(x)  # (batch, 128, 16)k
        x = x.view(x.size(0), -1)  # (batch, 128*16)
        x = self.relu4(self.fc1(x))  # (batch, 256)
        x = self.dropout(x)
        x = self.fc2(x)  # (batch, 4)
        return x

In [5]:
def train_model(model, device, train_loader, val_loader=None, epochs=50, batch_size=16, model_name="EEGInnerSpeechCNN", example_input=None, checkpoint_dir="models/", verbose=False):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    os.makedirs("models/checkpoints", exist_ok=True)
    os.makedirs("models/runs", exist_ok=True)
    GRAD_CLIP = 1.0
    patience = 5  # epochs
    warmup_epochs = 3
    initial_lr = 1e-5
    model.to(device)
    criterion = nn.CrossEntropyLoss()
    model_optimizer = optim.AdamW(model.parameters(), lr=5e-4, weight_decay=1e-3)
    scheduler = ReduceLROnPlateau(model_optimizer, mode="min", factor=0.5, patience=patience, min_lr=1e-5)
    writer = SummaryWriter(log_dir='runs/' + model_name)

    best_val_loss = float('inf')
    early_stop_counter = 0
    
    if example_input is not None:
        writer.add_graph(model, example_input.to(device))

    for epoch in range(epochs):
        # Leanring rate warmup
        if epoch < warmup_epochs:
            lr = initial_lr + (5e-4 - initial_lr) * (epoch / warmup_epochs)
            for param_group in model_optimizer.param_groups:
                param_group["lr"] = lr

        model.train()
        running_train_loss = 0.0
        all_train_preds = []
        all_train_targets = []
        
        for X_batch, y_batch in train_loader:
            X_batch = X_batch.to(device)
            y_batch = y_batch.to(device).squeeze().long()
            if verbose: print(f"y_batch.shape: {y_batch.shape}")
            model_optimizer.zero_grad()
            logits = model(X_batch)  # logits, shape (batch, 4)
            if verbose: print(f"logits.shape: {logits.shape}")
            loss = criterion(logits, y_batch)
            loss.backward()

            # Gradient Clipping to prevent exploding gradients
            torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
            
            model_optimizer.step()
            
            running_train_loss += loss.item() * X_batch.size(0)
            
            predicted_classes = logits.argmax(dim=1)
            all_train_preds.append(predicted_classes.cpu())
            all_train_targets.append(y_batch.cpu())
        
        avg_train_loss = running_train_loss / len(train_loader.dataset)
        train_preds = torch.cat(all_train_preds).numpy()
        train_targets = torch.cat(all_train_targets).numpy()
        train_acc = accuracy_score(train_targets, train_preds)
        
        writer.add_scalar("Loss/Train", avg_train_loss, epoch)
        writer.add_scalar("Accuracy/Train", train_acc, epoch)
        writer.add_scalar("Learning Rate", model_optimizer.param_groups[0]['lr'], epoch)

        if val_loader is not None:
            model.eval()
            running_val_loss = 0.0
            all_val_preds = []
            all_val_targets = []
            with torch.no_grad():
                for X_batch, y_batch in val_loader:
                    X_batch = X_batch.to(device)
                    y_batch = y_batch.to(device).squeeze().long()
                    if verbose: print(f"y_batch.shape: {y_batch.shape}")
                    logits = model(X_batch)
                    if verbose: print(f"logits.shape: {logits.shape}")
                    loss = criterion(logits, y_batch)
                    running_val_loss += loss.item() * X_batch.size(0)
                    all_val_preds.append(logits.argmax(dim=1).cpu())
                    all_val_targets.append(y_batch.cpu())
            
            avg_val_loss = running_val_loss / len(val_loader.dataset)
            val_preds = torch.cat(all_val_preds).numpy()
            val_targets = torch.cat(all_val_targets).numpy()
            val_acc = accuracy_score(val_targets, val_preds)
            
            writer.add_scalar("Loss/Validation", avg_val_loss, epoch)
            writer.add_scalar("Accuracy/Validation", val_acc, epoch)

            if epoch >= warmup_epochs:
                scheduler.step(avg_val_loss)

            print(f"{model_name} Epoch {epoch+1}/{epochs} | "
                  f"Train Loss: {avg_train_loss:.6f} | Train Acc: {train_acc:.4f} | "
                  f"Val Loss: {avg_val_loss:.6f} | Val Acc: {val_acc:.4f}")

            # scheduler.step(avg_val_loss)

            # Save best model checkpoint
            if avg_val_loss < best_val_loss - 1e-5:
                best_val_loss = avg_val_loss
                early_stop_counter = 0
                print(f"Model Checkpoint | epoch: {epoch} | best_val_loss: {best_val_loss}")
                torch.save(model.state_dict(), checkpoint_dir + model_name + ".pth")
            else:
                early_stop_counter += 1
                if early_stop_counter >= patience:
                    print(f"Early stopping at epoch {epoch+1}")
                    break

    writer.close()


In [6]:
## Convolutional Neural Network Model
# --- CNN Model ---
class EcogClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(128, 256, kernel_size=7, stride=1, padding=3),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.MaxPool1d(2),  # Down to (256, 576)
            nn.Conv1d(256, 128, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.MaxPool1d(2),  # Down to (128, 288)
            nn.Conv1d(128, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool1d(1),  # Output: (64, 1)
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),      # (64,)
            nn.Linear(64, 4)   # 3 output classes
        )

    def forward(self, x):
        return self.classifier(self.net(x))


In [7]:
def butter_bandpass(lowcut, highcut, fs, order=5):
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    b, a = butter(order, [low, high], btype='band')
    return b, a

In [8]:
def bandpass_filter(data, lowcut, highcut, fs=256.0):
    b, a = butter_bandpass(lowcut, highcut, fs)
    filtered = lfilter(b, a, data, axis=2)
    return filtered

In [9]:
def compute_psd(data, fs=256.0, bands=[(0.5, 4), (4, 8), (8, 12), (12, 30)], batch_size=100):
    num_trials, num_channels, num_samples = data.shape
    num_bands = len(bands)
    psd_features = np.zeros((num_trials, num_bands, num_channels), dtype=np.float32)

    for start_idx in range(0, num_trials, batch_size):
        end_idx = min(start_idx + batch_size, num_trials)
        batch_data = data[start_idx:end_idx]

        for band_idx, (low, high) in enumerate(bands):
            filtered = bandpass_filter(batch_data, low, high, fs)
            freqs, psd = welch(filtered, fs=fs, axis=2, nperseg=256)
            psd_mean = np.mean(psd, axis=2)
            psd_features[start_idx:end_idx, band_idx, :] = psd_mean
            del filtered, psd, psd_mean
    
        del batch_data
    return psd_features


In [10]:
def scale_channels_fit(X, name):
    scaler = RobustScaler()
    X_reshaped = X.reshape(-1, X.shape[-1])
    scaler.fit(X_reshaped)
    return scaler

In [11]:
def scale_channels_transform(X, scaler):
    X_reshaped = X.reshape(-1, X.shape[-1])
    X_scaled = scaler.transform(X_reshaped)
    return X_scaled.reshape(X.shape)

In [None]:
class EEGInnerSpeechClassifierUpdated(nn.Module):
    def __init__(self, num_classes=4, num_bands=4):
        super(EEGInnerSpeechClassifierUpdated, self).__init__()
        self.conv1 = nn.Conv1d(num_bands, 32, kernel_size=5, padding=2)
        self.bn1 = nn.BatchNorm1d(32)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool1d(2)

        self.depthwise = nn.Conv1d(32, 32, kernel_size=5, padding=2, groups=32)
        self.pointwise = nn.Conv1d(32, 64, kernel_size=1)
        self.bn2 = nn.BatchNorm1d(64)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool1d(2)

        self.conv3 = nn.Conv1d(64, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm1d(128)
        self.relu3 = nn.ReLU()
        self.pool3 = nn.AdaptiveAvgPool1d(1)

        self.dropout = nn.Dropout(0.5)
        self.fc1 = nn.Linear(128, 256)
        self.fc2 = nn.Linear(256, num_classes)

    def forward(self, x):
        x = self.pool1(self.relu1(self.bn1(self.conv1(x))))
        x = self.pool2(self.relu2(self.bn2(self.pointwise(self.depthwise(x)))))
        x = self.pool3(self.relu3(self.bn3(self.conv3(x))))
        x = x.view(x.size(0), -1)
        x = self.dropout(self.fc1(x))
        return self.fc2(x)



### Load Data

In [53]:
X_train = torch.load("data/X_train.pth")
y_train = torch.load("data/y_train.pth")
# X_test = torch.load("data/X_test.pth")
# y_test = torch.load("data/y_test.pth")

In [54]:
# Encode the raw strings as integer labels
label_encoder = LabelEncoder()
y_train = torch.tensor(label_encoder.fit_transform(y_train), device=device)  # integer labels: 0, 1, 2, 3

In [55]:
y_train.__len__()

3200

In [56]:
X_train.shape

torch.Size([3200, 128, 640])

In [57]:
y_train

tensor([1, 1, 0,  ..., 1, 2, 0], device='cuda:0')

In [58]:
y_train = y_train.long()
# y_test = y_test.long()

In [59]:
# Create training and validation sets
y_train = y_train.cpu().numpy()
X_train = X_train.cpu().numpy()

In [61]:
X_train.shape

(3200, 128, 640)

In [66]:
X_train

array([[[ 1.71176738e-06,  1.16846070e-06,  3.31827414e-06, ...,
          1.45639676e-05,  1.36433307e-05,  9.70088202e-06],
        [ 1.59351507e-06,  3.84599622e-07,  3.03085951e-06, ...,
          1.46348033e-05,  1.18493852e-05,  8.63002326e-06],
        [ 2.50736617e-06,  1.45015636e-06,  4.11861471e-06, ...,
          1.63883400e-05,  1.36700965e-05,  1.03945050e-05],
        ...,
        [ 8.63118801e-07, -9.67614975e-07,  7.31853769e-06, ...,
          3.82829762e-06,  1.74064473e-06,  3.74019630e-06],
        [ 1.06878054e-06, -7.16682192e-07,  7.37907106e-06, ...,
          3.53623059e-06,  1.96710643e-06,  3.74723605e-06],
        [ 7.90739272e-07,  3.03287683e-07,  8.61955922e-06, ...,
          1.08448432e-05,  8.78994442e-06,  7.95353232e-06]],

       [[-1.19148411e-05, -1.36932596e-05, -1.51152241e-05, ...,
          7.38314402e-06,  8.36708480e-06,  7.61167036e-06],
        [-1.37663030e-05, -1.48070165e-05, -1.48490914e-05, ...,
          6.93859585e-06,  8.07943282e

## Transforming X data into power spectral frequency bands

In [None]:
# X_test = X_test.cpu().numpy()
# y_test = y_test.cpu().numpy()

In [None]:
# fs = 256.0
# X_train_psd = compute_psd(X_train, fs=fs, batch_size=100)

# Per-channel Scaling

# scaler = scale_channels_fit(X_train_psd, "X_train")
# X_train_scaled = scale_channels_transform(X_train_psd, scaler)
# del X_train, 
# # X_test
# gc.collect()

31

In [22]:
y_train = y_train.reshape(-1)
assert y_train.shape[0] == X_train_scaled.shape[0]

In [23]:
X_train_split, X_val_split, y_train_split, y_validation_split = train_test_split(
    X_train_scaled, y_train, test_size=0.2, random_state=42, stratify=y_train
)

In [24]:
y_train_split.shape

(2560,)

In [25]:
train_dataset = EEGInnerSpeechDataset(X_train_split, y_train_split)
val_dataset = EEGInnerSpeechDataset(X_val_split, y_validation_split)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle = True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle = False)

In [None]:
# del train_dataset, val_dataset, train_loader, val_loader
# gc.collect()

### Train model

In [None]:
# X_batch.shape: torch.Size([32, 128, 1153])
# X_batch.shape: ([16, 4, 128])
# y_batch.squeeze().long().shape: torch.Size([32])

# for X_batch, y_batch in train_loader:
    # print(X_batch.shape)

In [None]:
# Training on a small batch
# X_train_split_small_batch_for_model_testing = X_train_split[:10]
# y_train_split_small_batch_for_model_testing = y_train_split[:10]
# 
# train_dataset = InnerSpeechDataset(X_train_split_small_batch_for_model_testing, y_train_split_small_batch_for_model_testing)
# val_dataset = InnerSpeechDataset(X_val_split, y_validation_split)
# 
# train_loader = DataLoader(train_dataset, batch_size=32, shuffle = True)
# val_loader = DataLoader(val_dataset, batch_size=32, shuffle = True)
# model = InnerSpeechModel()
# train_model(model, device, train_loader, epochs=100, model_name="InnerSpeechModel_v0", verbose=False)


In [None]:
# # EcogClassifier
# model = EcogClassifier()

# train_model(model, device, train_loader, val_loader, epochs=100, example_input=torch.randn(1, 128, 1153), model_name="EcogClassifier_v0", verbose=False)

# EcogClassifier_v0 Epoch 1/100 | Train Loss: 1.394683 | Train Acc: 0.2350 | Val Loss: 1.387976 | Val Acc: 0.2600
# Model Checkpoint | epoch: 0 | best_val_loss: 1.3879758723576863
# EcogClassifier_v0 Epoch 2/100 | Train Loss: 1.387742 | Train Acc: 0.2450 | Val Loss: 1.386153 | Val Acc: 0.2417
# Model Checkpoint | epoch: 1 | best_val_loss: 1.3861531829833984
# EcogClassifier_v0 Epoch 3/100 | Train Loss: 1.387694 | Train Acc: 0.2462 | Val Loss: 1.387894 | Val Acc: 0.2533
# EcogClassifier_v0 Epoch 4/100 | Train Loss: 1.387689 | Train Acc: 0.2437 | Val Loss: 1.386056 | Val Acc: 0.2467
# Model Checkpoint | epoch: 3 | best_val_loss: 1.386056129137675
# EcogClassifier_v0 Epoch 5/100 | Train Loss: 1.385896 | Train Acc: 0.2396 | Val Loss: 1.386433 | Val Acc: 0.2600
# EcogClassifier_v0 Epoch 6/100 | Train Loss: 1.384991 | Train Acc: 0.2558 | Val Loss: 1.390329 | Val Acc: 0.2467
# EcogClassifier_v0 Epoch 7/100 | Train Loss: 1.383838 | Train Acc: 0.2671 | Val Loss: 1.387250 | Val Acc: 0.2467
# Early stopping at epoch 7

In [None]:
model = EEGInnerSpeechClassifierUpdated()

_ = train_model(model, device, train_loader, val_loader, epochs=100, example_input=torch.randn([16, 4, 128]), model_name="EEG_InnerSpeech_CNN_Classifier_v0", verbose=False)


### Validating the model will function: Overfitting on a small batch of data



In [None]:

# Try to overfit a small batch
small_X = torch.tensor(X_train_scaled[:32], dtype=torch.float32).to(device)
small_y = torch.tensor(y_train[:32], dtype=torch.long).to(device)


X_train_split, X_val_split, y_train_split, y_validation_split = train_test_split(
    X_train_scaled[:32], y_train[:32], test_size=0.2, random_state=42, stratify=y_train[:32]
)

train_dataset = EEGInnerSpeechDataset(small_X, small_y)
val_dataset = EEGInnerSpeechDataset(X_val_split, y_validation_split)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle = True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle = False)

model = EEGInnerSpeechClassifierUpdated().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
loss_fn = torch.nn.CrossEntropyLoss()

for i in range(100):
    model.train()
    logits = model(small_X)
    loss = loss_fn(logits, small_y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    preds = logits.argmax(dim=1)
    acc = (preds == small_y).float().mean().item()
    print(f"Epoch {i}: Loss = {loss.item():.4f} | Acc = {acc:.2f}")
    if acc == 1.0:
        break
