In [1]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
import pandas as pd
from numpy import random

In [2]:
# Prepare data for training, validation and testing
TRAIN_FILE = 'D:\Study\Ostfold\MachineLearning\git\data\ohenc_data_colNames.train';
VAL_FILE = 'D:\Study\Ostfold\MachineLearning\git\data\ohenc_data_colNames.val';
TEST_FILE = 'D:\Study\Ostfold\MachineLearning\git\data\ohenc_data_colNames.test';

# use one of 2 labels
redundant_label = 'outcome<50K'
label_name = 'outcome>50K'

# training data
train = pd.read_table(TRAIN_FILE, sep=' ')
train.pop(redundant_label)
train_x, train_y = train, train.pop(label_name)

# validation data
val = pd.read_table(VAL_FILE, sep=' ')
val.pop(redundant_label)
val_x, val_y = val, val.pop(label_name)

# testing data
test = pd.read_table(TEST_FILE, sep=' ')
test.pop(redundant_label)
test_x, test_y = test, test.pop(label_name)

display(train_x.head())
display(train_y.head())

Unnamed: 0,age,workclassMissing,workclassFederal-gov,workclassLocal-gov,workclassNever-worked,workclassPrivate,workclassSelf-emp-inc,workclassSelf-emp-not-inc,workclassState-gov,workclassWithout-pay,...,native-countryPortugal,native-countryPuerto-Rico,native-countryScotland,native-countrySouth,native-countryTaiwan,native-countryThailand,native-countryTrinadad-Tobago,native-countryUnited-States,native-countryVietnam,native-countryYugoslavia
0,-1.286609,0,0,0,0,1,0,0,0,0,...,0,0,0,0,0,0,0,1,0,0
1,0.395073,0,0,0,0,0,1,0,0,0,...,0,0,0,0,0,0,0,1,0,0
2,0.02949,0,0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,1,0,0
3,-1.286609,0,0,0,0,1,0,0,0,0,...,0,0,0,0,0,0,0,1,0,0
4,0.833773,0,0,0,0,0,0,0,1,0,...,0,0,0,0,0,0,0,1,0,0


0    0
1    0
2    0
3    0
4    1
Name: outcome>50K, dtype: int64

In [3]:
#prepare loaders
train = torch.utils.data.TensorDataset(torch.from_numpy(train_x.values).type(torch.FloatTensor), torch.from_numpy(train_y.values))
train_loader = torch.utils.data.DataLoader(train, batch_size=128, shuffle=True)
train_loader_val = torch.utils.data.DataLoader(train, batch_size=10000, shuffle=True)

val = torch.utils.data.TensorDataset(torch.from_numpy(val_x.values).type(torch.FloatTensor), torch.from_numpy(val_y.values))
val_loader = torch.utils.data.DataLoader(val, batch_size=10000, shuffle=True)

test = torch.utils.data.TensorDataset(torch.from_numpy(test_x.values).type(torch.FloatTensor), torch.from_numpy(test_y.values))
test_loader = torch.utils.data.DataLoader(test, batch_size=10000, shuffle=True)

In [4]:
class Net4HiddenLayers(nn.Module):
    def __init__(self, nodes1, nodes2, nodes4, dropout):
        super(Net4HiddenLayers, self).__init__()
        self.fc1 = nn.Linear(108, nodes1)
        self.fc2 = nn.Linear(nodes1, nodes2)
        self.fc3 = nn.Linear(nodes2, nodes3)
        self.fc4 = nn.Linear(nodes3, nodes4)
        self.fc5 = nn.Linear(nodes4, 2)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p=dropout, training=self.training)
        x = F.relu(self.fc2(x))
        x = F.dropout(x, p=dropout, training=self.training)
        x = F.relu(self.fc3(x))
        x = F.dropout(x, p=dropout, training=self.training)
        x = F.relu(self.fc4(x))
        x = F.dropout(x, p=dropout, training=self.training)
        
        x = self.fc5(x)
        return F.log_softmax(x, dim=1)

class Net3HiddenLayers(nn.Module):
    def __init__(self, nodes1, nodes2, nodes3, dropout):
        super(Net3HiddenLayers, self).__init__()
        self.fc1 = nn.Linear(108, nodes1)
        self.fc2 = nn.Linear(nodes1, nodes2)
        self.fc3 = nn.Linear(nodes2, nodes3)
        self.fc4 = nn.Linear(nodes3, 2)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p=dropout, training=self.training)
        x = F.relu(self.fc2(x))
        x = F.dropout(x, p=dropout, training=self.training)
        x = F.relu(self.fc3(x))
        x = F.dropout(x, p=dropout, training=self.training)
        
        x = self.fc4(x)
        return F.log_softmax(x, dim=1)
    
class Net2HiddenLayers(nn.Module):
    def __init__(self, nodes1, nodes2, dropout):
        super(Net2HiddenLayers, self).__init__()
        self.fc1 = nn.Linear(108, nodes1)
        self.fc2 = nn.Linear(nodes1, nodes2)
        self.fc3 = nn.Linear(nodes2, 2)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p=dropout, training=self.training)
        x = F.relu(self.fc2(x))
        x = F.dropout(x, p=dropout, training=self.training)
        
        x = self.fc3(x)
        return F.log_softmax(x, dim=1)

class Net1HiddenLayer(nn.Module):
    def __init__(self, nodes, dropout):
        super(Net1HiddenLayer, self).__init__()
        self.fc1 = nn.Linear(108, nodes)
        self.fc2 = nn.Linear(nodes, 2)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p=dropout, training=self.training)
        
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

def train(epoch, optimizer, model, log_enable = False):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
#         if args.cuda:
#             data, target = data.cuda(), target.cuda()
        data, target = Variable(data), Variable(target)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if log_enable and (batch_idx % log_interval == 0):
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.data[0]))

def evaluate(data_loader, data_set="validation"):
    model.eval()
    test_loss = 0
    correct = 0
    for data, target in data_loader:
#         if args.cuda:
#             data, target = data.cuda(), target.cuda()
        data, target = Variable(data, volatile=True), Variable(target)
        output = model(data)
        test_loss += F.nll_loss(output, target, size_average=False).data[0] # sum up batch loss
        pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
        correct += pred.eq(target.data.view_as(pred)).long().cpu().sum()

    test_loss /= len(data_loader.dataset)
    print('{} set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)'.format(
        data_set, test_loss, correct, len(data_loader.dataset),
        100. * correct / len(data_loader.dataset)))

