In [1]:
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import sampler

import torchvision.transforms as T
import numpy as np
import h5py
import logging
import os
import sys
import datetime

In [2]:
class EEGDataset(torch.utils.data.Dataset):
    def __init__(self, x, y, train):
        super(EEGDataset).__init__()
        assert x.shape[0] == y.size
        self.x = x
        #temp_y = np.zeros((y.size, 2))
        #for i in range(y.size):
        #    temp_y[i, y[i]] = 1
        #self.y = temp_y
        self.y = [y[i][0] for i in range(y.size)]
        self.train = train

    def __getitem__(self,key):
        return (self.x[key], self.y[key])

    def __len__(self):
        return len(self.y)

In [3]:

# Load EEG data
transform = T.Compose([
    T.ToTensor()
])
f = h5py.File('child_mind_x_train_v2.mat', 'r')
x_train = f['X_train']
x_train = np.reshape(x_train,(-1,1,24,256))
print('X_train shape: ' + str(x_train.shape))
f = h5py.File('child_mind_y_train_v2.mat', 'r')
y_train = f['Y_train']
print('Y_train shape: ' + str(y_train.shape))
train_data = EEGDataset(x_train, y_train, True)


f = h5py.File('child_mind_x_val_v2.mat', 'r')
x_val = f['X_val']
x_val = np.reshape(x_val,(-1,1,24,256))
print('X_val shape: ' + str(x_val.shape))
f = h5py.File('child_mind_y_val_v2.mat', 'r')
y_val = f['Y_val']
print('Y_val shape: ' + str(y_val.shape))
val_data = EEGDataset(x_val, y_val, True)

f = h5py.File('child_mind_x_test_v2.mat', 'r')
x_test = f['X_test']
x_test = np.reshape(x_test,(-1,1,24,256))
print('X_test shape: ' + str(x_test.shape))
f = h5py.File('child_mind_y_test_v2.mat', 'r')
y_test = f['Y_test']
print('Y_test shape: ' + str(y_test.shape))
test_data = EEGDataset(x_test, y_test, False)
loader_test = DataLoader(test_data, batch_size=70)

X_train shape: (71300, 1, 24, 256)
Y_train shape: (71300, 1)
X_val shape: (39868, 1, 24, 256)
Y_val shape: (39868, 1)
X_test shape: (16006, 1, 24, 256)
Y_test shape: (16006, 1)


In [4]:
print(np.histogram(y_train))

(array([35626,     0,     0,     0,     0,     0,     0,     0,     0,
       35674]), array([0. , 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1. ]))


In [5]:
# Test with MNIST
'''
import torchvision.datasets as dset
NUM_TRAIN = 40000
transform = T.Compose([
                T.ToTensor(),
                T.CenterCrop(24),
                T.Pad((116,0))
            ])
mnist_train = dset.MNIST('./mnist', train=True, download=True,
                             transform=transform)
loader_train = DataLoader(mnist_train, batch_size=64, 
                          sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN)))

mnist_val = dset.MNIST('./mnist', train=True, download=True,
                           transform=transform)
loader_val = DataLoader(mnist_val, batch_size=64, 
                        sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN, 60000)))

mnist_test = dset.MNIST('./mnist', train=False, download=True, 
                            transform=transform)
loader_test = DataLoader(mnist_test, batch_size=64)
'''

"\nimport torchvision.datasets as dset\nNUM_TRAIN = 40000\ntransform = T.Compose([\n                T.ToTensor(),\n                T.CenterCrop(24),\n                T.Pad((116,0))\n            ])\nmnist_train = dset.MNIST('./mnist', train=True, download=True,\n                             transform=transform)\nloader_train = DataLoader(mnist_train, batch_size=64, \n                          sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN)))\n\nmnist_val = dset.MNIST('./mnist', train=True, download=True,\n                           transform=transform)\nloader_val = DataLoader(mnist_val, batch_size=64, \n                        sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN, 60000)))\n\nmnist_test = dset.MNIST('./mnist', train=False, download=True, \n                            transform=transform)\nloader_test = DataLoader(mnist_test, batch_size=64)\n"

