In [1]:
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import sampler

import torchvision.transforms as T
import numpy as np
import h5py
import logging
import os
import sys
import datetime
import csv
np.random.seed(0)
torch.manual_seed(0)

<torch._C.Generator at 0x1554dbe41230>

In [2]:
class EEGDataset(torch.utils.data.Dataset):
    def __init__(self, x, y, train):
        super(EEGDataset).__init__()
        assert x.shape[0] == y.size
        self.x = x
        #temp_y = np.zeros((y.size, 2))
        #for i in range(y.size):
        #    temp_y[i, y[i]] = 1
        #self.y = temp_y
        self.y = [y[i][0] for i in range(y.size)]
        self.train = train

    def __getitem__(self,key):
        return (self.x[key], self.y[key])

    def __len__(self):
        return len(self.y)

In [13]:
# Load EEG data
path = '/expanse/projects/nemar/child-mind-dtyoung/'
transform = T.Compose([
    T.ToTensor()
])
f = h5py.File(path + 'child_mind_x_train_v2.mat', 'r')
x_train = f['X_train']
x_train = np.reshape(x_train,(-1,1,24,256))
print('X_train shape: ' + str(x_train.shape))
f = h5py.File(path + 'child_mind_y_train_v2.mat', 'r')
y_train = f['Y_train']
print('Y_train shape: ' + str(y_train.shape))
train_data = EEGDataset(x_train, y_train, True)


f = h5py.File(path + 'child_mind_x_val_v2.mat', 'r')
x_val = f['X_val']
x_val = np.reshape(x_val,(-1,1,24,256))
print('X_val shape: ' + str(x_val.shape))
f = h5py.File(path + 'child_mind_y_val_v2.mat', 'r')
y_val = f['Y_val']
print('Y_val shape: ' + str(y_val.shape))
val_data = EEGDataset(x_val, y_val, True)

X_train shape: (71300, 1, 24, 256)
Y_train shape: (71300, 1)
X_val shape: (39868, 1, 24, 256)
Y_val shape: (39868, 1)


In [4]:
print(np.histogram(y_train))

(array([35626,     0,     0,     0,     0,     0,     0,     0,     0,
       35674]), array([0. , 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1. ]))


In [5]:
USE_GPU = True

dtype = torch.float32 # we will be using float throughout this tutorial

if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

# Constant to control how frequently we print train loss
print_every = 100

print('using device:', device)

using device: cuda


In [6]:
def check_accuracy(loader, model):
    if loader.dataset.train:
        print('Checking accuracy on validation set')
    else:
        print('Checking accuracy on test set')   
    num_correct = 0
    num_samples = 0
    model.eval()  # set model to evaluation mode
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device=device, dtype=dtype)  # move to device, e.g. GPU
            y = y.to(device=device, dtype=torch.long)
            scores = model(x)
            _, preds = scores.max(1)
            num_correct += (preds == y).sum()
            num_samples += preds.size(0)
        acc = float(num_correct) / num_samples
        print('Got %d / %d correct (%.2f)' % (num_correct, num_samples, 100 * acc))

In [7]:
def train(model, optimizer, epochs=1):
    """
    Train a model on CIFAR-10 using the PyTorch Module API.
    
    Inputs:
    - model: A PyTorch Module giving the model to train.
    - optimizer: An Optimizer object we will use to train the model
    - epochs: (Optional) A Python integer giving the number of epochs to train for
    
    Returns: Nothing, but prints model accuracies during training.
    """
    model = model.to(device=device)  # move the model parameters to CPU/GPU
    for e in range(epochs):
        for t, (x, y) in enumerate(loader_train):
            model.train()  # put model to training mode
            x = x.to(device=device, dtype=dtype)  # move to device, e.g. GPU
            y = y.to(device=device, dtype=torch.long)

            scores = model(x)
            loss = F.cross_entropy(scores, y)

            # Zero out all of the gradients for the variables which the optimizer
            # will update.
            optimizer.zero_grad()

            # This is the backwards pass: compute the gradient of the loss with
            # respect to each  parameter of the model.
            loss.backward()

            # Actually update the parameters of the model using the gradients
            # computed by the backwards pass.
            optimizer.step()

            if t % print_every == 0:
                print('Epoch %d, Iteration %d, loss = %.4f' % (e, t, loss.item()))
                check_accuracy(loader_val, model)
                print()

