In [3]:
import torch
from torch import nn, optim
from torch.autograd import Variable
from torch.nn import functional as F
from torchvision import datasets

import dlc_practical_prologue as prologue
from BaseNet import *
from ConvNet1 import *

########################################################
### FRAMEWORK FOR INPUT AS TWO SINGLE CHANNEL IMAGES ###

# In this framework, the network is first trained to recognize the digits of each image from each pair and with the help of the class labels. To do so, we use the class labels provided and use a CrossEntropyLoss to maximize the response of the correct digit. Once the network can predict the digits, we compare the digits and define if they are a pair or not

nb_classes = 10
nb_input_channels = 1

mini_batch_size = 100
nb_epochs = 300
eta = 0.001

def prep_input_vanilla(train_input):
    new_train_input = train_input.view(-1,1,14,14)
    return new_train_input

def prep_target_vanilla(train_classes):    
    train_classes = train_classes.flatten()   
    stack = torch.tensor([0.0]*10)    
    for i in train_classes:
        current = torch.tensor([0.0]*10)
        current[i.item()] = 1
        stack = torch.cat([stack.view(-1,10),current.view(-1,10)])       
    return stack[1:]


# Computes the number of errors when predicting with the model passed as parameter and comparing results with the target classes; this corresponds to the task of our network, ie predicting the right digit.
def compute_nb_errors(model, input_, target, mini_batch_size):
    nb_errors = 0  
    for b in range(0, input_.size(0), mini_batch_size):
        output = model(input_.narrow(0, b, mini_batch_size))
        _, target_classes = torch.max(target.narrow(0, b, mini_batch_size), 1)
        _, predicted_classes = torch.max(outputs, 1)
        nb_errors += (predicted_classes != target_classes).sum().item()      
    return nb_errors

# input = 1000x2x14x14: Gets a 2 channels input tensor, predicts a digit for each image of each channel using the model passed as parameter, then compares the pairs channel-wise
# output = 1000: 1 if the 2 images are a pair, 0 otherwise
def compare_pairs(model, input_):
    tensor_a = torch.max(model(input_[:,0,:,:].view(-1,1,14,14)), 1)[1]
    tensor_b = torch.max(model(input_[:,1,:,:].view(-1,1,14,14)),1)[1]
    return torch.eq(tensor_a, tensor_b)


def train_model_1C(model, train_input, train_classes, mini_batch_size=mini_batch_size,
                   criterion=torch.nn.CrossEntropyLoss(), nb_epochs=nb_epochs, eta=eta):
    
    optimizer = torch.optim.SGD(model.parameters(), lr=eta, momentum=0.95)
    
    train_input = prep_input_vanilla(train_input)
    train_target = prep_target_vanilla(train_classes) # the target are the class labels
    
    for e in range(0, nb_epochs):
        for b in range(0, train_input.size(0), mini_batch_size):
            output = model(train_input.narrow(0, b, mini_batch_size))
            #print("output = ", output.type)
            target = train_target.narrow(0, b, mini_batch_size)
            # the nn.CrossEntropyLoss expects a class index as the target for each value
            loss = criterion(output,target.max(1)[1])
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        if (e % 10 == 9):
            print('epoch: %d, loss: %.5f' %
                  (e+1, loss.data.item()))


In [4]:
train_input, train_target, train_classes, test_input, test_target, test_classes = \
    prologue.generate_pair_sets(nb=1000)

model = BaseNet(nb_classes)
#model = ConvNet1(nb_classes)

train_model_1C(model, train_input, train_classes)

epoch: 10, loss: 0.04437
epoch: 20, loss: 0.00349
epoch: 30, loss: 0.00189
epoch: 40, loss: 0.00135
epoch: 50, loss: 0.00105
epoch: 60, loss: 0.00086
epoch: 70, loss: 0.00073
epoch: 80, loss: 0.00063
epoch: 90, loss: 0.00055
epoch: 100, loss: 0.00049
epoch: 110, loss: 0.00044
epoch: 120, loss: 0.00040
epoch: 130, loss: 0.00037
epoch: 140, loss: 0.00034
epoch: 150, loss: 0.00032
epoch: 160, loss: 0.00030
epoch: 170, loss: 0.00028
epoch: 180, loss: 0.00026
epoch: 190, loss: 0.00025
epoch: 200, loss: 0.00023
epoch: 210, loss: 0.00022
epoch: 220, loss: 0.00021
epoch: 230, loss: 0.00020
epoch: 240, loss: 0.00019
epoch: 250, loss: 0.00018
epoch: 260, loss: 0.00017
epoch: 270, loss: 0.00017
epoch: 280, loss: 0.00016
epoch: 290, loss: 0.00015
epoch: 300, loss: 0.00015
