# Direction Thought Detection

In this notebook we will train a model to determine whether a person is thinking `left`, `right`, or `none`.

## Running a Survey

First we can import our library and create a survey, so we can train a model on the resulting data. We'll ask the participant to first get into a comfortable position, think of left, think of right, think of none, etc. 3x.

In [None]:
%matplotlib notebook
# Reload external source files when they change
%load_ext autoreload
%autoreload 2
import sys
from datetime import timedelta, datetime
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

sys.path.append("../src")
from recorder import Muse2EEGRecorder
from survey import Survey

left_step = (timedelta(seconds=30), "left", "Think about going left.", True)
right_step = (timedelta(seconds=30), "right", "Think about going right.", True)
none_step = (timedelta(seconds=30), "right", "Just breathe.", True)
direction_schedule = [
    (timedelta(seconds=30), "intro", "Just breathe normally, gently relax any tension, and get into a comfortable position.", False),
    left_step,
    right_step,
    none_step,
    left_step,
    right_step,
    none_step,
    left_step,
    right_step,
    none_step
]

test_schedule = [
    (timedelta(seconds=5), "intro", "Just breathe normally, gently relax any tension, get in a comfortable position.", True)
]

muse2_recorder = Muse2EEGRecorder()
direction_survey = Survey(muse2_recorder, "Left-Right-Thoughts", "Thinking 'left' for 30s, then 'right', then 'none' - repeat 3x.", direction_schedule)
direction_survey.record("Jared")

## Preparing Data for Learning

In [None]:
import os
import torch
from torch.utils.data import Dataset, DataLoader

from eeg_preprocessing import preprocess_eeg_channel
from eegdata import EEGSurveyDataset, ChunkedDataset, MultiDataset

def transform_normalize(data):
    for ch in range(data.shape[1]):
        data[:, ch] = preprocess_eeg_channel(data[:, ch])
        stddev = data[:, ch].std()
        if stddev != 0:
            data[:, ch] = (data[:, ch] - data[:, ch].mean()) / stddev
    return data

batch_size = 35

# Create PyTorch Datasets
ds1 = EEGSurveyDataset("../data/muse2-recordings/surveys/Eyes open-closed Jared 2021-06-27 15:35:13.033803", 7665, transform=transform_normalize)
ds2 = EEGSurveyDataset("../data/muse2-recordings/surveys/Eyes open-closed Jared 2021-06-27 15:42:18.386982", 7665, transform=transform_normalize)
ds3 = EEGSurveyDataset("../data/muse2-recordings/surveys/Eyes open-closed Jared 2021-06-27 16:30:50.691073", 7665, transform=transform_normalize)
ds4 = EEGSurveyDataset("../data/muse2-recordings/surveys/Eyes open-closed Jared 2021-06-27 16:35:02.095245", 7665, transform=transform_normalize)
ds5 = EEGSurveyDataset("../data/muse2-recordings/surveys/Eyes open-closed Jared 2021-06-27 16:50:59.722742", 7665, transform=transform_normalize)
ds6 = EEGSurveyDataset("../data/muse2-recordings/surveys/Eyes open-closed Jared 2021-06-27 16:55:04.568128", 7665, transform=transform_normalize)
ds7 = EEGSurveyDataset("../data/muse2-recordings/surveys/Eyes open-closed Jared 2021-06-27 17:56:48.686557", 7665, transform=transform_normalize)
ds8 = EEGSurveyDataset("../data/muse2-recordings/surveys/Eyes open-closed Jared 2021-06-27 18:00:38.062883", 7665, transform=transform_normalize)
ds9 = EEGSurveyDataset("../data/muse2-recordings/surveys/Eyes open-closed Jared 2021-06-27 18:04:25.933181", 7665, transform=transform_normalize)
ds1, ds2, ds3, ds4, ds5, ds6, ds7, ds8, ds9 = ChunkedDataset(ds1, batch_size), ChunkedDataset(ds2, batch_size), ChunkedDataset(ds3, batch_size), ChunkedDataset(ds4, batch_size), ChunkedDataset(ds5, batch_size), ChunkedDataset(ds6, batch_size), ChunkedDataset(ds7, batch_size), ChunkedDataset(ds8, batch_size), ChunkedDataset(ds9, batch_size)
train_dataset = MultiDataset([ds1, ds2, ds3, ds4, ds5, ds9, ds8])
test_dataset = MultiDataset([ds6, ds7])
# Create DataLoaders
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