In [8]:
# toggle between learning rate and batch size values 
# lr = 0.002 # original
lr = 0.000363 
# batch_size = 70 # original
batch_size = 64
loader_train = DataLoader(train_data, batch_size=batch_size, shuffle=True)
loader_val = DataLoader(val_data, batch_size=batch_size)
model = nn.Sequential(
    nn.Conv2d(1,100,3),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Dropout(0.25),
    nn.Conv2d(100,100,3),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Dropout(0.25),
    nn.Conv2d(100,300,(2,3)),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Dropout(0.25),
    nn.Conv2d(300,300,(1,7)),
    nn.ReLU(),
    nn.MaxPool2d((1,2), stride=1),
    nn.Dropout(0.25),
    nn.Conv2d(300,100,(1,3)),
    nn.Conv2d(100,100,(1,3)),
    nn.Flatten(),
    nn.Linear(1900,6144),
    nn.Linear(6144,2),
)

optimizer = torch.optim.Adamax(model.parameters(), lr=lr)
train(model, optimizer, epochs=100)

Epoch 0, Iteration 0, loss = 0.9701
Checking accuracy on validation set
Got 19700 / 39868 correct (49.41)

Epoch 0, Iteration 100, loss = 0.9867
Checking accuracy on validation set
Got 20575 / 39868 correct (51.61)

Epoch 0, Iteration 200, loss = 0.6950
Checking accuracy on validation set
Got 20013 / 39868 correct (50.20)

Epoch 0, Iteration 300, loss = 0.6672
Checking accuracy on validation set
Got 20002 / 39868 correct (50.17)

Epoch 0, Iteration 400, loss = 0.7104
Checking accuracy on validation set
Got 20166 / 39868 correct (50.58)

Epoch 0, Iteration 500, loss = 0.6793
Checking accuracy on validation set
Got 20044 / 39868 correct (50.28)

Epoch 0, Iteration 600, loss = 0.7564
Checking accuracy on validation set
Got 20808 / 39868 correct (52.19)

Epoch 0, Iteration 700, loss = 0.6636
Checking accuracy on validation set
Got 20044 / 39868 correct (50.28)

Epoch 0, Iteration 800, loss = 0.7188
Checking accuracy on validation set
Got 20001 / 39868 correct (50.17)

Epoch 0, Iteration 90

Got 24449 / 39868 correct (61.32)

Epoch 6, Iteration 400, loss = 0.6730
Checking accuracy on validation set
Got 25750 / 39868 correct (64.59)

Epoch 6, Iteration 500, loss = 0.6436
Checking accuracy on validation set
Got 25623 / 39868 correct (64.27)

Epoch 6, Iteration 600, loss = 0.6772
Checking accuracy on validation set
Got 21890 / 39868 correct (54.91)

Epoch 6, Iteration 700, loss = 0.6126
Checking accuracy on validation set
Got 24924 / 39868 correct (62.52)

Epoch 6, Iteration 800, loss = 0.6270
Checking accuracy on validation set
Got 25617 / 39868 correct (64.25)

Epoch 6, Iteration 900, loss = 0.6487
Checking accuracy on validation set
Got 26148 / 39868 correct (65.59)

Epoch 6, Iteration 1000, loss = 0.6899
Checking accuracy on validation set
Got 25731 / 39868 correct (64.54)

Epoch 6, Iteration 1100, loss = 0.6501
Checking accuracy on validation set
Got 25443 / 39868 correct (63.82)

Epoch 7, Iteration 0, loss = 0.6142
Checking accuracy on validation set
Got 25218 / 39868 c

Got 27081 / 39868 correct (67.93)

Epoch 12, Iteration 700, loss = 0.5631
Checking accuracy on validation set
Got 27292 / 39868 correct (68.46)

Epoch 12, Iteration 800, loss = 0.5159
Checking accuracy on validation set
Got 26807 / 39868 correct (67.24)

Epoch 12, Iteration 900, loss = 0.5182
Checking accuracy on validation set
Got 27190 / 39868 correct (68.20)

Epoch 12, Iteration 1000, loss = 0.6020
Checking accuracy on validation set
Got 26878 / 39868 correct (67.42)

