### Experiment : Classify Spampinato eeg data
**Goal:** Try to verify that spampinato eeg data can be classify
</br>
**Model:** Bidirectional LSTM with attention
</br>
**Data:** spampinato eeg
</br>
**Result:** <font color='green'>Acc 48%</font>
</br>
**Conclusion:** The accuracy is better than EEGNet and LSTM

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from IPython.display import clear_output
import time
from datetime import datetime
import sys, os
import torch.nn as nn
import torch
from torchvision import datasets

from libs.EEGModels.EEGNet import EEGNet
from libs.Classification_Training import Classification_Training
from libs.utilities import save_result_csv

from libs.utilities import get_freer_gpu
device = torch.device(get_freer_gpu()) if torch.cuda.is_available() else torch.device("cpu")
start_time = time.time()
print(device)

cuda:1


In [2]:
print("======================================================")

try:
    print(sys.argv[3])
    is_debug    = False
except IndexError:
    is_debug    = True
    
print("---------------Running agument------------------")
print(f'Debug mode is : {is_debug}')


if is_debug:
    participant_id  = "all"
    model_name      = "LSTM_attention"                  #         EEGNet , CNN1D
    model_round     = 1
    result_file     = f"result_{model_name}"
    n_epochs        = 500
else:
    participant_id      = sys.argv[1]
    if participant_id != "all":
        participant_id = int(participant_id)
    model_name          = sys.argv[2]
    model_round         = int(sys.argv[3])
    result_file         = sys.argv[4]
    n_epochs            = 300
    

print("Device = ", device)
print("participant_id :" , participant_id)
print("model_name :" , model_name)
print("result_file :" , result_file)
print("n_epochs :", n_epochs)

model_save_path    = "save/weight/Spam_EEG_classify"
workers   =  8 
batch_size  = 256

running_model= f"Spam_EEG_classify_{model_name}-p{participant_id}"
running_time = datetime.today().strftime('%Y%m%d-t%H%M%S')
env = [running_model,running_time ]

---------------Running agument------------------
Debug mode is : True
Device =  cuda:1
participant_id : all
model_name : LSTM_attention
result_file : result_LSTM_attention
n_epochs : 500


In [3]:
input_dim     = 128
hidden_dim    = 32
num_layers    = 1
num_classes   = 40
bidirectional = True
dropout       = 0
learning_rate = 1e-3

scheduler = "StepLR"

In [4]:
from torch.utils.data import Dataset, DataLoader
import pickle

class Spam_EEG_ds(Dataset):
    file_location = 'dataset/spampinoto_dataset/Spam_EEG_14_70.pickle'

    # __FILE_VAL_LOC__ = os.path.join(__dirname__, 'content/very_nice_dataset/')

    def __init__(self, device, participant_id=1):
        super(Spam_EEG_ds, self).__init__()

        self.participant_id = participant_id
        self.whole_data = pickle.load(open(self.file_location, "rb"))
        self.curr_participant_data = self.whole_data[self.participant_id]
        self.device = device

    def __getitem__(self, idx):
        
        eeg   = self.curr_participant_data[0][idx].to(self.device)
        label = self.curr_participant_data[1][idx].to(self.device)
        return eeg, label

    def __len__(self):
        return len(self.curr_participant_data[1])

    def get_name(self):
        return "spampinato_data"

    def change_participant_id(self, participant_id=1):
        self.participant_id        = participant_id
        self.curr_participant_data = self.whole_data[self.participant_id]
    def get_eeg_shape(self):
        return self.curr_participant_data[0].shape
    
    def set_eeg_shape(self, cnn_type) :
        if cnn_type == "CNN1D":
            print(cnn_type)
            print(type( self.curr_participant_data[0]))
            #self.curr_participant_data[0] = self.curr_participant_data[0].view(-1,  128, 1,  491 )
            #self.curr_participant_data[0] = self.curr_participant_data[0].permute(0 ,2 ,1 ,3)
            
            cur_eeg   = self.curr_participant_data[0].permute(0 ,2 ,1 ,3)
            cur_label = self.curr_participant_data[1]
            
            self.curr_participant_data = (cur_eeg, cur_label)
            
        elif cnn_type == "CNN2D" or cnn_type=="EEGNet":
            print(cnn_type)
            cur_eeg   = self.curr_participant_data[0].permute(0 ,2 ,1 ,3)
            cur_label = self.curr_participant_data[1]
            self.curr_participant_data = (cur_eeg, cur_label)
        elif "LSTM" in cnn_type :
            cur_eeg   = torch.squeeze(self.curr_participant_data[0]).permute(0, 2, 1)
            cur_label = self.curr_participant_data[1]
            self.curr_participant_data = (cur_eeg, cur_label)