In [6]:
USE_GPU = True

dtype = torch.float32 # we will be using float throughout this tutorial

if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

# Constant to control how frequently we print train loss
print_every = 100

print('using device:', device)

using device: cuda


In [7]:
def check_accuracy(loader, model):
    if loader.dataset.train:
        print('Checking accuracy on validation set')
    else:
        print('Checking accuracy on test set')   
    num_correct = 0
    num_samples = 0
    model.eval()  # set model to evaluation mode
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device=device, dtype=dtype)  # move to device, e.g. GPU
            y = y.to(device=device, dtype=torch.long)
            scores = model(x)
            _, preds = scores.max(1)
            num_correct += (preds == y).sum()
            num_samples += preds.size(0)
        acc = float(num_correct) / num_samples
        print('Got %d / %d correct (%.2f)' % (num_correct, num_samples, 100 * acc))

In [8]:
def train(model, optimizer, epochs=1):
    """
    Train a model on CIFAR-10 using the PyTorch Module API.
    
    Inputs:
    - model: A PyTorch Module giving the model to train.
    - optimizer: An Optimizer object we will use to train the model
    - epochs: (Optional) A Python integer giving the number of epochs to train for
    
    Returns: Nothing, but prints model accuracies during training.
    """
    model = model.to(device=device)  # move the model parameters to CPU/GPU
    for e in range(epochs):
        for t, (x, y) in enumerate(loader_train):
            model.train()  # put model to training mode
            x = x.to(device=device, dtype=dtype)  # move to device, e.g. GPU
            y = y.to(device=device, dtype=torch.long)

            scores = model(x)
            loss = F.cross_entropy(scores, y)

            # Zero out all of the gradients for the variables which the optimizer
            # will update.
            optimizer.zero_grad()

            # This is the backwards pass: compute the gradient of the loss with
            # respect to each  parameter of the model.
            loss.backward()

            # Actually update the parameters of the model using the gradients
            # computed by the backwards pass.
            optimizer.step()

            if t % print_every == 0:
                print('Iteration %d, loss = %.4f' % (t, loss.item()))
                check_accuracy(loader_val, model)
                print()

In [None]:
log_format = '%(asctime)s %(message)s'
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
                    format=log_format, datefmt='%m/%d %I:%M:%S %p')
fh = logging.FileHandler(os.path.join('logs', f'log-{datetime.datetime.today()}.txt'))
fh.setFormatter(logging.Formatter(log_format))
logging.getLogger().addHandler(fh)

for i in range(25):
    r = -4 * (1-0.25*np.random.rand())
    lr = 10**r
    #lr = 0.000362
    batch_size = 2**np.random.randint(2,7)
    #batch_size = 4
    logging.info('Learning rate: %f, batch_size: %d' % (lr, batch_size))
    loader_train = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    loader_val = DataLoader(val_data, batch_size=batch_size)
    model = nn.Sequential(
        nn.Conv2d(1,100,3),
        nn.ReLU(),
        nn.MaxPool2d(2, 2),
        nn.Dropout(0.25),
        nn.Conv2d(100,100,3),
        nn.ReLU(),
        nn.MaxPool2d(2, 2),
        nn.Dropout(0.25),
        nn.Conv2d(100,300,(2,3)),
        nn.ReLU(),
        nn.MaxPool2d(2, 2),
        nn.Dropout(0.25),
        nn.Conv2d(300,300,(1,7)),
        nn.ReLU(),
        nn.MaxPool2d((1,2), stride=1),
        nn.Dropout(0.25),
        nn.Conv2d(300,100,(1,3)),
        nn.Conv2d(100,100,(1,3)),
        nn.Flatten(),
        nn.Linear(1900,6144),
        nn.Linear(6144,2),
    )

    # pred = model(next(iter(loader_train))[0])
    #%%
    # print(pred.shape)
    #%%
    optimizer = torch.optim.Adamax(model.parameters(), lr=lr)
    train(model, optimizer, epochs=5)
    check_accuracy(loader_test, model)
    torch.save(model.state_dict(), f'logs/model_saved-lr{lr}-bs{batch_size}-{datetime.datetime.today()}')