def train_and_eval(optimizer, model, epochs, log_enable=False):
    for epoch in range(1, epochs + 1):
        train(epoch, optimizer, model, log_enable)
        if (log_enable):
            evaluate(train_loader_val, "training")
            evaluate(val_loader)
            print("\n")
    
    evaluate(train_loader_val, "training")
    evaluate(val_loader)


In [7]:
log_interval = 1000
epochs = 100
max_count = 50
print("Using Adam optimizer") 
      
hidden_set = [2048, 1024, 512, 256, 128, 64, 32, 16] 
for count in range(max_count):
    lr = 10**random.uniform(-2, -4)
    dropout = random.uniform(0.1,0.7)
    layers = random.randint(1, 4)
    hidden_units = random.randint(1, size=layers)
    l2_reg = 10**random.uniform(-4,0)
    for i in range(layers):
        hidden_units[i] = hidden_set[random.randint(0,8)]
        
    hidden_units = sorted(hidden_units, reverse=True)
    
    torch.manual_seed(1234)
    print("{}, hidden units{}, lr {}, dropout {}, l2_reg {}".format(
        count, hidden_units, lr, dropout, l2_reg))
    
    if layers == 1:
        model = Net1HiddenLayer(hidden_units[0].item(), dropout)
    elif layers == 2:
        model = Net2HiddenLayers(hidden_units[0].item(), hidden_units[1].item(), 
                                 dropout)
    elif layers == 3:
        model = Net3HiddenLayers(hidden_units[0].item(), hidden_units[1].item(), 
                                 hidden_units[2].item(), dropout)
    elif layers == 4:
        model = Net4HiddenLayers(hidden_units[0].item(), hidden_units[1].item(), 
                                 hidden_units[2].item(), hidden_units[3].item(), 
                                 dropout)

    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=l2_reg)
    train_and_eval(optimizer, model, epochs)
    

Using Adam optimizer
0, hidden units[512, 128, 16], lr 0.008817740148838987, dropout 0.19292064211820775, l2_reg 0.06526312845736566
training set: Average loss: 0.3870, Accuracy: 21584/26048 (82.86%)
validation set: Average loss: 0.3935, Accuracy: 5348/6513 (82.11%)
1, hidden units[512, 32], lr 0.0002197650292535493, dropout 0.12708883969802606, l2_reg 0.00019603703649661065
training set: Average loss: 0.1949, Accuracy: 23778/26048 (91.29%)
validation set: Average loss: 0.3864, Accuracy: 5465/6513 (83.91%)
2, hidden units[512, 512], lr 0.0009722409087384057, dropout 0.6595924273002486, l2_reg 0.00018757377382423246
training set: Average loss: 0.2443, Accuracy: 23106/26048 (88.71%)
validation set: Average loss: 0.3414, Accuracy: 5529/6513 (84.89%)
3, hidden units[512, 256, 128], lr 0.0003445814894676429, dropout 0.6628926041214146, l2_reg 0.00039927980952424554
training set: Average loss: 0.2374, Accuracy: 23268/26048 (89.33%)
validation set: Average loss: 0.3435, Accuracy: 5506/6513 (8

training set: Average loss: 0.5916, Accuracy: 19792/26048 (75.98%)
validation set: Average loss: 0.5932, Accuracy: 4928/6513 (75.66%)
35, hidden units[2048], lr 0.0003249533170129262, dropout 0.381941374409544, l2_reg 0.9372571544652458
training set: Average loss: 0.6152, Accuracy: 19792/26048 (75.98%)
validation set: Average loss: 0.6164, Accuracy: 4928/6513 (75.66%)
36, hidden units[256, 16], lr 0.0005717946142694649, dropout 0.46523891276893126, l2_reg 0.001821698869940675
training set: Average loss: 0.2891, Accuracy: 22689/26048 (87.10%)
validation set: Average loss: 0.3248, Accuracy: 5553/6513 (85.26%)
37, hidden units[1024, 256], lr 0.00525813800072652, dropout 0.6744389032918569, l2_reg 0.7035084210643078
training set: Average loss: 0.6037, Accuracy: 19792/26048 (75.98%)
validation set: Average loss: 0.6051, Accuracy: 4928/6513 (75.66%)
38, hidden units[512, 128, 32], lr 0.006296566518349741, dropout 0.24877006707539231, l2_reg 0.1588067475169318
training set: Average loss: 0.56

In [12]:
log_interval = 1000
epochs = 100
max_count = 30
print("Using Adam optimizer Finer search") 
      
hidden_set = [2048, 1024, 512, 256, 128, 64, 32, 16] 
for count in range(max_count):
    lr = 10**random.uniform(-2, -4)
    dropout = random.uniform(0.1,0.7)
    layers = random.randint(1, 4)
    hidden_units = [256, 16]
    l2_reg = 10**random.uniform(-4,0)
    
    print("{}, hidden units{}, lr {}, dropout {}, l2_reg {}".format(count, hidden_units, lr, dropout, l2_reg))
    
    torch.manual_seed(1234)
    model = Net2HiddenLayers(hidden_units[0], hidden_units[1], dropout)

    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=l2_reg)
    train_and_eval(optimizer, model, epochs)