In [5]:
from torch.utils.data import DataLoader, Subset, Dataset, TensorDataset
from sklearn.model_selection import train_test_split


def train_test_split_ds( dataset_in , test_size=0.1, batch_size=10 ):
    seed       = 42

    _labels = dataset_in.whole_data[dataset_in.participant_id][1]
    # generate indices: instead of the actual data we pass in integers instead
    train_indices, test_indices, _, _ = train_test_split(
                                        range(len(dataset_in))   ,
                                        _labels                   ,
                                        stratify     = _labels    ,
                                        test_size    = test_size ,
                                        random_state = seed
                                    )
    
    # generate subset based on indices
    train_split = Subset(dataset_in, train_indices)
    test_split  = Subset(dataset_in, test_indices)

    # create batches
    train_iterator = DataLoader( train_split, batch_size=batch_size, shuffle=True  )
    val_iterator   = DataLoader( test_split,  batch_size=batch_size, shuffle=False )
    
    return train_iterator, val_iterator

In [6]:
from torch.nn import functional as F

class LSTM_attention(nn.Module):
    '''
    Expected Input Shape: (batch, seq_len, channels)
    '''
    def __init__(self, input_dim, hidden_dim, num_layers, num_classes, bidirectional, dropout):
        super(LSTM_attention, self).__init__()
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.lstm       = nn.LSTM(input_dim, 
                                  hidden_dim, 
                                  num_layers, 
                                  bidirectional=bidirectional, 
                                  dropout=dropout, 
                                  batch_first=True
                                 )
        self.fc      = nn.Linear(hidden_dim * num_layers *2 , num_classes)
        self.softmax = nn.LogSoftmax(dim=1)
        
        self.lin_Q = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.lin_K = nn.Linear()
        self.lin_V = nn.Linear()

    def attention_net(self, lstm_output, final_state):
        # lstm_output.shape(out)  =  ([256, 491, 64])
        # final_state.shape(hn)   =  ([256, 64])
        
        hidden = final_state.unsqueeze(2)                         # hidden shape : [batch_size, n_hidden * num_directions(=2), 1(=n_layer)]
        #print("hidden.shape attention_net: ", hidden.shape)      # ([256, 64, 1])
       
        attn_weights = torch.bmm(lstm_output, hidden).squeeze(2)  # attn_weights : [batch_size, seq_len, 1]
        #print("attn_weights.shape: ", attn_weights.shape )       # ([256, 491])
        
        soft_attn_weights = F.softmax(attn_weights, 1)
        # [batch_size, n_hidden * num_directions(=2), seq_len] * [batch_size, seq_len, 1] = [batch_size, n_hidden * num_directions(=2), 1]
        
        context = torch.bmm(lstm_output.transpose(1, 2), soft_attn_weights.unsqueeze(2)).squeeze(2)
        return 
        
    def forward(self, x):
        
        #print("X input.shape", x.shape)
        
        # Set initial hidden and cell states
        #*2 because it's bidirectional 
        h0 = torch.zeros(self.num_layers * 2 , x.size(0), self.hidden_dim).to(device).float()
        c0 = torch.zeros(self.num_layers * 2 , x.size(0), self.hidden_dim).to(device).float()
       
        # Forward propagate LSTM
        out, (hn, cn) = self.lstm(x, (h0, c0)) # out.shape : tensor of shape (batch_size, seq_length, hidden_size)
        #print("out.shape" , out.shape)  #  ([256, 491, 64])     
        #print(hn[-2,:,:].shape)
        #print(hn[-1,:,:].shape)

        #output, output_lengths = nn.utils.rnn.pad_packed_sequence(out, batch_first=True)
        hn = torch.cat((hn[-2,:,:], hn[-1,:,:]), dim = 1)
        #hn = [batch size, hidden dim * num directions]
        #print("hn.shape" , hn.shape)  # got ([256, 64])
        
        attn_output, attention = self.attention_net(out, hn)
        out = self.fc( attn_output  )

        return out

