In [14]:
import mne
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

# Define a function to load, preprocess, and segment epochs from a subject's runs
def load_epochs_for_subject(subject, runs, label):
    epochs_list = []
    labels_list = []
    
    # Download EDF files for the given subject and runs (files are cached locally)
    edf_files = mne.datasets.eegbci.load_data(subject, runs)
    # Load and concatenate raw data from all runs for this condition
    raw_list = []
    for f in edf_files:
        try:
            raw = mne.io.read_raw_edf(f, preload=True, verbose=False)
            raw_list.append(raw)
        except Exception as e:
            print(f"Error reading {f}: {e}")
            continue
    if not raw_list:
        return None, None
    raw = mne.concatenate_raws(raw_list)
    
    # Select EEG channels and apply bandpass filter (1-40 Hz)
    raw.pick_types(eeg=True)
    raw.filter(1, 40, fir_design='firwin', verbose=False)
    
    # Get data and sampling frequency
    data, times = raw.get_data(return_times=True)  # shape: (n_channels, n_times)
    sfreq = raw.info['sfreq']
    
    # For simplicity, select the first EEG channel
    channel_data = data[0]  # shape: (n_times,)
    
    # Segment the continuous signal into 1-second epochs with 50% overlap
    window_size = int(sfreq * 1)  # 1-second window
    stride = int(sfreq * 0.5)     # 50% overlap
    for start in range(0, len(channel_data) - window_size, stride):
        epoch = channel_data[start:start + window_size]
        epochs_list.append(epoch)
        labels_list.append(label)
    
    if epochs_list:
        epochs_array = np.array(epochs_list)  # shape: (n_epochs, window_size)
        labels_array = np.array(labels_list)  # shape: (n_epochs,)
        return epochs_array, labels_array
    else:
        return None, None

# Define subject range and run numbers for the two conditions
# For the EEG Motor Movement/Imagery dataset:
#   Movement: runs [3, 7, 11]
#   Imagery:   runs [4, 8, 12]
subjects = range(1, 4)  # For demonstration; extend this range to use all subjects
movement_runs = [3, 5, 7, 9, 11, 13]
imagery_runs   = [4, 6, 8, 10, 12, 14]

all_epochs_list = []
all_labels_list = []

# Iterate over each subject and both conditions
for subject in subjects:
    # Movement condition (label 0)
    epochs, labels = load_epochs_for_subject(subject, movement_runs, label=0)
    if epochs is not None:
        all_epochs_list.append(epochs)
        all_labels_list.append(labels)
    
    # Imagery condition (label 1)
    epochs, labels = load_epochs_for_subject(subject, imagery_runs, label=1)
    if epochs is not None:
        all_epochs_list.append(epochs)
        all_labels_list.append(labels)

# Combine data from all subjects and both conditions
if all_epochs_list:
    all_epochs = np.concatenate(all_epochs_list, axis=0)
    all_labels = np.concatenate(all_labels_list, axis=0)
else:
    raise RuntimeError("No epochs loaded. Check dataset download and processing.")

print(f"Total epochs: {all_epochs.shape[0]}")

# Convert epochs to PyTorch tensor with shape (n_epochs, seq_len, input_size)
# Here, input_size=1 because we use one channel per epoch.
epochs_tensor = torch.tensor(all_epochs, dtype=torch.float32).unsqueeze(-1)
labels_tensor = torch.tensor(all_labels, dtype=torch.long)
print("Epochs tensor shape:", epochs_tensor.shape)
print("Labels tensor shape:", labels_tensor.shape)

# Shuffle the dataset (optional but recommended)
perm = torch.randperm(epochs_tensor.size(0))
epochs_tensor = epochs_tensor[perm]
labels_tensor = labels_tensor[perm]

# Split into training and testing sets (e.g., 80-20 split)
split_idx = int(0.8 * epochs_tensor.size(0))
train_X, test_X = epochs_tensor[:split_idx], epochs_tensor[split_idx:]
train_y, test_y = labels_tensor[:split_idx], labels_tensor[split_idx:]
print("Training set shape:", train_X.shape)
print("Test set shape:", test_X.shape)

# Define a simple LSTM-based classifier model in PyTorch
class EEGRNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(EEGRNN, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, dropout=0.5)
        self.fc = nn.Linear(hidden_size, num_classes)
    
    def forward(self, x):
        # x shape: (batch, seq_len, input_size)
        out, _ = self.lstm(x)
        # Take the output from the last time step
        out = out[:, -1, :]  # shape: (batch, hidden_size)
        out = self.fc(out)   # shape: (batch, num_classes)
        return out

# Hyperparameters
input_size = 32     # one feature per time step (one EEG channel)
hidden_size = 64    # can experiment with this
num_layers = 2      # number of LSTM layers
num_classes = 2     # binary classification: movement vs. imagery
num_epochs = 10
batch_size = 32
learning_rate = 0.0001

# Instantiate the model, loss function, and optimizer
model = EEGRNN(input_size, hidden_size, num_layers, num_classes)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Create data loaders for training and testing sets
train_dataset = torch.utils.data.TensorDataset(train_X, train_y)
test_dataset = torch.utils.data.TensorDataset(test_X, test_y)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Training loop
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for batch_X, batch_y in train_loader:
        optimizer.zero_grad()
        outputs = model(batch_X.view(batch_X.size(0), -1, input_size))
        loss = criterion(outputs, batch_y)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * batch_X.size(0)
    epoch_loss = running_loss / len(train_loader.dataset)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}")

# Evaluation on test set
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for batch_X, batch_y in test_loader:
        outputs = model(batch_X.view(batch_X.size(0), -1, input_size))
        _, predicted = torch.max(outputs.data, 1)
        total += batch_y.size(0)
        correct += (predicted == batch_y).sum().item()
        
print(f"Test Accuracy: {100 * correct / total:.2f}%")

NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Total epochs: 8940
Epochs tensor shape: torch.Size([8940, 160, 1])
Labels tensor shape: torch.Size([8940])
Training set shape: torch.Size([7152, 160, 1])
Test set shape: torch.Size([1788, 160, 1])
Epoch [1/10], Loss: 0.6934
Epoch [2/10], Loss: 0.6932
Epoch [3/10], Loss: 0.6932
Epoch [4/10], Loss: 0.6933
Epoch [5/10], Loss: 0.6932
Epoch [6/10], Loss: 0.6933
Epoch [7/10], Loss: 0.6932
Epoch [8/10], Loss: 0.6932
Epoch [9/10], Loss: 0.6932
Epoch [10/10], Loss: 0.6933
Test Accuracy: 49.27%
