In [1]:
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import sampler

import torchvision.transforms as T
import numpy as np
import h5py
import logging
import os
import sys
import datetime
import csv
np.random.seed(0)
torch.manual_seed(0)

<torch._C.Generator at 0x1554dbe42230>

In [2]:
class EEGDataset(torch.utils.data.Dataset):
    def __init__(self, x, y, train):
        super(EEGDataset).__init__()
        assert x.shape[0] == y.size
        self.x = x
        #temp_y = np.zeros((y.size, 2))
        #for i in range(y.size):
        #    temp_y[i, y[i]] = 1
        #self.y = temp_y
        self.y = [y[i][0] for i in range(y.size)]
        self.train = train

    def __getitem__(self,key):
        return (self.x[key], self.y[key])

    def __len__(self):
        return len(self.y)

In [3]:
# Load EEG data
transform = T.Compose([
    T.ToTensor()
])
f = h5py.File('child_mind_x_train_v2.mat', 'r')
x_train = f['X_train']
x_train = np.reshape(x_train,(-1,1,24,256))
print('X_train shape: ' + str(x_train.shape))
f = h5py.File('child_mind_y_train_v2.mat', 'r')
y_train = f['Y_train']
print('Y_train shape: ' + str(y_train.shape))
train_data = EEGDataset(x_train, y_train, True)


f = h5py.File('child_mind_x_val_v2.mat', 'r')
x_val = f['X_val']
x_val = np.reshape(x_val,(-1,1,24,256))
print('X_val shape: ' + str(x_val.shape))
f = h5py.File('child_mind_y_val_v2.mat', 'r')
y_val = f['Y_val']
print('Y_val shape: ' + str(y_val.shape))
val_data = EEGDataset(x_val, y_val, True)

X_train shape: (71300, 1, 24, 256)
Y_train shape: (71300, 1)
X_val shape: (39868, 1, 24, 256)
Y_val shape: (39868, 1)


In [4]:
print(np.histogram(y_train))

(array([35626,     0,     0,     0,     0,     0,     0,     0,     0,
       35674]), array([0. , 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1. ]))


In [5]:
USE_GPU = True

dtype = torch.float32 # we will be using float throughout this tutorial

if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

# Constant to control how frequently we print train loss
print_every = 100

print('using device:', device)

using device: cuda


In [6]:
def check_accuracy(loader, model):
    if loader.dataset.train:
        print('Checking accuracy on validation set')
    else:
        print('Checking accuracy on test set')   
    num_correct = 0
    num_samples = 0
    model.eval()  # set model to evaluation mode
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device=device, dtype=dtype)  # move to device, e.g. GPU
            y = y.to(device=device, dtype=torch.long)
            scores = model(x)
            _, preds = scores.max(1)
            num_correct += (preds == y).sum()
            num_samples += preds.size(0)
        acc = float(num_correct) / num_samples
        print('Got %d / %d correct (%.2f)' % (num_correct, num_samples, 100 * acc))

In [7]:
def train(model, optimizer, epochs=1):
    """
    Train a model on CIFAR-10 using the PyTorch Module API.
    
    Inputs:
    - model: A PyTorch Module giving the model to train.
    - optimizer: An Optimizer object we will use to train the model
    - epochs: (Optional) A Python integer giving the number of epochs to train for
    
    Returns: Nothing, but prints model accuracies during training.
    """
    model = model.to(device=device)  # move the model parameters to CPU/GPU
    for e in range(epochs):
        for t, (x, y) in enumerate(loader_train):
            model.train()  # put model to training mode
            x = x.to(device=device, dtype=dtype)  # move to device, e.g. GPU
            y = y.to(device=device, dtype=torch.long)

            scores = model(x)
            loss = F.cross_entropy(scores, y)

            # Zero out all of the gradients for the variables which the optimizer
            # will update.
            optimizer.zero_grad()

            # This is the backwards pass: compute the gradient of the loss with
            # respect to each  parameter of the model.
            loss.backward()

            # Actually update the parameters of the model using the gradients
            # computed by the backwards pass.
            optimizer.step()

            if t % print_every == 0:
                print('Epoch %d, Iteration %d, loss = %.4f' % (e, t, loss.item()))
                check_accuracy(loader_val, model)
                print()

