In [1]:
import torch
from torch import nn
from torch.nn import functional as F

import dlc_practical_prologue as prologue

train_input, train_target, test_input, test_target = \
    prologue.load_data(one_hot_labels = True, normalize = True, flatten = False)

######################################################################

class Net(nn.Module):
    def __init__(self, nb_hidden):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=5)
        self.fc1 = nn.Linear(256, nb_hidden)
        self.fc2 = nn.Linear(nb_hidden, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), kernel_size=3, stride=3))
        x = F.relu(F.max_pool2d(self.conv2(x), kernel_size=2, stride=2))
        x = F.relu(self.fc1(x.view(-1, 256)))
        x = self.fc2(x)
        return x

######################################################################

class Net2(nn.Module):
    def __init__(self):
        super().__init__()
        nb_hidden = 200
        self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
        self.conv2 = nn.Conv2d(32, 32, kernel_size=5)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=2)
        self.fc1 = nn.Linear(9 * 64, nb_hidden)
        self.fc2 = nn.Linear(nb_hidden, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), kernel_size=2))
        x = F.relu(F.max_pool2d(self.conv2(x), kernel_size=2))
        x = F.relu(self.conv3(x))
        x = F.relu(self.fc1(x.view(-1, 9 * 64)))
        x = self.fc2(x)
        return x

######################################################################

def train_model(model, train_input, train_target, mini_batch_size, nb_epochs = 100):
    criterion = nn.MSELoss()
    eta = 1e-1

    for e in range(nb_epochs):
        acc_loss = 0

        for b in range(0, train_input.size(0), mini_batch_size):
            output = model(train_input.narrow(0, b, mini_batch_size))
            loss = criterion(output, train_target.narrow(0, b, mini_batch_size))
            acc_loss = acc_loss + loss.item()

            model.zero_grad()
            loss.backward()

            with torch.no_grad():
                for p in model.parameters():
                    p -= eta * p.grad

        print(e, acc_loss)

def compute_nb_errors(model, input, target, mini_batch_size):
    nb_errors = 0

    for b in range(0, input.size(0), mini_batch_size):
        output = model(input.narrow(0, b, mini_batch_size))
        _, predicted_classes = output.max(1)
        for k in range(mini_batch_size):
            if target[b + k, predicted_classes[k]] <= 0:
                nb_errors = nb_errors + 1

    return nb_errors

######################################################################

mini_batch_size = 100

######################################################################
# Question 2

for k in range(10):
    model = Net(200)
    train_model(model, train_input, train_target, mini_batch_size)
    nb_test_errors = compute_nb_errors(model, test_input, test_target, mini_batch_size)
    print('test error Net {:0.2f}% {:d}/{:d}'.format((100 * nb_test_errors) / test_input.size(0),
                                                      nb_test_errors, test_input.size(0)))

######################################################################
# Question 3

for nh in [ 10, 50, 200, 500, 2500 ]:
    model = Net(nh)
    train_model(model, train_input, train_target, mini_batch_size)
    nb_test_errors = compute_nb_errors(model, test_input, test_target, mini_batch_size)
    print('test error Net nh={:d} {:0.2f}%% {:d}/{:d}'.format(nh,
                                                              (100 * nb_test_errors) / test_input.size(0),
                                                              nb_test_errors, test_input.size(0)))

######################################################################
# Question 4

model = Net2()
train_model(model, train_input, train_target, mini_batch_size)
nb_test_errors = compute_nb_errors(model, test_input, test_target, mini_batch_size)
print('test error Net2 {:0.2f}%% {:d}/{:d}'.format((100 * nb_test_errors) / test_input.size(0),
                                                   nb_test_errors, test_input.size(0)))

