In [1]:
# Imports for Tensor
import math
from collections import OrderedDict
import os
from tempfile import TemporaryDirectory
from typing import Tuple

import torch
from torch import nn, Tensor


import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.utils.data import dataset

In [2]:
#!ls 'gdrive/My Drive/Muller Group Drive/Ear EEG/Drowsiness_Detection/classifier_TBME'
#!ls C:\Users\arya_bastani\Documents\ear_eeg\data\ear_eeg_data
ear_eeg_base_path = '/data/shared/signal-diffusion/'
ear_eeg_data_path = ear_eeg_base_path + 'eeg_classification_data/ear_eeg_data/ear_eeg_clean'

%ls {ear_eeg_data_path}

In [3]:
# imports
import numpy as np

#import support scripts: pull_data
import support_scripts.read_in_ear_eeg as read_in_ear_eeg
import support_scripts.read_in_labels as read_in_labels
import support_scripts.eeg_filter as eeg_filter

In [4]:
##################
# READ-IN EAR EEG
##################
# NOTE, this takes a long time to run.
# (It could be parallelized to reduce runtime)

# name of spreadsheet with experiment details
# details_spreadsheet = 'gdrive/My Drive/Muller Group Drive/Ear EEG/Drowsiness_Detection/classifier_TBME/classification_scripts/trial_details_spreadsheet_basic.csv'
details_spreadsheet = ear_eeg_base_path + 'eeg_classification_data/ear_eeg_data/trial_details_spreadsheet_good.csv'

# file path to ear eeg data (must be formated r'filepath\\')
#data_filepath = r'C:\Users\Carolyn\OneDrive\Documents\school\berkeley\research\ear_eeg_classification_framework\experimental_recordings\drowsiness_studies\ear_eeg\\'
data_filepath = ear_eeg_base_path + 'eeg_classification_data/ear_eeg_data/ear_eeg_clean/'

# user number or all users('all', 'ryan', 'justin', 'carolyn', 'ashwin', 'connor')
input_users = 'all'

# channels of eeg to read in for each trial (must include 5 and 11 if re-refernecing is enabled in the next block)
data_chs = [1,2,3,4,5,7,8,9,10,11]

# sampling frequency of system (fs=1000 for wandmini)
fs = 1000

# plot eeg data that is read in
plot_raw_data_enable = False

# call read in ear eeg
all_raw_data, filenames, data_lengths, file_users, refs = read_in_ear_eeg.read_in_clean_data(details_spreadsheet, data_filepath, input_users, data_chs, fs, plot_raw_data_enable)
#all_raw_data = np.array(all_raw_data)

In [5]:
all_raw_data[21]

In [6]:
#################
# READ-IN LABELS
#################

# Note: label read in will match Ear EEG read in
# (same trials will be read in, and the experiment lengths will be the same)

# file path to labels(must be formated r'filepath\\')
#label_filepath = r'C:\Users\Carolyn\OneDrive\Documents\school\berkeley\research\ear_eeg_classification_framework\experimental_recordings\drowsiness_studies\labels\\'
label_filepath = ear_eeg_base_path + 'eeg_classification_data/ear_eeg_data/labels//'

# plot the labels that are read in
plot_labels_enable = False

# call read in labels
all_labels = read_in_labels.read_in_labels(filenames, data_lengths, label_filepath, plot_labels_enable)
all_labels = np.array(all_labels)

In [7]:
print(all_labels[21])

In [8]:
filtered_data = eeg_filter.filter_studies(all_raw_data)


print(len(all_raw_data))
print(all_raw_data[0].shape)

print(len(filtered_data))
print(filtered_data[0].shape)

In [9]:
# No longer have a need for the original raw data so we delete
del all_raw_data

In [10]:
# Data constants
carolyn_indices = [0,1,2,3,4]
ryan_indices = [5,6,7,8,9]
justin_indices = [10,11,12,13,14]
conor_indices = [15,16,17,18,19]
avi_indices = [20,21]
train_perc, val_perc, test_perc = 0.55, 0.30, .15
train_ind = [2,3,4,8,9,12,13,14,15,17,18,19,21]
val_ind = [1,6,11,16,20,7]
test_ind = [0,5,10]

# Model Constants
window_size = 10 # Seconds
sample_rate = 1000 # Hertz

# Model Params
seq_size = window_size * sample_rate

