In [1]:
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
  def __init__(self):
    super().__init__()

    self.conv1 = nn.Conv1d(1,16,5)
    self.conv2 = nn.Conv1d(16,32,5)
    self.conv3 = nn.Conv1d(32,64,5)

    self.fc1 = nn.Linear(159744,2)

    self.norm1 = nn.BatchNorm1d(1)

  def forward(self,x):

    x = self.norm1(x)
    x = F.relu(self.conv1(x))
    x = F.max_pool1d(x,2)
    x = F.relu(self.conv2(x))
    x = F.max_pool1d(x,2)
    x = F.relu(self.conv3(x))
    x = F.max_pool1d(x,2)
    x = torch.flatten(x,start_dim=1)
    x = F.log_softmax(self.fc1(x),dim=1)
    
    return x

net = Net()
net = net.double()

In [22]:
import os
import random
import numpy as np
import torch.optim as optim
import torch.tensor as tensor
import torch

optimizer = optim.Adam(net.parameters(),lr=0.001)

filenames = [line[2] for line in os.walk('/content/drive/MyDrive/fft')][0]

batch_size = 30
batch_number = len(filenames)/batch_size
epochs = 1

for epoch in range(epochs):

  random.shuffle(filenames)
  batches = [filenames[i:i+batch_size] for i in range(0,len(filenames),batch_size)]
  count = 0

  for batch in batches:

    count += 1
    labels = tensor([int('preictal' in i) for i in batch])
    n1 = torch.count_nonzero(labels)
    n0 = batch_size-n1
    x = tensor([np.load('/content/drive/MyDrive/fft/' + file)[:20000] for file in batch])
    x = x.unsqueeze(1)
    
    net.zero_grad()
    output = net(x)
    # print(output)
    loss = F.nll_loss(output,labels,weight=tensor([batch_size/(2*n0),batch_size/(2*n1)]).double())
    loss.backward()
    optimizer.step()
    print('Epoch number %d/%d' % (epoch+1,epochs))
    print('Batch number %d/%d' % (count,batch_number))
    print('Loss is %f' % (float(loss)))
    print('\n')


Epoch number 1/1
Batch number 1/82
Loss is 0.535302
Epoch number 1/1
Batch number 2/82
Loss is 0.608122
Epoch number 1/1
Batch number 3/82
Loss is 0.573932
Epoch number 1/1
Batch number 4/82
Loss is 0.847997
Epoch number 1/1
Batch number 5/82
Loss is 0.516464
Epoch number 1/1
Batch number 6/82
Loss is 0.602619
Epoch number 1/1
Batch number 7/82
Loss is 0.651250
Epoch number 1/1
Batch number 8/82
Loss is 0.824269
Epoch number 1/1
Batch number 9/82
Loss is 0.615098
Epoch number 1/1
Batch number 10/82
Loss is 1.038185


KeyboardInterrupt: ignored

In [None]:
with torch.no_grad():

    correct = 0
    preictal = 0
    interictal = 0
    totals = [0,0]
    count = 0

    for batch in batches:
      count += 1
      print(count)
      labels = [int('preictal' in i) for i in batch]
      x = tensor([np.load('/content/drive/MyDrive/fft/' + file)[:20000] for file in batch])
      x = x.unsqueeze(1)
      output = net(x)

      for idx,i in enumerate(output):
        totals[labels[idx]] += 1
        if torch.argmax(i) == labels[idx]:
          correct += 1
          if labels[idx] == 0:
            interictal += 1
          else:
            preictal += 1

In [17]:
accuracy = correct/len(filenames)
print(accuracy)

0.765040650406504


In [20]:
acc_interictal = interictal/totals[0]
acc_preictal = preictal/totals[1]
print(acc_interictal,acc_preictal)

0.9357224118316269 0.33760683760683763
