# Muse EEG

In this notebook we will train a model to associate EEG signals from the Muse 2 headset with the wearer's eyes being open or closed.

## Running a Survey

First we can import our library and create a survey, so we can train a model on the resulting data. We'll ask the participant to first get into a comfortable position, then open eyes for 30s, close for 30s, etc.

In [2]:
%matplotlib notebook
# Reload external source files when they change
%load_ext autoreload
%autoreload 2
import sys
from datetime import timedelta
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
sys.path.append("../src")
from recorder import Muse2EEGRecorder
from survey import Survey

eyes_open_step = (timedelta(seconds=30), "eyes_open", "Please open your eyes.", True)
eyes_closed_step = (timedelta(seconds=30), "eyes_closed", "Please close your eyes.", True)
eyes_schedule = [
    (timedelta(seconds=30), "intro", "Just breathe normally, gently relax any tension, get in a comfortable position.", False),
    eyes_open_step,
    eyes_closed_step,
    eyes_open_step,
    eyes_closed_step,
    eyes_open_step,
    eyes_closed_step
]

test_schedule = [
    (timedelta(seconds=5), "intro", "Just breathe normally, gently relax any tension, get in a comfortable position.", True)
]

#muse2_recorder = Muse2EEGRecorder()
#eyes_survey = Survey(muse2_recorder, "Eyes open-closed", "Eyes open for 30, closed for 30 - repeat 3x.", eyes_schedule)
#eyes_survey.record("Jared")

## Preparing Data for Learning

We need to transform our raw survey data into a format suitable for supervised learning. We will create an input tensor with the shape required by PyTorch - `(batch_size, kernel_size, seq_len)`, aka `(Samples, Variables, Length / time or sequence steps)`, or `[batch_size, channels, num_features (aka: H * W)]`.

1. Batch size can be tuned. We will start with `64`.
2. The second index is the number of features per batch. In this case, we have four EEG sensors, so that will be `4`.
3. The number of samples included for each feature in each batch.

The PyTorch `DataLoader` is an alternative to batching out data manually. After creating a `Dataset`, the DataLoader will batch data with a given batch size.

In [39]:
import os
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader

# Load EEG CSV
eyes_closed_1 = pd.read_csv("../data/muse2-recordings/surveys/Eyes open-closed 2021-06-24 13:13:38.263066/2_eyes_closed-eeg_raw.csv")
eyes_open_1 = pd.read_csv("../data/muse2-recordings/surveys/Eyes open-closed 2021-06-24 13:13:38.263066/3_eyes_open-eeg_raw.csv")
print("Eyes closed shape:", eyes_closed_1.shape)
print("Eyes open shape:", eyes_open_1.shape)

# Create features ndarray
closed_x = eyes_closed_1[["eeg1", "eeg2", "eeg3", "eeg4"]].to_numpy()
open_x = eyes_open_1[["eeg1", "eeg2", "eeg3", "eeg4"]].to_numpy()
X = np.concatenate((closed_x, open_x))
print("Input features shape:", X.shape)
# TODO Scale data??

# Create one-hot labels ndarray
Y = np.vstack((
    # Eyes are open column
    np.concatenate((np.zeros((eyes_closed_1.shape[0])), np.ones((eyes_open_1.shape[0])))),
    # Eyes are closed column
    np.concatenate((np.ones((eyes_open_1.shape[0])), np.zeros((eyes_closed_1.shape[0]))))
)).T
print("Input labels shape:", Y.shape)

# Split into train, test
X_train, X_test, Y_train, Y_test = train_test_split(X, Y)
print("Training data shape (features, labels):", (X_train.shape, Y_train.shape))

# Define a generic Dataset subclass
class EEGDataset(Dataset):
    def __init__(self, data, labels, num_features, transform=None, target_transform=None):
        self.data = data
        self.labels = labels
        self.num_features = num_features
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return int(len(self.labels) / self.num_features)

    def __getitem__(self, idx):
        start_i = int(idx * self.num_features)
        end_i = start_i + self.num_features
        datum = self.data[start_i:end_i]
        #label = self.labels[start_i:end_i]
        label = self.labels[end_i]
        if self.transform:
            datum = self.transform(datum)
        if self.target_transform:
            label = self.target_transform(label)
        return datum.T, label
    
