## Deep Learning Project 1 

Comparing digits

### TODO Next : 

- Maybe we are already using Weight Sharing ? **WE DO**
- Maybe get better performances ? 
- Add dropout layer and stuff like this
- Add Auxilary losses (also return CNNs results and use them with a loss function, similar to comparisson net) **Done**
- Benchmark **Function Done** 
- Write report **Started** 

In [1]:
import torch
import math
import dlc_practical_prologue as prologue
from torch import optim
from torch import Tensor
from torch import nn
from torch.nn import functional as F

In [2]:
# Generate the train and test sets.
N = 1000
train_input, train_target, train_classes, test_input, test_target, test_classes = prologue.generate_pair_sets(N)

In [21]:
# This model performs each digit classification with 2 different CNNs (so no weight sharing)        
class No_Weight_Sharing_Net(nn.Module):
    def __init__(self):
        super().__init__()
        
        # Layers that handle digit classification with first CNN
        self.conv1_1 = nn.Conv2d(1, 32, kernel_size=3)
        self.conv2_1 = nn.Conv2d(32, 64, kernel_size=3)
        self.fc1_1 = nn.Linear(256, 200)
        self.fc2_1 = nn.Linear(200, 10)
        
        # Layers that handle digit classification with second CNN
        self.conv1_2 = nn.Conv2d(1, 32, kernel_size=3)
        self.conv2_2 = nn.Conv2d(32, 64, kernel_size=3)
        self.fc1_2 = nn.Linear(256, 200)
        self.fc2_2 = nn.Linear(200, 10)
        
        # Layers that handle comparisson 
        self.fc3 = nn.Linear(20, 300)
        self.fc4 = nn.Linear(300, 300)
        self.fc5 = nn.Linear(300, 2)
        
    def cnn1(self, x):
        x = F.relu(F.max_pool2d(self.conv1_1(x), kernel_size=2))
        x = F.relu(F.max_pool2d(self.conv2_1(x), kernel_size=2))
        x = F.relu(self.fc1_1(x.view(-1, 256)))
        x = self.fc2_1(x)
        return x
    
    def cnn2(self, x):
        x = F.relu(F.max_pool2d(self.conv1_2(x), kernel_size=2))
        x = F.relu(F.max_pool2d(self.conv2_2(x), kernel_size=2))
        x = F.relu(self.fc1_2(x.view(-1, 256)))
        x = self.fc2_2(x)
        return x
    
    def mlp(self, x):
        x = F.relu(self.fc3(x))
        x = F.relu(self.fc4(x))
        x = self.fc5(x)
        return x
    
    def forward(self, x):
        s = x.shape
        input_1 = x[:,0,:,:].reshape([s[0],1,s[2],s[3]])
        input_2 = x[:,1,:,:].reshape([s[0],1,s[2],s[3]])
        
        output_1 = self.cnn1(input_1)
        output_2 = self.cnn2(input_2)
        
        concatenated = torch.cat((output_1, output_2), 1)
        
        comparison = self.mlp(concatenated)
        return comparison   

In [53]:
# Model Definition 

        
class Simple_Net(nn.Module):
    def __init__(self):
        super().__init__()
        
        # Layers that handle digit classification 
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.fc1 = nn.Linear(256, 200)
        self.fc2 = nn.Linear(200, 10)
        
        # Layers that handle comparisson 
        self.fc3 = nn.Linear(20, 300)
        self.fc4 = nn.Linear(300, 300)
        self.fc5 = nn.Linear(300, 2)
        
    def cnn(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), kernel_size=2))
        x = F.relu(F.max_pool2d(self.conv2(x), kernel_size=2))
        x = F.relu(self.fc1(x.view(-1, 256)))
        x = self.fc2(x)
        return x
    
    def mlp(self, x):
        x = F.relu(self.fc3(x))
        x = F.relu(self.fc4(x))
        x = self.fc5(x)
        return x
    
    def forward(self, x):
        s = x.shape
        input_1 = x[:,0,:,:].reshape([s[0],1,s[2],s[3]])
        input_2 = x[:,1,:,:].reshape([s[0],1,s[2],s[3]])
        
        output_1 = self.cnn(input_1)
        output_2 = self.cnn(input_2)
        
        concatenated = torch.cat((output_1, output_2), 1)
        
        comparison = self.mlp(concatenated)
        return comparison   