In [15]:
# toggle between learning rate and batch size values 
lr = 0.002 # original
# lr = 0.000363 
batch_size = 70 # original
# batch_size = 64
loader_train = DataLoader(train_data, batch_size=batch_size, shuffle=True)
loader_val = DataLoader(val_data, batch_size=batch_size)
model = nn.Sequential(
    nn.Conv2d(1,100,3),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Dropout(0.25),
    nn.Conv2d(100,100,3),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Dropout(0.25),
    nn.Conv2d(100,300,(2,3)),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Dropout(0.25),
    nn.Conv2d(300,300,(1,7)),
    nn.ReLU(),
    nn.MaxPool2d((1,2), stride=1),
    nn.Dropout(0.25),
    nn.Conv2d(300,100,(1,3)),
    nn.Conv2d(100,100,(1,3)),
    nn.Flatten(),
    nn.Linear(1900,6144),
    nn.Linear(6144,2),
)

optimizer = torch.optim.Adamax(model.parameters(), lr=lr)
train(model, optimizer, epochs=100)

Epoch 0, Iteration 0, loss = 0.7131
Checking accuracy on validation set
Got 19909 / 39868 correct (49.94)

Epoch 0, Iteration 100, loss = 0.6828
Checking accuracy on validation set
Got 19910 / 39868 correct (49.94)

Epoch 0, Iteration 200, loss = 0.6898
Checking accuracy on validation set
Got 19971 / 39868 correct (50.09)

Epoch 0, Iteration 300, loss = 0.6888
Checking accuracy on validation set
Got 19914 / 39868 correct (49.95)

Epoch 0, Iteration 400, loss = 0.6986
Checking accuracy on validation set
Got 19908 / 39868 correct (49.93)

Epoch 0, Iteration 500, loss = 0.6921
Checking accuracy on validation set
Got 19910 / 39868 correct (49.94)

Epoch 0, Iteration 600, loss = 0.6930
Checking accuracy on validation set
Got 19975 / 39868 correct (50.10)

Epoch 0, Iteration 700, loss = 0.6922
Checking accuracy on validation set
Got 19975 / 39868 correct (50.10)

Epoch 0, Iteration 800, loss = 0.6924
Checking accuracy on validation set
Got 19975 / 39868 correct (50.10)

Epoch 0, Iteration 90

Got 25361 / 39868 correct (63.61)

Epoch 6, Iteration 1000, loss = 0.4894
Checking accuracy on validation set
Got 26915 / 39868 correct (67.51)

Epoch 7, Iteration 0, loss = 0.5542
Checking accuracy on validation set
Got 26566 / 39868 correct (66.63)

Epoch 7, Iteration 100, loss = 0.5479
Checking accuracy on validation set
Got 26773 / 39868 correct (67.15)

Epoch 7, Iteration 200, loss = 0.5493
Checking accuracy on validation set
Got 26106 / 39868 correct (65.48)

Epoch 7, Iteration 300, loss = 0.6009
Checking accuracy on validation set
Got 25681 / 39868 correct (64.42)

Epoch 7, Iteration 400, loss = 0.6276
Checking accuracy on validation set
Got 25375 / 39868 correct (63.65)

Epoch 7, Iteration 500, loss = 0.5225
Checking accuracy on validation set
Got 26004 / 39868 correct (65.23)

Epoch 7, Iteration 600, loss = 0.6329
Checking accuracy on validation set
Got 26836 / 39868 correct (67.31)

Epoch 7, Iteration 700, loss = 0.5831
Checking accuracy on validation set
Got 26842 / 39868 co

Got 26972 / 39868 correct (67.65)

Epoch 13, Iteration 800, loss = 0.4638
Checking accuracy on validation set
Got 27067 / 39868 correct (67.89)