Epoch 12, Iteration 1100, loss = 0.5013
Checking accuracy on validation set
Got 26817 / 39868 correct (67.26)

Epoch 13, Iteration 0, loss = 0.5447
Checking accuracy on validation set
Got 27256 / 39868 correct (68.37)

Epoch 13, Iteration 100, loss = 0.4642
Checking accuracy on validation set
Got 27118 / 39868 correct (68.02)

Epoch 13, Iteration 200, loss = 0.5233
Checking accuracy on validation set
Got 27140 / 39868 correct (68.07)

Epoch 13, Iteration 300, loss = 0.4831
Checking accuracy on validation set
Got 27344 

Got 28218 / 39868 correct (70.78)

Epoch 18, Iteration 1000, loss = 0.5667
Checking accuracy on validation set
Got 27391 / 39868 correct (68.70)

Epoch 18, Iteration 1100, loss = 0.4496
Checking accuracy on validation set
Got 28091 / 39868 correct (70.46)

Epoch 19, Iteration 0, loss = 0.5331
Checking accuracy on validation set
Got 27939 / 39868 correct (70.08)

Epoch 19, Iteration 100, loss = 0.4309
Checking accuracy on validation set
Got 28276 / 39868 correct (70.92)

Epoch 19, Iteration 200, loss = 0.4374
Checking accuracy on validation set
Got 28339 / 39868 correct (71.08)

Epoch 19, Iteration 300, loss = 0.5078
Checking accuracy on validation set
Got 28304 / 39868 correct (70.99)

Epoch 19, Iteration 400, loss = 0.4864
Checking accuracy on validation set
Got 28465 / 39868 correct (71.40)

Epoch 19, Iteration 500, loss = 0.4756
Checking accuracy on validation set
Got 28059 / 39868 correct (70.38)

Epoch 19, Iteration 600, loss = 0.4765
Checking accuracy on validation set
Got 28556 

Got 28686 / 39868 correct (71.95)

Epoch 25, Iteration 100, loss = 0.4857
Checking accuracy on validation set
Got 28309 / 39868 correct (71.01)

Epoch 25, Iteration 200, loss = 0.4511
Checking accuracy on validation set
Got 28539 / 39868 correct (71.58)

Epoch 25, Iteration 300, loss = 0.4563
Checking accuracy on validation set
Got 28844 / 39868 correct (72.35)

Epoch 25, Iteration 400, loss = 0.4454
Checking accuracy on validation set
Got 28627 / 39868 correct (71.80)

Epoch 25, Iteration 500, loss = 0.4395
Checking accuracy on validation set
Got 28636 / 39868 correct (71.83)

Epoch 25, Iteration 600, loss = 0.4174
Checking accuracy on validation set
Got 28212 / 39868 correct (70.76)

Epoch 25, Iteration 700, loss = 0.3753
Checking accuracy on validation set
Got 28779 / 39868 correct (72.19)

Epoch 25, Iteration 800, loss = 0.4715
Checking accuracy on validation set
Got 28605 / 39868 correct (71.75)

Epoch 25, Iteration 900, loss = 0.4969
Checking accuracy on validation set
Got 28653 

Got 28728 / 39868 correct (72.06)

Epoch 31, Iteration 400, loss = 0.4333
Checking accuracy on validation set
Got 28939 / 39868 correct (72.59)

Epoch 31, Iteration 500, loss = 0.5809
Checking accuracy on validation set
Got 28294 / 39868 correct (70.97)

Epoch 31, Iteration 600, loss = 0.4515
Checking accuracy on validation set
Got 29088 / 39868 correct (72.96)

Epoch 31, Iteration 700, loss = 0.4310
Checking accuracy on validation set
Got 28952 / 39868 correct (72.62)

Epoch 31, Iteration 800, loss = 0.4285
Checking accuracy on validation set
Got 29159 / 39868 correct (73.14)

Epoch 31, Iteration 900, loss = 0.4838
Checking accuracy on validation set
Got 28976 / 39868 correct (72.68)

Epoch 31, Iteration 1000, loss = 0.4451
Checking accuracy on validation set
Got 29026 / 39868 correct (72.81)