In [56]:
def train_model_simple_net(model, train_input, train_target, mini_batch_size, nb_epochs = 100, use_optimizer= None, _print=False):
    criterion = nn.MSELoss()
    eta = 1e-3
    if use_optimizer == "sgd":
        optimizer = optim.SGD(model.parameters(), lr=eta)
    if use_optimizer == "adam":
        optimizer = optim.Adam(model.parameters(), lr=eta)
    for e in range(nb_epochs):
        acc_loss = 0

        for b in range(0, train_input.size(0), mini_batch_size):
            output = model(train_input.narrow(0, b, mini_batch_size))
            target = train_target.narrow(0, b, mini_batch_size).reshape(output.shape).float()
            
            loss = criterion(output, target)
            acc_loss = acc_loss + loss.item()
 
            model.zero_grad()
            loss.backward()
            
            if use_optimizer != None :
                optimizer.step()
            else :
                with torch.no_grad():
                    for p in model.parameters():
                        p -= eta * p.grad
        if _print:
            print(e, acc_loss)
            
def train_model_simple_net_2(model, train_input, train_target, mini_batch_size, nb_epochs = 100, use_optimizer= None, _print=False):
    criterion = nn.CrossEntropyLoss()
    eta = 1e-3
    if use_optimizer == "sgd":
        optimizer = optim.SGD(model.parameters(), lr=eta)
    if use_optimizer == "adam":
        optimizer = optim.Adam(model.parameters(), lr=eta)
    for e in range(nb_epochs):
        acc_loss = 0

        for b in range(0, train_input.size(0), mini_batch_size):
            output = model(train_input.narrow(0, b, mini_batch_size))
            target = train_target.narrow(0, b, mini_batch_size).long()
            loss = criterion(output, target)
            acc_loss = acc_loss + loss.item()
 
            model.zero_grad()
            loss.backward()
            
            if use_optimizer != None :
                optimizer.step()
            else :
                with torch.no_grad():
                    for p in model.parameters():
                        p -= eta * p.grad
        if _print:
            print(e, acc_loss)
        
def compute_nb_errors_simple_net(model, input, target, mini_batch_size):
    nb_errors = 0

    for b in range(0, input.size(0), mini_batch_size):
        output = model(input.narrow(0, b, mini_batch_size))
        _, predicted_classes = output.max(1)
        for k in range(mini_batch_size):
            if target[b + k, predicted_classes[k]] <= 0:
                nb_errors = nb_errors + 1

    return nb_errors

In [57]:
model_total = Simple_Net()

train_target_one_hot = prologue.convert_to_one_hot_labels(train_input, train_target)
train_model_simple_net_2(model_total, train_input, train_target, mini_batch_size=250, nb_epochs=25, use_optimizer="adam")

In [58]:

test_target_total = prologue.convert_to_one_hot_labels(test_input, test_target)
nb_test_errors = compute_nb_errors_simple_net(model_total, test_input, test_target_total, mini_batch_size=250)
print('test error Net {:0.2f}% {:d}/{:d}'.format((100 * nb_test_errors) / test_input.size(0),
                                                      nb_test_errors, test_input.size(0)))

test error Net 14.90% 149/1000


In [90]:
for i in range(10):
    input_to_test = test_input[i]
    first_label = test_classes[i][0]
    second_label = test_classes[i][1] 
    s = input_to_test.shape
    output = model_total(input_to_test.reshape([1,s[0], s[1], s[2]]))
    _, predicted_classes = output.max(1)
    print(f"Predicted : {first_label} {'>' if predicted_classes.item() == 0 else '<'} {second_label}")