Epoch 13, Iteration 900, loss = 0.6204
Checking accuracy on validation set
Got 25749 / 39868 correct (64.59)

Epoch 13, Iteration 1000, loss = 0.4972
Checking accuracy on validation set
Got 27081 / 39868 correct (67.93)

Epoch 14, Iteration 0, loss = 0.5628
Checking accuracy on validation set
Got 27310 / 39868 correct (68.50)

Epoch 14, Iteration 100, loss = 0.5892
Checking accuracy on validation set
Got 26442 / 39868 correct (66.32)

Epoch 14, Iteration 200, loss = 0.5033
Checking accuracy on validation set
Got 27176 / 39868 correct (68.16)

Epoch 14, Iteration 300, loss = 0.5413
Checking accuracy on validation set
Got 27482 / 39868 correct (68.93)

Epoch 14, Iteration 400, loss = 0.5467
Checking accuracy on validation set
Got 26427 / 39868 correct (66.29)

Epoch 14, Iteration 500, loss = 0.5269
Checking accuracy on validation set
Got 27268 /

Got 26790 / 39868 correct (67.20)

Epoch 20, Iteration 600, loss = 0.5332
Checking accuracy on validation set
Got 25975 / 39868 correct (65.15)

Epoch 20, Iteration 700, loss = 0.4975
Checking accuracy on validation set
Got 27381 / 39868 correct (68.68)

Epoch 20, Iteration 800, loss = 0.5702
Checking accuracy on validation set
Got 26907 / 39868 correct (67.49)

Epoch 20, Iteration 900, loss = 0.4829
Checking accuracy on validation set
Got 26109 / 39868 correct (65.49)

Epoch 20, Iteration 1000, loss = 0.5155
Checking accuracy on validation set
Got 25324 / 39868 correct (63.52)

Epoch 21, Iteration 0, loss = 0.4201
Checking accuracy on validation set
Got 27887 / 39868 correct (69.95)

Epoch 21, Iteration 100, loss = 0.5139
Checking accuracy on validation set
Got 25641 / 39868 correct (64.31)

Epoch 21, Iteration 200, loss = 0.5885
Checking accuracy on validation set
Got 26350 / 39868 correct (66.09)

Epoch 21, Iteration 300, loss = 0.5616
Checking accuracy on validation set
Got 26539 /

Got 26788 / 39868 correct (67.19)

Epoch 27, Iteration 400, loss = 0.5158
Checking accuracy on validation set
Got 24717 / 39868 correct (62.00)

Epoch 27, Iteration 500, loss = 0.4866
Checking accuracy on validation set
Got 26157 / 39868 correct (65.61)

Epoch 27, Iteration 600, loss = 0.3875
Checking accuracy on validation set
Got 26989 / 39868 correct (67.70)

Epoch 27, Iteration 700, loss = 0.5919
Checking accuracy on validation set
Got 25549 / 39868 correct (64.08)

Epoch 27, Iteration 800, loss = 0.5760
Checking accuracy on validation set
Got 27382 / 39868 correct (68.68)

Epoch 27, Iteration 900, loss = 0.5617
Checking accuracy on validation set
Got 27683 / 39868 correct (69.44)

Epoch 27, Iteration 1000, loss = 0.5947
Checking accuracy on validation set
Got 26788 / 39868 correct (67.19)

Epoch 28, Iteration 0, loss = 0.4955
Checking accuracy on validation set
Got 27120 / 39868 correct (68.02)

Epoch 28, Iteration 100, loss = 0.4810
Checking accuracy on validation set
Got 26553 /

Got 26712 / 39868 correct (67.00)

Epoch 34, Iteration 200, loss = 0.6181
Checking accuracy on validation set
Got 26385 / 39868 correct (66.18)

Epoch 34, Iteration 300, loss = 0.4842
Checking accuracy on validation set
Got 27193 / 39868 correct (68.21)

Epoch 34, Iteration 400, loss = 0.5343
Checking accuracy on validation set
Got 26277 / 39868 correct (65.91)

Epoch 34, Iteration 500, loss = 0.4172
Checking accuracy on validation set
Got 25555 / 39868 correct (64.10)