In [7]:
def get_acc_save_result( ) :
    
    # ==============================================
    # modelling
    print("...... Modelling")
    n_class         = 40
    print("n_class :" , n_class)
    
    model_to_train  = None
    eeg_shape = dataset.get_eeg_shape()
    




    if model_name == "Conv1D_LSTM" :
        print(f"select {model_name}")
        
        if  eeg_shape[2] != 1 or eeg_shape[1] != 128 :
            dataset.set_eeg_shape('Conv1D_LSTM')
#             print(dataset.get_eeg_shape())
            
        
#         dataset.set_eeg_shape(model_name)

        model_to_train = Conv1D_LSTM(input_dim, hidden_dim, num_layers, num_classes, bidirectional, dropout)
        model_to_train = model_to_train.float()

    elif  "LSTM" in model_name :
        dataset.set_eeg_shape('LSTM')
        
        print(f"select {model_name}")
        print("dataset.get_eeg_shape() " , dataset.get_eeg_shape() )

#         dataset.set_eeg_shape(model_name)

        model_to_train = LSTM_attention(input_dim, hidden_dim, num_layers, num_classes, bidirectional, dropout)
        model_to_train = model_to_train.float() #define precision as float to reduce running time

        
    training_obj                 = Classification_Training( model_obj = model_to_train ,  device= device ,learning_rate=learning_rate, criterion=nn.CrossEntropyLoss()  , env=env, visdom_update=False )
    training_obj.is_debug        = is_debug
    training_obj.is_plot_graph   = is_debug
    training_obj.do_training(   train_loader, val_loader , scheduler=scheduler,  n_epochs=n_epochs )

#     #=================================================
#     # save model weight 
#     print("..... Save model weight  " )

#     training_obj.save_state_dict(  f'{model_save_path}' )
    
    # ==============================================
    print("..... Save result to csv file")

    result_row = training_obj.get_result()
    # [ type(self.best_model).__name__ ,  self.best_train_acc, self.best_val_acc, self.best_test_acc , self.best_epoch ]
    data_name     = "Spampinato"
    par_name      =  participant_id
    result_row.insert(0,data_name)
    result_row.insert(1,par_name)
    result_row.insert(3,model_round)

#     headers =  ['data_name', 'participant', 'model_round', 'model_name', 'train_acc', 'val_acc', 'test_acc', 'best_epoch']
    headers =  ['data_name', 'participant',  'model_name', 'model_round', 'train_acc', 'val_acc', 'test_acc', 'best_epoch']
    save_result_csv(headers, result_row, f"{result_file}.csv")
    
    print(result_row)

In [None]:
dataset = Spam_EEG_ds(device,  participant_id=participant_id)
train_loader, val_loader  = train_test_split_ds(dataset, batch_size=batch_size)

# %%
org_eeg_shape = dataset.get_eeg_shape()
get_acc_save_result()


...... Modelling
n_class : 40
select LSTM_attention
dataset.get_eeg_shape()  torch.Size([11965, 440, 128])
Model train device: cuda:1
The model LSTM_attention has 44,072 trainable parameters
...Do trainning for : 500 epochs
Epoch: 01/500 |	Train Loss: 3.68794   | Train Acc: 3.37%   |	 Val. Loss: 3.66987  | Val. Acc: 4.34%   |	 LR: 0.001 |	Best epoch : None
Epoch: 02/500 |	Train Loss: 3.63613   | Train Acc: 5.67%   |	 Val. Loss: 3.62475  | Val. Acc: 5.10%   |	 LR: 0.001 |	Best epoch : 0
Epoch: 03/500 |	Train Loss: 3.57322   | Train Acc: 7.41%   |	 Val. Loss: 3.55721  | Val. Acc: 6.35%   |	 LR: 0.001 |	Best epoch : 1
Epoch: 04/500 |	Train Loss: 3.47579   | Train Acc: 8.58%   |	 Val. Loss: 3.46529  | Val. Acc: 7.85%   |	 LR: 0.001 |	Best epoch : 2
Epoch: 05/500 |	Train Loss: 3.35846   | Train Acc: 9.83%   |	 Val. Loss: 3.33128  | Val. Acc: 8.86%   |	 LR: 0.001 |	Best epoch : 3
Epoch: 06/500 |	Train Loss: 3.23967   | Train Acc: 11.25%   |	 Val. Loss: 3.22088  | Val. Acc: 10.03%   |	 LR: 0.

In [None]:
print(f"Model running time : { (time.time()-start_time)/(60)}") 

In [None]:
print("participant_id :", participant_id)