Predicted : 4 < 8
Predicted : 3 < 7
Predicted : 3 < 7
Predicted : 1 < 4
Predicted : 4 < 5
Predicted : 4 < 6
Predicted : 3 < 7
Predicted : 4 > 1
Predicted : 5 > 1
Predicted : 6 < 5


In [91]:
# Benchmark of the basic network with Adam optimizer
nb_trials = 10
N = 1000
performances = []
for trial in range(nb_trials):
    
    # Generate Data 
    train_input, train_target, train_classes, test_input, test_target, test_classes = prologue.generate_pair_sets(N)
    train_target_one_hot = prologue.convert_to_one_hot_labels(train_input, train_target)
    test_target_total = prologue.convert_to_one_hot_labels(test_input, test_target)
    
    # Define the model 
    model_total = Simple_Net()
    
    # Train the model
    train_model_simple_net(model_total, train_input, train_target_one_hot, mini_batch_size=250, 
                      nb_epochs=25, use_optimizer="adam")
    
    # Evaluate performances 
    nb_test_errors = compute_nb_errors_simple_net(model_total, test_input, test_target_total, mini_batch_size=250)
    print('test error Net {:d} {:0.2f}% {:d}/{:d}'.format(trial, (100 * nb_test_errors) / test_input.size(0),
                                                          nb_test_errors, test_input.size(0)))
    performances.append(nb_test_errors)
    
mean_perf = 100 * sum(performances) / (N * nb_trials)
print(f"Average precision of this architecture {mean_perf}")

0 4.6856465339660645
1 1.4614585041999817
2 1.0191197395324707
3 0.9204623401165009
4 0.8185340911149979
5 0.7011362165212631
6 0.6271707266569138
7 0.5884476453065872
8 0.5331521481275558
9 0.4940749928355217
10 0.4571880176663399
11 0.421296089887619
12 0.39193376153707504
13 0.35857199132442474
14 0.327596090734005
15 0.3004740923643112
16 0.2747432813048363
17 0.2485593818128109
18 0.22220168635249138


KeyboardInterrupt: 

In [67]:
class Auxiliary_Loss_Net(nn.Module):
    def __init__(self):
        super().__init__()
        
        # Layers that handle digit classification 
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.fc1 = nn.Linear(256, 200)
        self.fc2 = nn.Linear(200, 10)
        
        # Layers that handle comparisson 
        self.fc3 = nn.Linear(20, 300)
        self.fc4 = nn.Linear(300, 300)
        self.fc5 = nn.Linear(300, 2)
        
    def cnn(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), kernel_size=2))
        x = F.relu(F.max_pool2d(self.conv2(x), kernel_size=2))
        x = F.relu(self.fc1(x.view(-1, 256)))
        x = self.fc2(x)
        return x
    
    def mlp(self, x):
        x = F.relu(self.fc3(x))
        x = F.relu(self.fc4(x))
        x = self.fc5(x)
        return x
    
    def forward(self, x):
        s = x.shape
        
        input_1 = x[:,0,:,:].reshape([s[0],1,s[2],s[3]])
        input_2 = x[:,1,:,:].reshape([s[0],1,s[2],s[3]])
        
        output_1 = self.cnn(input_1)
        output_2 = self.cnn(input_2)
        
        concatenated = torch.cat((output_1, output_2), 1)
        
        comparison = self.mlp(concatenated)
        return output_1, output_2, comparison  
    
