<a href="https://colab.research.google.com/github/iamsoroush/DeepEEGAbstractor/blob/master/cv_hmdd_6s.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
#@title # Clone the repository and upgrade Keras {display-mode: "form"}

!git clone https://github.com/iamsoroush/DeepEEGAbstractor.git
!pip install --upgrade keras

In [0]:
#@title # Imports {display-mode: "form"}

import os
import pickle
import sys
sys.path.append('DeepEEGAbstractor')

import numpy as np

from src.helpers import CrossValidator
from src.models import EEGNet, ESTCNNModel
from src.dataset import DataLoader, Splitter, FixedLenGenerator

from google.colab import drive
drive.mount('/content/gdrive')

In [0]:
#@title # Set data path {display-mode: "form"}

#@markdown ---
#@markdown Type in the folder in your google drive that contains numpy _data_ folder:

parent_dir = 'soroush'#@param {type:"string"}
gdrive_path =  os.path.abspath(os.path.join('gdrive/My Drive', parent_dir))
data_dir = os.path.join(gdrive_path, 'data')
cv_results_dir = os.path.join(gdrive_path, 'cross_validation')
if not os.path.exists(cv_results_dir):
    os.mkdir(cv_results_dir)

print('Data directory: ', data_dir)
print('Cross validation results dir: ', cv_results_dir)

In [0]:
#@title ## Set Parameters

batch_size = 80
epochs = 50
k = 10
t = 10
instance_duration = 6
instance_overlap = 1.5
sampling_rate = 256 #@param {type:"number"}
n_channels = 20 #@param {type:"number"}
task = 'hmdd'
data_mode = 'cross_subject'

In [0]:
#@title ## EEGNet

model_name = 'EEGNet'

train_generator = FixedLenGenerator(batch_size=batch_size,
                                    duration=instance_duration,
                                    overlap=instance_overlap,
                                    sampling_rate=sampling_rate,
                                    is_train=True)

test_generator = FixedLenGenerator(batch_size=8,
                                   duration=instance_duration,
                                   overlap=instance_overlap,
                                   sampling_rate=sampling_rate,
                                   is_train=False)

params = {'task': task,
          'data_mode': data_mode,
          'main_res_dir': cv_results_dir,
          'model_name': model_name,
          'epochs': epochs,
          'train_generator': train_generator,
          'test_generator': test_generator,
          't': t,
          'k': k,
          'channel_drop': True}

validator = CrossValidator(**params)

dataloader = DataLoader(data_dir,
                        task,
                        data_mode,
                        sampling_rate,
                        instance_duration,
                        instance_overlap)
data, labels = dataloader.load_data()

input_shape = (sampling_rate * instance_duration,
               n_channels)

model_obj = EEGNet(input_shape,
                   model_name=model_name)

scores = validator.do_cv(model_obj,
                         data,
                         labels)

In [0]:
#@title ## E-ST-CNN

model_name = 'ESTCNN'

train_generator = FixedLenGenerator(batch_size=batch_size,
                                    duration=instance_duration,
                                    overlap=instance_overlap,
                                    sampling_rate=sampling_rate,
                                    is_train=True)

test_generator = FixedLenGenerator(batch_size=8,
                                   duration=instance_duration,
                                   overlap=instance_overlap,
                                   sampling_rate=sampling_rate,
                                   is_train=False)

params = {'task': task,
          'data_mode': data_mode,
          'main_res_dir': cv_results_dir,
          'model_name': model_name,
          'epochs': epochs,
          'train_generator': train_generator,
          'test_generator': test_generator,
          't': t,
          'k': k,
          'channel_drop': True}

validator = CrossValidator(**params)

dataloader = DataLoader(data_dir,
                        task,
                        data_mode,
                        sampling_rate,
                        instance_duration,
                        instance_overlap)
data, labels = dataloader.load_data()

input_shape = (sampling_rate * instance_duration,
               n_channels)

model_obj = ESTCNNModel(input_shape,
                        model_name=model_name)

scores = validator.do_cv(model_obj,
                         data,
                         labels)