In [2]:
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import sampler

import torchvision.transforms as T
import numpy as np
import h5py
import logging
import os
import sys
import datetime

In [3]:
class EEGDataset(torch.utils.data.Dataset):
    def __init__(self, x, y, train):
        super(EEGDataset).__init__()
        assert x.shape[0] == y.size
        self.x = x
        #temp_y = np.zeros((y.size, 2))
        #for i in range(y.size):
        #    temp_y[i, y[i]] = 1
        #self.y = temp_y
        self.y = [y[i][0] for i in range(y.size)]
        self.train = train

    def __getitem__(self,key):
        return (self.x[key], self.y[key])

    def __len__(self):
        return len(self.y)

In [9]:

# Load EEG data
transform = T.Compose([
    T.ToTensor()
])
f = h5py.File('child_mind_x_train_v2.mat', 'r')
x_train = f['X_train']
x_train = np.reshape(x_train,(-1,1,24,256))
print('X_train shape: ' + str(x_train.shape))
f = h5py.File('child_mind_y_train_v2.mat', 'r')
y_train = f['Y_train']
print('Y_train shape: ' + str(y_train.shape))
train_data = EEGDataset(x_train, y_train, True)


f = h5py.File('child_mind_x_val_v2.mat', 'r')
x_val = f['X_val']
x_val = np.reshape(x_val,(-1,1,24,256))
print('X_val shape: ' + str(x_val.shape))
f = h5py.File('child_mind_y_val_v2.mat', 'r')
y_val = f['Y_val']
print('Y_val shape: ' + str(y_val.shape))
val_data = EEGDataset(x_val, y_val, True)

f = h5py.File('child_mind_x_test_v2.mat', 'r')
x_test = f['X_test']
x_test = np.reshape(x_test,(-1,1,24,256))
print('X_test shape: ' + str(x_test.shape))
f = h5py.File('child_mind_y_test_v2.mat', 'r')
y_test = f['Y_test']
print('Y_test shape: ' + str(y_test.shape))
test_data = EEGDataset(x_test, y_test, False)
loader_test = DataLoader(test_data, batch_size=70)

X_train shape: (71300, 1, 24, 256)
Y_train shape: (71300, 1)
X_val shape: (39868, 1, 24, 256)
Y_val shape: (39868, 1)
X_test shape: (16006, 1, 24, 256)
Y_test shape: (16006, 1)


In [5]:
print(np.histogram(y_train))

(array([35626,     0,     0,     0,     0,     0,     0,     0,     0,
       35674]), array([0. , 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1. ]))


In [5]:
# Test with MNIST
'''
import torchvision.datasets as dset
NUM_TRAIN = 40000
transform = T.Compose([
                T.ToTensor(),
                T.CenterCrop(24),
                T.Pad((116,0))
            ])
mnist_train = dset.MNIST('./mnist', train=True, download=True,
                             transform=transform)
loader_train = DataLoader(mnist_train, batch_size=64, 
                          sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN)))

mnist_val = dset.MNIST('./mnist', train=True, download=True,
                           transform=transform)
loader_val = DataLoader(mnist_val, batch_size=64, 
                        sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN, 60000)))

mnist_test = dset.MNIST('./mnist', train=False, download=True, 
                            transform=transform)
loader_test = DataLoader(mnist_test, batch_size=64)
'''

"\nimport torchvision.datasets as dset\nNUM_TRAIN = 40000\ntransform = T.Compose([\n                T.ToTensor(),\n                T.CenterCrop(24),\n                T.Pad((116,0))\n            ])\nmnist_train = dset.MNIST('./mnist', train=True, download=True,\n                             transform=transform)\nloader_train = DataLoader(mnist_train, batch_size=64, \n                          sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN)))\n\nmnist_val = dset.MNIST('./mnist', train=True, download=True,\n                           transform=transform)\nloader_val = DataLoader(mnist_val, batch_size=64, \n                        sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN, 60000)))\n\nmnist_test = dset.MNIST('./mnist', train=False, download=True, \n                            transform=transform)\nloader_test = DataLoader(mnist_test, batch_size=64)\n"

In [4]:
USE_GPU = True

dtype = torch.float32 # we will be using float throughout this tutorial

if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

# Constant to control how frequently we print train loss
print_every = 100

print('using device:', device)

using device: cuda