Epoch 31, Iteration 1100, loss = 0.5390
Checking accuracy on validation set
Got 28712 / 39868 correct (72.02)

Epoch 32, Iteration 0, loss = 0.3583
Checking accuracy on validation set
Got 29020 

Got 29186 / 39868 correct (73.21)

Epoch 37, Iteration 700, loss = 0.4374
Checking accuracy on validation set
Got 29214 / 39868 correct (73.28)

Epoch 37, Iteration 800, loss = 0.3619
Checking accuracy on validation set
Got 29105 / 39868 correct (73.00)

Epoch 37, Iteration 900, loss = 0.3045
Checking accuracy on validation set
Got 29314 / 39868 correct (73.53)

Epoch 37, Iteration 1000, loss = 0.4245
Checking accuracy on validation set
Got 29288 / 39868 correct (73.46)

Epoch 37, Iteration 1100, loss = 0.3556
Checking accuracy on validation set
Got 29101 / 39868 correct (72.99)

Epoch 38, Iteration 0, loss = 0.4609
Checking accuracy on validation set
Got 29053 / 39868 correct (72.87)

Epoch 38, Iteration 100, loss = 0.4485
Checking accuracy on validation set
Got 28440 / 39868 correct (71.34)

Epoch 38, Iteration 200, loss = 0.4815
Checking accuracy on validation set
Got 29281 / 39868 correct (73.44)

Epoch 38, Iteration 300, loss = 0.4090
Checking accuracy on validation set
Got 29179 

Got 29263 / 39868 correct (73.40)

Epoch 43, Iteration 1000, loss = 0.4928
Checking accuracy on validation set
Got 29293 / 39868 correct (73.47)

Epoch 43, Iteration 1100, loss = 0.3805
Checking accuracy on validation set
Got 29522 / 39868 correct (74.05)

Epoch 44, Iteration 0, loss = 0.3156
Checking accuracy on validation set
Got 29473 / 39868 correct (73.93)

Epoch 44, Iteration 100, loss = 0.4126
Checking accuracy on validation set
Got 29560 / 39868 correct (74.14)

Epoch 44, Iteration 200, loss = 0.4367
Checking accuracy on validation set
Got 29042 / 39868 correct (72.85)

Epoch 44, Iteration 300, loss = 0.5440
Checking accuracy on validation set
Got 29519 / 39868 correct (74.04)

Epoch 44, Iteration 400, loss = 0.6781
Checking accuracy on validation set
Got 29204 / 39868 correct (73.25)

Epoch 44, Iteration 500, loss = 0.3635
Checking accuracy on validation set
Got 29513 / 39868 correct (74.03)

Epoch 44, Iteration 600, loss = 0.3540
Checking accuracy on validation set
Got 29436 

Got 29325 / 39868 correct (73.56)

Epoch 50, Iteration 100, loss = 0.5955
Checking accuracy on validation set
Got 29244 / 39868 correct (73.35)

Epoch 50, Iteration 200, loss = 0.5255
Checking accuracy on validation set
Got 29546 / 39868 correct (74.11)

Epoch 50, Iteration 300, loss = 0.4310
Checking accuracy on validation set
Got 29673 / 39868 correct (74.43)

Epoch 50, Iteration 400, loss = 0.3326
Checking accuracy on validation set
Got 29349 / 39868 correct (73.62)

Epoch 50, Iteration 500, loss = 0.3052
Checking accuracy on validation set
Got 29356 / 39868 correct (73.63)

Epoch 50, Iteration 600, loss = 0.3711
Checking accuracy on validation set
Got 29487 / 39868 correct (73.96)

Epoch 50, Iteration 700, loss = 0.4002
Checking accuracy on validation set
Got 29298 / 39868 correct (73.49)

Epoch 50, Iteration 800, loss = 0.4621
Checking accuracy on validation set
Got 29710 / 39868 correct (74.52)

Epoch 50, Iteration 900, loss = 0.3326
Checking accuracy on validation set
Got 29428 

Got 29381 / 39868 correct (73.70)

Epoch 56, Iteration 400, loss = 0.3594
Checking accuracy on validation set
Got 29311 / 39868 correct (73.52)

Epoch 56, Iteration 500, loss = 0.3750
Checking accuracy on validation set
Got 29626 / 39868 correct (74.31)