Epoch 34, Iteration 600, loss = 0.5352
Checking accuracy on validation set
Got 26955 / 39868 correct (67.61)

Epoch 34, Iteration 700, loss = 0.4086
Checking accuracy on validation set
Got 26168 / 39868 correct (65.64)

Epoch 34, Iteration 800, loss = 0.4295
Checking accuracy on validation set
Got 26203 / 39868 correct (65.72)

Epoch 34, Iteration 900, loss = 0.5827
Checking accuracy on validation set
Got 26326 / 39868 correct (66.03)

Epoch 34, Iteration 1000, loss = 0.4281
Checking accuracy on validation set
Got 25542

Got 26546 / 39868 correct (66.58)

Epoch 41, Iteration 0, loss = 0.4962
Checking accuracy on validation set
Got 26560 / 39868 correct (66.62)

Epoch 41, Iteration 100, loss = 0.4961
Checking accuracy on validation set
Got 26802 / 39868 correct (67.23)

Epoch 41, Iteration 200, loss = 0.4133
Checking accuracy on validation set
Got 27146 / 39868 correct (68.09)

Epoch 41, Iteration 300, loss = 0.4200
Checking accuracy on validation set
Got 26451 / 39868 correct (66.35)

Epoch 41, Iteration 400, loss = 0.5503
Checking accuracy on validation set
Got 27605 / 39868 correct (69.24)

Epoch 41, Iteration 500, loss = 0.5041
Checking accuracy on validation set
Got 27106 / 39868 correct (67.99)

Epoch 41, Iteration 600, loss = 0.5597
Checking accuracy on validation set
Got 26539 / 39868 correct (66.57)

Epoch 41, Iteration 700, loss = 0.4103
Checking accuracy on validation set
Got 26335 / 39868 correct (66.06)

Epoch 41, Iteration 800, loss = 0.4931
Checking accuracy on validation set
Got 26344 / 

Got 27730 / 39868 correct (69.55)

Epoch 47, Iteration 900, loss = 0.4598
Checking accuracy on validation set
Got 27479 / 39868 correct (68.92)

Epoch 47, Iteration 1000, loss = 0.6880
Checking accuracy on validation set
Got 27891 / 39868 correct (69.96)

Epoch 48, Iteration 0, loss = 0.4973
Checking accuracy on validation set
Got 28526 / 39868 correct (71.55)

Epoch 48, Iteration 100, loss = 0.5656
Checking accuracy on validation set
Got 27648 / 39868 correct (69.35)

Epoch 48, Iteration 200, loss = 0.4313
Checking accuracy on validation set
Got 27965 / 39868 correct (70.14)

Epoch 48, Iteration 300, loss = 0.4800
Checking accuracy on validation set
Got 27745 / 39868 correct (69.59)

Epoch 48, Iteration 400, loss = 0.4936
Checking accuracy on validation set
Got 27291 / 39868 correct (68.45)

Epoch 48, Iteration 500, loss = 0.4783
Checking accuracy on validation set
Got 27194 / 39868 correct (68.21)

Epoch 48, Iteration 600, loss = 0.3977
Checking accuracy on validation set
Got 27920 /

Got 27707 / 39868 correct (69.50)

Epoch 54, Iteration 700, loss = 0.5387
Checking accuracy on validation set
Got 28047 / 39868 correct (70.35)

Epoch 54, Iteration 800, loss = 0.3796
Checking accuracy on validation set
Got 27111 / 39868 correct (68.00)

Epoch 54, Iteration 900, loss = 0.4308
Checking accuracy on validation set
Got 26875 / 39868 correct (67.41)

Epoch 54, Iteration 1000, loss = 0.5813
Checking accuracy on validation set
Got 27568 / 39868 correct (69.15)

Epoch 55, Iteration 0, loss = 0.3103
Checking accuracy on validation set
Got 27956 / 39868 correct (70.12)

Epoch 55, Iteration 100, loss = 0.4062
Checking accuracy on validation set
Got 27948 / 39868 correct (70.10)

