In [1]:
from dotenv import load_dotenv
import os
from glob import glob
import mne
import numpy as np
import torch
import torch.nn as nn
import gc 
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from torch.utils.tensorboard import SummaryWriter
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.metrics import accuracy_score, classification_report
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score
import joblib

load_dotenv()
root_dir = os.getenv("ROOT_DIR")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

2025-05-27 23:01:27.317931: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-05-27 23:01:27.328344: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1748401287.339246   24966 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1748401287.342450   24966 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1748401287.351445   24966 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

In [2]:
# Model definitions
class InnerSpeechDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.long)
    
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx].unsqueeze(0)

In [3]:
# CNN/LSTM hybrid
class InnerSpeechModel(nn.Module):
    def __init__(self):
        super().__init__()

        # CNN component: outputs 256 channels
        self.convolv = nn.Sequential(
            nn.Conv1d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Conv1d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Conv1d(in_channels=128, out_channels=64, kernel_size=3, padding=1),  # Fixed to 128 channels
            nn.BatchNorm1d(64),
            nn.ReLU(inplace=True),
            # nn.Dropout(p=0.1),
            # nn.MaxPool1d(kernel_size=4, stride=4)
        )

        # hidden_size = 32

        # Bi-LSTM component (2 Layers)
        # self.lstm = nn.LSTM(input_size=64, hidden_size=hidden_size, num_layers=2, batch_first=True, bidirectional=True)

        # self.attn_weight = nn.Linear(2 * hidden_size, 1, bias=False)

        # Fully connected layer
        self.fc = nn.Sequential(
            # nn.Linear(2*hidden_size, 32),
            # nn.ReLU(inplace=True),
            nn.Dropout(p=0.1),
            nn.Linear(64, 4)  # Matches hidden_size=64
        )

    def forward(self, x):
        # Input shape: (batch, 128, 1153)
        x = self.convolv(x)      # Shape: (batch, 128, 288)
        # print(f"after convolve: {x.shape}")
        # print(f"x.shape after convolv: {x.shape}")
        # x = x.permute(0, 2, 1)   # Shape: (batch, 288, 128)
        # print(f"after permute: {x.shape}")
        # lstm_out, (h_n, c_n) = self.lstm(x)  # lstm_out shape: (batch, 288, 128)
        # print(f"lstm_out.shape after lstm: {lstm_out.shape}")
        # print("\n")
        
        # Global Average Pooling
        x = torch.mean(x, dim=2)  # Global average pooling over time

        # Compute attention scores
        # Flatten across features: attn_score[i, t] = wT * h_{i, t}
        # Then softmax over t to get α_{i, t}
        # attn_scores = self.attn_weight(lstm_out).squeeze(-1)
        # attn_weights = torch.softmax(attn_scores, dim=1)
        # Weighted sum of LSTM outputs:
        # attn_applied = torch.bmm(attn_weights.unsqueeze(1), lstm_out).squeeze(1)

        # Classification
        output = self.fc(x)
        return output