In [5]:
def check_accuracy(loader, model):
    if loader.dataset.train:
        print('Checking accuracy on validation set')
    else:
        print('Checking accuracy on test set')   
    num_correct = 0
    num_samples = 0
    model.eval()  # set model to evaluation mode
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device=device, dtype=dtype)  # move to device, e.g. GPU
            y = y.to(device=device, dtype=torch.long)
            scores = model(x)
            _, preds = scores.max(1)
            num_correct += (preds == y).sum()
            num_samples += preds.size(0)
        acc = float(num_correct) / num_samples
        print('Got %d / %d correct (%.2f)' % (num_correct, num_samples, 100 * acc))

In [19]:
def train(model, optimizer, epochs=1):
    """
    Train a model on CIFAR-10 using the PyTorch Module API.
    
    Inputs:
    - model: A PyTorch Module giving the model to train.
    - optimizer: An Optimizer object we will use to train the model
    - epochs: (Optional) A Python integer giving the number of epochs to train for
    
    Returns: Nothing, but prints model accuracies during training.
    """
    model = model.to(device=device)  # move the model parameters to CPU/GPU
    for e in range(epochs):
        for t, (x, y) in enumerate(loader_train):
            model.train()  # put model to training mode
            x = x.to(device=device, dtype=dtype)  # move to device, e.g. GPU
            y = y.to(device=device, dtype=torch.long)

            scores = model(x)
            loss = F.cross_entropy(scores, y)

            # Zero out all of the gradients for the variables which the optimizer
            # will update.
            optimizer.zero_grad()

            # This is the backwards pass: compute the gradient of the loss with
            # respect to each  parameter of the model.
            loss.backward()

            # Actually update the parameters of the model using the gradients
            # computed by the backwards pass.
            optimizer.step()

            if t % print_every == 0:
                print('Epoch %d, Iteration %d, loss = %.4f' % (e, t, loss.item()))
                check_accuracy(loader_val, model)
                print()
        if e % 10 == 0:
            print('Save model at epoch %d' % (e))
            torch.save(model.state_dict(), 'logs/model_saved-temp')

In [None]:
log_format = '%(asctime)s %(message)s'
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
                    format=log_format, datefmt='%m/%d %I:%M:%S %p')
fh = logging.FileHandler(os.path.join('logs', f'log-fine-tune-0.000363-32.txt'))
fh.setFormatter(logging.Formatter(log_format))
logging.getLogger().addHandler(fh)

for i in range(25):
    r = -4 * (1-0.25*np.random.rand())
    lr = 10**r
    #lr = 0.000363
    batch_size = 2**np.random.randint(2,7)
    #batch_size = 32
    logging.info('Learning rate: %f, batch_size: %d' % (lr, batch_size))
    loader_train = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    loader_val = DataLoader(val_data, batch_size=batch_size)
    model = nn.Sequential(
        nn.Conv2d(1,100,3),
        nn.ReLU(),
        nn.MaxPool2d(2, 2),
        nn.Dropout(0.25),
        nn.Conv2d(100,100,3),
        nn.ReLU(),
        nn.MaxPool2d(2, 2),
        nn.Dropout(0.25),
        nn.Conv2d(100,300,(2,3)),
        nn.ReLU(),
        nn.MaxPool2d(2, 2),
        nn.Dropout(0.25),
        nn.Conv2d(300,300,(1,7)),
        nn.ReLU(),
        nn.MaxPool2d((1,2), stride=1),
        nn.Dropout(0.25),
        nn.Conv2d(300,100,(1,3)),
        nn.Conv2d(100,100,(1,3)),
        nn.Flatten(),
        nn.Linear(1900,6144),
        nn.Linear(6144,2),
    )

    # pred = model(next(iter(loader_train))[0])
    #%%
    # print(pred.shape)
    #%%
    optimizer = torch.optim.Adamax(model.parameters(), lr=lr)
    train(model, optimizer, epochs=25)
    check_accuracy(loader_test, model)
    torch.save(model.state_dict(), f'logs/model_saved-lr{lr}-bs{batch_size}-{datetime.datetime.today()}')

03/15 10:38:38 PM Learning rate: 0.000363, batch_size: 32
Iteration 0, loss = 1.4855
Checking accuracy on validation set
Got 19910 / 39868 correct (49.94)

Iteration 100, loss = 0.5894
Checking accuracy on validation set
Got 20449 / 39868 correct (51.29)

Iteration 200, loss = 0.7102
Checking accuracy on validation set
Got 20024 / 39868 correct (50.23)

Iteration 300, loss = 0.6788
Checking accuracy on validation set
Got 21538 / 39868 correct (54.02)

Iteration 400, loss = 0.6577
Checking accuracy on validation set
Got 20451 / 39868 correct (51.30)