Epoch 55, Iteration 200, loss = 0.5448
Checking accuracy on validation set
Got 27928 / 39868 correct (70.05)

Epoch 55, Iteration 300, loss = 0.4326
Checking accuracy on validation set
Got 27819 / 39868 correct (69.78)

Epoch 55, Iteration 400, loss = 0.4651
Checking accuracy on validation set
Got 28449 /

Got 27871 / 39868 correct (69.91)

Epoch 61, Iteration 500, loss = 0.5492
Checking accuracy on validation set
Got 27505 / 39868 correct (68.99)

Epoch 61, Iteration 600, loss = 0.5320
Checking accuracy on validation set
Got 28135 / 39868 correct (70.57)

Epoch 61, Iteration 700, loss = 0.4680
Checking accuracy on validation set
Got 27286 / 39868 correct (68.44)

Epoch 61, Iteration 800, loss = 0.4308
Checking accuracy on validation set
Got 28144 / 39868 correct (70.59)

Epoch 61, Iteration 900, loss = 0.4893
Checking accuracy on validation set
Got 28023 / 39868 correct (70.29)

Epoch 61, Iteration 1000, loss = 0.4734
Checking accuracy on validation set
Got 28006 / 39868 correct (70.25)

Epoch 62, Iteration 0, loss = 0.4243
Checking accuracy on validation set
Got 28148 / 39868 correct (70.60)

Epoch 62, Iteration 100, loss = 0.3508
Checking accuracy on validation set
Got 27803 / 39868 correct (69.74)

Epoch 62, Iteration 200, loss = 0.4913
Checking accuracy on validation set
Got 28379 /

Got 28710 / 39868 correct (72.01)

Epoch 68, Iteration 300, loss = 0.4055
Checking accuracy on validation set
Got 27601 / 39868 correct (69.23)

Epoch 68, Iteration 400, loss = 0.4309
Checking accuracy on validation set
Got 28680 / 39868 correct (71.94)

Epoch 68, Iteration 500, loss = 0.4507
Checking accuracy on validation set
Got 26049 / 39868 correct (65.34)

Epoch 68, Iteration 600, loss = 0.6069
Checking accuracy on validation set
Got 27966 / 39868 correct (70.15)

Epoch 68, Iteration 700, loss = 0.3844
Checking accuracy on validation set
Got 27250 / 39868 correct (68.35)

Epoch 68, Iteration 800, loss = 0.4974
Checking accuracy on validation set
Got 27709 / 39868 correct (69.50)

Epoch 68, Iteration 900, loss = 0.5849
Checking accuracy on validation set
Got 28328 / 39868 correct (71.05)

Epoch 68, Iteration 1000, loss = 0.4254
Checking accuracy on validation set
Got 27757 / 39868 correct (69.62)

Epoch 69, Iteration 0, loss = 0.3709
Checking accuracy on validation set
Got 28700 /

Got 28515 / 39868 correct (71.52)

Epoch 75, Iteration 100, loss = 0.5210
Checking accuracy on validation set
Got 28774 / 39868 correct (72.17)

Epoch 75, Iteration 200, loss = 0.3704
Checking accuracy on validation set
Got 27687 / 39868 correct (69.45)

Epoch 75, Iteration 300, loss = 0.3096
Checking accuracy on validation set
Got 28851 / 39868 correct (72.37)

Epoch 75, Iteration 400, loss = 0.3969
Checking accuracy on validation set
Got 27884 / 39868 correct (69.94)

Epoch 75, Iteration 500, loss = 0.4496
Checking accuracy on validation set
Got 28265 / 39868 correct (70.90)

Epoch 75, Iteration 600, loss = 0.6005
Checking accuracy on validation set
Got 28752 / 39868 correct (72.12)

Epoch 75, Iteration 700, loss = 0.3574
Checking accuracy on validation set
Got 28521 / 39868 correct (71.54)

Epoch 75, Iteration 800, loss = 0.5025
Checking accuracy on validation set
Got 28663 / 39868 correct (71.89)