class EEGSurveyDataset(Dataset):
    """
    Given a survey path, load each eeg_raw csv file and load it as a datum.
    """
    def __init__(self, survey_path, max_size, transform=None, target_transform=None):
        self.data_files = [survey_path + "/" + f for f in os.listdir(survey_path) if f.endswith("eeg_raw.csv")]
        self.label_map = {f:self._parse_filename(f)[1] for f in self.data_files}
        self.ilabel_map = self._create_ilabel_map()
        print(self.ilabel_map)
        self.max_size = max_size
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.data_files)

    def __getitem__(self, idx):
        data_file = self.data_files[idx]
        data = pd.read_csv(data_file)[["eeg1", "eeg2", "eeg3", "eeg4"]].to_numpy()[:self.max_size]
        label = self.ilabel_map[self.label_map[data_file]]
        if self.transform:
            data = self.transform(data)
        if self.target_transform:
            label = self.target_transform(label)
        return data.T, label
    
    def _create_ilabel_map(self):
        ilabel_map = {}
        for key in self.label_map.keys():
            if key not in ilabel_map:
                ilabel_map[self.label_map[key]] = len(ilabel_map.keys())-1
        return ilabel_map
    
    def _parse_filename(self, filename):
        filename = filename.split("/")[-1]
        name, extension = filename.split(".", 1)
        num_tag, typ = name.split("-", 1)
        num, tag = num_tag.split("_", 1)
        return num, tag

"""
class EEGSurveyDataset2(Dataset):
    '''
    Given a survey path and a batch_size, load each eeg_raw csv file in the path, and
    split it into batches of size `batch_size`.
    '''
    def __init__(self, survey_path, batch_size, transform=None, target_transform=None):
        self.data_files = [survey_path + "/" + f for f in os.listdir(survey_path) if f.endswith("eeg_raw.csv")]
        self.label_map = {f:self._parse_filename(f)[1] for f in self.data_files}
        self.ilabel_map = self._create_ilabel_map()
        print(self.ilabel_map)
        self.batch_size = batch_size
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.data_files)

    def __getitem__(self, idx):
        data_file = self.data_files[idx]
        data = pd.read_csv(data_file)[["eeg1", "eeg2", "eeg3", "eeg4"]].to_numpy() #[:self.max_size]
        label = self.ilabel_map[self.label_map[data_file]]
        if self.transform:
            data = self.transform(data)
        if self.target_transform:
            label = self.target_transform(label)
        num_samples = int(np.ceil(len(data) / self.batch_size))
        print(data.T.reshape((self.batch_size, 4, num_samples)).shape)
        return data.T.reshape((self.batch_size, 4, num_samples)), label
    
    def _create_ilabel_map(self):
        ilabel_map = {}
        for key in self.label_map.keys():
            if key not in ilabel_map:
                ilabel_map[self.label_map[key]] = len(ilabel_map.keys())-1
        return ilabel_map
    
    def _parse_filename(self, filename):
        filename = filename.split("/")[-1]
        name, extension = filename.split(".", 1)
        num_tag, typ = name.split("-", 1)
        num, tag = num_tag.split("_", 1)
        return num, tag
"""

# Create PyTorch Datasets
#train_dataset, test_dataset = EEGDataset(X_train, Y_train, 256), EEGDataset(X_test, Y_test, 256)
train_dataset, test_dataset = EEGSurveyDataset("../data/muse2-recordings/surveys/Eyes open-closed Jared 2021-06-26 16:01:52.123715", 256), EEGSurveyDataset("../data/muse2-recordings/surveys/Eyes open-closed Jared 2021-06-26 16:33:30.120561", 7666)

# Create DataLoaders
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=True)

Eyes closed shape: (7936, 7)
Eyes open shape: (7936, 7)
Input features shape: (15872, 4)
Input labels shape: (15872, 2)
Training data shape (features, labels): ((11904, 4), (11904, 2))
{'eyes_open': 0, 'eyes_closed': 1}
{'eyes_open': 0, 'eyes_closed': 1}


## PyTorch Model

