In [1]:
import torch
from torch import nn, optim
from torch.autograd import Variable
from torch.nn import functional as F
from torchvision import datasets

import dlc_practical_prologue as prologue
from BaseNet import *
from ConvNet1 import *

########################################################
### FRAMEWORK FOR INPUT AS TWO SINGLE CHANNEL IMAGES ###

# In this framework, the network is first trained to recognize the digits of each image from each pair and with the help of the class labels. To do so, we use the class labels provided and use a CrossEntropyLoss to maximize the response of the correct digit. Once the network can predict the digits, we compare the digits and define if they are a pair or not

nb_classes = 10
nb_input_channels = 1

mini_batch_size = 100
nb_epochs = 300
eta = 0.001

def prep_input_vanilla(train_input):
    new_train_input = train_input.view(-1,1,14,14)
    return new_train_input

def prep_target_vanilla(train_classes):    
    train_classes = train_classes.flatten()   
    stack = torch.tensor([0.0]*10)    
    for i in train_classes:
        current = torch.tensor([0.0]*10)
        current[i.item()] = 1
        stack = torch.cat([stack.view(-1,10),current.view(-1,10)])       
    return stack[1:]


# Computes the number of errors when predicting with the model passed as parameter and comparing results with the target classes; this corresponds to the task of our network, ie predicting the right digit.
def compute_nb_errors(model, input_, target, mini_batch_size=mini_batch_size):
    nb_errors = 0  
    for b in range(0, input_.size(0), mini_batch_size):
        output = model(input_.narrow(0, b, mini_batch_size))
        _, target_classes = torch.max(target.narrow(0, b, mini_batch_size), 1)
        _, predicted_classes = torch.max(output, 1)
        nb_errors += (predicted_classes != target_classes).sum().item()      
    return nb_errors

# input = 1000x2x14x14: Gets a 2 channels input tensor, predicts a digit for each image of each channel using the model passed as parameter, then compares the pairs channel-wise
# output = 1000: 1 if the 2 images are a pair, 0 otherwise
def compare_pairs(model, input_):
    tensor_a = torch.max(model(input_[:,0,:,:].view(-1,1,14,14)), 1)[1]
    tensor_b = torch.max(model(input_[:,1,:,:].view(-1,1,14,14)),1)[1]
    return torch.eq(tensor_a, tensor_b)


# Network is a classifier: it treats the pairs as 2 images of 1 channel and is trained to predict the digit from the image
def train_model_1C(model, train_input, train_classes, mini_batch_size=mini_batch_size,
                   criterion=torch.nn.CrossEntropyLoss(), nb_epochs=nb_epochs, eta=eta):
    
    optimizer = torch.optim.SGD(model.parameters(), lr=eta, momentum=0.95)
    
    train_input = prep_input_vanilla(train_input)
    train_target = prep_target_vanilla(train_classes) # the target are the class labels
    
    for e in range(0, nb_epochs):
        for b in range(0, train_input.size(0), mini_batch_size):
            output = model(train_input.narrow(0, b, mini_batch_size))
            target = train_target.narrow(0, b, mini_batch_size)
            # the nn.CrossEntropyLoss expects a class index as the target for each value
            loss = criterion(output,target.max(1)[1])
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        if (e % 100 == 99):
            print('epoch: %d, loss: %.5f' %
                  (e+1, loss.data.item()))
            
def test_model_1C(model, test_input, test_target, test_classes):
    test_input_vanilla = prep_input_vanilla(test_input)
    test_classes_target = prep_target_vanilla(test_classes)
    
    # Number of digits incorrectly identified
    nb_errors_digits = compute_nb_errors(model, test_input_vanilla, test_classes_target)
    print("nb_errors_digits = ", nb_errors_digits)
    
    # Number of pairs incorrectly identified
    test_output_pairs = compare_pairs(model, test_input).type(torch.LongTensor)
    nb_errors_pairs = torch.abs(test_output_pairs-test_target).sum().item()
    print("nb_errors_pairs = ", nb_errors_pairs)
    

In [2]:
train_input, train_target, train_classes, test_input, test_target, test_classes = \
    prologue.generate_pair_sets(nb=1000)

#model = BaseNet(nb_classes)
model = ConvNet1(nb_classes)

train_model_1C(model, train_input, train_classes)


epoch: 100, loss: 0.00136
epoch: 200, loss: 0.00045
epoch: 300, loss: 0.00025


In [3]:
test_model_1C(model, test_input, test_target, test_classes)

nb_errors_digits =  172
nb_errors_pairs =  448


In [4]:
# Computes the number of errors between target and classes
def compute_nb_errors_data(set_, target, classes): 
    is_pair = torch.eq(classes[:,0], classes[:,1]).type(torch.LongTensor)
    nb_errors = torch.abs(target-is_pair).sum().item()
    print('There are %d errors in the %s dataset' % (nb_errors, set_))
    
compute_nb_errors_data('training', train_target, train_classes)
compute_nb_errors_data('test', test_target, test_classes)

There are 462 errors in the training dataset
There are 429 errors in the test dataset