Epoch 75, Iteration 900, loss = 0.4651
Checking accuracy on validation set
Got 28186 

Got 28499 / 39868 correct (71.48)

Epoch 81, Iteration 1000, loss = 0.5165
Checking accuracy on validation set
Got 28207 / 39868 correct (70.75)

Epoch 82, Iteration 0, loss = 0.5179
Checking accuracy on validation set
Got 28769 / 39868 correct (72.16)

Epoch 82, Iteration 100, loss = 0.4237
Checking accuracy on validation set
Got 28458 / 39868 correct (71.38)

Epoch 82, Iteration 200, loss = 0.3672
Checking accuracy on validation set
Got 28655 / 39868 correct (71.87)

Epoch 82, Iteration 300, loss = 0.4180
Checking accuracy on validation set
Got 28923 / 39868 correct (72.55)

Epoch 82, Iteration 400, loss = 0.3728
Checking accuracy on validation set
Got 28194 / 39868 correct (70.72)

Epoch 82, Iteration 500, loss = 0.3830
Checking accuracy on validation set
Got 28902 / 39868 correct (72.49)

Epoch 82, Iteration 600, loss = 0.3582
Checking accuracy on validation set
Got 29001 / 39868 correct (72.74)

Epoch 82, Iteration 700, loss = 0.4280
Checking accuracy on validation set
Got 29122 /

Got 28341 / 39868 correct (71.09)

Epoch 88, Iteration 800, loss = 0.5780
Checking accuracy on validation set
Got 28611 / 39868 correct (71.76)

Epoch 88, Iteration 900, loss = 0.3377
Checking accuracy on validation set
Got 28639 / 39868 correct (71.83)

Epoch 88, Iteration 1000, loss = 0.3841
Checking accuracy on validation set
Got 28924 / 39868 correct (72.55)

Epoch 89, Iteration 0, loss = 0.4376
Checking accuracy on validation set
Got 28955 / 39868 correct (72.63)

Epoch 89, Iteration 100, loss = 0.4412
Checking accuracy on validation set
Got 28873 / 39868 correct (72.42)

Epoch 89, Iteration 200, loss = 0.5041
Checking accuracy on validation set
Got 28986 / 39868 correct (72.70)

Epoch 89, Iteration 300, loss = 0.4052
Checking accuracy on validation set
Got 28832 / 39868 correct (72.32)

Epoch 89, Iteration 400, loss = 0.4915
Checking accuracy on validation set
Got 29050 / 39868 correct (72.87)

Epoch 89, Iteration 500, loss = 0.5630
Checking accuracy on validation set
Got 28843 /

Got 28809 / 39868 correct (72.26)

Epoch 95, Iteration 600, loss = 0.3722
Checking accuracy on validation set
Got 28857 / 39868 correct (72.38)

Epoch 95, Iteration 700, loss = 0.4516
Checking accuracy on validation set
Got 28331 / 39868 correct (71.06)

Epoch 95, Iteration 800, loss = 0.4451
Checking accuracy on validation set
Got 29054 / 39868 correct (72.88)

Epoch 95, Iteration 900, loss = 0.3217
Checking accuracy on validation set
Got 29033 / 39868 correct (72.82)

Epoch 95, Iteration 1000, loss = 0.4219
Checking accuracy on validation set
Got 29324 / 39868 correct (73.55)

Epoch 96, Iteration 0, loss = 0.5142
Checking accuracy on validation set
Got 28760 / 39868 correct (72.14)

Epoch 96, Iteration 100, loss = 0.3626
Checking accuracy on validation set
Got 29060 / 39868 correct (72.89)

Epoch 96, Iteration 200, loss = 0.4299
Checking accuracy on validation set
Got 29096 / 39868 correct (72.98)

Epoch 96, Iteration 300, loss = 0.4129
Checking accuracy on validation set
Got 29112 /