Epoch 56, Iteration 600, loss = 0.3864
Checking accuracy on validation set
Got 29568 / 39868 correct (74.16)

Epoch 56, Iteration 700, loss = 0.4563
Checking accuracy on validation set
Got 29679 / 39868 correct (74.44)

Epoch 56, Iteration 800, loss = 0.4841
Checking accuracy on validation set
Got 29566 / 39868 correct (74.16)

Epoch 56, Iteration 900, loss = 0.3632
Checking accuracy on validation set
Got 29599 / 39868 correct (74.24)

Epoch 56, Iteration 1000, loss = 0.3578
Checking accuracy on validation set
Got 29506 / 39868 correct (74.01)

Epoch 56, Iteration 1100, loss = 0.3833
Checking accuracy on validation set
Got 29727 / 39868 correct (74.56)

Epoch 57, Iteration 0, loss = 0.3723
Checking accuracy on validation set
Got 29778 

Got 29031 / 39868 correct (72.82)

Epoch 62, Iteration 700, loss = 0.3923
Checking accuracy on validation set
Got 29727 / 39868 correct (74.56)

Epoch 62, Iteration 800, loss = 0.3015
Checking accuracy on validation set
Got 29448 / 39868 correct (73.86)

Epoch 62, Iteration 900, loss = 0.5110
Checking accuracy on validation set
Got 29394 / 39868 correct (73.73)

Epoch 62, Iteration 1000, loss = 0.3115
Checking accuracy on validation set
Got 29711 / 39868 correct (74.52)

Epoch 62, Iteration 1100, loss = 0.3587
Checking accuracy on validation set
Got 29765 / 39868 correct (74.66)

Epoch 63, Iteration 0, loss = 0.2481
Checking accuracy on validation set
Got 29520 / 39868 correct (74.04)

Epoch 63, Iteration 100, loss = 0.3393
Checking accuracy on validation set
Got 29113 / 39868 correct (73.02)

Epoch 63, Iteration 200, loss = 0.4152
Checking accuracy on validation set
Got 29639 / 39868 correct (74.34)

Epoch 63, Iteration 300, loss = 0.3748
Checking accuracy on validation set
Got 29501 

Got 29634 / 39868 correct (74.33)

Epoch 68, Iteration 1000, loss = 0.4930
Checking accuracy on validation set
Got 29479 / 39868 correct (73.94)

Epoch 68, Iteration 1100, loss = 0.2916
Checking accuracy on validation set
Got 29477 / 39868 correct (73.94)

Epoch 69, Iteration 0, loss = 0.4406
Checking accuracy on validation set
Got 29707 / 39868 correct (74.51)

Epoch 69, Iteration 100, loss = 0.4954
Checking accuracy on validation set
Got 29618 / 39868 correct (74.29)

Epoch 69, Iteration 200, loss = 0.2255
Checking accuracy on validation set
Got 29807 / 39868 correct (74.76)

Epoch 69, Iteration 300, loss = 0.4748
Checking accuracy on validation set
Got 29576 / 39868 correct (74.18)

Epoch 69, Iteration 400, loss = 0.2952
Checking accuracy on validation set
Got 29611 / 39868 correct (74.27)

Epoch 69, Iteration 500, loss = 0.3971
Checking accuracy on validation set
Got 29578 / 39868 correct (74.19)

Epoch 69, Iteration 600, loss = 0.2392
Checking accuracy on validation set
Got 29610 

Got 29622 / 39868 correct (74.30)

Epoch 75, Iteration 100, loss = 0.3112
Checking accuracy on validation set
Got 29561 / 39868 correct (74.15)

Epoch 75, Iteration 200, loss = 0.2316
Checking accuracy on validation set
Got 29542 / 39868 correct (74.10)

Epoch 75, Iteration 300, loss = 0.3917
Checking accuracy on validation set
Got 29664 / 39868 correct (74.41)

Epoch 75, Iteration 400, loss = 0.2799
Checking accuracy on validation set
Got 29589 / 39868 correct (74.22)

Epoch 75, Iteration 500, loss = 0.3253
Checking accuracy on validation set
Got 29573 / 39868 correct (74.18)

