In [1]:
from glob import glob
import scipy.io
import torch.nn as nn
import torch
import numpy as np
import mne

In [2]:
IDD_data_path='./data/Data/CleanData/CleanData_TDC/Rest'
TDC_data_path='./data/Data/CleanData/CleanData_IDD/Rest'

In [3]:
def convertmat2mne(data):
  ch_names = ['AF3', 'F7', 'F3', 'FC5', 'T7', 'P7', 'O1', 'O2', 'P8', 'T8', 'FC6', 'F4', 'F8', 'AF4']
  ch_types = ['eeg'] * 14
  sampling_freq=128
  info = mne.create_info(ch_names, ch_types=ch_types, sfreq=sampling_freq)
  info.set_montage('standard_1020')
  data=mne.io.RawArray(data, info)
  data.set_eeg_reference()
  data.filter(l_freq=1,h_freq=30)
  epochs=mne.make_fixed_length_epochs(data,duration=4,overlap=0)
  return epochs.get_data()

In [4]:
%%capture
idd_subject=[]
for idd in glob(IDD_data_path+'/*.mat'):
  data=scipy.io.loadmat(idd)['clean_data']
  data=convertmat2mne(data)
  idd_subject.append(data)

In [5]:
%%capture
tdc_subject=[]
for tdc in glob(TDC_data_path+'/*.mat'):
  data=scipy.io.loadmat(tdc)['clean_data']
  data=convertmat2mne(data)
  tdc_subject.append(data)

In [6]:
control_epochs_labels=[len(i)*[0] for i in tdc_subject]
patients_epochs_labels=[len(i)*[1] for i in idd_subject]
print(len(control_epochs_labels),len(patients_epochs_labels))

7 7


In [7]:
data_list=tdc_subject+idd_subject
label_list=control_epochs_labels+patients_epochs_labels
groups_list=[[i]*len(j) for i, j in enumerate(data_list)]
print(len(data_list),len(label_list),len(groups_list))

14 14 14


In [8]:
from sklearn.model_selection import GroupKFold,LeaveOneGroupOut
from sklearn.preprocessing import StandardScaler
gkf=GroupKFold()
from sklearn.base import TransformerMixin,BaseEstimator
from sklearn.preprocessing import StandardScaler
#https://stackoverflow.com/questions/50125844/how-to-standard-scale-a-3d-matrix
class StandardScaler3D(BaseEstimator,TransformerMixin):
    #batch, sequence, channels
    def __init__(self):
        self.scaler = StandardScaler()

    def fit(self,X,y=None):
        self.scaler.fit(X.reshape(-1, X.shape[2]))
        return self

    def transform(self,X):
        return self.scaler.transform(X.reshape( -1,X.shape[2])).reshape(X.shape)

In [9]:
import numpy as np
data_array=np.concatenate(data_list)
label_array=np.concatenate(label_list)
group_array=np.concatenate(groups_list)
data_array=np.moveaxis(data_array,1,2)

print(data_array.shape,label_array.shape,group_array.shape)

(420, 512, 14) (420,) (420,)


In [10]:
accuracy=[]
for train_index, val_index in gkf.split(data_array, label_array, groups=group_array):
    train_features,train_labels=data_array[train_index],label_array[train_index]
    val_features,val_labels=data_array[val_index],label_array[val_index]
    scaler=StandardScaler3D()
    train_features=scaler.fit_transform(train_features)
    val_features=scaler.transform(val_features)


In [11]:
train_features.shape


(360, 512, 14)

In [12]:
from tensorflow.keras.layers import Input,Dense,concatenate,Flatten,GRU,Conv1D
from tensorflow.keras.models import Model

In [13]:
def block(input):
  conv1 = Conv1D(32, 2, strides=2,activation='relu',padding="same")(input)
  conv2 = Conv1D(32, 4, strides=2,activation='relu',padding="causal")(input)
  conv3 = Conv1D(32, 8, strides=2,activation='relu',padding="causal")(input)
  x = concatenate([conv1,conv2,conv3],axis=2)
  return x

In [14]:
input= Input(shape=(512,14))
block1=block(input)
block2=block(block1)
block3=block(block2)

In [15]:
gru_out1 = GRU(32,activation='tanh',return_sequences=True)(block3)
gru_out2 = GRU(32,activation='tanh',return_sequences=True)(gru_out1)
gru_out = concatenate([gru_out1,gru_out2],axis=2)
gru_out3 = GRU(32,activation='tanh',return_sequences=True)(gru_out)
gru_out = concatenate([gru_out1,gru_out2,gru_out3])
gru_out4 = GRU(32,activation='tanh')(gru_out)

In [16]:
predictions = Dense(1,activation='sigmoid')(gru_out4)
model = Model(inputs=input, outputs=predictions)

In [17]:
model.compile(optimizer = 'adam', loss = 'binary_crossentropy', metrics=['accuracy'])

In [18]:
model.fit(train_features,train_labels,epochs=10,batch_size=128,validation_data=(val_features,val_labels))

Epoch 1/10
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 468ms/step - accuracy: 0.4921 - loss: 0.6981 - val_accuracy: 0.6500 - val_loss: 0.6250
Epoch 2/10
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 145ms/step - accuracy: 0.7262 - loss: 0.6364 - val_accuracy: 0.9500 - val_loss: 0.5230
Epoch 3/10
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 133ms/step - accuracy: 0.9411 - loss: 0.5154 - val_accuracy: 1.0000 - val_loss: 0.2950
Epoch 4/10
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 132ms/step - accuracy: 0.9890 - loss: 0.3112 - val_accuracy: 1.0000 - val_loss: 0.1257
Epoch 5/10
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 133ms/step - accuracy: 0.9929 - loss: 0.1320 - val_accuracy: 1.0000 - val_loss: 0.0472
Epoch 6/10
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 136ms/step - accuracy: 1.0000 - loss: 0.0434 - val_accuracy: 1.0000 - val_loss: 0.0180
Epoch 7/10
[1m3/3[0m [32m━━━━━━━━━━━━

<keras.src.callbacks.history.History at 0x2280b4d95e0>