In [14]:
import os
import numpy as np
import matplotlib.pyplot as plt

from sklearn.pipeline import Pipeline
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.model_selection import ShuffleSplit, cross_val_score
import mne
from mne import Epochs, pick_types, events_from_annotations
from mne.datasets import eegbci
from mne.channels import make_standard_montage
from mne.io import concatenate_raws, read_raw_edf,read_raw_edf,read_raw_gdf,read_raw_fif
from mne.datasets import eegbci
from mne.decoding import CSP
from mne.filter import construct_iir_filter,create_filter
from sklearn.model_selection import train_test_split
import torch
from scipy.signal import iirfilter, sosfiltfilt
import torch.optim as optim  
from torch.utils.data import Dataset, DataLoader  
from torch.utils.data import Subset  
from torch import nn  
import torch.nn.functional as F  
from torch.utils.data import TensorDataset

In [15]:

def get_physio():
    subject =['01','02','03','04','05','06','07','08','09','10','11','12','13','14']
    path = "/root/EEG_Model/dataset/physio/MNE-eegbci-data/files/eegmmidb/1.0.0"
    folders = os.listdir(path)
    subject_count = 0
    train_path = []
    valid_path = []
    count = 0
    for fol in folders:
        if count == 10:
            break
        else:
            count +=1
        for i in range (len (subject)):
            if  subject[i] in ['01','02','03','04','05','06','07','08','09','10','13','14']:
                pass
            elif subject[i] in ['11','12']:
                file = "/root/EEG_Model/dataset/physio/MNE-eegbci-data/files/eegmmidb/1.0.0/{}/{}R{}.edf".format(fol,fol, subject[i])
                valid_path.append(file)
            
  
    return valid_path

def get_test_epoch(data_path,tmin,tmax,event_id,preprocess=False,ica=False):
    
    raw = concatenate_raws([read_raw_edf(f, preload=True,verbose='WARNING') for f in data_path])
    raw_data = raw.copy()
    eegbci.standardize(raw_data)
    montage = mne.channels.make_standard_montage('standard_1005')
    raw_data.set_montage(montage)
    print(raw_data.info['ch_names'])
    raw_data.rename_channels(lambda x: x.strip('.'))
    print(raw_data.info['ch_names'])
    print(raw_data.info['sfreq'])
    
    sfreq = 641
    nyq = sfreq / 2 
    f_p = 40.
    
    # Apply band-pass filter
    if preprocess == True:
        iir_param = dict(order=6, ftype='butter', output='sos')
        #iir_param = construct_iir_filter(iir_param, 40, None, 1000, 'low', return_copy=False) 
         
        #raw_data.filter(l_freq=0.05, h_freq=40.,fir_design='firwin', verbose=20)
        raw_data.filter(l_freq=0.05, h_freq=75.,method = 'iir',iir_params=iir_param,phase='zero')
        #raw_data.notch_filter(60,filter_length='auto', phase='zero')
        raw_data.notch_filter(50,filter_length='auto', phase='zero')
        
    if ica == True:
        ica = mne.preprocessing.ICA(n_components=64, max_iter=100)
        ica.fit(raw_data)
        ica.exclude = [1, 2]  # details on how we picked these are omitted here
        ica.plot_properties(raw_data, picks=ica.exclude)
        ica.apply(raw_data)
    #2 electrode        
    #raw_data.pick_channels(['C3','C4'])
    #16 electrode
    #raw_data.pick_channels(['FC3','FCz','FC5','C1','C2','C3','C4','C5','C6','Cz','CP3','CPz','CP4','P3','Pz','P4'])
    events, event_id = events_from_annotations(raw_data,event_id=event_id)
    picks = pick_types(raw_data.info, meg=False, eeg=True, stim=False, eog=False,
                       exclude='bads')
    reject_criteria = dict(eeg=100e-6)  #most frequency in this range is not brain components
    
    epochs = Epochs(raw_data, events, event_id, tmin, tmax, proj=True, picks=picks,
                    baseline=None,preload=True)
    labels = epochs.events[:, -1]
    return epochs.get_data(),labels,epochs,raw_data
    