Epoch 75, Iteration 600, loss = 0.3775
Checking accuracy on validation set
Got 29560 / 39868 correct (74.14)

Epoch 75, Iteration 700, loss = 0.4284
Checking accuracy on validation set
Got 29775 / 39868 correct (74.68)

Epoch 75, Iteration 800, loss = 0.2283
Checking accuracy on validation set
Got 29734 / 39868 correct (74.58)

Epoch 75, Iteration 900, loss = 0.3352
Checking accuracy on validation set
Got 29547 

Got 29508 / 39868 correct (74.01)

Epoch 81, Iteration 400, loss = 0.3433
Checking accuracy on validation set
Got 29659 / 39868 correct (74.39)

Epoch 81, Iteration 500, loss = 0.2604
Checking accuracy on validation set
Got 29432 / 39868 correct (73.82)

Epoch 81, Iteration 600, loss = 0.4288
Checking accuracy on validation set
Got 29541 / 39868 correct (74.10)

Epoch 81, Iteration 700, loss = 0.2623
Checking accuracy on validation set
Got 29586 / 39868 correct (74.21)

Epoch 81, Iteration 800, loss = 0.3956
Checking accuracy on validation set
Got 29572 / 39868 correct (74.17)

Epoch 81, Iteration 900, loss = 0.1948
Checking accuracy on validation set
Got 29681 / 39868 correct (74.45)

Epoch 81, Iteration 1000, loss = 0.2856
Checking accuracy on validation set
Got 29652 / 39868 correct (74.38)

Epoch 81, Iteration 1100, loss = 0.3276
Checking accuracy on validation set
Got 29564 / 39868 correct (74.15)

Epoch 82, Iteration 0, loss = 0.3026
Checking accuracy on validation set
Got 29525 

Got 29518 / 39868 correct (74.04)

Epoch 87, Iteration 700, loss = 0.3160
Checking accuracy on validation set
Got 29382 / 39868 correct (73.70)

Epoch 87, Iteration 800, loss = 0.2144
Checking accuracy on validation set
Got 29270 / 39868 correct (73.42)

Epoch 87, Iteration 900, loss = 0.3995
Checking accuracy on validation set
Got 29422 / 39868 correct (73.80)

Epoch 87, Iteration 1000, loss = 0.2933
Checking accuracy on validation set
Got 29405 / 39868 correct (73.76)

Epoch 87, Iteration 1100, loss = 0.3550
Checking accuracy on validation set
Got 29624 / 39868 correct (74.31)

Epoch 88, Iteration 0, loss = 0.4660
Checking accuracy on validation set
Got 29474 / 39868 correct (73.93)

Epoch 88, Iteration 100, loss = 0.3694
Checking accuracy on validation set
Got 29037 / 39868 correct (72.83)

Epoch 88, Iteration 200, loss = 0.3762
Checking accuracy on validation set
Got 29531 / 39868 correct (74.07)

Epoch 88, Iteration 300, loss = 0.3463
Checking accuracy on validation set
Got 29469 

Got 29444 / 39868 correct (73.85)

Epoch 93, Iteration 1000, loss = 0.3165
Checking accuracy on validation set
Got 29373 / 39868 correct (73.68)

Epoch 93, Iteration 1100, loss = 0.3698
Checking accuracy on validation set
Got 29295 / 39868 correct (73.48)

Epoch 94, Iteration 0, loss = 0.3120
Checking accuracy on validation set
Got 29360 / 39868 correct (73.64)

Epoch 94, Iteration 100, loss = 0.3424
Checking accuracy on validation set
Got 29433 / 39868 correct (73.83)

Epoch 94, Iteration 200, loss = 0.3178
Checking accuracy on validation set
Got 29224 / 39868 correct (73.30)

Epoch 94, Iteration 300, loss = 0.2846
Checking accuracy on validation set
Got 29389 / 39868 correct (73.72)

Epoch 94, Iteration 400, loss = 0.3078
Checking accuracy on validation set
Got 29487 / 39868 correct (73.96)

Epoch 94, Iteration 500, loss = 0.2987
Checking accuracy on validation set
Got 29509 / 39868 correct (74.02)

Epoch 94, Iteration 600, loss = 0.2678
Checking accuracy on validation set
Got 29467 

