In [3]:
import torch
import torchvision
from torch import nn
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import save_image
from torchvision.datasets import MNIST
import os
from torchvision.datasets import DatasetFolder
import numpy as np
from matplotlib import pyplot as plt
import pandas as pd
from torch.utils.tensorboard import SummaryWriter

## Training CNN on EEG spectrum data 

In [4]:
class EEGDataset(Dataset):
    def __init__(self, csv_file, transform=None):

        self.data = pd.read_csv(csv_file)
        self.diagnosis_map ={'MDD': 1.0, 'Health': 0.0, 'Health-2': 0.0}
        self.transform =transform

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):

        raw = self.data.iloc[idx]
        eeg =np.load(raw['file'])
        if eeg.shape[0]==15:
            eeg =np.delete(eeg, 7, 0)
        label =self.diagnosis_map[raw['diagnosis']]
      
        if self.transform:
            eeg = self.transform(eeg)

        return eeg, label

In [5]:
train_dataset =EEGDataset('cross_spectrum_short/train.csv')
val_dataset =EEGDataset('cross_spectrum_short/val.csv')

In [40]:
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, drop_last = False)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, drop_last = True)

In [41]:
class EEGCNN(nn.Module):
    def __init__(self):
        super(EEGCNN, self).__init__()
        self.cnn1 = nn.Sequential(
            nn.Conv2d(4, 16, (3,3), stride=1, padding=1),
            nn.BatchNorm2d(16),
            nn.ELU(True),
            nn.MaxPool2d(2, stride=2))
            
        self.cnn2 = nn.Sequential(
            nn.Conv2d(16, 8, (3,3), stride=1, padding=1),
            nn.BatchNorm2d(8),
            nn.ELU(True),
            nn.MaxPool2d(2, stride=2))

        
        self.pool = nn.AdaptiveAvgPool2d((3,3))
        self.fc = nn.Sequential(nn.Linear(8*3*3, 2))


    def forward(self, x):
        x = self.cnn1(x)
        x = self.cnn2(x)
        x=self.pool(x)
        classes =self.fc(x.view(-1, 8*3*3))
        
        return classes

In [42]:
class EEGCNNv2(nn.Module):
    def __init__(self):
        super(EEGCNN, self).__init__()
        self.cnn1 = nn.Sequential(
            nn.Conv2d(15, 15, (23,1), stride=1, padding=1, groups=15),
            nn.BatchNorm2d(15),
            nn.ELU(True),
            nn.MaxPool2d(2, stride=2))
            
        self.cnn2 = nn.Sequential(
            nn.Conv2d(15, 15, (1,1), stride=1, padding=1),
            nn.BatchNorm2d(15),
            nn.ELU(True),
            nn.MaxPool2d(2, stride=2))

        
        self.pool = nn.AdaptiveAvgPool2d((1,6))
        self.fc = nn.Sequential(nn.Linear(15*1*6, 2))


    def forward(self, x):
        x = self.cnn1(x)
        x = self.cnn2(x)
        x=self.pool(x)
        classes =self.fc(x.view(-1, *1*6))
        
        return classes

In [43]:
28/4

7.0

learning_rate =0.001
num_epochs =800

In [44]:
learning_rate =0.0001
num_epochs =300

In [45]:
## 0.001 -the best 

In [46]:
model = EEGCNN().cuda()

In [47]:
model

