In [1]:
import torch
import torchvision
from torch import nn
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import save_image
from torchvision.datasets import MNIST
import os
from torchvision.datasets import DatasetFolder
import numpy as np
from matplotlib import pyplot as plt
import pandas as pd
from torch.utils.tensorboard import SummaryWriter

## Training CNN on EEG spectrum data 

In [2]:
class EEGDataset(Dataset):
    def __init__(self, csv_file, transform=None):

        self.data = pd.read_csv(csv_file)
        self.diagnosis_map ={'MDD': 1.0, 'Health': 0.0, 'Health-2': 0.0}
        self.transform =transform

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):

        raw = self.data.iloc[idx]
        eeg =np.load(raw['file'])
        if eeg.shape[0]==15:
            eeg =np.delete(eeg, 7, 0)
        label =self.diagnosis_map[raw['diagnosis']]
      
        if self.transform:
            eeg = self.transform(eeg)

        return eeg, label

In [204]:
train_dataset =EEGDataset('spectrums_short/train2.csv')
val_dataset =EEGDataset('spectrums_short/val2.csv')

In [205]:
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, drop_last = False)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, drop_last = True)

In [225]:
class EEGCNN(nn.Module):
    def __init__(self):
        super(EEGCNN, self).__init__()
        self.cnn1 = nn.Sequential(
            nn.Conv2d(14, 14, (3,3), stride=1, padding=1),
            nn.BatchNorm2d(14),
            nn.ELU(True),
            nn.MaxPool2d(2, stride=2))
            
        self.cnn2 = nn.Sequential(
            nn.Conv2d(14, 8, (3,3), stride=1, padding=1),
            nn.BatchNorm2d(8),
            nn.ELU(True),
            nn.MaxPool2d(2, stride=2))

        
        self.pool = nn.AdaptiveAvgPool2d((3,9))
        self.fc = nn.Sequential(nn.Linear(8*3*9, 2))


    def forward(self, x):
        x = self.cnn1(x)
        x = self.cnn2(x)
        x=self.pool(x)
        classes =self.fc(x.view(-1, 8*3*9))
        
        return classes

In [226]:
class EEGCNNv2(nn.Module):
    def __init__(self):
        super(EEGCNN, self).__init__()
        self.cnn1 = nn.Sequential(
            nn.Conv2d(15, 15, (23,1), stride=1, padding=1, groups=15),
            nn.BatchNorm2d(15),
            nn.ELU(True),
            nn.MaxPool2d(2, stride=2))
            
        self.cnn2 = nn.Sequential(
            nn.Conv2d(15, 15, (1,1), stride=1, padding=1),
            nn.BatchNorm2d(15),
            nn.ELU(True),
            nn.MaxPool2d(2, stride=2))

        
        self.pool = nn.AdaptiveAvgPool2d((1,6))
        self.fc = nn.Sequential(nn.Linear(15*1*6, 2))


    def forward(self, x):
        x = self.cnn1(x)
        x = self.cnn2(x)
        x=self.pool(x)
        classes =self.fc(x.view(-1, *1*6))
        
        return classes

In [227]:
28/4

7.0

learning_rate =0.001
num_epochs =800

In [228]:
learning_rate =0.00001
num_epochs =300

In [229]:
## 0.001 -the best 

In [230]:
model = EEGCNN().cuda()

In [231]:
model

