In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import numpy as np
import pickle
from tqdm import tqdm

## Load Preprocessed Data

In [None]:
with open('preprocessed/preprocessed_eeg_data.pkl', 'rb') as f:
    data = pickle.load(f)

X = data['X']
y = data['y']

print(f"Data shape: {X.shape}")
print(f"Labels shape: {y.shape}")

## Dataset Class

In [None]:
class EEGDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.FloatTensor(X)
        self.y = torch.LongTensor(y)
    
    def __len__(self):
        return len(self.y)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

## CNN + RNN Model

In [None]:
class EEGNet(nn.Module):
    def __init__(self, num_channels, num_freq_bins, num_time_steps, num_classes=109):
        super(EEGNet, self).__init__()
        
        # CNN for spatial-frequency features (2 layers instead of 3)
        self.conv1 = nn.Conv2d(num_channels, 32, kernel_size=(3, 3), padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.pool1 = nn.MaxPool2d(kernel_size=(2, 2))
        
        self.conv2 = nn.Conv2d(32, 64, kernel_size=(3, 3), padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.pool2 = nn.MaxPool2d(kernel_size=(2, 2))
        
        # Calculate flattened size
        h_out = num_freq_bins // 4
        w_out = num_time_steps // 4
        self.flat_size = 64 * h_out * w_out
        
        # RNN for temporal features
        self.lstm = nn.LSTM(input_size=self.flat_size, hidden_size=256, 
                           num_layers=2, batch_first=True, dropout=0.3)
        
        # Classification head
        self.dropout = nn.Dropout(0.5)
        self.fc = nn.Linear(256, num_classes)
        
        self.relu = nn.ReLU()
    
    def forward(self, x):
        batch_size = x.size(0)
        
        # CNN layers
        x = self.pool1(self.relu(self.bn1(self.conv1(x))))
        x = self.pool2(self.relu(self.bn2(self.conv2(x))))
        
        # Flatten for RNN
        x = x.view(batch_size, 1, -1)
        
        # LSTM
        x, (hn, cn) = self.lstm(x)
        x = x[:, -1, :]
        
        # Classification
        x = self.dropout(x)
        x = self.fc(x)
        
        return x

## Training Functions

In [None]:
def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for inputs, labels in tqdm(dataloader, desc='Training', leave=False):
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    
    return total_loss / len(dataloader), 100 * correct / total

def validate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, labels in tqdm(dataloader, desc='Validating', leave=False):
            inputs, labels = inputs.to(device), labels.to(device)
            
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    return total_loss / len(dataloader), 100 * correct / total

## Setup Training

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

dataset = EEGDataset(X, y)

train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

print(f"Train samples: {len(train_dataset)}")
print(f"Val samples: {len(val_dataset)}")

## Initialize Model

In [None]:
num_channels = X.shape[1]
num_freq_bins = X.shape[2]
num_time_steps = X.shape[3]

model = EEGNet(num_channels, num_freq_bins, num_time_steps, num_classes=109)
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)

print(f"Model parameters: {sum(p.numel() for p in model.parameters())}")

## Train Model

In [None]:
num_epochs = 50
best_val_acc = 0
history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}

for epoch in range(num_epochs):
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    val_loss, val_acc = validate(model, val_loader, criterion, device)
    
    scheduler.step(val_loss)
    
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    
    print(f"Epoch {epoch+1}/{num_epochs} - "
          f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}% - "
          f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
    
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), 'preprocessed/best_model.pth')
        print(f"Model saved with validation accuracy: {val_acc:.2f}%")

print(f"\nBest validation accuracy: {best_val_acc:.2f}%")

## Save Training History

In [None]:
with open('preprocessed/training_history.pkl', 'wb') as f:
    pickle.dump(history, f)

print("Training complete. Model and history saved.")

## Evaluate on Validation Set

In [None]:
model.load_state_dict(torch.load('preprocessed/best_model.pth'))
model.eval()

all_preds = []
all_labels = []

with torch.no_grad():
    for inputs, labels in val_loader:
        inputs = inputs.to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        
        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.numpy())

predictions = {
    'y_true': np.array(all_labels),
    'y_pred': np.array(all_preds)
}

with open('preprocessed/predictions.pkl', 'wb') as f:
    pickle.dump(predictions, f)

print("Predictions saved for performance analysis.")