def get_data():
    valid_path = get_physio()
        
    #tmin, tmax = -0.2, 0.4
    tmin, tmax = 0, 4
    event_id = dict(T1=0, T2=1)

    valid_epoch,valid_labels,raw_epoch,raw = get_test_epoch(valid_path,tmin,tmax,event_id,False)
    
    return valid_epoch,valid_labels,raw_epoch,raw

#train_epoch,valid_epoch,train_labels,valid_labels,raw_epoch,raw= get_data()
valid_epoch,valid_labels,raw_epoch,raw= get_data()

['FC5', 'FC3', 'FC1', 'FCz', 'FC2', 'FC4', 'FC6', 'C5', 'C3', 'C1', 'Cz', 'C2', 'C4', 'C6', 'CP5', 'CP3', 'CP1', 'CPz', 'CP2', 'CP4', 'CP6', 'Fp1', 'Fpz', 'Fp2', 'AF7', 'AF3', 'AFz', 'AF4', 'AF8', 'F7', 'F5', 'F3', 'F1', 'Fz', 'F2', 'F4', 'F6', 'F8', 'FT7', 'FT8', 'T7', 'T8', 'T9', 'T10', 'TP7', 'TP8', 'P7', 'P5', 'P3', 'P1', 'Pz', 'P2', 'P4', 'P6', 'P8', 'PO7', 'PO3', 'POz', 'PO4', 'PO8', 'O1', 'Oz', 'O2', 'Iz']
['FC5', 'FC3', 'FC1', 'FCz', 'FC2', 'FC4', 'FC6', 'C5', 'C3', 'C1', 'Cz', 'C2', 'C4', 'C6', 'CP5', 'CP3', 'CP1', 'CPz', 'CP2', 'CP4', 'CP6', 'Fp1', 'Fpz', 'Fp2', 'AF7', 'AF3', 'AFz', 'AF4', 'AF8', 'F7', 'F5', 'F3', 'F1', 'Fz', 'F2', 'F4', 'F6', 'F8', 'FT7', 'FT8', 'T7', 'T8', 'T9', 'T10', 'TP7', 'TP8', 'P7', 'P5', 'P3', 'P1', 'Pz', 'P2', 'P4', 'P6', 'P8', 'PO7', 'PO3', 'POz', 'PO4', 'PO8', 'O1', 'Oz', 'O2', 'Iz']
160.0
Used Annotations descriptions: ['T1', 'T2']
Not setting metadata
Not setting metadata
300 matching events found
No baseline correction applied
0 projection item

In [None]:
# 100068 events(epoch)
# 64 channel
# 961 Time(samples)
#(event,channel,time)
print(valid_epoch.shape)
print(valid_labels.shape)
print('---------------')
#print(valid_epoch.shape)
#print(valid_labels.shape)
X = valid_epoch[:, np.newaxis,:,:]
y = valid_labels
print(X.shape)
print(valid_epoch.shape[1])
print(valid_labels)
print(raw.info['ch_names'])
print(raw.info['nchan'])