def train_model_auxiliary_loss(model, train_input, train_target, train_classes, mini_batch_size, nb_epochs = 100, use_optimizer= None, _print=False):
    criterion_auxilary = nn.CrossEntropyLoss()
    criterion_final = nn.MSELoss()
    
    eta = 1e-3
    if use_optimizer == "sgd":
        optimizer = optim.SGD(model.parameters(), lr=eta)
    if use_optimizer == "adam":
        optimizer = optim.Adam(model.parameters(), lr=eta)
    for e in range(nb_epochs):
        acc_loss = 0

        for b in range(0, train_input.size(0), mini_batch_size):
            digit_1, digit_2, comparison = model(train_input.narrow(0, b, mini_batch_size))
            
            target_comparison = train_target.narrow(0, b, mini_batch_size).reshape(comparison.shape).float()
            
            target_digit_1, target_digit_2 = train_classes.narrow(0, b, mini_batch_size)[:,0], train_classes.narrow(0, b, mini_batch_size)[:,1]
            loss1 = criterion_auxilary(digit_1, target_digit_1)
            loss2 = criterion_auxilary(digit_2, target_digit_2)
            loss3 = criterion_final(comparison, target_comparison)
            acc_loss = acc_loss + loss1.item() + loss2.item() + loss3.item()
 
            model.zero_grad()
            loss1.backward(retain_graph=True)
            loss2.backward(retain_graph=True)
            loss3.backward()
            
            if use_optimizer != None :
                optimizer.step()
            else :
                with torch.no_grad():
                    for p in model.parameters():
                        p -= eta * p.grad
        if _print :
            print(e, acc_loss)
        
def train_model_auxiliary_loss_2(model, train_input, train_target, train_classes, mini_batch_size, nb_epochs = 100, use_optimizer= None, _print=False):
    criterion_auxilary = nn.CrossEntropyLoss()
    criterion_final = nn.CrossEntropyLoss()
    
    eta = 1e-3
    if use_optimizer == "sgd":
        optimizer = optim.SGD(model.parameters(), lr=eta)
    if use_optimizer == "adam":
        optimizer = optim.Adam(model.parameters(), lr=eta)
    for e in range(nb_epochs):
        acc_loss = 0

        for b in range(0, train_input.size(0), mini_batch_size):
            digit_1, digit_2, comparison = model(train_input.narrow(0, b, mini_batch_size))
            
            target_comparison = train_target.narrow(0, b, mini_batch_size).long()
            
            target_digit_1, target_digit_2 = train_classes.narrow(0, b, mini_batch_size)[:,0], train_classes.narrow(0, b, mini_batch_size)[:,1]
            loss1 = criterion_auxilary(digit_1, target_digit_1)
            loss2 = criterion_auxilary(digit_2, target_digit_2)
            loss3 = criterion_final(comparison, target_comparison)
            acc_loss = acc_loss + loss1.item() + loss2.item() + loss3.item()
 
            model.zero_grad()
            loss1.backward(retain_graph=True)
            loss2.backward(retain_graph=True)
            loss3.backward()
            
            if use_optimizer != None :
                optimizer.step()
            else :
                with torch.no_grad():
                    for p in model.parameters():
                        p -= eta * p.grad
        if _print :
            print(e, acc_loss)
def compute_nb_errors_auxilary_loss(model, input, target, mini_batch_size):
    nb_errors = 0

    for b in range(0, input.size(0), mini_batch_size):
        _, _, output = model(input.narrow(0, b, mini_batch_size))
        _, predicted_classes = output.max(1)
        for k in range(mini_batch_size):
            if target[b + k, predicted_classes[k]] <= 0:
                nb_errors = nb_errors + 1

    return nb_errors

In [50]:
model_auxiliary = Auxiliary_Loss_Net()

train_target_one_hot = prologue.convert_to_one_hot_labels(train_input, train_target)
train_model_auxiliary_loss(model_auxiliary, train_input, train_target_one_hot, train_classes, mini_batch_size=250, nb_epochs=25, use_optimizer="adam")