Using Adam optimizer Finer search
0, hidden units[256, 16], lr 0.005834189526635833, dropout 0.10617107901779305, l2_reg 0.45795014711675003
training set: Average loss: 0.5900, Accuracy: 19792/26048 (75.98%)
validation set: Average loss: 0.5918, Accuracy: 4928/6513 (75.66%)
1, hidden units[256, 16], lr 0.007561882272951254, dropout 0.5915709897889455, l2_reg 0.006874029929299779
training set: Average loss: 0.3434, Accuracy: 21894/26048 (84.05%)
validation set: Average loss: 0.3606, Accuracy: 5415/6513 (83.14%)
2, hidden units[256, 16], lr 0.007372419088753163, dropout 0.14650047699932317, l2_reg 0.0004953537404701623
training set: Average loss: 0.2987, Accuracy: 22430/26048 (86.11%)
validation set: Average loss: 0.3322, Accuracy: 5512/6513 (84.63%)
3, hidden units[256, 16], lr 0.0008110457017224536, dropout 0.1989028782843575, l2_reg 0.0009841837284941307
training set: Average loss: 0.2645, Accuracy: 22949/26048 (88.10%)
validation set: Average loss: 0.3308, Accuracy: 5525/6513 (84.83%

In [11]:
log_interval = 1000
epochs = 1000
max_count = 1
print("Using Adam optimizer Finer search") 
      
hidden_set = [2048, 1024, 512, 256, 128, 64, 32, 16] 
for count in range(max_count):
    lr = 0.0005717946142694649
    dropout = 0.46523891276893126
    hidden_units = [256, 16]
    l2_reg = 0.001821698869940675
    
    print("{}, hidden units{}, lr {}, dropout {}, l2_reg {}".format(count, hidden_units, lr, dropout, l2_reg))
    
    torch.manual_seed(1234)
    model = Net2HiddenLayers(hidden_units[0], hidden_units[1], dropout)

    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=l2_reg)
    train_and_eval(optimizer, model, epochs, True)

Using Adam optimizer Finer search
0, hidden units[256, 16], lr 0.0005717946142694649, dropout 0.46523891276893126, l2_reg 0.001821698869940675
training set: Average loss: 0.3229, Accuracy: 22199/26048 (85.22%)
validation set: Average loss: 0.3389, Accuracy: 5467/6513 (83.94%)


training set: Average loss: 0.3095, Accuracy: 22311/26048 (85.65%)
validation set: Average loss: 0.3272, Accuracy: 5503/6513 (84.49%)


training set: Average loss: 0.3055, Accuracy: 22397/26048 (85.98%)
validation set: Average loss: 0.3247, Accuracy: 5523/6513 (84.80%)


training set: Average loss: 0.3034, Accuracy: 22418/26048 (86.06%)
validation set: Average loss: 0.3242, Accuracy: 5520/6513 (84.75%)


training set: Average loss: 0.3042, Accuracy: 22413/26048 (86.04%)
validation set: Average loss: 0.3259, Accuracy: 5516/6513 (84.69%)


training set: Average loss: 0.3016, Accuracy: 22475/26048 (86.28%)
validation set: Average loss: 0.3243, Accuracy: 5534/6513 (84.97%)


training set: Average loss: 0.3002, Accur

training set: Average loss: 0.2906, Accuracy: 22588/26048 (86.72%)
validation set: Average loss: 0.3233, Accuracy: 5540/6513 (85.06%)


training set: Average loss: 0.2924, Accuracy: 22606/26048 (86.79%)
validation set: Average loss: 0.3238, Accuracy: 5535/6513 (84.98%)


training set: Average loss: 0.2914, Accuracy: 22619/26048 (86.84%)
validation set: Average loss: 0.3237, Accuracy: 5552/6513 (85.24%)


training set: Average loss: 0.2917, Accuracy: 22595/26048 (86.74%)
validation set: Average loss: 0.3231, Accuracy: 5543/6513 (85.11%)


training set: Average loss: 0.2911, Accuracy: 22636/26048 (86.90%)
validation set: Average loss: 0.3232, Accuracy: 5548/6513 (85.18%)


training set: Average loss: 0.2918, Accuracy: 22576/26048 (86.67%)
validation set: Average loss: 0.3237, Accuracy: 5532/6513 (84.94%)


training set: Average loss: 0.2934, Accuracy: 22658/26048 (86.99%)
validation set: Average loss: 0.3237, Accuracy: 5546/6513 (85.15%)


training set: Average loss: 0.2930, Accuracy: 22

training set: Average loss: 0.2893, Accuracy: 22642/26048 (86.92%)
validation set: Average loss: 0.3237, Accuracy: 5549/6513 (85.20%)


training set: Average loss: 0.2897, Accuracy: 22647/26048 (86.94%)
validation set: Average loss: 0.3219, Accuracy: 5550/6513 (85.21%)


training set: Average loss: 0.2885, Accuracy: 22671/26048 (87.04%)
validation set: Average loss: 0.3236, Accuracy: 5553/6513 (85.26%)


training set: Average loss: 0.2894, Accuracy: 22642/26048 (86.92%)
validation set: Average loss: 0.3240, Accuracy: 5563/6513 (85.41%)


training set: Average loss: 0.2890, Accuracy: 22670/26048 (87.03%)
validation set: Average loss: 0.3242, Accuracy: 5556/6513 (85.31%)


training set: Average loss: 0.2897, Accuracy: 22647/26048 (86.94%)
validation set: Average loss: 0.3234, Accuracy: 5540/6513 (85.06%)


training set: Average loss: 0.2895, Accuracy: 22643/26048 (86.93%)
validation set: Average loss: 0.3243, Accuracy: 5545/6513 (85.14%)


training set: Average loss: 0.2890, Accuracy: 22

training set: Average loss: 0.2884, Accuracy: 22690/26048 (87.11%)
validation set: Average loss: 0.3232, Accuracy: 5543/6513 (85.11%)


training set: Average loss: 0.2896, Accuracy: 22651/26048 (86.96%)
validation set: Average loss: 0.3228, Accuracy: 5549/6513 (85.20%)


training set: Average loss: 0.2885, Accuracy: 22666/26048 (87.02%)
validation set: Average loss: 0.3218, Accuracy: 5547/6513 (85.17%)


training set: Average loss: 0.2890, Accuracy: 22648/26048 (86.95%)
validation set: Average loss: 0.3227, Accuracy: 5540/6513 (85.06%)


training set: Average loss: 0.2897, Accuracy: 22657/26048 (86.98%)
validation set: Average loss: 0.3255, Accuracy: 5532/6513 (84.94%)


training set: Average loss: 0.2897, Accuracy: 22697/26048 (87.14%)
validation set: Average loss: 0.3230, Accuracy: 5544/6513 (85.12%)


training set: Average loss: 0.2898, Accuracy: 22631/26048 (86.88%)
validation set: Average loss: 0.3250, Accuracy: 5554/6513 (85.28%)


training set: Average loss: 0.2894, Accuracy: 22