In [11]:
# Split up into train, val, and test datasets
train_data, val_data, test_data = [],[],[]
train_labels, val_labels, test_labels = [],[],[]


for i in range(len(filtered_data)):
    if i in train_ind:
        train_data.append(filtered_data[i])
        train_labels.append(all_labels[i])
    elif i in val_ind:
        val_data.append(filtered_data[i])
        val_labels.append(all_labels[i])
    else:
        test_data.append(filtered_data[i])
        test_labels.append(all_labels[i])

In [12]:
len(train_data[0].T[:,0])

In [13]:
del filtered_data

In [14]:
def formatt_data(data_set, seq_len):
    index_sample_count_map = OrderedDict()
    # Data needs to be input as (samples, channels), for ex: (2,400,000, 10)
    formatted_datasets = []
    for i in range(len(data_set)):
        data = data_set[i]
        data_length = data.shape[0]
        num_seqs = int(data_length/seq_len)
        
        index_sample_count_map[i] = num_seqs
        
        formatted_data = np.array(np.split(data, num_seqs))
        formatted_datasets.append(formatted_data[:,:128,:]) # 256 is for reducing len TEMPORARILY
    return formatted_datasets, index_sample_count_map

def one_hot_encode(input):
    one_hot_formatted_seq_labels = []
    b = np.zeros((int(input.size), int(input.max() + 1)))
    b[np.arange(input.size), input] = 1
    one_hot_labels = np.array(b)

    return one_hot_labels

def formatt_labels(labels_set, seq_len):
    formatted_labels = []
    for i in range(len(labels_set)):
        labels = labels_set[i]
        old_one_hot_labels = one_hot_encode(labels)
        labels_length = old_one_hot_labels.shape[0]
        num_seqs = int(labels_length/seq_len)
        

        new_labels = np.array(np.split(old_one_hot_labels[:num_seqs*seq_len], num_seqs))
        formatted_labels.append(new_labels[:,:128,:])
    return formatted_labels

In [15]:
print(np.array(train_data).shape)
print(np.array(train_data).T.shape)

In [16]:
# Format the data
proc_train_X, train_seq_count_map = formatt_data(train_data, seq_size)
#proc_val_X, val_seq_count_map = formatt_data(val_data, seq_size)

In [17]:
# Format the labels
proc_train_y = formatt_labels(train_labels, seq_size)

#proc_val_y = formatt_labels(val_labels, seq_size)

In [18]:
print(proc_train_X[0].shape)
proc_train_y[0].shape

In [19]:
# Save the data as tensors (THIS DATA IS FORMATTED REALLY POORLY, fix after the model stops getting overloaded)
batch_ids = []
preproc_path = os.path.join(ear_eeg_base_path, "saved_tensors")
os.makedirs(preproc_path, exist_ok=True)
for index in range(len(proc_train_X)):
    batch_id = preproc_path + '/recording_' + str(index) + '.pt'
    batch_ids.append(index)

    torch.save(proc_train_X[index], batch_id)

In [20]:
# Fix ME INCORPORATE THE ADDITION OF THE MAP's
class Dataset(torch.utils.data.Dataset):
    'Characterizes a dataset for PyTorch'
    def __init__(self, data_path, index_sample_count_map, labels):
        'Initialization'
        self.data_path = data_path
        self.labels = labels
        self.list_IDs = index_sample_count_map.keys()
        self.index_sample_count_map = index_sample_count_map
        
        recording_global_indices = []
        total_count = 0
        for recording_len in index_sample_count_map.values():
            total_count += recording_len
            recording_global_indices.append(total_count)
        
        self.recording_global_indices = recording_global_indices
        
    def __len__(self):
        'Denotes the total number of samples'
        
        return sum(self.recording_global_indices)

    def __getitem__(self, index):
        'Generates one sample of data'
        # Load data and get label
        recording_lens = self.index_sample_count_map.values()
        
        recording_index = 0
        lower_bound = 0
        upper_bound = 0
        
        for i in range(len(self.recording_global_indices)):
            upper_bound = self.recording_global_indices[i]
            if index >= lower_bound and index < upper_bound:
                recording_index = i
                break
            lower_bound = upper_bound
            
        full_recording_X = torch.load(self.data_path + '/recording_' + str(recording_index) + '.pt')
        full_recording_y = self.labels[recording_index]
        
        inside_recording_index = index - sum(self.recording_global_indices[:recording_index])

        X = full_recording_X[inside_recording_index]
        y = full_recording_y[inside_recording_index]
        
        return X, y