In [3]:
#use 2d convolution and 3d input (1,channel,timewindow)
class gamenet(nn.Module):
    def __init__(self):
        super(gamenet,self).__init__()
        
        self.l1 = nn.Sequential(
            #in_channel = 16
            #out_channel or Filter size = 100
            #kernel size = (1,25)
            #stride = 1
            #padding = Same
            #Relu
            nn.Conv2d(1,100,kernel_size=(1,25),stride=1,padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(100)
        )
        self.l2 = nn.Sequential(
            #in_channel = 100
            #out_channel or Filter size = 100
            #kernel size = (16,1)
            #stride = 1
            #padding = Valid
            #Relu
            #nn.Conv2d(100,100,kernel_size=(16,1),stride=1,padding='valid')
            nn.Conv2d(100,100,kernel_size=(64,1),stride=1,padding='valid'),
            nn.ReLU(),
            nn.BatchNorm2d(100)
        )
        self.l3 = nn.Sequential(
            #in_channel = 100
            #out_channel = 50
            #kernel size = (1,30)
            #stride = 1
            #padding = Same
            #Relu
            nn.Conv2d(100,50,kernel_size=(1,30),stride=1,padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(50)
        )
        self.maxpooling1 = nn.MaxPool2d(kernel_size=(1,7),stride=5)
        self.l4 = nn.Sequential(
            #in_channel = 50
            #out_channel = 50
            #kernel size = (1,30)
            #stride = 1
            #padding = Same
            #Relu
            nn.Conv2d(50,50,kernel_size=(1,30),stride=1,padding='same'),
            nn.ReLU(),
            nn.BatchNorm2d(50)
        )
        self.maxpooling2 = nn.MaxPool2d(kernel_size=(1,3),stride=2)
        

        self.flatten = nn.Sequential(
            nn.Flatten(),
            nn.BatchNorm1d(3150),
            nn.Dropout(0.15)
        )
        self.fc1 = nn.Sequential(
            nn.Linear(3150,1024),
            nn.ReLU(),
            nn.BatchNorm1d(1024),
            nn.Dropout(0.15)
        )
        self.fc2 = nn.Sequential(
            nn.Linear(1024,512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Dropout(0.15)
        )
        self.fc3 = nn.Sequential(
            nn.Linear(512,256),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Dropout(0.15)
        )
        self.fc3 = nn.Sequential(
            nn.Linear(512,256),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Dropout(0.15)
        )
        self.fc4 = nn.Sequential(
            nn.Linear(256,128),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Dropout(0.15)
        )
        self.fc5 = nn.Sequential(
            nn.Linear(128,64),
            nn.ReLU(),
            nn.BatchNorm1d(64),
            nn.Dropout(0.15)
        )
        self.fc6 = nn.Sequential(
            nn.Linear(64,32),
            nn.ReLU(),
            nn.BatchNorm1d(32),
            nn.Dropout(0.15)
        )
        self.softmax = nn.Sequential(
            nn.Linear(32,2),
            nn.Softmax()
        )
    def forward(self,x):
        
        out = self.l1(x)
        out = self.l2(out)
        out = self.l3(out)
        out = self.maxpooling1(out)
        out = self.l4(out)
        out = self.maxpooling2(out)
        out = self.flatten(out)
        out = self.fc1(out)
        out = self.fc2(out)
        out = self.fc3(out)
        out = self.fc4(out)
        out = self.fc5(out)
        out = self.fc6(out)
        out = self.softmax(out)
        return out
        

In [4]:
model = gamenet()
path = '/root/EEG_Model/save_weight/Physionet_Gamenet_64elec_executedImagine_Wandb_Newcuda-423-bestacc86.00.pth'
model.load_state_dict(torch.load(path))
model.eval()

gamenet(
  (l1): Sequential(
    (0): Conv2d(1, 100, kernel_size=(1, 25), stride=(1, 1), padding=same)
    (1): ReLU()
    (2): BatchNorm2d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (l2): Sequential(
    (0): Conv2d(100, 100, kernel_size=(64, 1), stride=(1, 1), padding=valid)
    (1): ReLU()
    (2): BatchNorm2d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (l3): Sequential(
    (0): Conv2d(100, 50, kernel_size=(1, 30), stride=(1, 1), padding=same)
    (1): ReLU()
    (2): BatchNorm2d(50, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (maxpooling1): MaxPool2d(kernel_size=(1, 7), stride=5, padding=0, dilation=1, ceil_mode=False)
  (l4): Sequential(
    (0): Conv2d(50, 50, kernel_size=(1, 30), stride=(1, 1), padding=same)
    (1): ReLU()
    (2): BatchNorm2d(50, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (maxpooling2): MaxPool2d(kernel_size=(1, 3), stride=2, padding=0, dilation

In [12]:
#(600, 1, 64, 641)
test_input = torch.rand(50,1,64,641)
output = model(test_input)
print(test_input.shape)
print(output)

_, predicted = torch.max(output, 1)
print(predicted)

torch.Size([50, 1, 64, 641])
tensor([[1., 0.],
        [1., 0.],
        [1., 0.],
        [0., 1.],
        [1., 0.],
        [1., 0.],
        [0., 1.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [0., 1.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [0., 1.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.]], grad_fn=<SoftmaxBackward0>)
tensor([0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0