03/15 10:38:38 PM Learning rate: 0.000363, batch_size: 32
Iteration 0, loss = 1.4855
Checking accuracy on validation set
Got 19910 / 39868 correct (49.94)

Iteration 100, loss = 0.5894
Checking accuracy on validation set
Got 20449 / 39868 correct (51.29)

Iteration 200, loss = 0.7102
Checking accuracy on validation set
Got 20024 / 39868 correct (50.23)

Iteration 300, loss = 0.6788
Checking accuracy on validation set
Got 21538 / 39868 correct (54.02)

Iteration 400, loss = 0.6577
Checking accuracy on validation set
Got 20451 / 39868 correct (51.30)

Iteration 500, loss = 0.6818
Checking accuracy on validation set
Got 19988 / 39868 correct (50.14)

Iteration 600, loss = 0.6616
Checking accuracy on validation set
Got 20574 / 39868 correct (51.61)

Iteration 700, loss = 0.6864
Checking accuracy on validation set
Got 20010 / 39868 correct (50.19)

Iteration 800, loss = 0.6850
Checking accuracy on validation set
Got 20015 / 39868 correct (50.20)

Iteration 900, loss = 0.6655
Checking accura

Got 25494 / 39868 correct (63.95)

Iteration 1300, loss = 0.9916
Checking accuracy on validation set
Got 25686 / 39868 correct (64.43)

Iteration 1400, loss = 0.4929
Checking accuracy on validation set
Got 25596 / 39868 correct (64.20)

Iteration 1500, loss = 0.5201
Checking accuracy on validation set
Got 26064 / 39868 correct (65.38)

Iteration 1600, loss = 0.7018
Checking accuracy on validation set
Got 24938 / 39868 correct (62.55)

Iteration 1700, loss = 0.5997
Checking accuracy on validation set
Got 25108 / 39868 correct (62.98)

Iteration 1800, loss = 0.6806
Checking accuracy on validation set
Got 25523 / 39868 correct (64.02)

Iteration 1900, loss = 0.5341
Checking accuracy on validation set
Got 25893 / 39868 correct (64.95)

Iteration 2000, loss = 0.8539
Checking accuracy on validation set
Got 25967 / 39868 correct (65.13)

Iteration 2100, loss = 0.5384
Checking accuracy on validation set
Got 25892 / 39868 correct (64.94)

Iteration 2200, loss = 0.6412
Checking accuracy on valid

Iteration 4700, loss = 0.6520
Checking accuracy on validation set
Got 20355 / 39868 correct (51.06)

Iteration 4800, loss = 0.6679
Checking accuracy on validation set
Got 20686 / 39868 correct (51.89)

Iteration 4900, loss = 2.7602
Checking accuracy on validation set
Got 20953 / 39868 correct (52.56)

Iteration 5000, loss = 0.6753
Checking accuracy on validation set
Got 20302 / 39868 correct (50.92)

Iteration 5100, loss = 0.5653
Checking accuracy on validation set
Got 21750 / 39868 correct (54.56)

Iteration 5200, loss = 0.7132
Checking accuracy on validation set
Got 21320 / 39868 correct (53.48)

Iteration 5300, loss = 0.7013
Checking accuracy on validation set
Got 20849 / 39868 correct (52.30)

Iteration 5400, loss = 0.7353
Checking accuracy on validation set
Got 21550 / 39868 correct (54.05)

Iteration 5500, loss = 0.5975
Checking accuracy on validation set
Got 21901 / 39868 correct (54.93)

Iteration 5600, loss = 0.9917
Checking accuracy on validation set
Got 20899 / 39868 correct

Iteration 12800, loss = 0.7058
Checking accuracy on validation set
Got 20029 / 39868 correct (50.24)

Iteration 12900, loss = 0.7974
Checking accuracy on validation set
Got 20154 / 39868 correct (50.55)

Iteration 13000, loss = 0.7547
Checking accuracy on validation set
Got 20308 / 39868 correct (50.94)

Iteration 13100, loss = 0.7214
Checking accuracy on validation set
Got 20219 / 39868 correct (50.71)