torch.Size([250, 10]) torch.Size([250])
torch.Size([250, 10]) torch.Size([250])
torch.Size([250, 10]) torch.Size([250])
torch.Size([250, 10]) torch.Size([250])
torch.Size([250, 10]) torch.Size([250])
torch.Size([250, 10]) torch.Size([250])
torch.Size([250, 10]) torch.Size([250])
torch.Size([250, 10]) torch.Size([250])
torch.Size([250, 10]) torch.Size([250])
torch.Size([250, 10]) torch.Size([250])
torch.Size([250, 10]) torch.Size([250])
torch.Size([250, 10]) torch.Size([250])
torch.Size([250, 10]) torch.Size([250])
torch.Size([250, 10]) torch.Size([250])
torch.Size([250, 10]) torch.Size([250])
torch.Size([250, 10]) torch.Size([250])
torch.Size([250, 10]) torch.Size([250])
torch.Size([250, 10]) torch.Size([250])
torch.Size([250, 10]) torch.Size([250])
torch.Size([250, 10]) torch.Size([250])
torch.Size([250, 10]) torch.Size([250])
torch.Size([250, 10]) torch.Size([250])
torch.Size([250, 10]) torch.Size([250])
torch.Size([250, 10]) torch.Size([250])
torch.Size([250, 10]) torch.Size([250])


In [101]:

test_target_total = prologue.convert_to_one_hot_labels(test_input, test_target)
nb_test_errors = compute_nb_errors_auxilary_loss(model_auxiliary, test_input, test_target_total, mini_batch_size=250)
print('test error Net {:0.2f}% {:d}/{:d}'.format((100 * nb_test_errors) / test_input.size(0),
                                                      nb_test_errors, test_input.size(0)))

test error Net 9.70% 97/1000


In [76]:
def benchmark_model(model, train_function, evaluate_function, nb_trials=40, N=1000, mini_batch_size=250, nb_epochs=25, model_requires_target_and_classes=False, one_hot_train_target=True, _print=False):
    # Benchmark of the basic network with Adam optimizer
    performances = []
    for trial in range(nb_trials):

        # Generate Data 
        train_input, train_target, train_classes, test_input, test_target, test_classes = prologue.generate_pair_sets(N)
        if one_hot_train_target:
            train_target_one_hot = prologue.convert_to_one_hot_labels(train_input, train_target)
        else :
            train_target_one_hot = train_target
        test_target_one_hot = prologue.convert_to_one_hot_labels(test_input, test_target)

        # Define the model 
        model_total = model()

        # Train the model
        if model_requires_target_and_classes : 
            train_function(model_total, train_input, train_target_one_hot, train_classes, mini_batch_size=mini_batch_size,
                           nb_epochs=nb_epochs, use_optimizer="adam", _print=_print)
        else :
            train_function(model_total, train_input, train_target_one_hot, mini_batch_size=mini_batch_size,
                           nb_epochs=nb_epochs, use_optimizer="adam", _print=_print)

        # Evaluate performances 
        nb_test_errors = evaluate_function(model_total, test_input, test_target_one_hot, mini_batch_size=mini_batch_size)
        print('test error Net trial {:d} {:0.2f}% {:d}/{:d}'.format(trial, (100 * nb_test_errors) / test_input.size(0),
                                                              nb_test_errors, test_input.size(0)))
        performances.append(nb_test_errors)

    mean_perf = 100 * sum(performances) / (N * nb_trials)
    print(f"Average precision of this architecture {mean_perf}%")
    
    std_dev = math.sqrt(sum(list(map(lambda x : x - mean_perf,performances))))/nb_trials
    print(f"With standard deviation of  {std_dev}")

In [77]:
print("Benchmark of the model with no Weight Sharing")
#benchmark_model(No_Weight_Sharing_Net, train_model_simple_net, compute_nb_errors_simple_net)
print("Benchmark of the model with no Weight Sharing CrossEntropyLoss")
benchmark_model(No_Weight_Sharing_Net, train_model_simple_net_2, compute_nb_errors_simple_net, one_hot_train_target=False)