Now we can create the neural network we will be training on the collected data. We will model it after the [simple BCI model by Sentdex](https://github.com/Sentdex/NNfSiX). Sentdex doesn't provide the training or validation data, so we will assume the input is simply batches of raw EEG data.

In [49]:
import torch
import torch.nn as nn
import torch.nn.functional as F

hidden_layers_size = 64
n_channels = 4
n_outputs = 2

# Create model
net = nn.Sequential(
    # Pass input to a 1D convolutional layer with a kernel size of 3, apply to activation function.
    nn.Conv1d(n_channels, hidden_layers_size, 3),
    nn.ReLU(),

    # Pass previous layer output to a 1D convolutional layer with a kernel size of 2, apply to activation function,
    # and get the max value from each kernel.
    nn.Conv1d(hidden_layers_size, hidden_layers_size, 2),
    nn.ReLU(),
    nn.MaxPool1d(kernel_size=2),

    # Pass previous layer output to a 1D convolutional layer with a kernel size of 2, apply to activation function,
    # and get the max value from each kernel. (same as previous layer)
    nn.Conv1d(hidden_layers_size, hidden_layers_size, 2),
    nn.ReLU(),
    nn.MaxPool1d(kernel_size=2),

    # Flatten the convolutions. Input shape: (a, b, c), Output shape: (a, b*c)
    nn.Flatten(),
    
    # ?
    # XXX: The first number needs to be updated each time the input shapes change. We could instead
    #      Create a class-based Module, and do a single pass through the conv portion of the network
    #      in order to determine the actual size.
    #      (This technique is shown in https://www.youtube.com/watch?v=1gQR24B3ISE&list=PLQVvvaa0QuDdeMyHEYc0gxFpYwHY2Qfdh&index=7)
    nn.Linear(3968, 512),
    #nn.ReLU(),

    # ?
    nn.Linear(512, n_outputs),
    nn.Softmax()
)

print(len(train_dataloader))
for i, data in enumerate(train_dataloader, 0):
    features, labels = data
    print("Model input shape:", features.shape)
    out = net(features.float())
    print("Model output shape:", out.shape)
    print(out)
    break

1
Model input shape: torch.Size([6, 4, 256])
Model output shape: torch.Size([6, 2])
tensor([[0.9774, 0.0226],
        [0.4656, 0.5344],
        [0.8724, 0.1276],
        [0.3181, 0.6819],
        [0.7758, 0.2242],
        [0.1740, 0.8260]], grad_fn=<SoftmaxBackward>)


### Training the model

In [50]:
from torch.utils.data import DataLoader
import torch.optim as optim

# Define criterion and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.01)

# Train for n epochs
n = 20
net.train()
for epoch in range(n):
    for i, data in enumerate(train_dataloader, 0):
        features, labels = data

        # Zero gradients
        optimizer.zero_grad()

        # Forward
        predictions = net(features.float())
        
        # Compute loss
        loss = criterion(predictions, labels.long())
        
        # Backward
        loss.backward()
        
        # Optimize
        optimizer.step()

        # Print loss
        print(f"loss: {loss.item()}")

print('Finished Training')

loss: 0.9151536822319031
loss: 0.6465950012207031
loss: 0.8132615685462952
loss: 0.8132616877555847
loss: 0.8132615685462952
loss: 0.8132616877555847
loss: 0.8132615685462952
loss: 0.8132616877555847
loss: 0.8132615685462952
loss: 0.8132616877555847
loss: 0.8132616877555847
loss: 0.8132615685462952
loss: 0.8132616877555847
loss: 0.8132615685462952
loss: 0.8132616877555847
loss: 0.8132616877555847
loss: 0.8132615685462952
loss: 0.8132615685462952
loss: 0.8132616877555847
loss: 0.8132616877555847
Finished Training


In [24]:
net.eval()
total = 0
correct = 0
with torch.no_grad():
    for i, data in enumerate(test_dataloader, 0):
        features, labels = data
        out = net(features.float())
        preds = F.log_softmax(out, dim=1).argmax(dim=1)
        print(labels.size(), preds.size(), out.size())
        total += labels.size(0)
        correct += (preds == labels).sum().item()

print("Correct:", correct, "/", total)

torch.Size([6]) torch.Size([6]) torch.Size([6, 2])
Correct: 3 / 6
