<a href="https://colab.research.google.com/github/henry880127/EEGNet_related/blob/main/EEGNet_example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import scipy
import scipy.io as sio # cannot use for v7.3 mat file
import numpy as np
import matplotlib.pyplot as plt
from scipy import signal
import os
import pickle
from EEGNet_function import EEGNet
from sklearn.preprocessing import OneHotEncoder
from sklearn.model_selection import KFold
import pandas as pd

## Functions

### data prprocessing

In [2]:
def loadPickle(pklDir):
    with open(pklDir,'rb') as fp:
        dd = pickle.load(fp)
        return dd

def reshape2Input(a):
    newshape = (a.shape[2], 1, a.shape[0], a.shape[1])  # trials , 1 , EEG_channels , sample_points
    return a.reshape(newshape)

def saveResult(df, save_dir):
    df.to_csv(save_dir)

In [3]:
def random_trials(full_data, output_size, axis_to_choose):
    # Create a sample array

    # Randomly select an index from each column
    random_indices = np.random.choice(full_data.shape[axis_to_choose], output_size, replace=False)
    # Use the selected indices to access the data
    selected_data = np.take(full_data, indices=random_indices, axis=axis_to_choose)
    return selected_data

### EEGNet

In [4]:
def EEGNet_fit(folder_dir, random_select_trials = False, pkl_name='unname', logs_dir='logs', K=None, startPt=0, epoch_length=1000):
    '''data'''
    dictData = loadPickle(f"{folder_dir}/{pkl_name}")
    targetEEG_train = reshape2Input(dictData['targetEEG_train'])[
        :, :, :, startPt:startPt+epoch_length]
    targetEEG_test = reshape2Input(dictData['targetEEG_test'])[
        :, :, :, startPt:startPt+epoch_length]
    nontargetEEG_train = reshape2Input(dictData['nontargetEEG_train'])[
        :, :, :, startPt:startPt+epoch_length]
    nontargetEEG_test = reshape2Input(dictData['nontargetEEG_test'])[
        :, :, :, startPt:startPt+epoch_length]

    # Decide which of the above data has more trials,
    # and randomly choose the trials form larger one.
    if(random_select_trials==True):
        if(targetEEG_train.shape[0] > nontargetEEG_train.shape[0]):
            targetEEG_train = random_trials(targetEEG_train, nontargetEEG_train.shape[0], 0)
        elif(targetEEG_train.shape[0] < nontargetEEG_train.shape[0]):
            nontargetEEG_train = random_trials(nontargetEEG_train, targetEEG_train.shape[0], 0)
        else:
            pass
        if(targetEEG_test.shape[0] > nontargetEEG_test.shape[0]):
            targetEEG_test = random_trials(targetEEG_test, nontargetEEG_test.shape[0], 0)
        elif(targetEEG_test.shape[0] < nontargetEEG_test.shape[0]):
            nontargetEEG_test = random_trials(nontargetEEG_test, targetEEG_test.shape[0], 0)
        else:
            pass
    epochs_train = np.concatenate((targetEEG_train, nontargetEEG_train), axis=0)
    epochs_test = np.concatenate((targetEEG_test, nontargetEEG_test), axis=0)

    # Labeling
    encoder = OneHotEncoder(sparse=False)
    y_train = np.ones((targetEEG_train.shape[0]))
    y_train = np.concatenate((y_train, np.ones(nontargetEEG_train.shape[0])+1))
    y_train = encoder.fit_transform(y_train.reshape(-1, 1))
    y_test = np.ones((targetEEG_test.shape[0]))
    y_test = np.concatenate((y_test, np.ones(nontargetEEG_test.shape[0])+1))
    y_test = encoder.fit_transform(y_test.reshape(-1, 1))

    # Shuffle
    num_samples = epochs_train.shape[0]
    shuffled_indices = np.arange(num_samples)
    np.random.shuffle(shuffled_indices)
    epochs_train = epochs_train[shuffled_indices, :, :, :]
    y_train = y_train[shuffled_indices,:]
    num_samples = epochs_test.shape[0]
    shuffled_indices = np.arange(num_samples)
    np.random.shuffle(shuffled_indices)
    epochs_test = epochs_test[shuffled_indices, :, :, :]
    y_test = y_test[shuffled_indices,:]
    print('y_train.shape:', y_train.shape)
    print('y_test.shape:', y_test.shape)
    print('epochs_train.shape:', epochs_train.shape)
    print('epochs_test.shape:', epochs_test.shape)
    # savingFoldername = 'data_check'
    # saveResult(pd.DataFrame(y_train), f'./results/{savingFoldername}/y_train.csv')
    # saveResult(pd.DataFrame(y_test), f'./results/{savingFoldername}/y_test.csv')