Iteration 13200, loss = 0.6922
Checking accuracy on validation set
Got 20492 / 39868 correct (51.40)

Iteration 13300, loss = 0.6707
Checking accuracy on validation set
Got 20108 / 39868 correct (50.44)

Iteration 13400, loss = 0.7098
Checking accuracy on validation set
Got 20240 / 39868 correct (50.77)

Iteration 13500, loss = 0.7680
Checking accuracy on validation set
Got 20619 / 39868 correct (51.72)

Iteration 13600, loss = 1.0220
Checking accuracy on validation set
Got 22161 / 39868 correct (55.59)

Iteration 13700, loss = 0.6901
Checking accuracy on validation set
Got 23316 / 398

Iteration 3000, loss = 0.7403
Checking accuracy on validation set
Got 22399 / 39868 correct (56.18)

Iteration 3100, loss = 0.7490
Checking accuracy on validation set
Got 20342 / 39868 correct (51.02)

Iteration 3200, loss = 0.6718
Checking accuracy on validation set
Got 19998 / 39868 correct (50.16)

Iteration 3300, loss = 0.6579
Checking accuracy on validation set
Got 20370 / 39868 correct (51.09)

Iteration 3400, loss = 0.7693
Checking accuracy on validation set
Got 20927 / 39868 correct (52.49)

Iteration 3500, loss = 0.7871
Checking accuracy on validation set
Got 21309 / 39868 correct (53.45)

Iteration 3600, loss = 0.8174
Checking accuracy on validation set
Got 20667 / 39868 correct (51.84)

Iteration 3700, loss = 0.7599
Checking accuracy on validation set
Got 20054 / 39868 correct (50.30)

Iteration 3800, loss = 0.5826
Checking accuracy on validation set
Got 20195 / 39868 correct (50.65)

Iteration 3900, loss = 0.6182
Checking accuracy on validation set
Got 20282 / 39868 correct

Got 23674 / 39868 correct (59.38)

Iteration 11200, loss = 0.4340
Checking accuracy on validation set
Got 21881 / 39868 correct (54.88)

Iteration 11300, loss = 0.6711
Checking accuracy on validation set
Got 21695 / 39868 correct (54.42)

Iteration 11400, loss = 0.7846
Checking accuracy on validation set
Got 22701 / 39868 correct (56.94)

Iteration 11500, loss = 0.6917
Checking accuracy on validation set
Got 21798 / 39868 correct (54.68)

Iteration 11600, loss = 0.7319
Checking accuracy on validation set
Got 21991 / 39868 correct (55.16)

Iteration 11700, loss = 0.4602
Checking accuracy on validation set
Got 21929 / 39868 correct (55.00)

Iteration 11800, loss = 0.5700
Checking accuracy on validation set
Got 22672 / 39868 correct (56.87)

Iteration 11900, loss = 0.9042
Checking accuracy on validation set
Got 21600 / 39868 correct (54.18)

Iteration 12000, loss = 0.5914
Checking accuracy on validation set
Got 22224 / 39868 correct (55.74)

Iteration 12100, loss = 0.6114
Checking accurac

Got 23838 / 39868 correct (59.79)

Iteration 1400, loss = 0.4077
Checking accuracy on validation set
Got 23949 / 39868 correct (60.07)

Iteration 1500, loss = 0.2473
Checking accuracy on validation set
Got 23569 / 39868 correct (59.12)

Iteration 1600, loss = 0.3703
Checking accuracy on validation set
Got 22153 / 39868 correct (55.57)

Iteration 1700, loss = 0.4680
Checking accuracy on validation set
Got 23684 / 39868 correct (59.41)

Iteration 1800, loss = 0.3816
Checking accuracy on validation set
Got 21971 / 39868 correct (55.11)

Iteration 1900, loss = 0.4448
Checking accuracy on validation set
Got 24243 / 39868 correct (60.81)

Iteration 2000, loss = 0.1891
Checking accuracy on validation set
Got 23789 / 39868 correct (59.67)

Iteration 2100, loss = 0.4081
Checking accuracy on validation set
Got 23621 / 39868 correct (59.25)