In [None]:
def train_model(model, device, train_loader, val_loader=None, epochs=20, model_name="model", example_input=None, checkpoint_dir="models/", verbose=False):
    GRAD_CLIP = 1.0
    patience = 3  # epochs
    model.to(device)
    criterion = nn.CrossEntropyLoss()
    model_optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-6)
    scheduler = ReduceLROnPlateau(model_optimizer, mode="min", factor=0.5, patience=patience)
    writer = SummaryWriter(log_dir='runs/' + model_name)
    best_val_loss = float('inf')
    early_stop_counter = 0
    
    if example_input is not None:
        writer.add_graph(model, example_input.to(device))

    for epoch in range(epochs):
        model.train()
        running_train_loss = 0.0
        all_train_preds = []
        all_train_targets = []
        
        for X_batch, y_batch in train_loader:
            X_batch = X_batch.to(device)
            y_batch = y_batch.to(device).squeeze().long()
            if verbose: print(f"y_batch.shape: {y_batch.shape}")
            model_optimizer.zero_grad()
            logits = model(X_batch)  # logits, shape (batch, 4)
            if verbose: print(f"logits.shape: {logits.shape}")
            loss = criterion(logits, y_batch)
            loss.backward()

            # Gradient Clipping to prevent exploding gradients
            torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
            
            model_optimizer.step()
            
            running_train_loss += loss.item() * X_batch.size(0)
            
            predicted_classes = logits.argmax(dim=1)
            all_train_preds.append(predicted_classes.cpu())
            all_train_targets.append(y_batch.cpu())
        
        avg_train_loss = running_train_loss / len(train_loader.dataset)
        train_preds = torch.cat(all_train_preds).numpy()
        train_targets = torch.cat(all_train_targets).numpy()
        train_acc = accuracy_score(train_targets, train_preds)
        
        writer.add_scalar("Loss/Train", avg_train_loss, epoch)
        writer.add_scalar("Accuracy/Train", train_acc, epoch)
        writer.add_scalar("Learning Rate", model_optimizer.param_groups[0]['lr'], epoch)

        if val_loader is not None:
            model.eval()
            running_val_loss = 0.0
            all_val_preds = []
            all_val_targets = []
            with torch.no_grad():
                for X_batch, y_batch in val_loader:
                    X_batch = X_batch.to(device)
                    y_batch = y_batch.to(device).squeeze().long()
                    if verbose: print(f"y_batch.shape: {y_batch.shape}")
                    logits = model(X_batch)
                    if verbose: print(f"logits.shape: {logits.shape}")
                    loss = criterion(logits, y_batch)
                    running_val_loss += loss.item() * X_batch.size(0)
                    all_val_preds.append(logits.argmax(dim=1).cpu())
                    all_val_targets.append(y_batch.cpu())
            
            avg_val_loss = running_val_loss / len(val_loader.dataset)
            val_preds = torch.cat(all_val_preds).numpy()
            val_targets = torch.cat(all_val_targets).numpy()
            val_acc = accuracy_score(val_targets, val_preds)
            
            writer.add_scalar("Loss/Validation", avg_val_loss, epoch)
            writer.add_scalar("Accuracy/Validation", val_acc, epoch)

            print(f"{model_name} Epoch {epoch+1}/{epochs} | "
                  f"Train Loss: {avg_train_loss:.6f} | Train Acc: {train_acc:.4f} | "
                  f"Val Loss: {avg_val_loss:.6f} | Val Acc: {val_acc:.4f}")

            scheduler.step(avg_val_loss)

            # Save best model checkpoint
            if avg_val_loss < best_val_loss - 1e-5:
                best_val_loss = avg_val_loss
                early_stop_counter = 0
                print(f"Model Checkpoint | epoch: {epoch} | best_val_loss: {best_val_loss}")
                torch.save(model.state_dict(), checkpoint_dir + model_name + ".pth")
            else:
                early_stop_counter += 1
                if early_stop_counter >= patience:
                    print(f"Early stopping at epoch {epoch+1}")
                    break

        else:
            print(f"{model_name} Epoch {epoch+1}/{epochs} | Train Loss: {avg_train_loss:.6f} | Train Acc: {train_acc:.4f}")

    writer.close()


In [5]:
## Convolutional Neural Network Model
# --- CNN Model ---
class EcogClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(128, 256, kernel_size=7, stride=1, padding=3),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.MaxPool1d(2),  # Down to (256, 576)
            nn.Conv1d(256, 128, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.MaxPool1d(2),  # Down to (128, 288)
            nn.Conv1d(128, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool1d(1),  # Output: (64, 1)
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),      # (64,)
            nn.Linear(64, 4)   # 3 output classes
        )

    def forward(self, x):
        return self.classifier(self.net(x))


In [6]:
def scale_channels_fit(X, scaler_name_prefix=None):
    """
    This is used for precise per-channel scaling.
    This will improve model performance.

    Args:
        X (numpy.ndarray): shape (3000, 128, 1153)
        scaler_name_prefix (str): prefix to identify the scaler.

    Returns:
        numpy.ndarray: X_scaled training data of zero-mean & unit variance.
    """
    n_samples, n_channels, n_timespoints = X.shape
    X_scaled = np.zeros_like(X)
    scales = []
    for channel in range(n_channels):
        scaler = StandardScaler()
        X_ch = X[:, channel, :]
        X_scaled[:, channel, :] = scaler.fit_transform(X_ch)
        scales.append(scaler)
    if scaler_name_prefix is not None: 
        joblib.dump(scales, scaler_name_prefix + "_scales.pkl")
    return X_scaled

def scale_channels_transform(X, scaler_name_prefix = None):
    """This is used to scale X test with the same scaler that X_train was scaled with.

    Args:
        X (numpy.ndarray): shape (1000, 128, 1153)
        scaler_name_prefix (str): prefix to identify the scaler.

    Returns:
        numpy.ndarray: X_scaled test data of zero-mean & unit variance.
    """
    if scaler_name_prefix is not None:
        scalers = joblib.load(f"{scaler_name_prefix}_scales.pkl")
    n_samples, n_channels, n_timespoints = X.shape
    X_scaled = np.zeros_like(X)
    for channel in range(n_channels):
        X_ch = X[:, channel, :]
        X_scaled[:, channel, :] = scalers[channel].transform(X_ch)
    return X_scaled