Iteration 500, loss = 0.6818
Checking accuracy on validation set
Got 19988 / 39868 correct (50.14)

Iteration 600, loss = 0.6616
Checking accuracy on validation set
Got 20574 / 39868 correct (51.61)

Iteration 700, loss = 0.6864
Checking accuracy on validation set
Got 20010 / 39868 correct (50.19)

Iteration 800, loss = 0.6850
Checking accuracy on validation set
Got 20015 / 39868 correct (50.20)

Iteration 900, loss = 0.6655
Checking accura

Got 25494 / 39868 correct (63.95)

Iteration 1300, loss = 0.9916
Checking accuracy on validation set
Got 25686 / 39868 correct (64.43)

Iteration 1400, loss = 0.4929
Checking accuracy on validation set
Got 25596 / 39868 correct (64.20)

Iteration 1500, loss = 0.5201
Checking accuracy on validation set
Got 26064 / 39868 correct (65.38)

Iteration 1600, loss = 0.7018
Checking accuracy on validation set
Got 24938 / 39868 correct (62.55)

Iteration 1700, loss = 0.5997
Checking accuracy on validation set
Got 25108 / 39868 correct (62.98)

Iteration 1800, loss = 0.6806
Checking accuracy on validation set
Got 25523 / 39868 correct (64.02)

Iteration 1900, loss = 0.5341
Checking accuracy on validation set
Got 25893 / 39868 correct (64.95)

Iteration 2000, loss = 0.8539
Checking accuracy on validation set
Got 25967 / 39868 correct (65.13)

Iteration 2100, loss = 0.5384
Checking accuracy on validation set
Got 25892 / 39868 correct (64.94)

Iteration 2200, loss = 0.6412
Checking accuracy on valid

###### Fine tune model
lr = 0.000363
batch_size = 64
* Train on top of 5 epoch of batch_size 32
* Train on top of 20 epochs of bs 64

In [21]:
model = nn.Sequential(
        nn.Conv2d(1,100,3),
        nn.ReLU(),
        nn.MaxPool2d(2, 2),
        nn.Dropout(0.25),
        nn.Conv2d(100,100,3),
        nn.ReLU(),
        nn.MaxPool2d(2, 2),
        nn.Dropout(0.25),
        nn.Conv2d(100,300,(2,3)),
        nn.ReLU(),
        nn.MaxPool2d(2, 2),
        nn.Dropout(0.25),
        nn.Conv2d(300,300,(1,7)),
        nn.ReLU(),
        nn.MaxPool2d((1,2), stride=1),
        nn.Dropout(0.25),
        nn.Conv2d(300,100,(1,3)),
        nn.Conv2d(100,100,(1,3)),
        nn.Flatten(),
        nn.Linear(1900,6144),
        nn.Linear(6144,2),
    )
model = model.to(device=device)
model.load_state_dict(torch.load('logs/model_saved-lr0.000363-bs64-2021-03-16 00:24:41.969544'))
model.eval()