training set: Average loss: 0.2878, Accuracy: 22716/26048 (87.21%)
validation set: Average loss: 0.3229, Accuracy: 5542/6513 (85.09%)


training set: Average loss: 0.2895, Accuracy: 22650/26048 (86.95%)
validation set: Average loss: 0.3237, Accuracy: 5538/6513 (85.03%)


training set: Average loss: 0.2889, Accuracy: 22639/26048 (86.91%)
validation set: Average loss: 0.3244, Accuracy: 5530/6513 (84.91%)


training set: Average loss: 0.2887, Accuracy: 22615/26048 (86.82%)
validation set: Average loss: 0.3244, Accuracy: 5538/6513 (85.03%)


training set: Average loss: 0.2877, Accuracy: 22691/26048 (87.11%)
validation set: Average loss: 0.3246, Accuracy: 5553/6513 (85.26%)


training set: Average loss: 0.2893, Accuracy: 22658/26048 (86.99%)
validation set: Average loss: 0.3223, Accuracy: 5549/6513 (85.20%)


training set: Average loss: 0.2889, Accuracy: 22679/26048 (87.07%)
validation set: Average loss: 0.3245, Accuracy: 5546/6513 (85.15%)


training set: Average loss: 0.2882, Accuracy: 22

training set: Average loss: 0.2875, Accuracy: 22684/26048 (87.09%)
validation set: Average loss: 0.3230, Accuracy: 5555/6513 (85.29%)


training set: Average loss: 0.2877, Accuracy: 22699/26048 (87.14%)
validation set: Average loss: 0.3230, Accuracy: 5557/6513 (85.32%)


training set: Average loss: 0.2885, Accuracy: 22638/26048 (86.91%)
validation set: Average loss: 0.3235, Accuracy: 5534/6513 (84.97%)


training set: Average loss: 0.2889, Accuracy: 22665/26048 (87.01%)
validation set: Average loss: 0.3249, Accuracy: 5544/6513 (85.12%)


training set: Average loss: 0.2882, Accuracy: 22647/26048 (86.94%)
validation set: Average loss: 0.3236, Accuracy: 5534/6513 (84.97%)


training set: Average loss: 0.2879, Accuracy: 22678/26048 (87.06%)
validation set: Average loss: 0.3227, Accuracy: 5553/6513 (85.26%)


training set: Average loss: 0.2891, Accuracy: 22643/26048 (86.93%)
validation set: Average loss: 0.3231, Accuracy: 5538/6513 (85.03%)


training set: Average loss: 0.2907, Accuracy: 22

training set: Average loss: 0.2880, Accuracy: 22682/26048 (87.08%)
validation set: Average loss: 0.3234, Accuracy: 5539/6513 (85.05%)


training set: Average loss: 0.2876, Accuracy: 22684/26048 (87.09%)
validation set: Average loss: 0.3238, Accuracy: 5538/6513 (85.03%)


training set: Average loss: 0.2875, Accuracy: 22667/26048 (87.02%)
validation set: Average loss: 0.3245, Accuracy: 5544/6513 (85.12%)


training set: Average loss: 0.2885, Accuracy: 22648/26048 (86.95%)
validation set: Average loss: 0.3244, Accuracy: 5554/6513 (85.28%)


training set: Average loss: 0.2880, Accuracy: 22682/26048 (87.08%)
validation set: Average loss: 0.3255, Accuracy: 5561/6513 (85.38%)


training set: Average loss: 0.2890, Accuracy: 22654/26048 (86.97%)
validation set: Average loss: 0.3241, Accuracy: 5540/6513 (85.06%)


training set: Average loss: 0.2877, Accuracy: 22704/26048 (87.16%)
validation set: Average loss: 0.3228, Accuracy: 5563/6513 (85.41%)


training set: Average loss: 0.2885, Accuracy: 22

training set: Average loss: 0.2885, Accuracy: 22698/26048 (87.14%)
validation set: Average loss: 0.3235, Accuracy: 5547/6513 (85.17%)


training set: Average loss: 0.2887, Accuracy: 22662/26048 (87.00%)
validation set: Average loss: 0.3230, Accuracy: 5543/6513 (85.11%)


training set: Average loss: 0.2911, Accuracy: 22592/26048 (86.73%)
validation set: Average loss: 0.3280, Accuracy: 5541/6513 (85.08%)


training set: Average loss: 0.2885, Accuracy: 22709/26048 (87.18%)
validation set: Average loss: 0.3230, Accuracy: 5560/6513 (85.37%)


training set: Average loss: 0.2877, Accuracy: 22660/26048 (86.99%)
validation set: Average loss: 0.3230, Accuracy: 5556/6513 (85.31%)


training set: Average loss: 0.2884, Accuracy: 22695/26048 (87.13%)
validation set: Average loss: 0.3230, Accuracy: 5546/6513 (85.15%)


training set: Average loss: 0.2878, Accuracy: 22651/26048 (86.96%)
validation set: Average loss: 0.3233, Accuracy: 5540/6513 (85.06%)


training set: Average loss: 0.2879, Accuracy: 22

training set: Average loss: 0.2884, Accuracy: 22712/26048 (87.19%)
validation set: Average loss: 0.3229, Accuracy: 5550/6513 (85.21%)


training set: Average loss: 0.2879, Accuracy: 22661/26048 (87.00%)
validation set: Average loss: 0.3250, Accuracy: 5554/6513 (85.28%)


training set: Average loss: 0.2898, Accuracy: 22606/26048 (86.79%)
validation set: Average loss: 0.3272, Accuracy: 5527/6513 (84.86%)


training set: Average loss: 0.2905, Accuracy: 22675/26048 (87.05%)
validation set: Average loss: 0.3252, Accuracy: 5536/6513 (85.00%)


training set: Average loss: 0.2891, Accuracy: 22618/26048 (86.83%)
validation set: Average loss: 0.3251, Accuracy: 5535/6513 (84.98%)


training set: Average loss: 0.2884, Accuracy: 22663/26048 (87.00%)
validation set: Average loss: 0.3241, Accuracy: 5550/6513 (85.21%)


training set: Average loss: 0.2871, Accuracy: 22696/26048 (87.13%)
validation set: Average loss: 0.3228, Accuracy: 5559/6513 (85.35%)