* Using MNIST
** Reduce the data-set (use --full for the full thing)
** Use 1000 train and 1000 test samples


  Variable._execution_engine.run_backward(


0 0.9033293724060059
1 0.7755392640829086
2 0.6989829391241074
3 0.6369436904788017
4 0.585752759128809
5 0.5477792806923389
6 0.5352977477014065
7 0.492869108915329
8 0.4682561010122299
9 0.45754119753837585
10 0.4197626858949661
11 0.4122621901333332
12 0.3939913995563984
13 0.39873431995511055
14 0.36853643134236336
15 0.3492521718144417
16 0.35211366042494774
17 0.33363276720046997
18 0.3345471788197756
19 0.3199388384819031
20 0.30358470790088177
21 0.30524951219558716
22 0.2839912623167038
23 0.29041992127895355
24 0.2780223246663809
25 0.26280489191412926
26 0.26611946895718575
27 0.25464852899312973
28 0.2536900229752064
29 0.24975788220763206
30 0.23268995992839336
31 0.24464355781674385
32 0.23444974422454834
33 0.21726738661527634
34 0.22033331915736198
35 0.22154290974140167
36 0.21384438313543797
37 0.21591318398714066
38 0.19906130246818066
39 0.19539870414882898
40 0.21349128521978855
41 0.1948958933353424
42 0.19021200016140938
43 0.19509266503155231
44 0.18238064087927

60 0.14387714117765427
61 0.14362910855561495
62 0.13984736613929272
63 0.13491330854594707
64 0.1335650123655796
65 0.13462816458195448
66 0.138306625187397
67 0.13987261150032282
68 0.1350750019773841
69 0.12995135132223368
70 0.12559451255947351
71 0.12394803296774626
72 0.12375781685113907
73 0.12665096390992403
74 0.13132713548839092
75 0.12497649434953928
76 0.12053326610475779
77 0.12093785032629967
78 0.11813299171626568
79 0.1177752623334527
80 0.12055205926299095
81 0.1167949540540576
82 0.1133378529921174
83 0.11075842380523682
84 0.10948862787336111
85 0.11289350874722004
86 0.11457317881286144
87 0.10962224844843149
88 0.10690208990126848
89 0.10536643490195274
90 0.10477024875581264
91 0.10664260666817427
92 0.10748065263032913
93 0.10633579548448324
94 0.10487331077456474
95 0.10157750733196735
96 0.10183803550899029
97 0.10648817103356123
98 0.10200215689837933
99 0.0962270894087851
test error Net 7.00% 70/1000
0 0.9222614914178848
1 0.7986443638801575
2 0.7361811995506

20 0.30163620971143246
21 0.29608009196817875
22 0.301229489967227
23 0.2714862208813429
24 0.2676946949213743
25 0.26236897706985474
26 0.25268949195742607
27 0.26867837086319923
28 0.2516995444893837
29 0.23458069376647472
30 0.23733846843242645
31 0.23469932563602924
32 0.22486749663949013
33 0.22509419918060303
34 0.22011995315551758
35 0.22795865312218666
36 0.21640407852828503
37 0.20980535633862019
38 0.20284109935164452
39 0.1947871558368206
40 0.19250351376831532
41 0.1980342660099268
42 0.19401358906179667
43 0.18319150526076555
44 0.18472572788596153
45 0.19864250533282757
46 0.18238922953605652
47 0.179458262398839
48 0.17944965790957212
49 0.17115240450948477
50 0.16924925334751606
51 0.16706496849656105
52 0.1633601300418377
53 0.16166158020496368
54 0.16163552924990654
55 0.1632257653400302
56 0.1629148069769144
57 0.15476581174880266
58 0.15272643975913525
59 0.15974125918000937
60 0.1567033687606454
61 0.14563098922371864
62 0.14104751963168383
63 0.13886678125709295
6

82 0.42045261338353157
83 0.4155973009765148
84 0.41091326251626015
85 0.4098556935787201
86 0.4113599732518196
87 0.4036650024354458
88 0.3980194181203842
89 0.3977685123682022
90 0.3975721709430218
91 0.393405519425869
92 0.3868793621659279
93 0.3825283944606781
94 0.38804537430405617
95 0.38904760405421257
96 0.377056747674942
97 0.3948100246489048
98 0.38570595905184746
99 0.3724111206829548
test error Net nh=10 21.80%% 218/1000
0 0.8999888747930527
1 0.777551181614399
2 0.7096283733844757
3 0.6560536548495293
4 0.6141289062798023
5 0.582493893802166
6 0.5551295019686222
7 0.5275567024946213
8 0.4992729350924492
9 0.500631183385849
10 0.46767666935920715
11 0.43226010724902153
12 0.4306519664824009
13 0.4139394015073776
14 0.41043708100914955
15 0.38735366985201836
16 0.3710239678621292
17 0.3708147294819355
18 0.3514758925884962
19 0.3414238393306732
20 0.32740948535501957
21 0.32692842558026314
22 0.3167440500110388
23 0.30328054912388325
24 0.2979781739413738
25 0.28948668763041

41 0.14797527063637972
42 0.14937864430248737
43 0.1443672077730298
44 0.14483896270394325
45 0.13786066509783268
46 0.13330982625484467
47 0.1347966082394123
48 0.1327350102365017
49 0.12811705842614174
50 0.13962119538336992
51 0.12909044232219458
52 0.12391479592770338
53 0.13088703993707895
54 0.12241585738956928
55 0.11619672738015652
56 0.11450216826051474
57 0.11437861900776625
58 0.11431473586708307
59 0.11280319839715958
60 0.11115698888897896
61 0.11223777942359447
62 0.12003114353865385
63 0.11262854747474194
64 0.10961450543254614
65 0.1116732032969594
66 0.10731378477066755
67 0.105434600263834
68 0.10707893501967192
69 0.10286740120500326
70 0.0982779897749424
71 0.09769027587026358
72 0.09675494208931923
73 0.09475962817668915
74 0.09609472658485174
75 0.0986628495156765
76 0.09314997214823961
77 0.09110120683908463
78 0.09388773608952761
79 0.09080445999279618
80 0.0893521229736507
81 0.09148143883794546
82 0.09022894501686096
83 0.0889164418913424
84 0.0867196908220648