Benchmark of the model with no Weight Sharing
Benchmark of the model with no Weight Sharing CrossEntropyLoss
test error Net trial 0 16.50% 165/1000
test error Net trial 1 16.90% 169/1000
test error Net trial 2 20.10% 201/1000
test error Net trial 3 16.00% 160/1000
test error Net trial 4 24.10% 241/1000
test error Net trial 5 16.30% 163/1000
test error Net trial 6 18.30% 183/1000
test error Net trial 7 14.70% 147/1000
test error Net trial 8 16.00% 160/1000
test error Net trial 9 17.70% 177/1000
test error Net trial 10 16.80% 168/1000
test error Net trial 11 15.20% 152/1000
test error Net trial 12 17.80% 178/1000
test error Net trial 13 17.50% 175/1000
test error Net trial 14 16.70% 167/1000
test error Net trial 15 16.90% 169/1000
test error Net trial 16 19.40% 194/1000
test error Net trial 17 25.00% 250/1000
test error Net trial 18 18.40% 184/1000
test error Net trial 19 15.20% 152/1000
test error Net trial 20 15.50% 155/1000
test error Net trial 21 15.20% 152/1000
test error Net trial 

In [78]:
print("Benchmark of the model with Weight Sharing MSE")
#benchmark_model(Simple_Net, train_model_simple_net, compute_nb_errors_simple_net)
print("Benchmark of the model with Weight Sharing CrossEntropyLoss")
benchmark_model(Simple_Net, train_model_simple_net_2, compute_nb_errors_simple_net, one_hot_train_target=False)

Benchmark of the model with Weight Sharing MSE
Benchmark of the model with Weight Sharing CrossEntropyLoss
test error Net trial 0 15.60% 156/1000
test error Net trial 1 15.80% 158/1000
test error Net trial 2 16.60% 166/1000
test error Net trial 3 18.60% 186/1000
test error Net trial 4 18.10% 181/1000
test error Net trial 5 16.90% 169/1000
test error Net trial 6 15.70% 157/1000
test error Net trial 7 16.20% 162/1000
test error Net trial 8 14.10% 141/1000
test error Net trial 9 17.00% 170/1000
test error Net trial 10 13.30% 133/1000
test error Net trial 11 15.50% 155/1000
test error Net trial 12 14.40% 144/1000
test error Net trial 13 16.20% 162/1000
test error Net trial 14 17.40% 174/1000
test error Net trial 15 17.40% 174/1000
test error Net trial 16 15.70% 157/1000
test error Net trial 17 16.20% 162/1000
test error Net trial 18 13.80% 138/1000
test error Net trial 19 14.60% 146/1000
test error Net trial 20 12.40% 124/1000
test error Net trial 21 17.30% 173/1000
test error Net trial 22

In [79]:
print("Benchmark of the model with Weight Sharing and an auxiliary loss MSE")
#benchmark_model(Auxiliary_Loss_Net, train_model_auxiliary_loss, compute_nb_errors_auxilary_loss, model_requires_target_and_classes=True)
print("Benchmark of the model with Weight Sharing and an auxiliary loss Cross Entropy Loss")
benchmark_model(Auxiliary_Loss_Net, train_model_auxiliary_loss_2, compute_nb_errors_auxilary_loss, model_requires_target_and_classes=True, one_hot_train_target=False)

Benchmark of the model with Weight Sharing and an auxiliary loss MSE
Benchmark of the model with Weight Sharing and an auxiliary loss Cross Entropy Loss
test error Net trial 0 11.60% 116/1000
test error Net trial 1 9.80% 98/1000
test error Net trial 2 11.20% 112/1000
test error Net trial 3 10.30% 103/1000
test error Net trial 4 9.50% 95/1000
test error Net trial 5 10.80% 108/1000
test error Net trial 6 9.60% 96/1000
test error Net trial 7 10.40% 104/1000
test error Net trial 8 11.30% 113/1000
test error Net trial 9 9.20% 92/1000
test error Net trial 10 13.90% 139/1000
test error Net trial 11 13.50% 135/1000
test error Net trial 12 10.10% 101/1000
test error Net trial 13 8.80% 88/1000
test error Net trial 14 11.60% 116/1000
test error Net trial 15 9.60% 96/1000
test error Net trial 16 12.20% 122/1000
test error Net trial 17 9.90% 99/1000
test error Net trial 18 8.90% 89/1000
test error Net trial 19 9.40% 94/1000
test error Net trial 20 10.00% 100/1000
test error Net trial 21 9.90% 99/10