## PyTorch Model

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

n_channels = 4
n_outputs = 2

# Create model
net = nn.Sequential(
    # Pass input to a 1D convolutional layer with a kernel size of 3, apply to activation function.
    nn.Conv1d(n_channels, 32, 3),
    nn.ReLU(),

    # Pass previous layer output to a 1D convolutional layer with a kernel size of 2, apply to activation function,
    # and get the max value from each kernel.
    nn.Conv1d(32, 32, 2),
    nn.ReLU(),
    nn.MaxPool1d(kernel_size=2),

    # Pass previous layer output to a 1D convolutional layer with a kernel size of 2, apply to activation function,
    # and get the max value from each kernel. (same as previous layer)
    nn.Conv1d(32, 32, 2),
    nn.ReLU(),
    nn.MaxPool1d(kernel_size=2),

    # Flatten the convolutions. Input shape: (a, b, c), Output shape: (a, b*c)
    nn.Flatten(),
    
    #nn.Dropout(0.5),
    # ?
    # XXX: The first number needs to be updated each time the input shapes change. We could instead
    #      Create a class-based Module, and do a single pass through the conv portion of the network
    #      in order to determine the actual size.
    #      (This technique is shown in https://www.youtube.com/watch?v=1gQR24B3ISE&list=PLQVvvaa0QuDdeMyHEYc0gxFpYwHY2Qfdh&index=7).
    #      For now, we can update this value as needed by commenting out all layers after Flatten(), then running the code
    #      below and inspecting the output shape. The x[1] value should be the first arg in the following line.
    nn.Linear(1696, 512),  # ~= nn.LazyLinear(512)

    # Flatten the linear layer into the required number of outputs
    nn.Linear(512, n_outputs),
    nn.Softmax()
)

for i, data in enumerate(train_dataloader, 0):
    features, labels = data
    print("Model input shape:", features.shape)
    out = net(features.float())
    print("Model output shape:", out.shape)
    break

## Training the Model

In [None]:
from torch.utils.data import DataLoader
import torch.optim as optim
from tqdm import tqdm

# Define criterion and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)

# Train for n epochs
n = 80
test_while_training = True

loss_history = []
eval_history = []
count = 0
net.train()
for epoch in tqdm(range(n)):
    for i, data in enumerate(train_dataloader, 0):
        features, labels = data

        # Zero gradients
        optimizer.zero_grad()
        #net.zero_grad()

        # Forward
        predictions = net(features.float())
        
        # Compute loss
        loss = criterion(predictions, labels.long())
        loss_history.append(loss.item())
        
        # Backward
        loss.backward()
        
        # Optimize
        optimizer.step()
        
        count += 1
        
    # Evaluate the model against the test dataset
    if test_while_training:
        total = 0
        correct = 0
        with torch.no_grad():
            for i, data in enumerate(test_dataloader, 0):
                features, labels = data
                out = net(features.float())
                preds = F.log_softmax(out, dim=1).argmax(dim=1)
                total += labels.size(0)
                correct += (preds == labels).sum().item()

        eval_history.append(correct / total)

%matplotlib inline
plt.title("Training session")
print("Passes per epoch:", count / n)
print("Final Loss:", loss_history[-1])
plt.plot(np.linspace(0, 1, len(loss_history)), loss_history, label="Loss")
if test_while_training:
    print(f"Final Accuracy: {int(100*eval_history[-1])}%")
    plt.plot(np.linspace(0, 1, len(eval_history)), eval_history, label="Accuracy")
plt.legend()
plt.show()

In [None]:
net.eval()
total = 0
correct = 0
with torch.no_grad():
    Model input shape: torch.Size([35, 4, 219])
    for i, data in enumerate(test_dataloader, 0):
        features, labels = data
        out = net(features.float())
        preds = F.log_softmax(out, dim=1).argmax(dim=1)
        total += labels.size(0)
        correct += (preds == labels).sum().item()

model_accuracy = int(correct / total * 100)
print("Correct:", correct, "/", total, "-", f"{model_accuracy}%")

## Saving the Model

In [None]:
timestamp = str(datetime.now())
torch.save(net.state_dict(), f"../models/{timestamp}-Muse_EEG_eyes_open-{model_accuracy}percent.pt")
with open(f"../models/{timestamp}-Muse_EEG_eyes_open-{model_accuracy}percent.model", "w") as f:
    f.write(str(net))