In [1]:
#ChronoNet: A Deep Recurrent Neural Network for Abnormal EEG Identification

![](img/model.png)

import torch 
import torch.nn as nn
import numpy as np

input_state = torch.randn(3,22,15000)

In [3]:
class Block(nn.Module):
    def __init__(self , ConvIn , Gru_In, Gru_Hid ):
        super().__init__()
        self.conv1 = nn.Conv1d(in_channels = ConvIn[0], out_channels = 32 , kernel_size = 2 , stride = 2 ,padding = 0 )
        self.conv2 = nn.Conv1d(in_channels = ConvIn[0], out_channels = 32 , kernel_size = 4 , stride = 2 ,padding = 1 )
        self.conv3 = nn.Conv1d(in_channels = ConvIn[0], out_channels = 32 , kernel_size = 8 , stride = 2 ,padding = 3 )
        
        self.conv4 = nn.Conv1d(in_channels = ConvIn[1], out_channels = 32 , kernel_size = 2 , stride = 2 ,padding = 0 )
        self.conv5 = nn.Conv1d(in_channels = ConvIn[1], out_channels = 32 , kernel_size = 4 , stride = 2 ,padding = 1 )
        self.conv6 = nn.Conv1d(in_channels = ConvIn[1], out_channels = 32 , kernel_size = 8 , stride = 2 ,padding = 3 )
        
        self.conv7 = nn.Conv1d(in_channels = ConvIn[2], out_channels = 32 , kernel_size = 2 , stride = 2 ,padding = 0 )
        self.conv8 = nn.Conv1d(in_channels = ConvIn[2], out_channels = 32 , kernel_size = 4 , stride = 2 ,padding = 1 )
        self.conv9 = nn.Conv1d(in_channels = ConvIn[2], out_channels = 32 , kernel_size = 8 , stride = 2 ,padding = 3 )
        
        self.gru1 = nn.GRU(input_size = Gru_In[0], hidden_size = Gru_Hid[0], batch_first = True ) #if Batch_first = True (N, L , H) -> 3 , 1875 , 96
        self.gru2 = nn.GRU(input_size = Gru_In[1], hidden_size = Gru_Hid[0], batch_first = True )
        self.gru3 = nn.GRU(input_size = Gru_In[2], hidden_size = Gru_Hid[0], batch_first = True )
        self.gru4 = nn.GRU(input_size = Gru_In[3], hidden_size = Gru_Hid[0], batch_first = True )
        
        self.linear1 = nn.Linear(in_features  = 1875, out_features = 1)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(32,1)
        self.relu = nn.ReLU()
        
    
    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.conv2(x)
        x3 = self.conv3(x)
        x = torch.cat([x1,x2,x3] , dim =1) #dim value depend on the which dimenstion you want to concate
        
        x4 = self.conv4(x)
        x5 = self.conv5(x)
        x6 = self.conv6(x)
        x = torch.cat([x4,x5,x6] , dim =1)
        
        x7 = self.conv7(x)
        x8 = self.conv8(x)
        x9 = self.conv9(x)
        x = torch.cat([x7,x8,x9] , dim =1)
        # x.shape = torch.Size([3, 96, 1875])
        x = torch.permute(x , (0,2,1))
        out1 , hn1 = self.gru1(x) #out1.shape torch.Size([3, 1875, 32]) -> ([batch size, Seq Length ,channel / feature ])
        out2 , hn2 = self.gru2(out1)
        cat1 = torch.cat([out1,out2] , dim=2)
        out3 , hn3 = self.gru3(cat1)
        cat2 = torch.cat([out1,out2 , out3] , dim=2)
        #cat2.shape -> torch.Size([3, 1875, 96 ])
        cat2_Linear = torch.permute(cat2 ,(0,2,1))
        lin1 = self.linear1(cat2_Linear)
        # lin1.shape -> torch.Size([3, 96, 1])
        out4 , hn4 = self.gru4(lin1.permute(0,2,1))
        flatten = self.flatten(out4)
        fc1 = self.fc1(flatten)
        
        return fc1
        
        

In [4]:
blok = Block(ConvIn = [22,96,96] , Gru_In = [96 , 32 ,64, 96] , Gru_Hid = [32])

In [5]:
out1 = blok(input_state)

In [6]:
out1.shape

torch.Size([3, 1])

# Data Preparation

In [7]:
import mne
from glob import glob
import scipy.io

IDD = "./EGG/Data/CleanData/CleanData_IDD/Rest"
TDC = "./EGG/Data/CleanData/CleanData_TDC/Rest"

In [8]:
def ConvertMat2Mne(input_data):
    channel_name = ["AF3", "F7", "F3", "FC5", "T7", "P7", "O1", "O2", "P8", "T8", "FC6", "F4", "F8", "AF4"]
    channel_type = ['eeg'] * 14
    n_channels = 14
    sampling_freq = 128
    info = mne.create_info(channel_name, ch_types = channel_type , sfreq = sampling_freq)
    info.set_montage('standard_1020')
    data = mne.io.RawArray(data = input_data, info = info)
    data.set_eeg_reference()
    data.filter(l_freq = 1, h_freq = 32)
    epochs =  mne.make_fixed_length_epochs(data, duration = 4, overlap = 0)
    return epochs.get_data()

In [9]:
idd_subject = []

for idd in glob(IDD + "/*.mat"):
    data = scipy.io.loadmat(idd)
    data=data["clean_data"]
    idd_subject.append(ConvertMat2Mne(data))

Creating RawArray with float64 data, n_channels=14, n_times=15360
    Range : 0 ... 15359 =      0.000 ...   119.992 secs
Ready.
EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 32 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 32.00 Hz
- Upper transition bandwidth: 8.00 Hz (-6 dB cutoff frequency: 36.00 Hz)
- Filter length: 423 samples (3.305 sec)

Not setting metadata
Not setting metadata
30 matching events found
No baseline correction applied
0 projection items activated
Loading data for 30 events and 512 original time points ...
0 bad epoch

In [10]:
tdc_subject = []

for tdc in glob(TDC + "/*.mat"):
    data = scipy.io.loadmat(tdc)
    data=data["clean_data"]
    tdc_subject.append(ConvertMat2Mne(data))

Creating RawArray with float64 data, n_channels=14, n_times=15360
    Range : 0 ... 15359 =      0.000 ...   119.992 secs
Ready.
EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 32 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 32.00 Hz
- Upper transition bandwidth: 8.00 Hz (-6 dB cutoff frequency: 36.00 Hz)
- Filter length: 423 samples (3.305 sec)

Not setting metadata
Not setting metadata
30 matching events found
No baseline correction applied
0 projection items activated
Loading data for 30 events and 512 original time points ...
0 bad epoch

In [16]:
np.shape(tdc_subject)

(7, 30, 14, 512)