In [116]:
class Auxiliary_Loss_Net_2(nn.Module):
    def __init__(self):
        super().__init__()
        
        # Layers that handle digit classification 
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.fc1 = nn.Linear(256, 200)
        self.fc2 = nn.Linear(200, 10)
        self.dropout_1 = nn.Dropout(p=0.1)
        self.dropout_2 = nn.Dropout(p=0.1)
        self.dropout_3 = nn.Dropout(p=0.1)
        # Layers that handle comparisson 
        self.fc3 = nn.Linear(20, 300)
        self.fc4 = nn.Linear(300, 300)
        self.fc5 = nn.Linear(300, 2)
        
    def cnn(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), kernel_size=2))
        x = self.dropout_1(x)
        x = F.relu(F.max_pool2d(self.conv2(x), kernel_size=2))
        x = self.dropout_2(x)
        x = F.relu(self.fc1(x.view(-1, 256)))
        x = self.dropout_3(x)
        x = self.fc2(x)
        return x
    
    def mlp(self, x):
        x = F.relu(self.fc3(x))
        x = F.relu(self.fc4(x))
        x = self.fc5(x)
        return x
    
    def forward(self, x):
        s = x.shape
        
        input_1 = x[:,0,:,:].reshape([s[0],1,s[2],s[3]])
        input_2 = x[:,1,:,:].reshape([s[0],1,s[2],s[3]])
        
        output_1 = self.cnn(input_1)
        output_2 = self.cnn(input_2)
        
        concatenated = torch.cat((output_1, output_2), 1)
        
        comparison = self.mlp(concatenated)
        return output_1, output_2, comparison  
    

In [119]:
trials = 40
benchmark_model(Auxiliary_Loss_Net_2, train_model_auxiliary_loss_2, compute_nb_errors_auxilary_loss, model_requires_target_and_classes=True, one_hot_train_target=False, nb_trials=trials, nb_epochs=50)

benchmark_model(Auxiliary_Loss_Net, train_model_auxiliary_loss_2, compute_nb_errors_auxilary_loss, model_requires_target_and_classes=True, one_hot_train_target=False, nb_trials=trials)


test error Net trial 0 8.70% 87/1000
test error Net trial 1 8.70% 87/1000
test error Net trial 2 8.50% 85/1000
test error Net trial 3 10.30% 103/1000
test error Net trial 4 8.00% 80/1000
test error Net trial 5 9.60% 96/1000
test error Net trial 6 9.00% 90/1000
test error Net trial 7 9.20% 92/1000
test error Net trial 8 7.90% 79/1000
test error Net trial 9 7.90% 79/1000
test error Net trial 10 7.80% 78/1000
test error Net trial 11 9.60% 96/1000
test error Net trial 12 9.50% 95/1000
test error Net trial 13 9.30% 93/1000
test error Net trial 14 8.40% 84/1000
test error Net trial 15 9.60% 96/1000
test error Net trial 16 9.40% 94/1000
test error Net trial 17 8.10% 81/1000
test error Net trial 18 9.10% 91/1000
test error Net trial 19 8.40% 84/1000
test error Net trial 20 7.00% 70/1000
test error Net trial 21 9.90% 99/1000
test error Net trial 22 9.90% 99/1000
test error Net trial 23 9.20% 92/1000
test error Net trial 24 8.00% 80/1000
test error Net trial 25 9.10% 91/1000
test error Net trial