training set: Average loss: 0.2898, Accuracy: 22

training set: Average loss: 0.2893, Accuracy: 22720/26048 (87.22%)
validation set: Average loss: 0.3223, Accuracy: 5550/6513 (85.21%)


training set: Average loss: 0.2879, Accuracy: 22661/26048 (87.00%)
validation set: Average loss: 0.3244, Accuracy: 5553/6513 (85.26%)


training set: Average loss: 0.2881, Accuracy: 22709/26048 (87.18%)
validation set: Average loss: 0.3244, Accuracy: 5550/6513 (85.21%)


training set: Average loss: 0.2889, Accuracy: 22627/26048 (86.87%)
validation set: Average loss: 0.3250, Accuracy: 5539/6513 (85.05%)


training set: Average loss: 0.2883, Accuracy: 22629/26048 (86.87%)
validation set: Average loss: 0.3241, Accuracy: 5534/6513 (84.97%)


training set: Average loss: 0.2872, Accuracy: 22692/26048 (87.12%)
validation set: Average loss: 0.3238, Accuracy: 5544/6513 (85.12%)


training set: Average loss: 0.2889, Accuracy: 22697/26048 (87.14%)
validation set: Average loss: 0.3250, Accuracy: 5533/6513 (84.95%)


training set: Average loss: 0.2868, Accuracy: 22

training set: Average loss: 0.2879, Accuracy: 22707/26048 (87.17%)
validation set: Average loss: 0.3236, Accuracy: 5547/6513 (85.17%)


training set: Average loss: 0.2882, Accuracy: 22676/26048 (87.05%)
validation set: Average loss: 0.3233, Accuracy: 5536/6513 (85.00%)


training set: Average loss: 0.2875, Accuracy: 22699/26048 (87.14%)
validation set: Average loss: 0.3245, Accuracy: 5547/6513 (85.17%)


training set: Average loss: 0.2879, Accuracy: 22665/26048 (87.01%)
validation set: Average loss: 0.3249, Accuracy: 5551/6513 (85.23%)


training set: Average loss: 0.2880, Accuracy: 22694/26048 (87.12%)
validation set: Average loss: 0.3230, Accuracy: 5543/6513 (85.11%)


training set: Average loss: 0.2881, Accuracy: 22658/26048 (86.99%)
validation set: Average loss: 0.3257, Accuracy: 5549/6513 (85.20%)


training set: Average loss: 0.2896, Accuracy: 22639/26048 (86.91%)
validation set: Average loss: 0.3245, Accuracy: 5526/6513 (84.85%)


training set: Average loss: 0.2885, Accuracy: 22

training set: Average loss: 0.2900, Accuracy: 22627/26048 (86.87%)
validation set: Average loss: 0.3267, Accuracy: 5533/6513 (84.95%)


training set: Average loss: 0.2872, Accuracy: 22684/26048 (87.09%)
validation set: Average loss: 0.3234, Accuracy: 5543/6513 (85.11%)


training set: Average loss: 0.2876, Accuracy: 22675/26048 (87.05%)
validation set: Average loss: 0.3238, Accuracy: 5541/6513 (85.08%)


training set: Average loss: 0.2879, Accuracy: 22674/26048 (87.05%)
validation set: Average loss: 0.3230, Accuracy: 5551/6513 (85.23%)


training set: Average loss: 0.2892, Accuracy: 22667/26048 (87.02%)
validation set: Average loss: 0.3241, Accuracy: 5536/6513 (85.00%)


training set: Average loss: 0.2891, Accuracy: 22597/26048 (86.75%)
validation set: Average loss: 0.3255, Accuracy: 5547/6513 (85.17%)


training set: Average loss: 0.2877, Accuracy: 22672/26048 (87.04%)
validation set: Average loss: 0.3237, Accuracy: 5550/6513 (85.21%)


training set: Average loss: 0.2876, Accuracy: 22

training set: Average loss: 0.2893, Accuracy: 22653/26048 (86.97%)
validation set: Average loss: 0.3227, Accuracy: 5543/6513 (85.11%)


training set: Average loss: 0.2880, Accuracy: 22645/26048 (86.94%)
validation set: Average loss: 0.3236, Accuracy: 5553/6513 (85.26%)


training set: Average loss: 0.2880, Accuracy: 22667/26048 (87.02%)
validation set: Average loss: 0.3229, Accuracy: 5545/6513 (85.14%)


training set: Average loss: 0.2898, Accuracy: 22705/26048 (87.17%)
validation set: Average loss: 0.3223, Accuracy: 5539/6513 (85.05%)


training set: Average loss: 0.2881, Accuracy: 22700/26048 (87.15%)
validation set: Average loss: 0.3227, Accuracy: 5551/6513 (85.23%)


training set: Average loss: 0.2878, Accuracy: 22631/26048 (86.88%)
validation set: Average loss: 0.3247, Accuracy: 5530/6513 (84.91%)


training set: Average loss: 0.2880, Accuracy: 22675/26048 (87.05%)
validation set: Average loss: 0.3228, Accuracy: 5546/6513 (85.15%)


training set: Average loss: 0.2878, Accuracy: 22

training set: Average loss: 0.2883, Accuracy: 22657/26048 (86.98%)
validation set: Average loss: 0.3248, Accuracy: 5542/6513 (85.09%)


training set: Average loss: 0.2875, Accuracy: 22694/26048 (87.12%)
validation set: Average loss: 0.3230, Accuracy: 5544/6513 (85.12%)


training set: Average loss: 0.2874, Accuracy: 22722/26048 (87.23%)
validation set: Average loss: 0.3225, Accuracy: 5543/6513 (85.11%)


training set: Average loss: 0.2870, Accuracy: 22698/26048 (87.14%)
validation set: Average loss: 0.3228, Accuracy: 5548/6513 (85.18%)


training set: Average loss: 0.2875, Accuracy: 22688/26048 (87.10%)
validation set: Average loss: 0.3237, Accuracy: 5556/6513 (85.31%)


training set: Average loss: 0.2870, Accuracy: 22681/26048 (87.07%)
validation set: Average loss: 0.3234, Accuracy: 5547/6513 (85.17%)