EEGCNN(
  (cnn1): Sequential(
    (0): Conv2d(14, 14, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(14, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ELU(alpha=True)
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (cnn2): Sequential(
    (0): Conv2d(14, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ELU(alpha=True)
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (pool): AdaptiveAvgPool2d(output_size=(3, 9))
  (fc): Sequential(
    (0): Linear(in_features=216, out_features=2, bias=True)
  )
)

In [232]:
model = EEGCNN().cuda()
criterion = nn.MSELoss()
cr_loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate,
                             weight_decay=1e-1)

In [233]:
batch, labels =next(iter(train_loader))

In [234]:
batch =batch.float().cuda()

In [235]:
batch.shape

torch.Size([4, 14, 24, 98])

In [236]:
def plot_batch_and_outputs(batch, outputs, show_type ='img'):
    Y_ticks =np.linspace(0, 128, num=224)
    X_ticks =np.linspace(0, 45, num=24)
    random_channel =np.random.randint(batch.shape[1])
    levels =45
    
    if show_type =='img':
        print('True: ')
        spectrum = plt.imshow(batch[0,random_channel,:,:].detach().cpu().numpy(),cmap='jet')
        plt.show()
        print('Predicted: ')
        spectrum = plt.imshow(outputs[0,random_channel,:,:].detach().cpu().numpy(), cmap='jet')
        plt.show()
    else:
        spectrum = plt.contourf(Y_ticks,X_ticks,batch[0,random_channel,:,:].detach().cpu().numpy(),levels, cmap='jet')
        plt.show()
        print('Predicted: ')
        spectrum = plt.contourf(Y_ticks,X_ticks,outputs[0,random_channel,:,:].detach().cpu().numpy(),levels, cmap='jet')
        plt.show()

In [237]:
pred =model(batch)

In [238]:
pred.shape

torch.Size([4, 2])

In [239]:
#plot_batch_and_outputs(batch, output,show_type ='c')

In [240]:
batch.shape

torch.Size([4, 14, 24, 98])

In [241]:
writer = SummaryWriter('runs/cnn-novossib2')

In [242]:
for epoch in range(num_epochs):
    train_loss = 0
    train_correct =0
    train_total =0
    for data in train_loader:
        img, label = data
        img = img.float().cuda()
        label=label.long().cuda()
        # ===================forward=====================
        preds = model(img)
        _, predicted = torch.max(preds.data, 1)
        train_correct += (predicted == label).sum().item()
        train_total+= labels.size(0)
        
        loss =cr_loss(preds,label)
        # ===================backward====================
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.data
    # ===================log========================
    if epoch % 1 == 0:
        with torch.no_grad():
            val_loss =0
            val_correct =0
            val_total =0
            for data in val_loader: 
                img, label = data
                img = img.float().cuda()
                label=label.long().cuda()
                preds  = model(img)
                _, predicted = torch.max(preds.data, 1)
                val_correct += (predicted == label).sum().item()
                val_total+= labels.size(0)
                loss =cr_loss(preds,label)
                val_loss += loss.data
            print(train_total)
            print(val_total)
            writer.add_scalar('training loss',
                            train_loss/train_total,
                            epoch)
            writer.add_scalar('validation loss',
                            val_loss/val_total,
                            epoch)
            writer.add_scalar('train accuracy',
                            train_correct/ train_total,
                            epoch)
            writer.add_scalar('val accuracy',
                            val_correct/ val_total,
                            epoch)
            print('epoch [{}/{}], train loss:{:.4f}'.format(epoch+1, num_epochs, train_loss/train_total))    
            print('epoch [{}/{}], val loss:{:.4f}'.format(epoch+1, num_epochs,val_loss/ val_total))
            print('epoch [{}/{}], train accuracy:{:.4f}'.format(epoch+1, num_epochs, train_correct/ train_total)) 
            print('epoch [{}/{}], val accuracy:{:.4f}'.format(epoch+1, num_epochs, val_correct/ val_total)) 
            #plot_batch_and_outputs(img, output,show_type ='c')


1524
484
epoch [1/300], train loss:0.1662
epoch [1/300], val loss:0.1760
epoch [1/300], train accuracy:0.5761
epoch [1/300], val accuracy:0.3616
1524
484
epoch [2/300], train loss:0.1623
epoch [2/300], val loss:0.1821
epoch [2/300], train accuracy:0.6004
epoch [2/300], val accuracy:0.3161
1524
484
epoch [3/300], train loss:0.1601
epoch [3/300], val loss:0.1846
epoch [3/300], train accuracy:0.6280
epoch [3/300], val accuracy:0.2955
1524
484
epoch [4/300], train loss:0.1578
epoch [4/300], val loss:0.1865
epoch [4/300], train accuracy:0.6299
epoch [4/300], val accuracy:0.2893
1524
484
epoch [5/300], train loss:0.1559
epoch [5/300], val loss:0.1896
epoch [5/300], train accuracy:0.6522
epoch [5/300], val accuracy:0.2975
1524
484
epoch [6/300], train loss:0.1527
epoch [6/300], val loss:0.1886
epoch [6/300], train accuracy:0.6759
epoch [6/300], val accuracy:0.3182
1524
484
epoch [7/300], train loss:0.1505
epoch [7/300], val loss:0.1913
epoch [7/300], train accuracy:0.6962
epoch [7/300], val a

1524
484
epoch [57/300], train loss:0.0890
epoch [57/300], val loss:0.2901
epoch [57/300], train accuracy:0.8451
epoch [57/300], val accuracy:0.4463
1524
484
epoch [58/300], train loss:0.0873
epoch [58/300], val loss:0.3064
epoch [58/300], train accuracy:0.8537
epoch [58/300], val accuracy:0.4070
1524
484
epoch [59/300], train loss:0.0873
epoch [59/300], val loss:0.3110
epoch [59/300], train accuracy:0.8596
epoch [59/300], val accuracy:0.4008
1524
484
epoch [60/300], train loss:0.0867
epoch [60/300], val loss:0.3100
epoch [60/300], train accuracy:0.8576
epoch [60/300], val accuracy:0.4050
1524
484
epoch [61/300], train loss:0.0858
epoch [61/300], val loss:0.3178
epoch [61/300], train accuracy:0.8570
epoch [61/300], val accuracy:0.3905
1524
484
epoch [62/300], train loss:0.0840
epoch [62/300], val loss:0.2953
epoch [62/300], train accuracy:0.8694
epoch [62/300], val accuracy:0.4421
1524
484
epoch [63/300], train loss:0.0867
epoch [63/300], val loss:0.3045
epoch [63/300], train accuracy:

1524
484
epoch [112/300], train loss:0.0695
epoch [112/300], val loss:0.2900
epoch [112/300], train accuracy:0.8957
epoch [112/300], val accuracy:0.4277
1524
484
epoch [113/300], train loss:0.0702
epoch [113/300], val loss:0.2622
epoch [113/300], train accuracy:0.8904
epoch [113/300], val accuracy:0.4649
1524
484
epoch [114/300], train loss:0.0689
epoch [114/300], val loss:0.3003
epoch [114/300], train accuracy:0.8911
epoch [114/300], val accuracy:0.4132
1524
484
epoch [115/300], train loss:0.0700
epoch [115/300], val loss:0.2730
epoch [115/300], train accuracy:0.8944
epoch [115/300], val accuracy:0.4380
1524
484
epoch [116/300], train loss:0.0716
epoch [116/300], val loss:0.2646
epoch [116/300], train accuracy:0.8891
epoch [116/300], val accuracy:0.4525
1524
484
epoch [117/300], train loss:0.0696
epoch [117/300], val loss:0.2873
epoch [117/300], train accuracy:0.9029
epoch [117/300], val accuracy:0.4215
1524
484
epoch [118/300], train loss:0.0692
epoch [118/300], val loss:0.2783
epoch

1524
484
epoch [166/300], train loss:0.0623
epoch [166/300], val loss:0.2762
epoch [166/300], train accuracy:0.9049
epoch [166/300], val accuracy:0.4421
1524
484
epoch [167/300], train loss:0.0637
epoch [167/300], val loss:0.2522
epoch [167/300], train accuracy:0.9140
epoch [167/300], val accuracy:0.4814
1524
484
epoch [168/300], train loss:0.0625
epoch [168/300], val loss:0.2777
epoch [168/300], train accuracy:0.9042
epoch [168/300], val accuracy:0.4380
1524
484
epoch [169/300], train loss:0.0631
epoch [169/300], val loss:0.2350
epoch [169/300], train accuracy:0.9140
epoch [169/300], val accuracy:0.5083
1524
484
epoch [170/300], train loss:0.0617
epoch [170/300], val loss:0.2556
epoch [170/300], train accuracy:0.9154
epoch [170/300], val accuracy:0.4731
1524
484
epoch [171/300], train loss:0.0637
epoch [171/300], val loss:0.2698
epoch [171/300], train accuracy:0.9068
epoch [171/300], val accuracy:0.4545
1524
484
epoch [172/300], train loss:0.0633
epoch [172/300], val loss:0.2563
epoch

1524
484
epoch [220/300], train loss:0.0584
epoch [220/300], val loss:0.2427
epoch [220/300], train accuracy:0.9134
epoch [220/300], val accuracy:0.4959
1524
484
epoch [221/300], train loss:0.0603
epoch [221/300], val loss:0.2700
epoch [221/300], train accuracy:0.9140
epoch [221/300], val accuracy:0.4566
1524
484
epoch [222/300], train loss:0.0607
epoch [222/300], val loss:0.2638
epoch [222/300], train accuracy:0.9180
epoch [222/300], val accuracy:0.4628
1524
484
epoch [223/300], train loss:0.0605
epoch [223/300], val loss:0.2221
epoch [223/300], train accuracy:0.9121
epoch [223/300], val accuracy:0.5289
1524
484
epoch [224/300], train loss:0.0601
epoch [224/300], val loss:0.2203
epoch [224/300], train accuracy:0.9147
epoch [224/300], val accuracy:0.5351
1524
484
epoch [225/300], train loss:0.0592
epoch [225/300], val loss:0.2704
epoch [225/300], train accuracy:0.9173
epoch [225/300], val accuracy:0.4545
1524
484
epoch [226/300], train loss:0.0593
epoch [226/300], val loss:0.2398
epoch

1524
484
epoch [274/300], train loss:0.0562
epoch [274/300], val loss:0.2414
epoch [274/300], train accuracy:0.9213
epoch [274/300], val accuracy:0.4855
1524
484
epoch [275/300], train loss:0.0564
epoch [275/300], val loss:0.2252
epoch [275/300], train accuracy:0.9259
epoch [275/300], val accuracy:0.5124
1524
484
epoch [276/300], train loss:0.0560
epoch [276/300], val loss:0.2537
epoch [276/300], train accuracy:0.9193
epoch [276/300], val accuracy:0.4731
1524
484
epoch [277/300], train loss:0.0560
epoch [277/300], val loss:0.2417
epoch [277/300], train accuracy:0.9180
epoch [277/300], val accuracy:0.4876
1524
484
epoch [278/300], train loss:0.0560
epoch [278/300], val loss:0.2347
epoch [278/300], train accuracy:0.9259
epoch [278/300], val accuracy:0.4979
1524
484
epoch [279/300], train loss:0.0589
epoch [279/300], val loss:0.2629
epoch [279/300], train accuracy:0.9173
epoch [279/300], val accuracy:0.4669
1524
484
epoch [280/300], train loss:0.0547
epoch [280/300], val loss:0.2490
epoch