EEGCNN(
  (cnn1): Sequential(
    (0): Conv2d(4, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ELU(alpha=True)
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (cnn2): Sequential(
    (0): Conv2d(16, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ELU(alpha=True)
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (pool): AdaptiveAvgPool2d(output_size=(3, 3))
  (fc): Sequential(
    (0): Linear(in_features=72, out_features=2, bias=True)
  )
)

In [48]:
model = EEGCNN().cuda()
criterion = nn.MSELoss()
cr_loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate,
                             weight_decay=1e-1)

In [49]:
batch, labels =next(iter(train_loader))

In [50]:
batch =batch.float().cuda()

In [51]:
batch.shape

torch.Size([4, 4, 15, 15])

In [52]:
def plot_batch_and_outputs(batch, outputs, show_type ='img'):
    Y_ticks =np.linspace(0, 15, num=15)
    X_ticks =np.linspace(0, 15, num=15)
    random_channel =np.random.randint(batch.shape[1])
    levels =45
    
    if show_type =='img':
        print('True: ')
        spectrum = plt.imshow(batch[0,random_channel,:,:].detach().cpu().numpy(),cmap='jet')
        plt.show()
        print('Predicted: ')
        spectrum = plt.imshow(outputs[0,random_channel,:,:].detach().cpu().numpy(), cmap='jet')
        plt.show()
    else:
        spectrum = plt.contourf(Y_ticks,X_ticks,batch[0,random_channel,:,:].detach().cpu().numpy(),levels, cmap='jet')
        plt.show()
        print('Predicted: ')
        spectrum = plt.contourf(Y_ticks,X_ticks,outputs[0,random_channel,:,:].detach().cpu().numpy(),levels, cmap='jet')
        plt.show()

In [53]:
pred =model(batch)

In [54]:
pred.shape

torch.Size([4, 2])

In [55]:
#plot_batch_and_outputs(batch, output,show_type ='c')

In [56]:
batch.shape

torch.Size([4, 4, 15, 15])

In [57]:
writer = SummaryWriter('runs/cross-corr-cnn')

In [58]:
for epoch in range(num_epochs):
    train_loss = 0
    train_correct =0
    train_total =0
    for data in train_loader:
        img, label = data
        img = img.float().cuda()
        label=label.long().cuda()
        # ===================forward=====================
        preds = model(img)
        _, predicted = torch.max(preds.data, 1)
        train_correct += (predicted == label).sum().item()
        train_total+= labels.size(0)
        
        loss =cr_loss(preds,label)
        # ===================backward====================
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.data
    # ===================log========================
    if epoch % 1 == 0:
        with torch.no_grad():
            val_loss =0
            val_correct =0
            val_total =0
            for data in val_loader: 
                img, label = data
                img = img.float().cuda()
                label=label.long().cuda()
                preds  = model(img)
                _, predicted = torch.max(preds.data, 1)
                val_correct += (predicted == label).sum().item()
                val_total+= labels.size(0)
                loss =cr_loss(preds,label)
                val_loss += loss.data
            print(train_total)
            print(val_total)
            writer.add_scalar('training loss',
                            train_loss/train_total,
                            epoch)
            writer.add_scalar('validation loss',
                            val_loss/val_total,
                            epoch)
            writer.add_scalar('train accuracy',
                            train_correct/ train_total,
                            epoch)
            writer.add_scalar('val accuracy',
                            val_correct/ val_total,
                            epoch)
            print('epoch [{}/{}], train loss:{:.4f}'.format(epoch+1, num_epochs, train_loss/train_total))    
            print('epoch [{}/{}], val loss:{:.4f}'.format(epoch+1, num_epochs,val_loss/ val_total))
            print('epoch [{}/{}], train accuracy:{:.4f}'.format(epoch+1, num_epochs, train_correct/ train_total)) 
            print('epoch [{}/{}], val accuracy:{:.4f}'.format(epoch+1, num_epochs, val_correct/ val_total)) 
            #plot_batch_and_outputs(img, output,show_type ='c')


196
48
epoch [1/300], train loss:0.1722
epoch [1/300], val loss:0.1692
epoch [1/300], train accuracy:0.6020
epoch [1/300], val accuracy:0.6250
196
48
epoch [2/300], train loss:0.1583
epoch [2/300], val loss:0.1710
epoch [2/300], train accuracy:0.6684
epoch [2/300], val accuracy:0.5833
196
48
epoch [3/300], train loss:0.1494
epoch [3/300], val loss:0.1670
epoch [3/300], train accuracy:0.6633
epoch [3/300], val accuracy:0.5625
196
48
epoch [4/300], train loss:0.1479
epoch [4/300], val loss:0.1810
epoch [4/300], train accuracy:0.6735
epoch [4/300], val accuracy:0.5833
196
48
epoch [5/300], train loss:0.1438
epoch [5/300], val loss:0.1626
epoch [5/300], train accuracy:0.7143
epoch [5/300], val accuracy:0.5833
196
48
epoch [6/300], train loss:0.1350
epoch [6/300], val loss:0.1697
epoch [6/300], train accuracy:0.7653
epoch [6/300], val accuracy:0.5208
196
48
epoch [7/300], train loss:0.1305
epoch [7/300], val loss:0.1689
epoch [7/300], train accuracy:0.7704
epoch [7/300], val accuracy:0.5208

196
48
epoch [57/300], train loss:0.0356
epoch [57/300], val loss:0.2325
epoch [57/300], train accuracy:0.9694
epoch [57/300], val accuracy:0.4792
196
48
epoch [58/300], train loss:0.0332
epoch [58/300], val loss:0.2287
epoch [58/300], train accuracy:0.9796
epoch [58/300], val accuracy:0.5208
196
48
epoch [59/300], train loss:0.0294
epoch [59/300], val loss:0.2311
epoch [59/300], train accuracy:0.9796
epoch [59/300], val accuracy:0.5625
196
48
epoch [60/300], train loss:0.0300
epoch [60/300], val loss:0.2455
epoch [60/300], train accuracy:0.9847
epoch [60/300], val accuracy:0.5208
196
48
epoch [61/300], train loss:0.0332
epoch [61/300], val loss:0.2484
epoch [61/300], train accuracy:0.9796
epoch [61/300], val accuracy:0.4792
196
48
epoch [62/300], train loss:0.0410
epoch [62/300], val loss:0.2411
epoch [62/300], train accuracy:0.9643
epoch [62/300], val accuracy:0.4583
196
48
epoch [63/300], train loss:0.0342
epoch [63/300], val loss:0.2432
epoch [63/300], train accuracy:0.9847
epoch [

196
48
epoch [113/300], train loss:0.0313
epoch [113/300], val loss:0.2558
epoch [113/300], train accuracy:0.9694
epoch [113/300], val accuracy:0.5208
196
48
epoch [114/300], train loss:0.0236
epoch [114/300], val loss:0.2711
epoch [114/300], train accuracy:0.9898
epoch [114/300], val accuracy:0.3750
196
48
epoch [115/300], train loss:0.0255
epoch [115/300], val loss:0.2690
epoch [115/300], train accuracy:0.9898
epoch [115/300], val accuracy:0.4583
196
48
epoch [116/300], train loss:0.0251
epoch [116/300], val loss:0.2646
epoch [116/300], train accuracy:0.9847
epoch [116/300], val accuracy:0.5000
196
48
epoch [117/300], train loss:0.0241
epoch [117/300], val loss:0.2760
epoch [117/300], train accuracy:0.9847
epoch [117/300], val accuracy:0.4792
196
48
epoch [118/300], train loss:0.0281
epoch [118/300], val loss:0.2565
epoch [118/300], train accuracy:0.9847
epoch [118/300], val accuracy:0.4167
196
48
epoch [119/300], train loss:0.0238
epoch [119/300], val loss:0.2761
epoch [119/300], tr

196
48
epoch [168/300], train loss:0.0236
epoch [168/300], val loss:0.2561
epoch [168/300], train accuracy:0.9898
epoch [168/300], val accuracy:0.5417
196
48
epoch [169/300], train loss:0.0228
epoch [169/300], val loss:0.2328
epoch [169/300], train accuracy:0.9847
epoch [169/300], val accuracy:0.5417
196
48
epoch [170/300], train loss:0.0248
epoch [170/300], val loss:0.2217
epoch [170/300], train accuracy:0.9745
epoch [170/300], val accuracy:0.5417
196
48
epoch [171/300], train loss:0.0247
epoch [171/300], val loss:0.2361
epoch [171/300], train accuracy:0.9847
epoch [171/300], val accuracy:0.5833
196
48
epoch [172/300], train loss:0.0222
epoch [172/300], val loss:0.2389
epoch [172/300], train accuracy:0.9847
epoch [172/300], val accuracy:0.4583
196
48
epoch [173/300], train loss:0.0232
epoch [173/300], val loss:0.2311
epoch [173/300], train accuracy:0.9898
epoch [173/300], val accuracy:0.5417
196
48
epoch [174/300], train loss:0.0227
epoch [174/300], val loss:0.2254
epoch [174/300], tr

196
48
epoch [223/300], train loss:0.0237
epoch [223/300], val loss:0.2378
epoch [223/300], train accuracy:0.9898
epoch [223/300], val accuracy:0.5417
196
48
epoch [224/300], train loss:0.0214
epoch [224/300], val loss:0.2145
epoch [224/300], train accuracy:0.9847
epoch [224/300], val accuracy:0.5000
196
48
epoch [225/300], train loss:0.0206
epoch [225/300], val loss:0.2392
epoch [225/300], train accuracy:0.9898
epoch [225/300], val accuracy:0.5417
196
48
epoch [226/300], train loss:0.0188
epoch [226/300], val loss:0.2361
epoch [226/300], train accuracy:0.9847
epoch [226/300], val accuracy:0.4792
196
48
epoch [227/300], train loss:0.0192
epoch [227/300], val loss:0.2193
epoch [227/300], train accuracy:0.9898
epoch [227/300], val accuracy:0.5625
196
48
epoch [228/300], train loss:0.0196
epoch [228/300], val loss:0.2295
epoch [228/300], train accuracy:0.9898
epoch [228/300], val accuracy:0.5833
196
48
epoch [229/300], train loss:0.0245
epoch [229/300], val loss:0.2285
epoch [229/300], tr

196
48
epoch [278/300], train loss:0.0221
epoch [278/300], val loss:0.2338
epoch [278/300], train accuracy:0.9847
epoch [278/300], val accuracy:0.6250
196
48
epoch [279/300], train loss:0.0260
epoch [279/300], val loss:0.2532
epoch [279/300], train accuracy:0.9745
epoch [279/300], val accuracy:0.5417
196
48
epoch [280/300], train loss:0.0289
epoch [280/300], val loss:0.2160
epoch [280/300], train accuracy:0.9847
epoch [280/300], val accuracy:0.5208
196
48
epoch [281/300], train loss:0.0275
epoch [281/300], val loss:0.2218
epoch [281/300], train accuracy:0.9898
epoch [281/300], val accuracy:0.4583
196
48
epoch [282/300], train loss:0.0239
epoch [282/300], val loss:0.2373
epoch [282/300], train accuracy:0.9898
epoch [282/300], val accuracy:0.4583
196
48
epoch [283/300], train loss:0.0178
epoch [283/300], val loss:0.2538
epoch [283/300], train accuracy:0.9898
epoch [283/300], val accuracy:0.5417
196
48
epoch [284/300], train loss:0.0231
epoch [284/300], val loss:0.2390
epoch [284/300], tr