training set: Average loss: 0.2882, Accuracy: 22680/26048 (87.07%)
validation set: Average loss: 0.3237, Accuracy: 5545/6513 (85.14%)


training set: Average loss: 0.2873, Accuracy: 22

training set: Average loss: 0.2871, Accuracy: 22703/26048 (87.16%)
validation set: Average loss: 0.3250, Accuracy: 5548/6513 (85.18%)


training set: Average loss: 0.2872, Accuracy: 22666/26048 (87.02%)
validation set: Average loss: 0.3257, Accuracy: 5530/6513 (84.91%)


training set: Average loss: 0.2879, Accuracy: 22689/26048 (87.10%)
validation set: Average loss: 0.3228, Accuracy: 5544/6513 (85.12%)


training set: Average loss: 0.2887, Accuracy: 22664/26048 (87.01%)
validation set: Average loss: 0.3232, Accuracy: 5544/6513 (85.12%)


training set: Average loss: 0.2874, Accuracy: 22659/26048 (86.99%)
validation set: Average loss: 0.3236, Accuracy: 5543/6513 (85.11%)


training set: Average loss: 0.2877, Accuracy: 22698/26048 (87.14%)
validation set: Average loss: 0.3248, Accuracy: 5547/6513 (85.17%)


training set: Average loss: 0.2874, Accuracy: 22683/26048 (87.08%)
validation set: Average loss: 0.3241, Accuracy: 5544/6513 (85.12%)


training set: Average loss: 0.2883, Accuracy: 22

training set: Average loss: 0.2869, Accuracy: 22679/26048 (87.07%)
validation set: Average loss: 0.3246, Accuracy: 5549/6513 (85.20%)


training set: Average loss: 0.2875, Accuracy: 22684/26048 (87.09%)
validation set: Average loss: 0.3222, Accuracy: 5551/6513 (85.23%)


training set: Average loss: 0.2877, Accuracy: 22673/26048 (87.04%)
validation set: Average loss: 0.3231, Accuracy: 5550/6513 (85.21%)


training set: Average loss: 0.2878, Accuracy: 22708/26048 (87.18%)
validation set: Average loss: 0.3229, Accuracy: 5562/6513 (85.40%)


training set: Average loss: 0.2887, Accuracy: 22631/26048 (86.88%)
validation set: Average loss: 0.3239, Accuracy: 5547/6513 (85.17%)


training set: Average loss: 0.2881, Accuracy: 22702/26048 (87.15%)
validation set: Average loss: 0.3228, Accuracy: 5555/6513 (85.29%)


training set: Average loss: 0.2879, Accuracy: 22661/26048 (87.00%)
validation set: Average loss: 0.3230, Accuracy: 5546/6513 (85.15%)


training set: Average loss: 0.2876, Accuracy: 22

training set: Average loss: 0.2877, Accuracy: 22707/26048 (87.17%)
validation set: Average loss: 0.3227, Accuracy: 5552/6513 (85.24%)


training set: Average loss: 0.2877, Accuracy: 22673/26048 (87.04%)
validation set: Average loss: 0.3229, Accuracy: 5551/6513 (85.23%)


training set: Average loss: 0.2871, Accuracy: 22686/26048 (87.09%)
validation set: Average loss: 0.3232, Accuracy: 5557/6513 (85.32%)


training set: Average loss: 0.2874, Accuracy: 22663/26048 (87.00%)
validation set: Average loss: 0.3227, Accuracy: 5539/6513 (85.05%)


training set: Average loss: 0.2867, Accuracy: 22685/26048 (87.09%)
validation set: Average loss: 0.3232, Accuracy: 5562/6513 (85.40%)


training set: Average loss: 0.2881, Accuracy: 22723/26048 (87.24%)
validation set: Average loss: 0.3238, Accuracy: 5552/6513 (85.24%)


training set: Average loss: 0.2880, Accuracy: 22716/26048 (87.21%)
validation set: Average loss: 0.3229, Accuracy: 5557/6513 (85.32%)


training set: Average loss: 0.2863, Accuracy: 22

training set: Average loss: 0.2878, Accuracy: 22680/26048 (87.07%)
validation set: Average loss: 0.3236, Accuracy: 5550/6513 (85.21%)


training set: Average loss: 0.2872, Accuracy: 22676/26048 (87.05%)
validation set: Average loss: 0.3235, Accuracy: 5565/6513 (85.44%)


training set: Average loss: 0.2880, Accuracy: 22628/26048 (86.87%)
validation set: Average loss: 0.3257, Accuracy: 5540/6513 (85.06%)


training set: Average loss: 0.2875, Accuracy: 22645/26048 (86.94%)
validation set: Average loss: 0.3242, Accuracy: 5531/6513 (84.92%)


training set: Average loss: 0.2883, Accuracy: 22675/26048 (87.05%)
validation set: Average loss: 0.3232, Accuracy: 5559/6513 (85.35%)


training set: Average loss: 0.2870, Accuracy: 22685/26048 (87.09%)
validation set: Average loss: 0.3232, Accuracy: 5555/6513 (85.29%)


training set: Average loss: 0.2864, Accuracy: 22692/26048 (87.12%)
validation set: Average loss: 0.3233, Accuracy: 5554/6513 (85.28%)


training set: Average loss: 0.2877, Accuracy: 22

training set: Average loss: 0.2875, Accuracy: 22692/26048 (87.12%)
validation set: Average loss: 0.3239, Accuracy: 5551/6513 (85.23%)


training set: Average loss: 0.2885, Accuracy: 22685/26048 (87.09%)
validation set: Average loss: 0.3237, Accuracy: 5536/6513 (85.00%)


training set: Average loss: 0.2877, Accuracy: 22656/26048 (86.98%)
validation set: Average loss: 0.3239, Accuracy: 5529/6513 (84.89%)


training set: Average loss: 0.2881, Accuracy: 22705/26048 (87.17%)
validation set: Average loss: 0.3245, Accuracy: 5546/6513 (85.15%)


training set: Average loss: 0.2870, Accuracy: 22693/26048 (87.12%)
validation set: Average loss: 0.3230, Accuracy: 5560/6513 (85.37%)


training set: Average loss: 0.2883, Accuracy: 22612/26048 (86.81%)
validation set: Average loss: 0.3251, Accuracy: 5536/6513 (85.00%)