In [11]:
def test_model(model, test_data, subj_csv):
    # one-segment test
    print('Testing model accuracy using 1-segment metric')
    loader_test = DataLoader(test_data, batch_size=70)
    check_accuracy(loader_test, model)

    # 40-segment test
    print('Testing model accuracy using 40-segment per subject metric')
    with open(subj_csv, newline='') as csvfile:
        spamreader = csv.reader(csvfile, delimiter=' ', quotechar='|')
        subjIDs = [row[0] for row in spamreader]
    unique_subjs,indices = np.unique(subjIDs,return_index=True)

    iterable_test_data = list(iter(DataLoader(test_data, batch_size=1)))
    num_correct = []
    for subj,idx in zip(unique_subjs,indices):
    #     print(f'Subj {subj} - gender {iterable_test_data[idx][1]}')
        data = iterable_test_data[idx:idx+40]
        #print(np.sum([y for _,y in data]))
        assert 40 == np.sum([y for _,y in data]) or 0 == np.sum([y for _,y in data])
        preds = []
        correct = 0
        with torch.no_grad():
            for x,y in data:
                x = x.to(device=device, dtype=dtype)  # move to device, e.g. GPU
                correct = y
                scores = model(x)
                _, pred = scores.max(1)
                preds.append(pred)
        final_pred = (torch.mean(torch.FloatTensor(preds)) > 0.5).sum()
        num_correct.append((final_pred == correct).sum())
    #print(len(num_correct))
    acc = float(np.sum(num_correct)) / len(unique_subjs)
    print('Got %d / %d correct (%.2f)' % (np.sum(num_correct), len(unique_subjs), 100 * acc))

In [12]:
# Testing
# Balanced-class test set
print('Testing on balanced test set')
f = h5py.File(path + 'child_mind_x_test_v2.mat', 'r')
x_test = f['X_test']
x_test = np.reshape(x_test,(-1,1,24,256))
print('X_test shape: ' + str(x_test.shape))
f = h5py.File(path + 'child_mind_y_test_v2.mat', 'r')
y_test = f['Y_test']
print('Y_test shape: ' + str(y_test.shape))
test_data = EEGDataset(x_test, y_test, False)
test_model(model, test_data, 'test_subjIDs.csv')

print()

# All-male test set
print('Testing on all-male test set')
f = h5py.File(path + 'child_mind_x_test_v3.mat', 'r')
x_test = f['X_test']
x_test = np.reshape(x_test,(-1,1,24,256))
print('X_test shape: ' + str(x_test.shape))
f = h5py.File(path + 'child_mind_y_test_v3.mat', 'r')
y_test = f['Y_test']
print('Y_test shape: ' + str(y_test.shape))
test_data = EEGDataset(x_test, y_test, False)
test_model(model, test_data, 'test_subjIDs_more_test.csv')

Testing on balanced test set
X_test shape: (16006, 1, 24, 256)
Y_test shape: (16006, 1)
Testing model accuracy using 1-segment metric
Checking accuracy on test set
Got 11848 / 16006 correct (74.02)
Testing model accuracy using 40-segment per subject metric
Got 166 / 198 correct (83.84)

Testing on all-male test set
X_test shape: (52377, 1, 24, 256)
Y_test shape: (52377, 1)
Testing model accuracy using 1-segment metric
Checking accuracy on test set
Got 40186 / 52377 correct (76.72)
Testing model accuracy using 40-segment per subject metric
Got 567 / 649 correct (87.37)


In [None]:
# log_format = '%(asctime)s %(message)s'
# logging.basicConfig(stream=sys.stdout, level=logging.INFO,
#                     format=log_format, datefmt='%m/%d %I:%M:%S %p')
# fh = logging.FileHandler(os.path.join('logs', f'log-{datetime.datetime.today()}.txt'))
# fh.setFormatter(logging.Formatter(log_format))
# logging.getLogger().addHandler(fh)

# logging.info('Learning rate: %f, batch_size: %d' % (lr, batch_size))

# r = -4 * (1-0.25*np.random.rand())
# batch_size = 2**np.random.randint(2,7)
#     torch.save(model.state_dict(), f'logs/model_saved-lr{lr}-bs{batch_size}-{datetime.datetime.today()}')

# state_dict = torch.load()