In [10]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import sklearn.metrics as metrics
import numpy as np
import torchvision
from torchvision import datasets, transforms
from kuzu import NetLin, NetFull, NetConv

def train(args, model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()        
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

def test(args, model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    conf_matrix = np.zeros((10,10)) # initialize confusion matrix
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            # sum up batch loss
            test_loss += F.nll_loss(output, target, reduction='sum').item()
            # get the index of the max log-probability
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            conf_matrix = conf_matrix + metrics.confusion_matrix(
                          pred.cpu(), target.cpu(), labels=[0,1,2,3,4,5,6,7,8,9])
        np.set_printoptions(precision=4, suppress=True)
        print(type(conf_matrix))
        print(conf_matrix)

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))




In [162]:

'''
loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=8,
    num_workers=1,
    shuffle=True
)
print(loader)
print('dito',dataset.targets)
for batch_idx, (data, target) in enumerate(loader):
    print(batch_idx)
    print(data.shape)
    print(target)
    break
'''

# Fetch and load the test data
#testset = datasets.KMNIST(root='./Pokemon Dataset', train=False, download=False, transform=transform)
#test_loader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)

"\nloader = torch.utils.data.DataLoader(\n    dataset,\n    batch_size=8,\n    num_workers=1,\n    shuffle=True\n)\nprint(loader)\nprint('dito',dataset.targets)\nfor batch_idx, (data, target) in enumerate(loader):\n    print(batch_idx)\n    print(data.shape)\n    print(target)\n    break\n"

In [170]:
import codecs
csvReader = csv.reader(codecs.open('pokemon.csv', 'rU', 'utf-16'), delimiter='\t')
header = next(csvReader)
print(header)
print('---------')
types = dict()
english_name = 2
primary_type = 4
secondary_type = 5
for row in csvReader:
    name=row[english_name].lower().replace('\'','')
    #types[name] = [row[primary_type],row[secondary_type]]
    types[name] = [row[primary_type],'']
print(types)