In [21]:
import torch

# CUDA for PyTorch
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
torch.backends.cudnn.benchmark = True

# Parameters
BATCH_SIZE = 256
SHUFFLE = False
NUM_WORKERS = 1

max_epochs = 10

# Datasets
partition = train_seq_count_map# IDs
labels = proc_train_y# Labels

# Generators
training_set = Dataset(preproc_path, partition, labels)
training_generator = torch.utils.data.DataLoader(training_set, batch_size=BATCH_SIZE, 
                                                 shuffle=SHUFFLE, num_workers=NUM_WORKERS)


"""
validation_set = Dataset(partition['validation'], labels)
validation_generator = torch.utils.data.DataLoader(validation_set, **params)
"""

In [22]:
import torch
import torch.nn as nn
import torch.optim as optim

class TransformerClassifier(nn.Module):
    def __init__(self, input_dim, output_dim, hid_dim, n_layers, n_heads, pf_dim, dropout, batch_firsty):
        super().__init__()
        
        #self.tok_embedding = nn.Embedding(input_dim, hid_dim)
        self.pos_embedding = nn.Embedding(1000, hid_dim)  # position embedding
        
        self.encoder_layer = nn.TransformerEncoderLayer(hid_dim, n_heads, pf_dim, dropout, batch_first=batch_firsty)
        self.encoder = nn.TransformerEncoder(self.encoder_layer, n_layers)
        
        self.fc = nn.Linear(input_dim, output_dim)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, src):
        # src: [src_len, batch_size]
        
        # create position tensor
        #pos = torch.arange(0, src.shape[0]).unsqueeze(1).repeat(1, src.shape[1]).to(src.device)
        
        # embed tokens and positions
        #tok_embedded = self.dropout(self.tok_embedding(src))  # [src_len, batch_size, hid_dim]
        #pos_embedded = self.dropout(self.pos_embedding(pos))  # [src_len, batch_size, hid_dim]
        #embedded = tok_embedded + pos_embedded
        
        # encode sequence
        
        print(type(src))
        
        encoded = self.encoder(src.float())  # [src_len, batch_size, hid_dim]
        
        # get final output and apply linear layer
        final_output = encoded.mean(dim=0)  # [batch_size, hid_dim]
        logits = self.fc(final_output)  # [batch_size, output_dim]
        
        return logits

In [29]:
# define hyperparameters
INPUT_DIM = 10 #Not needed we're not embedding
OUTPUT_DIM = 2
HID_DIM = INPUT_DIM
N_LAYERS = 4
N_HEADS = 2
PF_DIM = 256
DROPOUT = 0.1
BATCH_FIRST = True # True: (batch, seq, feature). False: (seq, batch, feature)

# create model instance
model = TransformerClassifier(INPUT_DIM, OUTPUT_DIM, HID_DIM, N_LAYERS, N_HEADS, PF_DIM, DROPOUT, BATCH_FIRST)

# define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())

# define training and evaluation functions
def train(model, iterator, optimizer, criterion):
    model.train()
    for src, trg in iterator:        
        print("src shape: ", src.shape)
        print("src shape: ", trg.shape)
        
        optimizer.zero_grad()
        output = model(src)
        print("output:", output.shape)
        print("target:", trg.shape)
        loss = criterion(output, trg)
        loss.backward()
        optimizer.step()
    return loss

def evaluate(model, iterator, criterion):
    model.eval()
    epoch_loss = 0
    with torch.no_grad():
        for batch in iterator:
            src, trg = batch
            output = model(src)
            loss = criterion(output, trg)
            epoch_loss += loss.item()
    return epoch_loss / len(iterator)

In [30]:
# define training loop
N_EPOCHS = 10
best_valid_loss = float('inf')
for epoch in range(N_EPOCHS):
    print(epoch)
    train_loss = train(model, training_generator, optimizer, criterion)
#     valid_loss = evaluate(model, valid_iterator, criterion)
#     if valid_loss < best_valid_loss:
#         best_valid_loss = valid_loss
#         torch.save(model.state_dict(), 'model.pt')
#     print(f'Epoch {epoch+1}: train loss={train_loss:.3f}, valid loss={valid_loss:.3f}')

# load best model and evaluate on test set

model.load_state_dict(torch.load('model.pt'))
test_loss = evaluate(model, test_iterator, criterion)
print(f'Test loss={test_loss:.3f}')