In [50]:
#!/usr/bin/env python

import torch
from torch import nn
from torch import optim
from torch import Tensor
from torch.nn import functional as F
import dlc_practical_prologue as prologue

In [None]:
!wget www.di.ens.fr/~lelarge/MNIST.tar.gz
!tar -zxvf MNIST.tar.gz

from torchvision.datasets import MNIST
import torchvision.transforms as transforms

transform = transforms.Compose([transforms.ToTensor()])

train_data = MNIST(root = './data/mnist/', train=True, download=True, transform=transform)

val_data = MNIST(root = './data/mnist/', train=False, download=True, transform=transform)

In [66]:
#Data generation
N=10**3
train_input,train_target,train_classes,test_input,test_target,test_classes=prologue.generate_pair_sets(N)
train_target=train_target.float()
train_input=train_input.float()

In [52]:
#print(new_train_target)

In [87]:
#Base functions adapted from the practicals
class Net(nn.Module):
    def __init__(self, nb_hidden):
        super().__init__()
        self.conv1 = nn.Conv2d(2, 32, kernel_size=3, groups=2)
        #self.conv2 = nn.Conv2d(32, 64, kernel_size=2)
        self.fc1 = nn.Linear(512, nb_hidden)
        self.fc2 = nn.Linear(nb_hidden, 2)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), kernel_size=3, stride=3))
        #x = F.relu(F.max_pool2d(self.conv2(x), kernel_size=2, stride=2))
        x = F.relu(self.fc1(x.view(-1, 512)))
        x = F.softmax(self.fc2(x), dim=1)
        #print(x)
        return x

def train_model(model, train_input, train_target, mini_batch_size, nb_epochs = 500):
    criterion = nn.MSELoss()
    eta = 1e-2

    for e in range(nb_epochs):
        acc_loss = 0

        for b in range(0, train_input.size(0), mini_batch_size):
            output = model(train_input.narrow(0, b, mini_batch_size))
            loss = criterion(output[:,1], train_target.narrow(0, b, mini_batch_size))
            acc_loss = acc_loss + loss.item()
            model.zero_grad()
            loss.backward()

            with torch.no_grad():
                for p in model.parameters():
                    p -= eta * p.grad

        print(e, acc_loss)
            
def compute_nb_errors(model, input, target, mini_batch_size):
    nb_errors = 0

    for b in range(0, input.size(0), mini_batch_size):
        output = model(input.narrow(0, b, mini_batch_size))
        _, predicted_classes = output.max(1)
        for k in range(mini_batch_size):
            if target[b + k]!=predicted_classes[k]:
                nb_errors = nb_errors + 1

    return nb_errors

mini_batch_size=100

model = Net(20)
train_model(model, train_input, train_target,mini_batch_size)
nb_test_errors = compute_nb_errors(model, test_input, test_target, mini_batch_size)
print('test error Net {:0.2f}% {:d}/{:d}'.format((100 * nb_test_errors) / test_input.size(0),
                                                      nb_test_errors, test_input.size(0)))

0 4.331114739179611
1 4.449987322092056
2 4.449986636638641
3 4.449985831975937
4 4.449985027313232
5 4.449984043836594
6 4.449982941150665
7 4.449981719255447
8 4.449980169534683
9 4.449978470802307
10 4.44997638463974
11 4.449973911046982
12 4.449970901012421
13 4.449967056512833
14 4.449962168931961
15 4.449955701828003
16 4.44994643330574
17 4.44993269443512
18 4.449910283088684
19 4.449867904186249
20 4.449765503406525
21 4.449350327253342
22 4.454218506813049
23 4.439814388751984
24 4.44235372543335
25 4.4396699368953705
26 4.389044880867004
27 4.4698441326618195
28 4.477373510599136
29 3.7869069427251816
30 3.8381370306015015
31 3.017942175269127
32 3.4553590565919876
33 3.8607137501239777
34 3.8134843856096268
35 3.3726906776428223
36 3.3813152760267258
37 4.238058000802994
38 3.4239528626203537
39 3.043279081583023
40 2.9182365387678146
41 3.0237418562173843
42 2.613054931163788
43 2.7237123548984528
44 2.0827048271894455
45 1.8645974695682526
46 1.8683878630399704
47 1.670266

358 0.08832469381013652
359 0.08831794314755825
360 0.08831158611428691
361 0.08830635723279556
362 0.08829904456797522
363 0.08829251252609538
364 0.0882868600601796
365 0.08828028976859059
366 0.08827170221047709
367 0.08826695168681908
368 0.08826302468514768
369 0.08825573261128739
370 0.08824990315042669
371 0.08824362060113344
372 0.08823770206799963
373 0.08823328083963133
374 0.08822612392395968
375 0.08890027848246973
376 0.08835682300559711
377 0.08826728321582777
378 0.08823360165843042
379 0.08821622200775892
380 0.08820339081285056
381 0.08819535544898827
382 0.08818978950876044
383 0.0881844817995443
384 0.08817775262286887
385 0.08817375725629972
386 0.08816745010699378
387 0.08816371441207593
388 0.08815898711327463
389 0.08815230536129093
390 0.08815022825729102
391 0.08814534106932115
392 0.08813986161840148
393 0.08813434112380492
394 0.08814940076263156
395 0.08852155992644839
396 0.08828545008145738
397 0.08820757256034994
398 0.08816771659621736
399 0.088142912194