Iteration 2200, loss = 0.4057
Checking accuracy on validation set
Got 23442 / 39868 correct (58.80)

Iteration 2300, loss = 1.0217
Checking accuracy on valid

Iteration 9500, loss = 0.5856
Checking accuracy on validation set
Got 25435 / 39868 correct (63.80)

Iteration 9600, loss = 0.4250
Checking accuracy on validation set
Got 24580 / 39868 correct (61.65)

Iteration 9700, loss = 0.5784
Checking accuracy on validation set
Got 25641 / 39868 correct (64.31)

Iteration 9800, loss = 0.2313
Checking accuracy on validation set
Got 25742 / 39868 correct (64.57)

Iteration 9900, loss = 0.5119
Checking accuracy on validation set
Got 25549 / 39868 correct (64.08)

Iteration 10000, loss = 0.5260
Checking accuracy on validation set
Got 25325 / 39868 correct (63.52)

Iteration 10100, loss = 0.3989
Checking accuracy on validation set
Got 24370 / 39868 correct (61.13)

Iteration 10200, loss = 0.7484
Checking accuracy on validation set
Got 25052 / 39868 correct (62.84)

Iteration 10300, loss = 0.5463
Checking accuracy on validation set
Got 24722 / 39868 correct (62.01)

Iteration 10400, loss = 0.9075
Checking accuracy on validation set
Got 24953 / 39868 co

Got 25686 / 39868 correct (64.43)

Iteration 17600, loss = 0.2805
Checking accuracy on validation set
Got 25584 / 39868 correct (64.17)

Iteration 17700, loss = 0.4931
Checking accuracy on validation set
Got 25129 / 39868 correct (63.03)

Iteration 17800, loss = 0.7044
Checking accuracy on validation set
Got 25896 / 39868 correct (64.95)

Iteration 0, loss = 0.8913
Checking accuracy on validation set
Got 25093 / 39868 correct (62.94)

Iteration 100, loss = 0.4997
Checking accuracy on validation set
Got 25755 / 39868 correct (64.60)

Iteration 200, loss = 0.4955
Checking accuracy on validation set
Got 25723 / 39868 correct (64.52)

Iteration 300, loss = 0.2903
Checking accuracy on validation set
Got 25695 / 39868 correct (64.45)

Iteration 400, loss = 0.6092
Checking accuracy on validation set
Got 25774 / 39868 correct (64.65)

Iteration 500, loss = 0.7131
Checking accuracy on validation set
Got 25555 / 39868 correct (64.10)

Iteration 600, loss = 0.4875
Checking accuracy on validation 

Iteration 7800, loss = 0.4535
Checking accuracy on validation set
Got 25946 / 39868 correct (65.08)

Iteration 7900, loss = 0.9527
Checking accuracy on validation set
Got 25054 / 39868 correct (62.84)

Iteration 8000, loss = 1.3499
Checking accuracy on validation set
Got 25333 / 39868 correct (63.54)

Iteration 8100, loss = 0.5156
Checking accuracy on validation set
Got 25362 / 39868 correct (63.61)

Iteration 8200, loss = 0.5984
Checking accuracy on validation set
Got 25454 / 39868 correct (63.85)

Iteration 8300, loss = 0.3385
Checking accuracy on validation set
Got 24702 / 39868 correct (61.96)

Iteration 8400, loss = 0.3217
Checking accuracy on validation set
Got 25851 / 39868 correct (64.84)

Iteration 8500, loss = 0.5608
Checking accuracy on validation set
Got 25410 / 39868 correct (63.74)

Iteration 8600, loss = 0.2973
Checking accuracy on validation set
Got 25452 / 39868 correct (63.84)

Iteration 8700, loss = 0.0888
Checking accuracy on validation set
Got 25714 / 39868 correct

Got 26349 / 39868 correct (66.09)

Iteration 15900, loss = 0.6889
Checking accuracy on validation set
Got 25687 / 39868 correct (64.43)

Iteration 16000, loss = 0.6466
Checking accuracy on validation set
Got 26194 / 39868 correct (65.70)