In [26]:
def test_model(model, test_data, subj_csv):
    # one-segment test
    print('Testing model accuracy using 1-segment metric')
    loader_test = DataLoader(test_data, batch_size=70)
    check_accuracy(loader_test, model)

    # 40-segment test
    print('Testing model accuracy using 40-segment per subject metric')
    with open(subj_csv, newline='') as csvfile:
        spamreader = csv.reader(csvfile, delimiter=' ', quotechar='|')
        subjIDs = [row[0] for row in spamreader]
    unique_subjs,indices = np.unique(subjIDs,return_index=True)

    iterable_test_data = list(iter(DataLoader(test_data, batch_size=1)))
    num_correct = []
    for subj,idx in zip(unique_subjs,indices):
    #     print(f'Subj {subj} - gender {iterable_test_data[idx][1]}')
        data = iterable_test_data[idx:idx+40]
        #print(np.sum([y for _,y in data]))
        assert 40 == np.sum([y for _,y in data]) or 0 == np.sum([y for _,y in data])
        preds = []
        correct = 0
        with torch.no_grad():
            for x,y in data:
                x = x.to(device=device, dtype=dtype)  # move to device, e.g. GPU
                correct = y
                scores = model(x)
                _, pred = scores.max(1)
                preds.append(pred)
        final_pred = (torch.mean(torch.FloatTensor(preds)) > 0.5).sum()
        num_correct.append((final_pred == correct).sum())
    #print(len(num_correct))
    acc = float(np.sum(num_correct)) / len(unique_subjs)
    print('Got %d / %d correct (%.2f)' % (np.sum(num_correct), len(unique_subjs), 100 * acc))

In [27]:
# Testing
# Balanced-class test set
print('Testing on balanced test set')
f = h5py.File('child_mind_x_test_v2.mat', 'r')
x_test = f['X_test']
x_test = np.reshape(x_test,(-1,1,24,256))
print('X_test shape: ' + str(x_test.shape))
f = h5py.File('child_mind_y_test_v2.mat', 'r')
y_test = f['Y_test']
print('Y_test shape: ' + str(y_test.shape))
test_data = EEGDataset(x_test, y_test, False)
test_model(model, test_data, 'test_subjIDs.csv')

print()

# All-male test set
print('Testing on all-male test set')
f = h5py.File('child_mind_x_test_v3.mat', 'r')
x_test = f['X_test']
x_test = np.reshape(x_test,(-1,1,24,256))
print('X_test shape: ' + str(x_test.shape))
f = h5py.File('child_mind_y_test_v3.mat', 'r')
y_test = f['Y_test']
print('Y_test shape: ' + str(y_test.shape))
test_data = EEGDataset(x_test, y_test, False)
test_model(model, test_data, 'test_subjIDs_more_test.csv')

Testing on balanced test set
X_test shape: (16006, 1, 24, 256)
Y_test shape: (16006, 1)
Testing model accuracy using 1-segment metric
Checking accuracy on test set
Got 11793 / 16006 correct (73.68)
Testing model accuracy using 40-segment per subject metric
Got 163 / 198 correct (82.32)

Testing on all-male test set
X_test shape: (52377, 1, 24, 256)
Y_test shape: (52377, 1)
Testing model accuracy using 1-segment metric
Checking accuracy on test set
Got 41209 / 52377 correct (78.68)
Testing model accuracy using 40-segment per subject metric
Got 572 / 649 correct (88.14)


In [None]:
# log_format = '%(asctime)s %(message)s'
# logging.basicConfig(stream=sys.stdout, level=logging.INFO,
#                     format=log_format, datefmt='%m/%d %I:%M:%S %p')
# fh = logging.FileHandler(os.path.join('logs', f'log-{datetime.datetime.today()}.txt'))
# fh.setFormatter(logging.Formatter(log_format))
# logging.getLogger().addHandler(fh)

# logging.info('Learning rate: %f, batch_size: %d' % (lr, batch_size))

# r = -4 * (1-0.25*np.random.rand())
# batch_size = 2**np.random.randint(2,7)
#     torch.save(model.state_dict(), f'logs/model_saved-lr{lr}-bs{batch_size}-{datetime.datetime.today()}')

# state_dict = torch.load()