## Read Labels

In [1]:
import pandas as pd

labels = pd.read_csv("nmt_scalp_eeg_dataset/Labels.csv")
labels_dict = {}

subjects = labels.recordname
labels = labels.label
labels = [0 if l.lower() == "normal" else 1 for l in labels]

for sub, label in zip(subjects, labels):
    labels_dict[sub] = label

In [2]:
!ls nmt_scalp_eeg_dataset/normal

[34meval[m[m  [34mtrain[m[m


In [3]:
labels_dict['0000001.edf']

0

In [4]:
# number of subjects
len(subjects)

2417

## Read edf data

In [5]:
!echo $CONDA_DEFAULT_ENV

xai


In [6]:
%%capture
import numpy as np
from tqdm import tqdm
import glob

import scipy.io

import os
import mne
import matplotlib.pyplot as plt

In [7]:
def load_data(file_path):
    datax=mne.io.read_raw_edf(file_path,preload=True)
    datax.set_eeg_reference()
    datax.filter(l_freq=1,h_freq=45)
    epochs=mne.make_fixed_length_epochs(datax,duration=10,overlap=0)
    epochs=(epochs.get_data() * 1e6).astype(np.float32)
    return epochs #trials,channel,length

In [50]:
test_abnormal_files = glob.glob('./nmt_scalp_eeg_dataset/abnormal/eval/*.edf')
test_normal_files = glob.glob('./nmt_scalp_eeg_dataset/normal/eval/*.edf')
train_normal_files = glob.glob('./nmt_scalp_eeg_dataset/normal/train/*.edf')
train_abnormal_files = glob.glob('./nmt_scalp_eeg_dataset/abnormal/train/*.edf')

In [51]:
import random
val_normal_files = random.sample(train_normal_files, 10)
val_abnormal_files = random.sample(train_abnormal_files, 10)
train_normal_files = random.sample(train_normal_files, 50)
train_abnormal_files = random.sample(train_abnormal_files, 50)

In [52]:
len(train_abnormal_files), len(train_normal_files)

(50, 50)

In [53]:
%%capture

train_normal_features=[load_data(f) for f in train_normal_files]
train_abnormal_features=[load_data(f) for f in train_abnormal_files]
train_normal_labels=[0 for f in train_normal_files]
train_abnormal_labels=[1 for f in train_abnormal_files]

val_normal_features=[load_data(f) for f in val_normal_files]
val_abnormal_features=[load_data(f) for f in val_abnormal_files]
val_normal_labels=[0 for f in val_normal_files]
val_abnormal_labels=[1 for f in val_abnormal_files]

In [54]:
train_normal_labels=[len(i)*[0] for i in train_normal_features]
train_abnormal_labels=[len(i)*[1] for i in train_abnormal_features]

val_normal_labels=[len(i)*[0] for i in val_normal_features]
val_abnormal_labels=[len(i)*[1] for i in val_abnormal_features]

In [55]:
train_features = train_normal_features + train_abnormal_features
train_labels = train_normal_labels + train_abnormal_labels

val_features = val_normal_features + val_abnormal_features
val_labels = val_normal_labels + val_abnormal_labels

In [56]:
del train_normal_features
del train_abnormal_features

In [57]:
train_features = np.vstack(train_features)
train_labels = np.hstack(train_labels)

In [58]:
val_features = np.vstack(val_features)
val_labels = np.hstack(val_labels)

In [59]:
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.model_selection import GroupKFold
gkf=GroupKFold()
from sklearn.base import TransformerMixin,BaseEstimator
from sklearn.preprocessing import StandardScaler
#https://stackoverflow.com/questions/50125844/how-to-standard-scale-a-3d-matrix
class StandardScaler3D(BaseEstimator,TransformerMixin):
    #batch, sequence, channels
    def __init__(self):
        self.scaler = StandardScaler()

    def fit(self,X,y=None):
        self.scaler.fit(X.reshape(-1, X.shape[2]))
        return self

    def transform(self,X):
        return self.scaler.transform(X.reshape( -1,X.shape[2])).reshape(X.shape)

In [60]:
scaler = StandardScaler3D()
train_features = scaler.fit_transform(train_features)
val_features = scaler.transform(val_features)

In [61]:
from tqdm import tqdm

In [62]:
import torch.nn as nn
import torch
from torch.autograd import Variable

class Block(nn.Module):
  def __init__(self,inplace):
    super().__init__()
    self.conv1=nn.Conv1d(in_channels=inplace,out_channels=32,kernel_size=2,stride=2,padding=0)
    self.conv2=nn.Conv1d(in_channels=inplace,out_channels=32,kernel_size=4,stride=2,padding=1)
    self.conv3=nn.Conv1d(in_channels=inplace,out_channels=32,kernel_size=8,stride=2,padding=3)
    self.relu=nn.ReLU()

  def forward(self,x):
    x1=self.relu(self.conv1(x))
    x2=self.relu(self.conv2(x))
    x3=self.relu(self.conv3(x))
    x=torch.cat([x1,x3,x3],dim=1)
    return x

class ChronoNet(nn.Module):
  def __init__(self,channel):
    super().__init__()
    self.block1=Block(channel)
    self.block2=Block(96)
    self.block3=Block(96)
    self.gru1=nn.GRU(input_size=96,hidden_size=32,batch_first=True)
    self.gru2=nn.GRU(input_size=32,hidden_size=32,batch_first=True)
    self.gru3=nn.GRU(input_size=64,hidden_size=32,batch_first=True)
    self.gru4=nn.GRU(input_size=96,hidden_size=32,batch_first=True)
    self.gru_linear=nn.Linear(250,1)
    self.flatten=nn.Flatten()
    self.fc1=nn.Linear(32,1)
    self.relu=nn.ReLU()

  def forward(self,x):
    x = x.squeeze()
    x=self.block1(x)
    x=self.block2(x)
    x=self.block3(x)
    x=x.permute(0,2,1)
    gru_out1,_=self.gru1(x)
    gru_out2,_=self.gru2(gru_out1)
    gru_out=torch.cat([gru_out1,gru_out2],dim=2)
    gru_out3,_=self.gru3(gru_out)
    gru_out=torch.cat([gru_out1,gru_out2,gru_out3],dim=2)
    linear_out=self.relu(self.gru_linear(gru_out.permute(0,2,1)))
    gru_out4,_=self.gru4(linear_out.permute(0,2,1))
    x=self.flatten(gru_out4)
    x=self.fc1(x)
    out = torch.sigmoid(x)
    return out

In [63]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

train_features = torch.Tensor(train_features).to(device)
train_labels = torch.Tensor(train_labels).to(device)

batch_size = 16
train_data = torch.utils.data.TensorDataset(train_features, train_labels)
train_iter = torch.utils.data.DataLoader(train_data, batch_size, shuffle=True)

val_features = torch.Tensor(val_features).to(device)
val_labels = torch.Tensor(val_labels).to(device)

batch_size = 16
val_data = torch.utils.data.TensorDataset(val_features, val_labels)
val_iter = torch.utils.data.DataLoader(val_data, batch_size, shuffle=True)

In [64]:
train_features[0].shape

torch.Size([21, 2000])

In [65]:
def evaluate_model(model, loss_func, data_iter):
    model.eval()
    loss_sum, n = 0, 0
    with torch.no_grad():
        for x, y in data_iter:
            y_pred = model(x)
            y_pred = y_pred.squeeze()
            loss = loss_func(y_pred,y)
            loss_sum += loss.item()
            n += 1
        return loss_sum / n

In [66]:
n_chans = 21
model=ChronoNet(n_chans)
loss_func = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
epochs = 10

for epoch in range(1, epochs + 1):
    print("epoch", epoch) 
    loss_sum, n = 0.0, 0
    model.train()
    for t, (x, y) in enumerate(tqdm(train_iter)):
        y_pred = model(x)
        y_pred = y_pred.squeeze()
        loss = loss_func(y_pred, y)
        loss.backward()
        loss_sum += loss.item()
        optimizer.step()
        optimizer.zero_grad()
    
    val_loss = evaluate_model(model, loss_func, val_iter)
    print("Train loss:", loss_sum / (t+1))
    print("Val loss:", val_loss)

epoch 1


100%|█████████████████████████████████████████| 431/431 [01:43<00:00,  4.15it/s]


Train loss: 0.6859325429126447
Val loss: 0.6846961152553558
epoch 2


100%|█████████████████████████████████████████| 431/431 [01:43<00:00,  4.16it/s]


Train loss: 0.6660683229047019
Val loss: 0.6759910082817078
epoch 3


100%|█████████████████████████████████████████| 431/431 [01:45<00:00,  4.10it/s]


Train loss: 0.6682151165334919
Val loss: 0.630389387011528
epoch 4


100%|█████████████████████████████████████████| 431/431 [01:44<00:00,  4.13it/s]


Train loss: 0.6234215364517026
Val loss: 0.6270960444211959
epoch 5


100%|█████████████████████████████████████████| 431/431 [01:44<00:00,  4.14it/s]


Train loss: 0.6106176890795856
Val loss: 0.6167118436098099
epoch 6


100%|█████████████████████████████████████████| 431/431 [01:44<00:00,  4.13it/s]


Train loss: 0.6031776645897159
Val loss: 0.6060078346729278
epoch 7


100%|█████████████████████████████████████████| 431/431 [01:43<00:00,  4.15it/s]


Train loss: 0.6017707470813096
Val loss: 0.6176279726624488
epoch 8


100%|█████████████████████████████████████████| 431/431 [01:43<00:00,  4.15it/s]


Train loss: 0.5959621176918809
Val loss: 0.6084504339098931
epoch 9


 29%|███████████▉                             | 125/431 [00:30<01:14,  4.10it/s]


KeyboardInterrupt: 

In [49]:
loss_sum / (t+1)

0.634364660177581

In [68]:
with torch.no_grad():
    y_hat = model(val_features)

yhat = [0 if i<0.5 else 1 for i in y_hat]
ytrue = val_labels.numpy()
ypreds = yhat

from sklearn.metrics import accuracy_score
print("Accuracy: ", accuracy_score(ytrue, ypreds))

Accuracy:  0.70825456836799


In [69]:
from sklearn.metrics import confusion_matrix
confusion_matrix(ytrue, ypreds)

array([[616, 112],
       [351, 508]])