In [434]:
import torch
import math

from torch import optim
from torch import Tensor
from torch import nn
from torch.nn import functional as F

import dlc_practical_prologue as prologue

In [552]:
train_input, train_target, train_classes, test_input, test_target, test_classes = prologue.generate_pair_sets(1000)

In [555]:
mean, std = train_input.mean(), train_input.std()

train_input = train_input.sub_(mean).div_(std)
test_input = test_input.sub_(mean).div_(std)

# CNN Baseline

In [569]:
class Net(nn.Module):
    def __init__(self, nb_hidden):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(2, 32, kernel_size=3)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.fc1 = nn.Linear(256, nb_hidden)
        self.fc2 = nn.Linear(nb_hidden, 2)
        # The number of params:
        # conv1: 32 * (2 * 3 * 3 + 1) = 608
        # conv2: 64 * (32 * 3 * 3 + 1) = 18,496
        # fc1: (256 + 1) * 200 = 51,400
        # fc2: (200 + 1) * 2 = 402

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), kernel_size=2))
        x = F.relu(F.max_pool2d(self.conv2(x), kernel_size=2))
        x = F.relu(self.fc1(x.view(-1, 256)))
        x = self.fc2(x)
        return x

In [570]:
def train_model(model, train_input, train_target, mini_batch_size):
    criterion = nn.CrossEntropyLoss()
    eta = 1e-2
    optimizer = optim.SGD(model.parameters(), lr = eta)
    nb_epochs = 25

    for e in range(nb_epochs):
        sum_loss = 0
        for b in range(0, train_input.size(0), mini_batch_size):
            output = model(train_input.narrow(0, b, mini_batch_size))
            loss = criterion(output, train_target.narrow(0, b, mini_batch_size))
            model.zero_grad()
            loss.backward()
            sum_loss = sum_loss + loss.item()
            optimizer.step()
        print(e, sum_loss, loss.item())

def compute_nb_errors(model, input, target, mini_batch_size):
    nb_errors = 0

    for b in range(0, input.size(0), mini_batch_size):
        output = model(input.narrow(0, b, mini_batch_size))
        _, predicted_classes = torch.max(output, 1)
        for k in range(mini_batch_size):
            if target[b + k] != predicted_classes[k]:
                nb_errors += 1
                
    return nb_errors

In [571]:
model = Net(200)
mini_batch_size = 10

train_model(model, train_input, train_target, mini_batch_size)
nb_test_errors = compute_nb_errors(model, test_input, test_target, mini_batch_size)
print('test error of Net {:0.2f}% {:d}/{:d}'\
      .format((100 * nb_test_errors) / test_input.size(0), nb_test_errors, test_input.size(0)))

0 67.6786521077156 0.6805179119110107
1 64.84696900844574 0.6585527658462524
2 60.8485132753849 0.6247842311859131
3 55.65058037638664 0.6004491448402405
4 51.15196964144707 0.589335560798645
5 47.78808984160423 0.5765329599380493
6 44.7143219858408 0.5606686472892761
7 41.68522794544697 0.542558491230011
8 38.71396865695715 0.5174574851989746
9 35.90190816670656 0.4904783368110657
10 33.2814721390605 0.45714932680130005
11 30.735079683363438 0.4257314205169678
12 28.2323512211442 0.38853901624679565
13 25.69994755089283 0.353310763835907
14 23.130427099764347 0.3111879527568817
15 20.49694600701332 0.2673245072364807
16 17.913814686238766 0.22689025104045868
17 15.476673144847155 0.19381263852119446
18 13.122067553922534 0.15925191342830658
19 10.846822110936046 0.12692096829414368
20 8.666893860325217 0.09930717945098877
21 6.815945664420724 0.08093750476837158
22 5.340742407366633 0.06476687639951706
23 4.265223665162921 0.054555557668209076
24 3.412018136586994 0.044531241059303284

# CNN with secondary loss

In [572]:
class Net(nn.Module):
    def __init__(self, nb_hidden):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(2, 32, kernel_size=3)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.fc1 = nn.Linear(256, nb_hidden)
        self.fc2 = nn.Linear(nb_hidden, 2)
        # The number of params:
        # conv1: 32 * (2 * 3 * 3 + 1) = 608
        # conv2: 64 * (32 * 3 * 3 + 1) = 18,496
        # fc1: (256 + 1) * 200 = 51,400
        # fc2: (200 + 1) * 2 = 402

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), kernel_size=2))
        x = F.relu(F.max_pool2d(self.conv2(x), kernel_size=2))
        x = F.relu(self.fc1(x.view(-1, 256)))
        x = self.fc2(x)
        return x

In [581]:
train_classes

tensor([[9, 6],
        [9, 3],
        [2, 8],
        ...,
        [6, 5],
        [2, 6],
        [5, 2]])

In [579]:
def train_model(model, train_input, train_target, train_classes, mini_batch_size):
    criterion = nn.CrossEntropyLoss()
    eta = 1e-2
    optimizer = optim.SGD(model.parameters(), lr = eta)
    nb_epochs = 25
    lambda_l2 = 1e-3
    lambda_dif = 1e-2

    for e in range(nb_epochs):
        sum_loss = 0
        for b in range(0, train_input.size(0), mini_batch_size):
            output = model(train_input.narrow(0, b, mini_batch_size))
            
            loss = criterion(output, train_target.narrow(0, b, mini_batch_size))
            for p in model.parameters():
                loss += lambda_l2 * p.pow(2).sum()
            loss += 
            
            model.zero_grad()
            loss.backward()
            sum_loss = sum_loss + loss.item()
            optimizer.step()
            
        print(e, sum_loss, loss.item())

def compute_nb_errors(model, input, target, mini_batch_size):
    nb_errors = 0

    for b in range(0, input.size(0), mini_batch_size):
        output = model(input.narrow(0, b, mini_batch_size))
        _, predicted_classes = torch.max(output, 1)
        for k in range(mini_batch_size):
            if target[b + k] != predicted_classes[k]:
                nb_errors += 1
                
    return nb_errors

In [580]:
model = Net(200)
mini_batch_size = 10

train_model(model, train_input, train_target, mini_batch_size)
nb_test_errors = compute_nb_errors(model, test_input, test_target, mini_batch_size)
print('test error of Net {:0.2f}% {:d}/{:d}'\
      .format((100 * nb_test_errors) / test_input.size(0), nb_test_errors, test_input.size(0)))

0 77.71205252408981 0.7981922626495361
1 75.2626336812973 0.7867611050605774
2 72.00800466537476 0.7636763453483582
3 67.23345014452934 0.739741861820221
4 62.08496016263962 0.7242040038108826
5 57.79992687702179 0.7162256836891174
6 54.226865351200104 0.7041693329811096
7 51.12853318452835 0.6902925372123718
8 48.222820833325386 0.678102433681488
9 45.53885594010353 0.6560991406440735
10 42.98359262943268 0.6288289427757263
11 40.37042357027531 0.5963085889816284
12 37.89741112291813 0.5597342848777771
13 35.393269032239914 0.5280466079711914
14 33.00294107198715 0.4892234802246094
15 30.49362015724182 0.44699838757514954
16 28.146579667925835 0.40244346857070923
17 25.764848053455353 0.35953861474990845
18 23.5286066532135 0.3182789981365204
19 21.309089943766594 0.2768385708332062
20 19.436357460916042 0.23924520611763
21 17.692589186131954 0.20404361188411713
22 16.371249094605446 0.1814272105693817
23 15.337117664515972 0.1671266108751297
24 14.527177058160305 0.15548457205295563