Iteration 16100, loss = 0.4233
Checking accuracy on validation set
Got 26566 / 39868 correct (66.63)

Iteration 16200, loss = 0.3813
Checking accuracy on validation set
Got 26088 / 39868 correct (65.44)

Iteration 16300, loss = 0.4226
Checking accuracy on validation set
Got 25691 / 39868 correct (64.44)

Iteration 16400, loss = 0.4764
Checking accuracy on validation set
Got 26022 / 39868 correct (65.27)

Iteration 16500, loss = 0.6578
Checking accuracy on validation set
Got 25674 / 39868 correct (64.40)

Iteration 16600, loss = 0.2579
Checking accuracy on validation set
Got 26179 / 39868 correct (65.66)

Iteration 16700, loss = 0.3992
Checking accuracy on validation set
Got 26263 / 39868 correct (65.87)

Iteration 16800, loss = 0.7278
Checking accurac

Iteration 6600, loss = 0.6978
Checking accuracy on validation set
Got 26247 / 39868 correct (65.83)

Iteration 6700, loss = 0.4289
Checking accuracy on validation set
Got 26652 / 39868 correct (66.85)

Iteration 6800, loss = 0.7422
Checking accuracy on validation set
Got 26448 / 39868 correct (66.34)

Iteration 6900, loss = 0.5266
Checking accuracy on validation set
Got 26412 / 39868 correct (66.25)

Iteration 7000, loss = 0.3225
Checking accuracy on validation set
Got 26548 / 39868 correct (66.59)

Iteration 7100, loss = 0.3972
Checking accuracy on validation set
Got 26611 / 39868 correct (66.75)

Iteration 7200, loss = 0.4784
Checking accuracy on validation set
Got 26390 / 39868 correct (66.19)

Iteration 7300, loss = 0.4147
Checking accuracy on validation set
Got 26469 / 39868 correct (66.39)

Iteration 7400, loss = 0.4228
Checking accuracy on validation set
Got 26163 / 39868 correct (65.62)

Iteration 7500, loss = 0.6697
Checking accuracy on validation set
Got 26618 / 39868 correct

Got 26423 / 39868 correct (66.28)

Iteration 14700, loss = 0.3926
Checking accuracy on validation set
Got 25821 / 39868 correct (64.77)

Iteration 14800, loss = 0.4718
Checking accuracy on validation set
Got 26661 / 39868 correct (66.87)

Iteration 14900, loss = 0.5585
Checking accuracy on validation set
Got 26911 / 39868 correct (67.50)

Iteration 15000, loss = 0.2609
Checking accuracy on validation set
Got 26847 / 39868 correct (67.34)

Iteration 15100, loss = 0.3921
Checking accuracy on validation set
Got 26572 / 39868 correct (66.65)

Iteration 15200, loss = 0.7567
Checking accuracy on validation set
Got 26771 / 39868 correct (67.15)

Iteration 15300, loss = 0.4177
Checking accuracy on validation set
Got 26599 / 39868 correct (66.72)

Iteration 15400, loss = 0.6314
Checking accuracy on validation set
Got 26528 / 39868 correct (66.54)

Iteration 15500, loss = 0.5457
Checking accuracy on validation set
Got 26818 / 39868 correct (67.27)

Iteration 15600, loss = 0.3639
Checking accurac

Got 20281 / 39868 correct (50.87)

Iteration 4800, loss = 0.6624
Checking accuracy on validation set
Got 20907 / 39868 correct (52.44)

Iteration 4900, loss = 0.6365
Checking accuracy on validation set
Got 20565 / 39868 correct (51.58)

Iteration 5000, loss = 0.7051
Checking accuracy on validation set
Got 20029 / 39868 correct (50.24)

Iteration 5100, loss = 0.6618
Checking accuracy on validation set
Got 20167 / 39868 correct (50.58)

Iteration 5200, loss = 0.9230
Checking accuracy on validation set
Got 20059 / 39868 correct (50.31)

Iteration 5300, loss = 0.6436
Checking accuracy on validation set
Got 20176 / 39868 correct (50.61)

Iteration 5400, loss = 0.6952
Checking accuracy on validation set
Got 20230 / 39868 correct (50.74)