Sequential(
  (0): Conv2d(1, 100, kernel_size=(3, 3), stride=(1, 1))
  (1): ReLU()
  (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (3): Dropout(p=0.25, inplace=False)
  (4): Conv2d(100, 100, kernel_size=(3, 3), stride=(1, 1))
  (5): ReLU()
  (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (7): Dropout(p=0.25, inplace=False)
  (8): Conv2d(100, 300, kernel_size=(2, 3), stride=(1, 1))
  (9): ReLU()
  (10): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (11): Dropout(p=0.25, inplace=False)
  (12): Conv2d(300, 300, kernel_size=(1, 7), stride=(1, 1))
  (13): ReLU()
  (14): MaxPool2d(kernel_size=(1, 2), stride=1, padding=0, dilation=1, ceil_mode=False)
  (15): Dropout(p=0.25, inplace=False)
  (16): Conv2d(300, 100, kernel_size=(1, 3), stride=(1, 1))
  (17): Conv2d(100, 100, kernel_size=(1, 3), stride=(1, 1))
  (18): Flatten()
  (19): Linear(in_features=1900, out_features=6144, bias=True)
  (20)

In [22]:
# confirm state of model
check_accuracy(loader_test, model)

Checking accuracy on test set
Got 11730 / 16006 correct (73.29)


In [23]:
lr = 0.000363
batch_size = 64
loader_train = DataLoader(train_data, batch_size=batch_size, shuffle=True)
loader_val = DataLoader(val_data, batch_size=batch_size)
optimizer = torch.optim.Adamax(model.parameters(), lr=lr)
train(model, optimizer, epochs=100)
check_accuracy(loader_test, model)
torch.save(model.state_dict(), f'logs/model_saved-lr{lr}-bs{batch_size}-{datetime.datetime.today()}')

Epoch 0, Iteration 0, loss = 0.3125
Checking accuracy on validation set
Got 20270 / 39868 correct (50.84)

Epoch 0, Iteration 100, loss = 0.3041
Checking accuracy on validation set
Got 29072 / 39868 correct (72.92)

Epoch 0, Iteration 200, loss = 0.3190
Checking accuracy on validation set
Got 29201 / 39868 correct (73.24)

Epoch 0, Iteration 300, loss = 0.4306
Checking accuracy on validation set
Got 29371 / 39868 correct (73.67)

Epoch 0, Iteration 400, loss = 0.3536
Checking accuracy on validation set
Got 29379 / 39868 correct (73.69)

Epoch 0, Iteration 500, loss = 0.2745
Checking accuracy on validation set
Got 29171 / 39868 correct (73.17)

Epoch 0, Iteration 600, loss = 0.2626
Checking accuracy on validation set
Got 29241 / 39868 correct (73.34)

Epoch 0, Iteration 700, loss = 0.4312
Checking accuracy on validation set
Got 28702 / 39868 correct (71.99)

Epoch 0, Iteration 800, loss = 0.2887
Checking accuracy on validation set
Got 29159 / 39868 correct (73.14)

Epoch 0, Iteration 90

Epoch 6, Iteration 300, loss = 0.4802
Checking accuracy on validation set
Got 29271 / 39868 correct (73.42)

Epoch 6, Iteration 400, loss = 0.3985
Checking accuracy on validation set
Got 29358 / 39868 correct (73.64)

Epoch 6, Iteration 500, loss = 0.3977
Checking accuracy on validation set
Got 29215 / 39868 correct (73.28)

Epoch 6, Iteration 600, loss = 0.4127
Checking accuracy on validation set
Got 29261 / 39868 correct (73.39)

Epoch 6, Iteration 700, loss = 0.2165
Checking accuracy on validation set
Got 29256 / 39868 correct (73.38)

Epoch 6, Iteration 800, loss = 0.2912
Checking accuracy on validation set
Got 29422 / 39868 correct (73.80)

Epoch 6, Iteration 900, loss = 0.3081
Checking accuracy on validation set
Got 28812 / 39868 correct (72.27)

Epoch 6, Iteration 1000, loss = 0.3902
Checking accuracy on validation set
Got 29225 / 39868 correct (73.30)

Epoch 6, Iteration 1100, loss = 0.3034
Checking accuracy on validation set
Got 29178 / 39868 correct (73.19)

Epoch 7, Iteratio

Got 29232 / 39868 correct (73.32)

Epoch 12, Iteration 600, loss = 0.3458
Checking accuracy on validation set
Got 29403 / 39868 correct (73.75)

Epoch 12, Iteration 700, loss = 0.4713
Checking accuracy on validation set
Got 29313 / 39868 correct (73.53)

Epoch 12, Iteration 800, loss = 0.3145
Checking accuracy on validation set
Got 29319 / 39868 correct (73.54)

Epoch 12, Iteration 900, loss = 0.5066
Checking accuracy on validation set
Got 29310 / 39868 correct (73.52)

Epoch 12, Iteration 1000, loss = 0.2869
Checking accuracy on validation set
Got 29267 / 39868 correct (73.41)

Epoch 12, Iteration 1100, loss = 0.5002
Checking accuracy on validation set
Got 29524 / 39868 correct (74.05)

Epoch 13, Iteration 0, loss = 0.3728
Checking accuracy on validation set
Got 29239 / 39868 correct (73.34)

Epoch 13, Iteration 100, loss = 0.3893
Checking accuracy on validation set
Got 29445 / 39868 correct (73.86)

Epoch 13, Iteration 200, loss = 0.3193
Checking accuracy on validation set
Got 29254 

Got 29250 / 39868 correct (73.37)

Epoch 18, Iteration 900, loss = 0.2780
Checking accuracy on validation set
Got 29132 / 39868 correct (73.07)

Epoch 18, Iteration 1000, loss = 0.3115
Checking accuracy on validation set
Got 29335 / 39868 correct (73.58)

Epoch 18, Iteration 1100, loss = 0.3334
Checking accuracy on validation set
Got 29412 / 39868 correct (73.77)

Epoch 19, Iteration 0, loss = 0.4193
Checking accuracy on validation set
Got 29054 / 39868 correct (72.88)

Epoch 19, Iteration 100, loss = 0.4486
Checking accuracy on validation set
Got 29247 / 39868 correct (73.36)

Epoch 19, Iteration 200, loss = 0.4078
Checking accuracy on validation set
Got 29446 / 39868 correct (73.86)

Epoch 19, Iteration 300, loss = 0.2796
Checking accuracy on validation set
Got 29443 / 39868 correct (73.85)

Epoch 19, Iteration 400, loss = 0.4273
Checking accuracy on validation set
Got 29410 / 39868 correct (73.77)

Epoch 19, Iteration 500, loss = 0.3325
Checking accuracy on validation set
Got 29126 

Epoch 24, Iteration 1100, loss = 0.3625
Checking accuracy on validation set
Got 29234 / 39868 correct (73.33)

Epoch 25, Iteration 0, loss = 0.1977
Checking accuracy on validation set
Got 29094 / 39868 correct (72.98)

Epoch 25, Iteration 100, loss = 0.2431
Checking accuracy on validation set
Got 29399 / 39868 correct (73.74)

Epoch 25, Iteration 200, loss = 0.3414
Checking accuracy on validation set
Got 29327 / 39868 correct (73.56)

Epoch 25, Iteration 300, loss = 0.2513
Checking accuracy on validation set
Got 29340 / 39868 correct (73.59)

Epoch 25, Iteration 400, loss = 0.3361
Checking accuracy on validation set
Got 29350 / 39868 correct (73.62)

Epoch 25, Iteration 500, loss = 0.3231
Checking accuracy on validation set
Got 29412 / 39868 correct (73.77)

Epoch 25, Iteration 600, loss = 0.2525
Checking accuracy on validation set
Got 29483 / 39868 correct (73.95)

Epoch 25, Iteration 700, loss = 0.3054
Checking accuracy on validation set
Got 29159 / 39868 correct (73.14)

Epoch 25, I

Got 29169 / 39868 correct (73.16)

Epoch 31, Iteration 200, loss = 0.2733
Checking accuracy on validation set
Got 29094 / 39868 correct (72.98)

Epoch 31, Iteration 300, loss = 0.3721
Checking accuracy on validation set
Got 29319 / 39868 correct (73.54)

Epoch 31, Iteration 400, loss = 0.3178
Checking accuracy on validation set
Got 29269 / 39868 correct (73.41)

Epoch 31, Iteration 500, loss = 0.4016
Checking accuracy on validation set
Got 29274 / 39868 correct (73.43)

Epoch 31, Iteration 600, loss = 0.3762
Checking accuracy on validation set
Got 29270 / 39868 correct (73.42)

Epoch 31, Iteration 700, loss = 0.2963
Checking accuracy on validation set
Got 29167 / 39868 correct (73.16)

Epoch 31, Iteration 800, loss = 0.4542
Checking accuracy on validation set
Got 29243 / 39868 correct (73.35)

Epoch 31, Iteration 900, loss = 0.2732
Checking accuracy on validation set
Got 29279 / 39868 correct (73.44)

Epoch 31, Iteration 1000, loss = 0.2222
Checking accuracy on validation set
Got 29256

Got 29210 / 39868 correct (73.27)

Epoch 37, Iteration 500, loss = 0.2775
Checking accuracy on validation set
Got 28976 / 39868 correct (72.68)

Epoch 37, Iteration 600, loss = 0.2183
Checking accuracy on validation set
Got 28939 / 39868 correct (72.59)

Epoch 37, Iteration 700, loss = 0.2929
Checking accuracy on validation set
Got 29146 / 39868 correct (73.11)

Epoch 37, Iteration 800, loss = 0.3149
Checking accuracy on validation set
Got 29136 / 39868 correct (73.08)

Epoch 37, Iteration 900, loss = 0.2832
Checking accuracy on validation set
Got 29207 / 39868 correct (73.26)

Epoch 37, Iteration 1000, loss = 0.2560
Checking accuracy on validation set
Got 29156 / 39868 correct (73.13)

Epoch 37, Iteration 1100, loss = 0.2724
Checking accuracy on validation set
Got 29159 / 39868 correct (73.14)

Epoch 38, Iteration 0, loss = 0.3152
Checking accuracy on validation set
Got 29011 / 39868 correct (72.77)

Epoch 38, Iteration 100, loss = 0.3610
Checking accuracy on validation set
Got 29283 

Epoch 43, Iteration 700, loss = 0.3486
Checking accuracy on validation set
Got 29331 / 39868 correct (73.57)

Epoch 43, Iteration 800, loss = 0.3460
Checking accuracy on validation set
Got 29305 / 39868 correct (73.51)

Epoch 43, Iteration 900, loss = 0.2718
Checking accuracy on validation set
Got 28813 / 39868 correct (72.27)

Epoch 43, Iteration 1000, loss = 0.3599
Checking accuracy on validation set
Got 29190 / 39868 correct (73.22)

Epoch 43, Iteration 1100, loss = 0.2543
Checking accuracy on validation set
Got 28872 / 39868 correct (72.42)

Epoch 44, Iteration 0, loss = 0.2100
Checking accuracy on validation set
Got 29326 / 39868 correct (73.56)

Epoch 44, Iteration 100, loss = 0.2854
Checking accuracy on validation set
Got 28900 / 39868 correct (72.49)

Epoch 44, Iteration 200, loss = 0.2354
Checking accuracy on validation set
Got 29019 / 39868 correct (72.79)

Epoch 44, Iteration 300, loss = 0.2827
Checking accuracy on validation set
Got 28939 / 39868 correct (72.59)

Epoch 44, 

Got 29372 / 39868 correct (73.67)

Epoch 49, Iteration 1000, loss = 0.2621
Checking accuracy on validation set
Got 28955 / 39868 correct (72.63)

Epoch 49, Iteration 1100, loss = 0.2123
Checking accuracy on validation set
Got 29150 / 39868 correct (73.12)

Epoch 50, Iteration 0, loss = 0.3510
Checking accuracy on validation set
Got 28253 / 39868 correct (70.87)

Epoch 50, Iteration 100, loss = 0.3071
Checking accuracy on validation set
Got 29296 / 39868 correct (73.48)

Epoch 50, Iteration 200, loss = 0.2609
Checking accuracy on validation set
Got 29144 / 39868 correct (73.10)

Epoch 50, Iteration 300, loss = 0.2347
Checking accuracy on validation set
Got 28292 / 39868 correct (70.96)

Epoch 50, Iteration 400, loss = 0.3236
Checking accuracy on validation set
Got 28733 / 39868 correct (72.07)

Epoch 50, Iteration 500, loss = 0.2274
Checking accuracy on validation set
Got 29430 / 39868 correct (73.82)

Epoch 50, Iteration 600, loss = 0.3061
Checking accuracy on validation set
Got 29129 

Epoch 56, Iteration 0, loss = 0.3449
Checking accuracy on validation set
Got 29203 / 39868 correct (73.25)

Epoch 56, Iteration 100, loss = 0.2843
Checking accuracy on validation set
Got 28960 / 39868 correct (72.64)

Epoch 56, Iteration 200, loss = 0.1990
Checking accuracy on validation set
Got 28999 / 39868 correct (72.74)

Epoch 56, Iteration 300, loss = 0.3141
Checking accuracy on validation set
Got 29060 / 39868 correct (72.89)

Epoch 56, Iteration 400, loss = 0.1719
Checking accuracy on validation set
Got 28914 / 39868 correct (72.52)

Epoch 56, Iteration 500, loss = 0.3068
Checking accuracy on validation set
Got 29015 / 39868 correct (72.78)

Epoch 56, Iteration 600, loss = 0.2817
Checking accuracy on validation set
Got 29150 / 39868 correct (73.12)

Epoch 56, Iteration 700, loss = 0.2202
Checking accuracy on validation set
Got 29158 / 39868 correct (73.14)

Epoch 56, Iteration 800, loss = 0.2270
Checking accuracy on validation set
Got 28699 / 39868 correct (71.99)

Epoch 56, It

Got 29191 / 39868 correct (73.22)

Epoch 62, Iteration 300, loss = 0.3072
Checking accuracy on validation set
Got 29185 / 39868 correct (73.20)

Epoch 62, Iteration 400, loss = 0.2899
Checking accuracy on validation set
Got 29200 / 39868 correct (73.24)

Epoch 62, Iteration 500, loss = 0.2754
Checking accuracy on validation set
Got 28915 / 39868 correct (72.53)

Epoch 62, Iteration 600, loss = 0.3031
Checking accuracy on validation set
Got 28980 / 39868 correct (72.69)

Epoch 62, Iteration 700, loss = 0.3207
Checking accuracy on validation set
Got 28969 / 39868 correct (72.66)

Epoch 62, Iteration 800, loss = 0.2537
Checking accuracy on validation set
Got 29000 / 39868 correct (72.74)

Epoch 62, Iteration 900, loss = 0.2450
Checking accuracy on validation set
Got 29121 / 39868 correct (73.04)

Epoch 62, Iteration 1000, loss = 0.2520
Checking accuracy on validation set
Got 28840 / 39868 correct (72.34)

Epoch 62, Iteration 1100, loss = 0.1695
Checking accuracy on validation set
Got 2925

Got 28948 / 39868 correct (72.61)

Epoch 68, Iteration 600, loss = 0.1955
Checking accuracy on validation set
Got 29090 / 39868 correct (72.97)

Epoch 68, Iteration 700, loss = 0.2452
Checking accuracy on validation set
Got 28989 / 39868 correct (72.71)

Epoch 68, Iteration 800, loss = 0.2450
Checking accuracy on validation set
Got 28460 / 39868 correct (71.39)

Epoch 68, Iteration 900, loss = 0.2062
Checking accuracy on validation set
Got 28988 / 39868 correct (72.71)

Epoch 68, Iteration 1000, loss = 0.2316
Checking accuracy on validation set
Got 28869 / 39868 correct (72.41)

Epoch 68, Iteration 1100, loss = 0.2260
Checking accuracy on validation set
Got 28986 / 39868 correct (72.70)

Epoch 69, Iteration 0, loss = 0.1224
Checking accuracy on validation set
Got 29025 / 39868 correct (72.80)

Epoch 69, Iteration 100, loss = 0.2410
Checking accuracy on validation set
Got 28640 / 39868 correct (71.84)

Epoch 69, Iteration 200, loss = 0.2810
Checking accuracy on validation set
Got 29230 

Epoch 74, Iteration 800, loss = 0.2900
Checking accuracy on validation set
Got 28701 / 39868 correct (71.99)

Epoch 74, Iteration 900, loss = 0.2060
Checking accuracy on validation set
Got 29050 / 39868 correct (72.87)

Epoch 74, Iteration 1000, loss = 0.2494
Checking accuracy on validation set
Got 28368 / 39868 correct (71.15)

Epoch 74, Iteration 1100, loss = 0.1942
Checking accuracy on validation set
Got 28428 / 39868 correct (71.31)

Epoch 75, Iteration 0, loss = 0.1602
Checking accuracy on validation set
Got 28781 / 39868 correct (72.19)

Epoch 75, Iteration 100, loss = 0.2055
Checking accuracy on validation set
Got 29083 / 39868 correct (72.95)

Epoch 75, Iteration 200, loss = 0.3433
Checking accuracy on validation set
Got 28949 / 39868 correct (72.61)

Epoch 75, Iteration 300, loss = 0.2192
Checking accuracy on validation set
Got 28923 / 39868 correct (72.55)

Epoch 75, Iteration 400, loss = 0.3156
Checking accuracy on validation set
Got 28929 / 39868 correct (72.56)

Epoch 75, 

Got 28916 / 39868 correct (72.53)

Epoch 80, Iteration 1100, loss = 0.2333
Checking accuracy on validation set
Got 28650 / 39868 correct (71.86)

Save model at epoch 80
Epoch 81, Iteration 0, loss = 0.2881
Checking accuracy on validation set
Got 28836 / 39868 correct (72.33)

Epoch 81, Iteration 100, loss = 0.2748
Checking accuracy on validation set
Got 28936 / 39868 correct (72.58)

Epoch 81, Iteration 200, loss = 0.2586
Checking accuracy on validation set
Got 28978 / 39868 correct (72.68)

Epoch 81, Iteration 300, loss = 0.2368
Checking accuracy on validation set
Got 28576 / 39868 correct (71.68)

Epoch 81, Iteration 400, loss = 0.1599
Checking accuracy on validation set
Got 28943 / 39868 correct (72.60)

Epoch 81, Iteration 500, loss = 0.2555
Checking accuracy on validation set
Got 29000 / 39868 correct (72.74)

Epoch 81, Iteration 600, loss = 0.1500
Checking accuracy on validation set
Got 28897 / 39868 correct (72.48)

Epoch 81, Iteration 700, loss = 0.2282
Checking accuracy on val

Epoch 87, Iteration 100, loss = 0.2285
Checking accuracy on validation set
Got 28862 / 39868 correct (72.39)

Epoch 87, Iteration 200, loss = 0.2182
Checking accuracy on validation set
Got 28835 / 39868 correct (72.33)

Epoch 87, Iteration 300, loss = 0.1568
Checking accuracy on validation set
Got 28864 / 39868 correct (72.40)

Epoch 87, Iteration 400, loss = 0.2818
Checking accuracy on validation set
Got 28972 / 39868 correct (72.67)

Epoch 87, Iteration 500, loss = 0.1400
Checking accuracy on validation set
Got 28751 / 39868 correct (72.12)

Epoch 87, Iteration 600, loss = 0.3129
Checking accuracy on validation set
Got 28543 / 39868 correct (71.59)

Epoch 87, Iteration 700, loss = 0.1885
Checking accuracy on validation set
Got 28138 / 39868 correct (70.58)

Epoch 87, Iteration 800, loss = 0.2255
Checking accuracy on validation set
Got 28939 / 39868 correct (72.59)

Epoch 87, Iteration 900, loss = 0.1852
Checking accuracy on validation set
Got 28777 / 39868 correct (72.18)

Epoch 87, 

Got 28807 / 39868 correct (72.26)

Epoch 93, Iteration 400, loss = 0.2215
Checking accuracy on validation set
Got 28685 / 39868 correct (71.95)

Epoch 93, Iteration 500, loss = 0.1368
Checking accuracy on validation set
Got 28940 / 39868 correct (72.59)

Epoch 93, Iteration 600, loss = 0.2724
Checking accuracy on validation set
Got 28682 / 39868 correct (71.94)

Epoch 93, Iteration 700, loss = 0.1747
Checking accuracy on validation set
Got 28874 / 39868 correct (72.42)

Epoch 93, Iteration 800, loss = 0.1943
Checking accuracy on validation set
Got 28854 / 39868 correct (72.37)

Epoch 93, Iteration 900, loss = 0.3321
Checking accuracy on validation set
Got 28745 / 39868 correct (72.10)

Epoch 93, Iteration 1000, loss = 0.2240
Checking accuracy on validation set
Got 28472 / 39868 correct (71.42)

Epoch 93, Iteration 1100, loss = 0.3036
Checking accuracy on validation set
Got 28791 / 39868 correct (72.22)

Epoch 94, Iteration 0, loss = 0.3846
Checking accuracy on validation set
Got 28880 

Got 28812 / 39868 correct (72.27)

Epoch 99, Iteration 700, loss = 0.2445
Checking accuracy on validation set
Got 28761 / 39868 correct (72.14)

Epoch 99, Iteration 800, loss = 0.1224
Checking accuracy on validation set
Got 28715 / 39868 correct (72.03)

Epoch 99, Iteration 900, loss = 0.1105
Checking accuracy on validation set
Got 28611 / 39868 correct (71.76)

Epoch 99, Iteration 1000, loss = 0.1933
Checking accuracy on validation set
Got 28660 / 39868 correct (71.89)

Epoch 99, Iteration 1100, loss = 0.3460
Checking accuracy on validation set
Got 28452 / 39868 correct (71.37)

Checking accuracy on test set
Got 11556 / 16006 correct (72.20)


The model seems to not being able to enhance learning anymore. Why?
Try lower learning rate (?)

In [None]:
model = nn.Sequential(
        nn.Conv2d(1,100,3),
        nn.ReLU(),
        nn.MaxPool2d(2, 2),
        nn.Dropout(0.25),
        nn.Conv2d(100,100,3),
        nn.ReLU(),
        nn.MaxPool2d(2, 2),
        nn.Dropout(0.25),
        nn.Conv2d(100,300,(2,3)),
        nn.ReLU(),
        nn.MaxPool2d(2, 2),
        nn.Dropout(0.25),
        nn.Conv2d(300,300,(1,7)),
        nn.ReLU(),
        nn.MaxPool2d((1,2), stride=1),
        nn.Dropout(0.25),
        nn.Conv2d(300,100,(1,3)),
        nn.Conv2d(100,100,(1,3)),
        nn.Flatten(),
        nn.Linear(1900,6144),
        nn.Linear(6144,2),
    )
model = model.to(device=device)
model.load_state_dict(torch.load('logs/model_saved-lr0.000363-bs64-2021-03-16 00:24:41.969544'))
model.eval()
# confirm state of model
check_accuracy(loader_test, model) # should be 73.29
lr = 0.0000363
batch_size = 64
loader_train = DataLoader(train_data, batch_size=batch_size, shuffle=True)
loader_val = DataLoader(val_data, batch_size=batch_size)
optimizer = torch.optim.Adamax(model.parameters(), lr=lr)
train(model, optimizer, epochs=20)
check_accuracy(loader_test, model)
torch.save(model.state_dict(), f'logs/model_saved-lr{lr}-bs{batch_size}-{datetime.datetime.today()}')