### Load Data

In [7]:
X_train = torch.load("data/X_train.pth")
y_train = torch.load("data/y_train.pth")
X_test = torch.load("data/X_test.pth")
y_test = torch.load("data/y_test.pth")

In [8]:
X_train.shape

torch.Size([3000, 128, 1153])

In [9]:
y_train.shape

torch.Size([15, 200])

In [10]:
y_train = y_train.long()
y_test = y_test.long()

In [11]:
# Create training and validation sets
y_train = y_train.cpu().numpy()
X_train = X_train.cpu().numpy()

In [12]:
X_test = X_test.cpu().numpy()
y_test = y_test.cpu().numpy()

In [13]:
X_train.shape

(3000, 128, 1153)

In [14]:
X_test.shape

(1000, 128, 1153)

In [15]:
# Scaling the Data with StandardScaler
X_train_scaled = scale_channels_fit(X_train, "X_train")
X_test_scaled = scale_channels_transform(X_test, "X_train")

In [16]:
y_train = y_train.reshape(-1)
assert y_train.shape[0] == X_train.shape[0]

In [17]:
y_train.shape

(3000,)

In [18]:
del X_train, X_test
gc.collect()

20

In [19]:
X_train_split, X_val_split, y_train_split, y_validation_split = train_test_split(
    X_train_scaled, y_train, test_size=0.2, random_state=42, stratify=y_train
)

In [20]:
train_dataset = InnerSpeechDataset(X_train_split, y_train_split)
val_dataset = InnerSpeechDataset(X_val_split, y_validation_split)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle = True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle = True)

### Train model

In [21]:
# X_batch.shape: torch.Size([32, 128, 1153])
# y_batch.squeeze().long().shape: torch.Size([32])
# for X_batch, y_batch in train_loader:
#     print(X_batch[0])

In [22]:
os.makedirs("models/checkpoints", exist_ok=True)
os.makedirs("models/runs", exist_ok=True)

In [23]:
# Training on a small batch
# X_train_split_small_batch_for_model_testing = X_train_split[:10]
# y_train_split_small_batch_for_model_testing = y_train_split[:10]
# 
# train_dataset = InnerSpeechDataset(X_train_split_small_batch_for_model_testing, y_train_split_small_batch_for_model_testing)
# val_dataset = InnerSpeechDataset(X_val_split, y_validation_split)
# 
# train_loader = DataLoader(train_dataset, batch_size=32, shuffle = True)
# val_loader = DataLoader(val_dataset, batch_size=32, shuffle = True)
# model = InnerSpeechModel()
# train_model(model, device, train_loader, epochs=100, model_name="InnerSpeechModel_v0", verbose=False)


In [24]:
model = EcogClassifier()

train_model(model, device, train_loader, val_loader, epochs=100, example_input=torch.randn(1, 128, 1153), model_name="EcogClassifier_v0", verbose=False)

EcogClassifier_v0 Epoch 1/100 | Train Loss: 1.393480 | Train Acc: 0.2479 | Val Loss: 1.389577 | Val Acc: 0.2467
Model Checkpoint | epoch: 0 | best_val_loss: 1.3895774030685424
EcogClassifier_v0 Epoch 2/100 | Train Loss: 1.387445 | Train Acc: 0.2558 | Val Loss: 1.388865 | Val Acc: 0.2250
Model Checkpoint | epoch: 1 | best_val_loss: 1.388865172068278
EcogClassifier_v0 Epoch 3/100 | Train Loss: 1.385513 | Train Acc: 0.2621 | Val Loss: 1.388319 | Val Acc: 0.2183
Model Checkpoint | epoch: 2 | best_val_loss: 1.3883187913894652
EcogClassifier_v0 Epoch 4/100 | Train Loss: 1.384113 | Train Acc: 0.2771 | Val Loss: 1.387918 | Val Acc: 0.2117
Model Checkpoint | epoch: 3 | best_val_loss: 1.3879175821940104
EcogClassifier_v0 Epoch 5/100 | Train Loss: 1.383188 | Train Acc: 0.2846 | Val Loss: 1.387807 | Val Acc: 0.2200
Model Checkpoint | epoch: 4 | best_val_loss: 1.3878072277704874
EcogClassifier_v0 Epoch 6/100 | Train Loss: 1.382720 | Train Acc: 0.2958 | Val Loss: 1.387669 | Val Acc: 0.2267
Model Che