Iteration 5500, loss = 0.7705
Checking accuracy on validation set
Got 21109 / 39868 correct (52.95)

Iteration 5600, loss = 0.6011
Checking accuracy on validation set
Got 20645 / 39868 correct (51.78)

Iteration 5700, loss = 0.6756
Checking accuracy on valid

Iteration 3900, loss = 0.6306
Checking accuracy on validation set
Got 20167 / 39868 correct (50.58)

Iteration 4000, loss = 0.8375
Checking accuracy on validation set
Got 20235 / 39868 correct (50.75)

Iteration 4100, loss = 0.6043
Checking accuracy on validation set
Got 20016 / 39868 correct (50.21)

Iteration 4200, loss = 0.6985
Checking accuracy on validation set
Got 20340 / 39868 correct (51.02)

Iteration 4300, loss = 0.6673
Checking accuracy on validation set
Got 20597 / 39868 correct (51.66)

Iteration 4400, loss = 0.6749
Checking accuracy on validation set
Got 20770 / 39868 correct (52.10)

Iteration 4500, loss = 0.6747
Checking accuracy on validation set
Got 20240 / 39868 correct (50.77)

Iteration 4600, loss = 0.6619
Checking accuracy on validation set
Got 20626 / 39868 correct (51.74)

Iteration 4700, loss = 0.6913
Checking accuracy on validation set
Got 20867 / 39868 correct (52.34)

Iteration 4800, loss = 0.6918
Checking accuracy on validation set
Got 20567 / 39868 correct

Got 23350 / 39868 correct (58.57)

Iteration 3100, loss = 0.7801
Checking accuracy on validation set
Got 22950 / 39868 correct (57.56)

Iteration 3200, loss = 0.5043
Checking accuracy on validation set
Got 22901 / 39868 correct (57.44)

Iteration 3300, loss = 0.3085
Checking accuracy on validation set
Got 22839 / 39868 correct (57.29)

Iteration 3400, loss = 0.5103
Checking accuracy on validation set
Got 22916 / 39868 correct (57.48)

Iteration 3500, loss = 0.5937
Checking accuracy on validation set


In [12]:
state_dict = torch.load()

Iteration 0, loss = 1.1569
Checking accuracy on validation set
Got 19769 / 39868 correct (49.59)

Iteration 100, loss = 0.8040
Checking accuracy on validation set
Got 19896 / 39868 correct (49.90)

Iteration 200, loss = 0.6957
Checking accuracy on validation set
Got 19963 / 39868 correct (50.07)

Iteration 300, loss = 0.7002
Checking accuracy on validation set
Got 19961 / 39868 correct (50.07)

Iteration 400, loss = 0.9732
Checking accuracy on validation set
Got 19910 / 39868 correct (49.94)

Iteration 500, loss = 0.6898
Checking accuracy on validation set
Got 19958 / 39868 correct (50.06)

Iteration 600, loss = 0.6917
Checking accuracy on validation set
Got 19958 / 39868 correct (50.06)

Iteration 700, loss = 0.7093
Checking accuracy on validation set
Got 19958 / 39868 correct (50.06)

Iteration 800, loss = 0.6994
Checking accuracy on validation set
Got 19911 / 39868 correct (49.94)

Iteration 900, loss = 0.7499
Checking accuracy on validation set
Got 19911 / 39868 correct (49.94)

It

In [13]:
best_model = model
check_accuracy(loader_test, best_model)

Checking accuracy on test set
Got 8020 / 16006 correct (50.11)


In [115]:
def count_parameters_in_MB(model):
  return np.sum([v.size() for name, v in model.named_parameters() if "auxiliary" not in name])

In [116]:
count_parameters_in_MB(model)

  return ufunc.reduce(obj, axis, dtype, out, **passkwargs)


torch.Size([100, 1, 3, 3, 100, 100, 100, 3, 3, 100, 300, 100, 2, 3, 300, 300, 300, 1, 7, 300, 100, 300, 1, 3, 100, 100, 100, 1, 3, 100, 6144, 1900, 6144, 2, 6144, 2])