['national_number', 'gen', 'english_name', 'japanese_name', 'primary_type', 'secondary_type', 'classification', 'percent_male', 'percent_female', 'height_m', 'weight_kg', 'capture_rate', 'base_egg_steps', 'hp', 'attack', 'defense', 'sp_attack', 'sp_defense', 'speed', 'abilities_0', 'abilities_1', 'abilities_2', 'abilities_hidden', 'against_normal', 'against_fire', 'against_water', 'against_electric', 'against_grass', 'against_ice', 'against_fighting', 'against_poison', 'against_ground', 'against_flying', 'against_psychict', 'against_bug', 'against_rock', 'against_ghost', 'against_dragon', 'against_dark', 'against_steel', 'against_fairy', 'is_sublegendary', 'is_legendary', 'is_mythical', 'evochain_0', 'evochain_1', 'evochain_2', 'evochain_3', 'evochain_4', 'evochain_5', 'evochain_6', 'gigantamax', 'mega_evolution', 'mega_evolution_alt', 'description']
---------
{'bulbasaur': ['grass', ''], 'ivysaur': ['grass', ''], 'venusaur': ['grass', ''], 'charmander': ['fire', ''], 'charmeleon': ['f

In [171]:
print(dataset.targets)

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 

In [176]:
data_path = './Pokemon_Dataset'

dataset = torchvision.datasets.ImageFolder(
    root=data_path,
    transform=torchvision.transforms.ToTensor()
)

In [177]:
import os
import shutil

#match image name with the name in the csv file
nametype = dict()
#targets = []
sorted_path = "pokemon_type_sorted"
if not os.path.exists(sorted_path):
        os.makedirs(sorted_path)
for p in dataset.imgs:
    name = p[0].split('\\')[-1][:-4]
    print(name)
    if(len(name.split('-'))>1 and len(name.split('-')[1])>2):
        name = name.split('-')[0]
    print(p[0], types[name])
    #targets.append(types[name])
    for t in types[name]:
        if t == '':
            continue
        subsorted = sorted_path+"\\"+t
        if not os.path.exists(subsorted):
            os.makedirs(subsorted)
        shutil.copy2(p[0], subsorted+"\\"+name+".png") # complete target filename given
        nametype[p[0]] = types[name]

#dataset.targets = targets

abomasnow
./Pokemon_Dataset\1\abomasnow.png ['grass', '']
abra
./Pokemon_Dataset\1\abra.png ['psychic', '']
absol
./Pokemon_Dataset\1\absol.png ['dark', '']
accelgor
./Pokemon_Dataset\1\accelgor.png ['bug', '']
aegislash-shield
./Pokemon_Dataset\1\aegislash-shield.png ['steel', '']
aerodactyl
./Pokemon_Dataset\1\aerodactyl.png ['rock', '']
aggron
./Pokemon_Dataset\1\aggron.png ['steel', '']
aipom
./Pokemon_Dataset\1\aipom.png ['normal', '']
alakazam
./Pokemon_Dataset\1\alakazam.png ['psychic', '']
alcremie
./Pokemon_Dataset\1\alcremie.png ['fairy', '']
alomomola
./Pokemon_Dataset\1\alomomola.png ['water', '']
altaria
./Pokemon_Dataset\1\altaria.png ['dragon', '']
amaura
./Pokemon_Dataset\1\amaura.png ['rock', '']
ambipom
./Pokemon_Dataset\1\ambipom.png ['normal', '']
amoonguss
./Pokemon_Dataset\1\amoonguss.png ['grass', '']
ampharos
./Pokemon_Dataset\1\ampharos.png ['electric', '']
anorith
./Pokemon_Dataset\1\anorith.png ['rock', '']
appletun
./Pokemon_Dataset\1\appletun.png ['grass', 

./Pokemon_Dataset\1\froslass.png ['ice', '']
frosmoth
./Pokemon_Dataset\1\frosmoth.png ['ice', '']
furfrou-natural
./Pokemon_Dataset\1\furfrou-natural.png ['normal', '']
furret
./Pokemon_Dataset\1\furret.png ['normal', '']
gabite
./Pokemon_Dataset\1\gabite.png ['dragon', '']
gallade
./Pokemon_Dataset\1\gallade.png ['psychic', '']
galvantula
./Pokemon_Dataset\1\galvantula.png ['bug', '']
garbodor
./Pokemon_Dataset\1\garbodor.png ['poison', '']
garchomp
./Pokemon_Dataset\1\garchomp.png ['dragon', '']
gardevoir
./Pokemon_Dataset\1\gardevoir.png ['psychic', '']
gastly
./Pokemon_Dataset\1\gastly.png ['ghost', '']
gastrodon-west
./Pokemon_Dataset\1\gastrodon-west.png ['water', '']
genesect
./Pokemon_Dataset\1\genesect.png ['bug', '']
gengar
./Pokemon_Dataset\1\gengar.png ['ghost', '']
geodude
./Pokemon_Dataset\1\geodude.png ['rock', '']
gible
./Pokemon_Dataset\1\gible.png ['dragon', '']
gigalith
./Pokemon_Dataset\1\gigalith.png ['rock', '']
girafarig
./Pokemon_Dataset\1\girafarig.png ['norma

./Pokemon_Dataset\1\orbeetle.png ['bug', '']
oricorio-baile
./Pokemon_Dataset\1\oricorio-baile.png ['fire', '']
oshawott
./Pokemon_Dataset\1\oshawott.png ['water', '']
pachirisu
./Pokemon_Dataset\1\pachirisu.png ['electric', '']
palkia
./Pokemon_Dataset\1\palkia.png ['water', '']
palossand
./Pokemon_Dataset\1\palossand.png ['ghost', '']
palpitoad
./Pokemon_Dataset\1\palpitoad.png ['water', '']
pancham
./Pokemon_Dataset\1\pancham.png ['fighting', '']
pangoro
./Pokemon_Dataset\1\pangoro.png ['fighting', '']
panpour
./Pokemon_Dataset\1\panpour.png ['water', '']
pansage
./Pokemon_Dataset\1\pansage.png ['grass', '']
pansear
./Pokemon_Dataset\1\pansear.png ['fire', '']
paras
./Pokemon_Dataset\1\paras.png ['bug', '']
parasect
./Pokemon_Dataset\1\parasect.png ['bug', '']
passimian
./Pokemon_Dataset\1\passimian.png ['fighting', '']
patrat
./Pokemon_Dataset\1\patrat.png ['normal', '']
pawniard
./Pokemon_Dataset\1\pawniard.png ['dark', '']
pelipper
./Pokemon_Dataset\1\pelipper.png ['water', '']
p

tyrogue
./Pokemon_Dataset\1\tyrogue.png ['fighting', '']
tyrunt
./Pokemon_Dataset\1\tyrunt.png ['rock', '']
umbreon
./Pokemon_Dataset\1\umbreon.png ['dark', '']
unfezant
./Pokemon_Dataset\1\unfezant.png ['normal', '']
unown
./Pokemon_Dataset\1\unown.png ['psychic', '']
ursaring
./Pokemon_Dataset\1\ursaring.png ['normal', '']
urshifu-single-strike
./Pokemon_Dataset\1\urshifu-single-strike.png ['fighting', '']
uxie
./Pokemon_Dataset\1\uxie.png ['psychic', '']
vanillish
./Pokemon_Dataset\1\vanillish.png ['ice', '']
vanillite
./Pokemon_Dataset\1\vanillite.png ['ice', '']
vanilluxe
./Pokemon_Dataset\1\vanilluxe.png ['ice', '']
vaporeon
./Pokemon_Dataset\1\vaporeon.png ['water', '']
venipede
./Pokemon_Dataset\1\venipede.png ['bug', '']
venomoth
./Pokemon_Dataset\1\venomoth.png ['bug', '']
venonat
./Pokemon_Dataset\1\venonat.png ['bug', '']
venusaur
./Pokemon_Dataset\1\venusaur.png ['grass', '']
vespiquen
./Pokemon_Dataset\1\vespiquen.png ['bug', '']
vibrava
./Pokemon_Dataset\1\vibrava.png ['

In [178]:
#print(dataset.targets)
data_path = sorted_path

dataset = torchvision.datasets.ImageFolder(
    root=data_path,
    transform=torchvision.transforms.ToTensor()
)
print(dataset)
print(dataset.class_to_idx)

Dataset ImageFolder
    Number of datapoints: 898
    Root location: pokemon_type_sorted
    StandardTransform
Transform: ToTensor()
{'bug': 0, 'dark': 1, 'dragon': 2, 'electric': 3, 'fairy': 4, 'fighting': 5, 'fire': 6, 'flying': 7, 'ghost': 8, 'grass': 9, 'ground': 10, 'ice': 11, 'normal': 12, 'poison': 13, 'psychic': 14, 'rock': 15, 'steel': 16, 'water': 17}


In [179]:
#https://www.machinecurve.com/index.php/2021/02/03/how-to-use-k-fold-cross-validation-with-pytorch/

from sklearn.model_selection import KFold

# Configuration options
k_folds = 5
loss_function = nn.CrossEntropyLoss()

# For fold results
results = {}

# Set fixed random number seed
torch.manual_seed(42)

# Define the K-fold Cross Validator
kfold = KFold(n_splits=k_folds, shuffle=True)


# Start print
print('--------------------------------')

# K-fold Cross Validation model evaluation
for fold, (train_ids, test_ids) in enumerate(kfold.split(dataset)):

    # Print
    print(f'FOLD {fold}')
    print('--------------------------------')
    
    #print(train_ids, test_ids)

    # Sample elements randomly from a given list of ids, no replacement.
    train_subsampler = torch.utils.data.SubsetRandomSampler(train_ids)
    test_subsampler = torch.utils.data.SubsetRandomSampler(test_ids)
    #print(train_subsampler, test_subsampler)
    # Define data loaders for training and testing data in this fold
    trainloader = torch.utils.data.DataLoader(
                      dataset, 
                      batch_size=10, sampler=train_subsampler)
    testloader = torch.utils.data.DataLoader(
                      dataset,
                      batch_size=10, sampler=test_subsampler)
    

--------------------------------
FOLD 0
--------------------------------
FOLD 1
--------------------------------
FOLD 2
--------------------------------
FOLD 3
--------------------------------
FOLD 4
--------------------------------


In [181]:
def reset_weights(m):
  '''
    Try resetting model weights to avoid
    weight leakage.
  '''
  for layer in m.children():
    if hasattr(layer, 'reset_parameters'):
        print(f'Reset trainable parameters of layer = {layer}')
        layer.reset_parameters()

class SimpleConvNet(nn.Module):
    '''
    Simple Convolutional Neural Network
    '''
    def __init__(self):
        super(SimpleConvNet, self).__init__()
        # INSERT CODE HERE
        self.conv1 = nn.Conv2d(3, 6, 5,  stride=3)
        self.conv2 = nn.Conv2d(6, 12, 5,  stride=3)
        self.fc1 = nn.Linear(31212, 1000)
        self.fc2 = nn.Linear(1000, 50)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        #print(x.shape)

        x = x.view(x.shape[0], 31212)

        #print(x.shape)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        #print(x.shape)
        probs = F.log_softmax(x)
        return probs # CHANGE CODE HERE
  
# https://www.machinecurve.com/index.php/2021/02/03/how-to-use-k-fold-cross-validation-with-pytorch/
num_epochs = 5
# Init the neural network
network = SimpleConvNet()
network.apply(reset_weights)

# Initialize optimizer
optimizer = torch.optim.Adam(network.parameters(), lr=1e-4)

# Run the training loop for defined number of epochs
for epoch in range(0, num_epochs):

  # Print epoch
  print(f'Starting epoch {epoch+1}')

  # Set current loss value
  current_loss = 0.0

  # Iterate over the DataLoader for training data
  for i, data in enumerate(trainloader, 0):

    # Get inputs
    inputs, targets = data

    #print(inputs, "targets:",targets)
    # Zero the gradients
    optimizer.zero_grad()

    # Perform forward pass
    outputs = network(inputs)

    # Compute loss
    loss = loss_function(outputs, targets)

    # Perform backward pass
    loss.backward()

    # Perform optimization
    optimizer.step()

    # Print statistics
    current_loss += loss.item()
    if i % 500 == 499:
        print('Loss after mini-batch %5d: %.3f' %
              (i + 1, current_loss / 500))
        current_loss = 0.0

# Process is complete.
print('Training process has finished. Saving trained model.')

# Print about testing
print('Starting testing')

# Saving the model
save_path = f'./model-fold-{fold}.pth'
torch.save(network.state_dict(), save_path)

# Evaluationfor this fold
correct, total = 0, 0
with torch.no_grad():

  # Iterate over the test data and generate predictions
  for i, data in enumerate(testloader, 0):

    # Get inputs
    inputs, targets = data

    # Generate outputs
    outputs = network(inputs)

    # Set total and correct
    _, predicted = torch.max(outputs.data, 1)
    total += targets.size(0)
    print(predicted, " :: ", targets)
    correct += (predicted == targets).sum().item()

  # Print accuracy
  print('Accuracy for fold %d: %d %%' % (fold, 100.0 * correct / total))
  print('--------------------------------')
  results[fold] = 100.0 * (correct / total)

# Print fold results
print(f'K-FOLD CROSS VALIDATION RESULTS FOR {k_folds} FOLDS')
print('--------------------------------')
sum = 0.0
for key, value in results.items():
    print(f'Fold {key}: {value} %')
    sum += value
print(f'Average: {sum/len(results.items())} %')

Reset trainable parameters of layer = Conv2d(3, 6, kernel_size=(5, 5), stride=(3, 3))
Reset trainable parameters of layer = Conv2d(6, 12, kernel_size=(5, 5), stride=(3, 3))
Reset trainable parameters of layer = Linear(in_features=31212, out_features=1000, bias=True)
Reset trainable parameters of layer = Linear(in_features=1000, out_features=50, bias=True)
Starting epoch 1


  probs = F.log_softmax(x)


Starting epoch 2
Starting epoch 3
Starting epoch 4
Starting epoch 5
Training process has finished. Saving trained model.
Starting testing
tensor([12, 17, 17, 12, 12, 12, 12, 17, 17, 12])  ::  tensor([17, 12,  9, 12,  0, 14, 11, 11, 10,  5])
tensor([17, 17, 12, 13, 17, 17, 12, 17, 17,  6])  ::  tensor([ 1,  2,  3, 17,  8, 17, 17,  6, 13,  6])
tensor([17, 12, 12, 17, 12, 17, 17, 17, 17, 12])  ::  tensor([17, 12, 15,  6,  9, 16,  5, 17, 17,  8])
tensor([17, 12,  6, 17,  0,  9, 17,  6, 12, 13])  ::  tensor([ 8, 10, 12, 17,  6, 17,  1,  0,  9, 16])
tensor([12, 12,  9, 17, 17, 12, 12, 17, 17, 17])  ::  tensor([ 9, 17,  0, 10, 15,  0, 12, 10,  3, 12])
tensor([17, 12, 12, 12, 17, 17, 12, 17, 12,  3])  ::  tensor([ 8, 15,  1,  7, 17, 17, 11,  4, 14, 17])
tensor([17, 12, 17, 17, 17, 17, 17, 12, 12, 17])  ::  tensor([17,  3,  3,  4,  1,  9,  3, 12,  4,  9])
tensor([17, 13, 12, 17, 12, 17, 12, 17, 17, 12])  ::  tensor([17,  9,  8, 16,  0,  9, 12,  0, 12,  9])
tensor([12, 17, 12, 17, 17,  6, 17, 17

In [3]:
device = 'cpu'
epochs = 10
lr = 0.01
mom = 0.5
args = ''

In [129]:
# python3 kuzu_main.py --net lin
net = NetLin().to(device)

if list(net.parameters()):
    optimizer = optim.SGD(net.parameters(), lr=lr, momentum=mom)

    for epoch in range(1, epochs + 1):
        train(args, net, device, trainloader, optimizer, epoch)
        test(args, net, device, testloader)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (10x676875 and 784x10)

In [9]:
# python3 kuzu_main.py --net full
hidden_node = 10
#device = 'cuda'

for i in range(6, 12):
    hidden_node = 10 * i
    print('==== hidden_node is ', hidden_node, '====')
    
    net = NetFull().to(device)
    if list(net.parameters()):
        optimizer = optim.SGD(net.parameters(), lr=lr, momentum=mom)

        for epoch in range(1, epochs + 1):
            train(args, net, device, train_loader, optimizer, epoch)
            test(args, net, device, test_loader)

==== hidden_node is  60 ====


  probs = F.log_softmax(output)


<class 'numpy.ndarray'>
[[753.   4.   6.   3.  76.  10.   4.  15.  11.   5.]
 [  4. 655.  59.  40.  53.  33.  17.  25.  38.  58.]
 [  6. 113. 683.  69.  75. 146. 148.  30.  73.  82.]
 [ 11.  16.  38. 751.  20.  26.  12.  19.  43.   2.]
 [ 30.  28.  23.  14. 630.  25.  25. 111.   6.  73.]
 [ 58.  21.  10.  42.  13. 686.  18.  13.  36.  23.]
 [  2.  84.  63.  18.  25.  35. 733.  87.  51.  21.]
 [ 74.   9.  37.  19.  23.   8.  20. 531.   7.  31.]
 [ 36.  18.  38.  32.  22.  25.   9. 112. 714.  35.]
 [ 26.  52.  43.  12.  63.   6.  14.  57.  21. 670.]]

Test set: Average loss: 1.0027, Accuracy: 6806/10000 (68%)

<class 'numpy.ndarray'>
[[791.   3.   6.   4.  70.  10.   5.  13.  12.   5.]
 [  4. 705.  48.  24.  52.  25.  18.  26.  35.  48.]
 [  4.  82. 730.  69.  62. 139. 133.  25.  47.  82.]
 [ 11.  14.  39. 799.  15.  23.   8.  15.  47.   3.]
 [ 28.  28.  23.  11. 663.  23.  23.  80.   6.  66.]
 [ 41.  17.   9.  30.  12. 712.  12.  11.  27.  15.]
 [  2.  79.  51.  14.  30.  34. 768.  75. 

<class 'numpy.ndarray'>
[[841.   3.   7.   4.  46.  11.   3.  17.  12.   2.]
 [  4. 798.  14.  10.  32.  15.  14.  13.  29.  23.]
 [  1.  34. 836.  35.  20.  87.  68.  24.  33.  46.]
 [  6.   8.  44. 903.  10.  13.   7.   7.  43.   4.]
 [ 31.  23.  13.   4. 793.  15.  19.  28.   2.  37.]
 [ 26.  15.  13.  12.   8. 800.   4.  11.   9.  10.]
 [  3.  63.  31.   7.  34.  29. 870.  42.  32.  24.]
 [ 46.   2.  10.   7.  15.   1.   8. 795.   5.  15.]
 [ 34.  18.  17.   8.  20.  22.   2.  31. 827.  13.]
 [  8.  36.  15.  10.  22.   7.   5.  32.   8. 826.]]

Test set: Average loss: 0.5529, Accuracy: 8289/10000 (83%)

<class 'numpy.ndarray'>
[[845.   3.   8.   4.  45.  12.   3.  15.   9.   2.]
 [  4. 803.  12.  10.  31.  12.  11.  13.  29.  21.]
 [  1.  33. 841.  34.  16.  87.  61.  23.  29.  48.]
 [  5.   7.  44. 907.  10.  14.   9.   6.  42.   4.]
 [ 31.  24.  11.   4. 805.  14.  18.  26.   2.  33.]
 [ 24.  13.  13.  12.   7. 805.   4.  10.   9.  10.]
 [  3.  62.  29.   6.  32.  30. 879.  37. 

<class 'numpy.ndarray'>
[[816.   5.   7.   7.  54.   7.   4.  18.   9.   5.]
 [  4. 775.  33.  11.  34.  17.  19.  14.  26.  24.]
 [  1.  39. 806.  61.  29.  84.  79.  19.  34.  69.]
 [  6.   5.  32. 875.  10.  10.   8.   5.  43.   4.]
 [ 32.  24.  13.   3. 765.  12.  19.  35.   3.  39.]
 [ 45.  13.  15.  15.  11. 806.   8.  11.  14.  12.]
 [  6.  73.  31.   5.  34.  27. 844.  39.  36.  23.]
 [ 47.   9.  17.   6.  23.   6.  10. 783.   9.  15.]
 [ 36.  26.  26.   8.  17.  23.   2.  40. 817.  16.]
 [  7.  31.  20.   9.  23.   8.   7.  36.   9. 793.]]

Test set: Average loss: 0.6156, Accuracy: 8080/10000 (81%)

<class 'numpy.ndarray'>
[[829.   4.   7.   6.  51.   8.   3.  19.   8.   5.]
 [  4. 786.  28.   9.  32.  15.  18.  13.  23.  24.]
 [  1.  40. 821.  47.  31.  83.  73.  18.  31.  63.]
 [  7.   4.  32. 892.   8.   9.   9.   5.  45.   4.]
 [ 36.  22.  12.   2. 773.  13.  17.  31.   3.  39.]
 [ 38.  11.  15.  15.  11. 814.   5.  11.  13.  12.]
 [  6.  69.  29.   6.  34.  25. 856.  34. 

<class 'numpy.ndarray'>
[[812.   5.   6.   4.  69.  11.   4.  13.  11.   5.]
 [  2. 745.  41.  20.  50.  23.  23.  15.  35.  33.]
 [  2.  54. 764.  55.  50. 112. 110.  20.  37.  81.]
 [ 10.  10.  42. 839.  14.  20.   8.   9.  57.   4.]
 [ 25.  27.  18.   7. 684.  15.  21.  49.   5.  53.]
 [ 44.  18.  10.  21.  11. 749.   6.  10.  22.  11.]
 [  3.  77.  44.  12.  36.  34. 800.  56.  45.  29.]
 [ 53.   6.  19.  13.  21.   4.  17. 699.   6.  28.]
 [ 36.  18.  28.  16.  23.  23.   5.  76. 770.  18.]
 [ 13.  40.  28.  13.  42.   9.   6.  53.  12. 738.]]

Test set: Average loss: 0.7638, Accuracy: 7600/10000 (76%)

<class 'numpy.ndarray'>
[[826.   5.   6.   4.  63.  10.   4.  17.  12.   6.]
 [  2. 762.  31.  15.  49.  20.  21.  13.  32.  32.]
 [  1.  40. 783.  51.  35.  93.  97.  14.  33.  70.]
 [  9.   6.  44. 858.  13.  19.   9.   8.  58.   2.]
 [ 24.  28.  17.   5. 719.  15.  20.  38.   6.  50.]
 [ 41.  16.  11.  21.  10. 773.   6.  10.  18.   8.]
 [  3.  78.  43.   9.  39.  34. 823.  53. 

<class 'numpy.ndarray'>
[[755.   4.   7.   4.  74.  11.   5.  15.  10.   5.]
 [  4. 652.  67.  50.  47.  32.  21.  24.  40.  57.]
 [  5. 102. 682.  58.  74. 147. 152.  24.  73.  89.]
 [ 10.  17.  37. 749.  26.  23.  10.  16.  55.   4.]
 [ 30.  31.  21.  14. 626.  20.  32. 128.   7.  66.]
 [ 58.  23.  15.  34.  15. 694.  17.  14.  37.  23.]
 [  2.  84.  58.  17.  27.  32. 725.  75.  50.  22.]
 [ 74.  11.  34.  20.  27.   8.  20. 540.   7.  29.]
 [ 39.  18.  38.  38.  20.  24.   9. 106. 704.  34.]
 [ 23.  58.  41.  16.  64.   9.   9.  58.  17. 671.]]

Test set: Average loss: 1.0069, Accuracy: 6798/10000 (68%)

<class 'numpy.ndarray'>
[[778.   2.   7.   4.  69.  10.   4.  13.  11.   4.]
 [  3. 700.  56.  34.  45.  30.  22.  32.  37.  50.]
 [  4.  83. 708.  57.  68. 124. 126.  20.  51.  77.]
 [ 10.  16.  41. 804.  18.  18.   9.  12.  55.   6.]
 [ 31.  27.  20.  13. 653.  16.  25.  84.   5.  62.]
 [ 52.  21.  16.  26.  15. 740.  13.  12.  27.  21.]
 [  2.  80.  50.  17.  31.  31. 771.  63. 

<class 'numpy.ndarray'>
[[839.   4.   9.   3.  45.   9.   3.  16.  11.   3.]
 [  4. 795.  18.  10.  33.  18.  13.  18.  28.  21.]
 [  1.  27. 817.  39.  20.  76.  68.  18.  34.  52.]
 [  6.   4.  42. 903.   8.  15.   9.   3.  54.   5.]
 [ 35.  26.  12.   2. 788.  12.  20.  30.   4.  34.]
 [ 31.  14.  20.  15.  12. 817.   7.  10.  11.   5.]
 [  4.  69.  31.   9.  36.  29. 863.  35.  31.  24.]
 [ 42.   7.  13.   4.  17.   2.   6. 809.   5.  17.]
 [ 31.  20.  19.   6.  23.  16.   3.  28. 818.  13.]
 [  7.  34.  19.   9.  18.   6.   8.  33.   4. 826.]]

Test set: Average loss: 0.5606, Accuracy: 8275/10000 (83%)

<class 'numpy.ndarray'>
[[843.   4.   9.   3.  40.   8.   3.  15.   9.   3.]
 [  4. 806.  15.  10.  32.  18.  13.  17.  29.  18.]
 [  1.  24. 827.  38.  17.  76.  66.  19.  33.  50.]
 [  6.   4.  40. 908.   7.  14.   9.   3.  55.   5.]
 [ 33.  26.  11.   2. 804.  10.  18.  28.   3.  32.]
 [ 30.  12.  19.  15.   8. 824.   6.  10.  10.   5.]
 [  5.  64.  31.   7.  37.  28. 869.  35. 

<class 'numpy.ndarray'>
[[825.   5.   6.   7.  55.   9.   3.  22.   9.   2.]
 [  4. 776.  27.  16.  39.  19.  19.  14.  32.  22.]
 [  1.  42. 794.  34.  29.  79.  77.  20.  26.  63.]
 [  7.   4.  42. 884.  12.  11.   6.   3.  59.   4.]
 [ 29.  26.  16.   3. 753.  13.  17.  27.   8.  41.]
 [ 44.  17.  22.  23.  12. 810.  12.  11.  16.  12.]
 [  3.  67.  40.   9.  27.  30. 846.  43.  33.  31.]
 [ 47.   5.  15.   6.  22.   2.  11. 785.   5.  17.]
 [ 33.  25.  21.   7.  19.  19.   2.  34. 805.  15.]
 [  7.  33.  17.  11.  32.   8.   7.  41.   7. 793.]]

Test set: Average loss: 0.6201, Accuracy: 8071/10000 (81%)

<class 'numpy.ndarray'>
[[828.   5.   5.   7.  48.  10.   3.  22.   9.   2.]
 [  5. 789.  21.  14.  37.  18.  18.  16.  31.  19.]
 [  1.  38. 805.  33.  27.  78.  71.  23.  25.  55.]
 [  7.   4.  43. 890.  12.  11.   6.   3.  60.   4.]
 [ 35.  25.  20.   3. 771.  12.  19.  23.   8.  40.]
 [ 41.  15.  23.  23.  11. 814.  12.  11.  14.  10.]
 [  3.  65.  32.   5.  26.  30. 852.  43. 

<class 'numpy.ndarray'>
[[807.   5.   6.   6.  76.   9.   4.  17.   9.   6.]
 [  2. 729.  35.  19.  44.  23.  18.  17.  37.  37.]
 [  3.  61. 763.  57.  50. 109. 117.  22.  37.  83.]
 [  8.  11.  38. 840.  15.  19.   8.   9.  62.   5.]
 [ 29.  30.  26.   6. 692.  18.  21.  59.   3.  53.]
 [ 41.  19.  11.  26.   7. 766.   7.  10.  25.  16.]
 [  4.  82.  43.  10.  32.  29. 795.  58.  42.  27.]
 [ 54.   8.  27.  11.  22.   4.  15. 690.   7.  23.]
 [ 39.  18.  27.  14.  20.  16.   5.  72. 765.  23.]
 [ 13.  37.  24.  11.  42.   7.  10.  46.  13. 727.]]

Test set: Average loss: 0.7653, Accuracy: 7574/10000 (76%)

<class 'numpy.ndarray'>
[[820.   6.   7.   5.  73.   9.   4.  18.   9.   5.]
 [  2. 762.  24.  13.  41.  19.  16.  17.  34.  29.]
 [  2.  47. 792.  49.  37.  98. 102.  22.  30.  73.]
 [  8.   6.  40. 867.  14.  14.   7.   8.  59.   6.]
 [ 27.  26.  18.   5. 719.  15.  17.  48.   3.  46.]
 [ 43.  16.  11.  23.   7. 787.   7.  10.  22.  12.]
 [  4.  79.  42.   8.  35.  28. 823.  51. 

In [14]:
# torch.cuda.get_device_name(0)

# https://stats.stackexchange.com/questions/380996/convolutional-network-how-to-choose-output-channels-number-stride-and-padding
# depth of feature maps https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html?highlight=channels%20conv2d
# https://nextjournal.com/gkoehler/pytorch-mnist
class NetConv(nn.Module):
    # two convolutional layers and one fully connected layer,
    # all using relu, followed by log_softmax
    def __init__(self):
        super(NetConv, self).__init__()
        # INSERT CODE HERE
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 12, 5)
        self.fc1 = nn.Linear(4800, 400)
        self.fc2 = nn.Linear(400, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        #print(x.shape)
        
        x = x.view(x.shape[0], 4800)
        
        #print(x.shape)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        #print(x.shape)
        probs = F.log_softmax(x)
        return probs # CHANGE CODE HERE
    
net = NetConv().to(device)
mom = 0.9
lr = 0.01
if list(net.parameters()):
    optimizer = optim.SGD(net.parameters(), lr=lr, momentum=mom)

    for epoch in range(1, epochs + 1):
        train(args, net, device, train_loader, optimizer, epoch)
        test(args, net, device, test_loader)



<class 'numpy.ndarray'>
[[928.   4.   7.   4.  23.   5.   4.   7.  19.  10.]
 [  2. 892.   2.   2.  13.   9.   9.   8.  14.   4.]
 [  3.  12. 863.  26.   8.  75.  25.  15.   6.  18.]
 [  0.   2.  33. 927.  10.   8.   3.   0.  12.   5.]
 [ 47.  11.  15.   2. 903.   6.   6.  20.  16.  21.]
 [  8.   5.  12.  22.   2. 854.   7.   0.   4.   7.]
 [  1.  52.  36.   7.  20.  32. 941.  25.   8.  15.]
 [  7.   2.   1.   2.   1.   0.   1. 873.   0.   3.]
 [  1.   7.  15.   5.  16.   8.   3.  22. 919.  21.]
 [  3.  13.  16.   3.   4.   3.   1.  30.   2. 896.]]

Test set: Average loss: 0.3362, Accuracy: 8996/10000 (90%)

<class 'numpy.ndarray'>
[[965.   3.  10.   3.  22.   6.   3.   8.  20.   9.]
 [  1. 879.   1.   0.   3.   7.   6.   2.   5.   4.]
 [  2.   3. 853.   9.   3.  75.  14.   1.  17.  12.]
 [  1.   1.  36. 956.  10.   9.   2.   0.   5.   5.]
 [ 21.  14.   7.   1. 923.   9.   5.   3.  14.  14.]
 [  1.   1.   5.   4.   1. 821.   1.   0.   2.   1.]
 [  0.  64.  45.  13.  18.  41. 962.  12. 