training set: Average loss: 0.2880, Accuracy: 22668/26048 (87.02%)
validation set: Average loss: 0.3244, Accuracy: 5545/6513 (85.14%)


training set: Average loss: 0.2877, Accuracy: 22

training set: Average loss: 0.2864, Accuracy: 22722/26048 (87.23%)
validation set: Average loss: 0.3239, Accuracy: 5540/6513 (85.06%)


training set: Average loss: 0.2876, Accuracy: 22696/26048 (87.13%)
validation set: Average loss: 0.3255, Accuracy: 5542/6513 (85.09%)


training set: Average loss: 0.2885, Accuracy: 22630/26048 (86.88%)
validation set: Average loss: 0.3258, Accuracy: 5541/6513 (85.08%)


training set: Average loss: 0.2875, Accuracy: 22705/26048 (87.17%)
validation set: Average loss: 0.3226, Accuracy: 5552/6513 (85.24%)


training set: Average loss: 0.2879, Accuracy: 22648/26048 (86.95%)
validation set: Average loss: 0.3238, Accuracy: 5545/6513 (85.14%)


training set: Average loss: 0.2868, Accuracy: 22676/26048 (87.05%)
validation set: Average loss: 0.3246, Accuracy: 5536/6513 (85.00%)


training set: Average loss: 0.2883, Accuracy: 22691/26048 (87.11%)
validation set: Average loss: 0.3232, Accuracy: 5541/6513 (85.08%)


training set: Average loss: 0.2870, Accuracy: 22

training set: Average loss: 0.2902, Accuracy: 22644/26048 (86.93%)
validation set: Average loss: 0.3237, Accuracy: 5526/6513 (84.85%)


training set: Average loss: 0.2876, Accuracy: 22669/26048 (87.03%)
validation set: Average loss: 0.3250, Accuracy: 5540/6513 (85.06%)


training set: Average loss: 0.2878, Accuracy: 22674/26048 (87.05%)
validation set: Average loss: 0.3236, Accuracy: 5537/6513 (85.01%)


training set: Average loss: 0.2885, Accuracy: 22699/26048 (87.14%)
validation set: Average loss: 0.3237, Accuracy: 5530/6513 (84.91%)


training set: Average loss: 0.2881, Accuracy: 22696/26048 (87.13%)
validation set: Average loss: 0.3228, Accuracy: 5552/6513 (85.24%)


training set: Average loss: 0.2875, Accuracy: 22696/26048 (87.13%)
validation set: Average loss: 0.3233, Accuracy: 5540/6513 (85.06%)


training set: Average loss: 0.2884, Accuracy: 22642/26048 (86.92%)
validation set: Average loss: 0.3236, Accuracy: 5545/6513 (85.14%)


training set: Average loss: 0.2876, Accuracy: 22

training set: Average loss: 0.2875, Accuracy: 22678/26048 (87.06%)
validation set: Average loss: 0.3247, Accuracy: 5541/6513 (85.08%)


training set: Average loss: 0.2900, Accuracy: 22585/26048 (86.71%)
validation set: Average loss: 0.3248, Accuracy: 5542/6513 (85.09%)


training set: Average loss: 0.2875, Accuracy: 22709/26048 (87.18%)
validation set: Average loss: 0.3225, Accuracy: 5561/6513 (85.38%)


training set: Average loss: 0.2886, Accuracy: 22654/26048 (86.97%)
validation set: Average loss: 0.3245, Accuracy: 5556/6513 (85.31%)


training set: Average loss: 0.2877, Accuracy: 22667/26048 (87.02%)
validation set: Average loss: 0.3238, Accuracy: 5546/6513 (85.15%)


training set: Average loss: 0.2877, Accuracy: 22695/26048 (87.13%)
validation set: Average loss: 0.3228, Accuracy: 5549/6513 (85.20%)


training set: Average loss: 0.2873, Accuracy: 22689/26048 (87.10%)
validation set: Average loss: 0.3239, Accuracy: 5554/6513 (85.28%)


training set: Average loss: 0.2886, Accuracy: 22

training set: Average loss: 0.2894, Accuracy: 22657/26048 (86.98%)
validation set: Average loss: 0.3242, Accuracy: 5536/6513 (85.00%)


training set: Average loss: 0.2866, Accuracy: 22706/26048 (87.17%)
validation set: Average loss: 0.3237, Accuracy: 5551/6513 (85.23%)


training set: Average loss: 0.2873, Accuracy: 22690/26048 (87.11%)
validation set: Average loss: 0.3231, Accuracy: 5538/6513 (85.03%)


training set: Average loss: 0.2887, Accuracy: 22632/26048 (86.89%)
validation set: Average loss: 0.3234, Accuracy: 5543/6513 (85.11%)


training set: Average loss: 0.2879, Accuracy: 22656/26048 (86.98%)
validation set: Average loss: 0.3231, Accuracy: 5543/6513 (85.11%)


training set: Average loss: 0.2880, Accuracy: 22686/26048 (87.09%)
validation set: Average loss: 0.3238, Accuracy: 5539/6513 (85.05%)


training set: Average loss: 0.2878, Accuracy: 22690/26048 (87.11%)
validation set: Average loss: 0.3231, Accuracy: 5547/6513 (85.17%)


training set: Average loss: 0.2875, Accuracy: 22

In [13]:
log_interval = 1000
epochs = 215
max_count = 1
print("Using Adam optimizer Finer search") 
      
hidden_set = [2048, 1024, 512, 256, 128, 64, 32, 16] 
for count in range(max_count):
    lr = 0.0005717946142694649
    dropout = 0.46523891276893126
    hidden_units = [256, 16]
    l2_reg = 0.001821698869940675
    
    print("{}, hidden units{}, lr {}, dropout {}, l2_reg {}".format(count, hidden_units, lr, dropout, l2_reg))
    
    torch.manual_seed(1234)
    model = Net2HiddenLayers(hidden_units[0], hidden_units[1], dropout)

    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=l2_reg)
    train_and_eval(optimizer, model, epochs, True)

Using Adam optimizer Finer search
0, hidden units[256, 16], lr 0.0005717946142694649, dropout 0.46523891276893126, l2_reg 0.001821698869940675
training set: Average loss: 0.3229, Accuracy: 22199/26048 (85.22%)
validation set: Average loss: 0.3389, Accuracy: 5467/6513 (83.94%)


training set: Average loss: 0.3095, Accuracy: 22311/26048 (85.65%)
validation set: Average loss: 0.3272, Accuracy: 5503/6513 (84.49%)


training set: Average loss: 0.3055, Accuracy: 22397/26048 (85.98%)
validation set: Average loss: 0.3247, Accuracy: 5523/6513 (84.80%)


training set: Average loss: 0.3034, Accuracy: 22418/26048 (86.06%)
validation set: Average loss: 0.3242, Accuracy: 5520/6513 (84.75%)


training set: Average loss: 0.3042, Accuracy: 22413/26048 (86.04%)
validation set: Average loss: 0.3259, Accuracy: 5516/6513 (84.69%)


training set: Average loss: 0.3016, Accuracy: 22475/26048 (86.28%)
validation set: Average loss: 0.3243, Accuracy: 5534/6513 (84.97%)


training set: Average loss: 0.3002, Accur

training set: Average loss: 0.2906, Accuracy: 22588/26048 (86.72%)
validation set: Average loss: 0.3233, Accuracy: 5540/6513 (85.06%)


training set: Average loss: 0.2924, Accuracy: 22606/26048 (86.79%)
validation set: Average loss: 0.3238, Accuracy: 5535/6513 (84.98%)


training set: Average loss: 0.2914, Accuracy: 22619/26048 (86.84%)
validation set: Average loss: 0.3237, Accuracy: 5552/6513 (85.24%)


training set: Average loss: 0.2917, Accuracy: 22595/26048 (86.74%)
validation set: Average loss: 0.3231, Accuracy: 5543/6513 (85.11%)


training set: Average loss: 0.2911, Accuracy: 22636/26048 (86.90%)
validation set: Average loss: 0.3232, Accuracy: 5548/6513 (85.18%)


training set: Average loss: 0.2918, Accuracy: 22576/26048 (86.67%)
validation set: Average loss: 0.3237, Accuracy: 5532/6513 (84.94%)


training set: Average loss: 0.2934, Accuracy: 22658/26048 (86.99%)
validation set: Average loss: 0.3237, Accuracy: 5546/6513 (85.15%)


training set: Average loss: 0.2930, Accuracy: 22

training set: Average loss: 0.2893, Accuracy: 22642/26048 (86.92%)
validation set: Average loss: 0.3237, Accuracy: 5549/6513 (85.20%)


training set: Average loss: 0.2897, Accuracy: 22647/26048 (86.94%)
validation set: Average loss: 0.3219, Accuracy: 5550/6513 (85.21%)


training set: Average loss: 0.2885, Accuracy: 22671/26048 (87.04%)
validation set: Average loss: 0.3236, Accuracy: 5553/6513 (85.26%)


training set: Average loss: 0.2894, Accuracy: 22642/26048 (86.92%)
validation set: Average loss: 0.3240, Accuracy: 5563/6513 (85.41%)


training set: Average loss: 0.2890, Accuracy: 22670/26048 (87.03%)
validation set: Average loss: 0.3242, Accuracy: 5556/6513 (85.31%)


training set: Average loss: 0.2897, Accuracy: 22647/26048 (86.94%)
validation set: Average loss: 0.3234, Accuracy: 5540/6513 (85.06%)


training set: Average loss: 0.2895, Accuracy: 22643/26048 (86.93%)
validation set: Average loss: 0.3243, Accuracy: 5545/6513 (85.14%)


training set: Average loss: 0.2890, Accuracy: 22

training set: Average loss: 0.2884, Accuracy: 22690/26048 (87.11%)
validation set: Average loss: 0.3232, Accuracy: 5543/6513 (85.11%)


training set: Average loss: 0.2896, Accuracy: 22651/26048 (86.96%)
validation set: Average loss: 0.3228, Accuracy: 5549/6513 (85.20%)


training set: Average loss: 0.2885, Accuracy: 22666/26048 (87.02%)
validation set: Average loss: 0.3218, Accuracy: 5547/6513 (85.17%)


training set: Average loss: 0.2890, Accuracy: 22648/26048 (86.95%)
validation set: Average loss: 0.3227, Accuracy: 5540/6513 (85.06%)


training set: Average loss: 0.2897, Accuracy: 22657/26048 (86.98%)
validation set: Average loss: 0.3255, Accuracy: 5532/6513 (84.94%)


training set: Average loss: 0.2897, Accuracy: 22697/26048 (87.14%)
validation set: Average loss: 0.3230, Accuracy: 5544/6513 (85.12%)


training set: Average loss: 0.2898, Accuracy: 22631/26048 (86.88%)
validation set: Average loss: 0.3250, Accuracy: 5554/6513 (85.28%)


training set: Average loss: 0.2894, Accuracy: 22

training set: Average loss: 0.2878, Accuracy: 22716/26048 (87.21%)
validation set: Average loss: 0.3229, Accuracy: 5542/6513 (85.09%)


training set: Average loss: 0.2895, Accuracy: 22650/26048 (86.95%)
validation set: Average loss: 0.3237, Accuracy: 5538/6513 (85.03%)


training set: Average loss: 0.2889, Accuracy: 22639/26048 (86.91%)
validation set: Average loss: 0.3244, Accuracy: 5530/6513 (84.91%)


training set: Average loss: 0.2887, Accuracy: 22615/26048 (86.82%)
validation set: Average loss: 0.3244, Accuracy: 5538/6513 (85.03%)


training set: Average loss: 0.2877, Accuracy: 22691/26048 (87.11%)
validation set: Average loss: 0.3246, Accuracy: 5553/6513 (85.26%)


training set: Average loss: 0.2893, Accuracy: 22658/26048 (86.99%)
validation set: Average loss: 0.3223, Accuracy: 5549/6513 (85.20%)


training set: Average loss: 0.2889, Accuracy: 22679/26048 (87.07%)
validation set: Average loss: 0.3245, Accuracy: 5546/6513 (85.15%)


training set: Average loss: 0.2882, Accuracy: 22

In [14]:
evaluate(test_loader, "test")

test set: Average loss: 0.3108, Accuracy: 13940/16281 (85.62%)
