In [1]:
import torch
from torch import nn, optim
from torch.autograd import Variable
from torch.nn import functional as F
from torchvision import datasets

import matplotlib.pyplot as plt
import time
import copy
import dlc_practical_prologue as prologue
from BaseNet import *
from ConvNet1 import *

########################################################
### FRAMEWORK FOR INPUT AS TWO SINGLE CHANNEL IMAGES ###

# In this framework, the network is first trained to recognize the digits of each image from each pair and with the help of the class labels. To do so, we use the class labels provided and use a CrossEntropyLoss to maximize the response of the correct digit. Once the network can predict the digits, we compare the digits together.

nb_classes = 10
nb_input_channels = 1

mini_batch_size = 1000
nb_epochs = 10
eta = 0.01

def prep_input_vanilla(train_input):
    new_train_input = train_input.view(-1,1,14,14)
    return new_train_input

# One-hot encoding for train_classes
def prep_target_vanilla(train_classes): 
    y = train_classes.flatten().view(-1, 1)
    train_onehot = torch.FloatTensor(len(y), nb_classes)
    return train_onehot.scatter_(1, y, 1)


# Computes the number of errors when predicting with the model passed as parameter and comparing results with the target classes; this corresponds to the task of our network, ie predicting the right digit.
def compute_nb_errors(model, input_, target, mini_batch_size=mini_batch_size):
    #print("Model in training mode? ", model.training)
    nb_errors = 0  
    for b in range(0, input_.size(0), mini_batch_size):
        output = model(input_.narrow(0, b, mini_batch_size))
        #_, target_classes = torch.max(target.narrow(0, b, mini_batch_size), 1)
        target_classes = target.narrow(0, b, mini_batch_size)
        _, predicted_classes = torch.max(output, 1)
        nb_errors += (predicted_classes != target_classes).sum().item()      
    return nb_errors

# input = 1000x2x14x14: Gets a 2 channels input tensor, predicts a digit for each image of each channel using the model passed as parameter, then compares the pairs channel-wise
# output = 1000: 1 if the val_a < val_b, 0 otherwise
def compare_pairs(model, input_):
    tensor_a = torch.max(model(input_[:,0,:,:].view(-1,1,14,14)), 1)[1]
    tensor_b = torch.max(model(input_[:,1,:,:].view(-1,1,14,14)),1)[1]
    return torch.le(tensor_a, tensor_b)


# Network is a classifier: it treats the pairs as 2 images of 1 channel and is trained to predict the digit from the image
def train_model_1C(model, train_input, train_classes, optimizer, mini_batch_size=mini_batch_size,
                   criterion=torch.nn.CrossEntropyLoss(), nb_epochs=nb_epochs):
    train_input = prep_input_vanilla(train_input)
    train_target = train_classes.flatten() # the target are the class labels 
    nb_samples = len(train_input)
    
    since = time.time()
    val_acc_history = []
    
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    
    for e in range(0, nb_epochs):
        
        for phase in ['train', 'val']:
            if phase == 'train': model.train()
            else: model.eval()
                
            running_loss = 0.0
            running_corrects = 0
            
            for b in range(0, train_input.size(0), mini_batch_size):
                optimizer.zero_grad()
                with torch.set_grad_enabled(phase == 'train'):
                    output = model(train_input.narrow(0, b, mini_batch_size))
                    print("output = ", output[0])
                    target = train_target.narrow(0, b, mini_batch_size)
                    print("target = ", target[:10])
                    
                    # the nn.CrossEntropyLoss expects a class index as the target for each value
                    loss = criterion(output, target)
                    
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                        
                running_loss += loss.item() * train_input.size(0)
                print("torch.max = ", torch.max(output, 1)[1][:10])
                running_corrects += torch.sum(torch.max(output, 1)[1] == target)       

            epoch_loss = running_loss / nb_samples
            epoch_acc = running_corrects.double() / nb_samples
            
            if (e % 100 == 99):
                print('phase: %s, epoch: %d, loss: %.5f, acc: %.4f' %
                      (phase, e+1, epoch_loss, epoch_acc))
                
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
            if phase == 'val':
                val_acc_history.append(epoch_acc)
                
    time_elapsed = time.time() - since
    print('Training complete in %.0f min %.0f s' % (time_elapsed // 60, time_elapsed % 60))
    print('Best val acc: %.4f' % (best_acc))
    
    model.load_state_dict(best_model_wts)
    return model, val_acc_history, best_acc, time_elapsed
            
def test_model_1C(model, test_input, test_target, test_classes):
    model.eval()
    print("Model in training mode? ", model.training)
    test_input_vanilla = prep_input_vanilla(test_input)
    test_classes_target = test_classes.flatten()
    #test_classes_target = prep_target_vanilla(test_classes)
    
    # Number of digits incorrectly identified
    nb_errors_digits = compute_nb_errors(model, test_input_vanilla, test_classes_target)
    print("nb_errors_digits = ", nb_errors_digits)
    
    # Test accuracy on task = predicting digits
    acc_digits = nb_errors_digits / len(test_input_vanilla)
    
    # Number of wrong predictions (first digit less than or equal to the second)
    test_output_pairs = compare_pairs(model, test_input).type(torch.LongTensor)
    nb_errors_pairs = torch.abs(test_output_pairs-test_target).sum().item()
    print("nb_errors_pairs = ", nb_errors_pairs)
    
    # Test accuracy on task = comparison of pairs
    acc_pairs = nb_errors_pairs / len(test_input)
    
    return nb_errors_digits, acc_digits, nb_errors_pairs, acc_pairs
    

In [2]:
train_input, train_target, train_classes, test_input, test_target, test_classes = \
    prologue.generate_pair_sets(nb=1000)

#model = BaseNet1C(nb_classes)
model = ConvNet1_1C(nb_classes)

optimizer = torch.optim.SGD(model.parameters(), lr=eta, momentum=0.95)
#model, val_acc_history, best_acc, time_elapsed = train_model_1C(model, train_input, train_classes, optimizer)

def compute_properties(lst):
    mean = sum(lst) / len(lst)
    variance = sum([(e-mean)**2 for e in lst]) / (len(lst)-1)
    return mean, variance ** (1/2)

def multiple_training_runs(model, nb_runs):
    list_time = []
    list_best_val_acc = []
    list_nb_errors_digits = []
    list_acc_digits = []
    list_nb_errors_pairs = []
    list_acc_pairs = []
    for i in range(nb_runs):
        model, val_acc_history, best_acc, time_elapsed = train_model_1C(model, train_input, train_classes, optimizer)
        list_time.append(time_elapsed)
        list_best_val_acc.append(best_acc)
        
        nb_errors_digits, acc_digits, nb_errors_pairs, acc_pairs = test_model_1C(model, test_input, test_target, test_classes)
        list_nb_errors_digits.append(nb_errors_digits)
        list_acc_digits.append(acc_digits)
        list_nb_errors_pairs.append(nb_errors_pairs)
        list_acc_pairs.append(acc_pairs)
        
    mean_time, std_time = compute_properties(list_time)
    mean_best_val_acc, std_best_val_acc = compute_properties(list_best_val_acc)
    mean_nb_errors_digits, std_nb_errors_digits = compute_properties(list_nb_errors_digits)
    mean_acc_digits, std_acc_digits = compute_properties(list_acc_digits)
    mean_nb_errors_pairs, std_nb_errors_pairs = compute_properties(list_nb_errors_pairs)
    mean_acc_pairs, std_acc_pairs = compute_properties(list_acc_pairs)
    
    print("mean_time = {}, std_time = {}".format(mean_time, std_time))
    print("mean_acc_digits = {}, std_acc_digits = {}".format(mean_acc_digits, std_acc_digits))
    print("mean_acc_pairs = {}, std_acc_pairs = {}".format(mean_acc_pairs, std_acc_pairs))
    print("mean_nb_errors_pairs = {}, std_nb_errors_pairs = {}".format(mean_nb_errors_pairs, std_nb_errors_pairs))
    
    return mean_time, std_time, mean_best_val_acc, std_best_val_acc, mean_nb_errors_digits, std_nb_errors_digits, mean_acc_digits, std_acc_digits, mean_nb_errors_pairs, std_nb_errors_pairs, mean_acc_pairs, std_acc_pairs
    
multiple_training_runs(model, 10)       

output =  tensor([ 1.0008,  6.3652,  0.2859, -5.0930,  8.3421,  8.6901,  1.0041, -7.3581,
        -3.2682, -2.9248], grad_fn=<SelectBackward>)
target =  tensor([9, 3, 5, 4, 7, 4, 9, 6, 8, 8])
torch.max =  tensor([5, 4, 4, 4, 4, 4, 1, 4, 4, 1])
output =  tensor([  278.2824,  -194.8909,   323.4098,   364.3342, -1500.2299,    30.9957,
          263.9705,   245.6450,   241.4313,   332.2892],
       grad_fn=<SelectBackward>)
target =  tensor([7, 5, 9, 0, 9, 2, 8, 1, 4, 2])
torch.max =  tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3])
output =  tensor([ 100.1775,   79.2622,  112.0291, -758.2475,    1.5727,   82.5626,
         107.0772,   94.7957,  112.0170,   99.1546])
target =  tensor([9, 3, 5, 4, 7, 4, 9, 6, 8, 8])
torch.max =  tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 8])
output =  tensor([ 122.9980,   96.7903,  136.6442, -922.6578,   -1.9350,  100.6969,
         130.9464,  116.1273,  136.3516,  121.0441])
target =  tensor([7, 5, 9, 0, 9, 2, 8, 1, 4, 2])
torch.max =  tensor([2, 2, 2, 2, 2, 2, 2, 8, 8, 2])

output =  tensor([-0.1194,  0.0642,  0.1160,  0.0871,  0.0071,  0.1511,  0.0295, -0.5574,
         0.0910, -0.0581])
target =  tensor([7, 5, 9, 0, 9, 2, 8, 1, 4, 2])
torch.max =  tensor([5, 5, 5, 1, 5, 5, 5, 5, 1, 5])
output =  tensor([-0.1208,  0.1938,  0.1004,  0.0413,  0.0392,  0.1328,  0.0036, -0.5516,
         0.0760, -0.0887], grad_fn=<SelectBackward>)
target =  tensor([9, 3, 5, 4, 7, 4, 9, 6, 8, 8])
torch.max =  tensor([1, 1, 5, 1, 5, 1, 5, 5, 5, 1])
output =  tensor([-0.1507,  0.1823,  0.0506, -0.0195,  0.0023,  0.0994, -0.0297, -0.2250,
         0.0204, -0.1331], grad_fn=<SelectBackward>)
target =  tensor([7, 5, 9, 0, 9, 2, 8, 1, 4, 2])
torch.max =  tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
output =  tensor([-0.1491,  0.1864,  0.0504, -0.0307, -0.0264,  0.0957, -0.0333, -0.1764,
         0.0278, -0.1302])
target =  tensor([9, 3, 5, 4, 7, 4, 9, 6, 8, 8])
torch.max =  tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
output =  tensor([-0.1581,  0.1892,  0.0505, -0.0262, -0.0380,  0.0949, -0.0

output =  tensor([-0.1407,  0.2479, -0.0306, -0.1220, -0.0251,  0.0025, -0.0636,  0.0240,
        -0.0700, -0.0127])
target =  tensor([7, 5, 9, 0, 9, 2, 8, 1, 4, 2])
torch.max =  tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
output =  tensor([-0.1376,  0.2278, -0.0297, -0.0549,  0.0050, -0.0045, -0.0777,  0.0032,
        -0.0638, -0.0300], grad_fn=<SelectBackward>)
target =  tensor([9, 3, 5, 4, 7, 4, 9, 6, 8, 8])
torch.max =  tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
output =  tensor([-0.1656,  0.2563, -0.0302, -0.0892, -0.0064, -0.0038, -0.0764,  0.0125,
        -0.0741, -0.0208], grad_fn=<SelectBackward>)
target =  tensor([7, 5, 9, 0, 9, 2, 8, 1, 4, 2])
torch.max =  tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
output =  tensor([-0.1633,  0.2518, -0.0336, -0.0671,  0.0319, -0.0058, -0.0850,  0.0012,
        -0.0723, -0.0283])
target =  tensor([9, 3, 5, 4, 7, 4, 9, 6, 8, 8])
torch.max =  tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
output =  tensor([-0.1826,  0.2675, -0.0270, -0.0734, -0.0094, -0.0074, -0.0

torch.max =  tensor([1, 1, 1, 1, 1, 3, 1, 1, 1, 3])
output =  tensor([-0.1450,  0.2168, -0.0323, -0.0172,  0.0309, -0.0166, -0.0857, -0.0200,
        -0.0819, -0.0118])
target =  tensor([9, 3, 5, 4, 7, 4, 9, 6, 8, 8])
torch.max =  tensor([1, 1, 3, 1, 1, 1, 1, 1, 1, 3])
output =  tensor([-1.2887e-01,  2.1845e-01, -2.2116e-02,  1.8657e-03, -6.4075e-02,
        -2.5003e-02, -7.1841e-02, -1.7478e-04, -6.9148e-02, -1.7123e-02])
target =  tensor([7, 5, 9, 0, 9, 2, 8, 1, 4, 2])
torch.max =  tensor([1, 1, 1, 1, 1, 3, 1, 1, 1, 3])
output =  tensor([-0.1450,  0.2168, -0.0323, -0.0172,  0.0309, -0.0166, -0.0857, -0.0200,
        -0.0819, -0.0118], grad_fn=<SelectBackward>)
target =  tensor([9, 3, 5, 4, 7, 4, 9, 6, 8, 8])
torch.max =  tensor([1, 1, 3, 1, 1, 1, 1, 1, 1, 3])
output =  tensor([-0.1395,  0.2300, -0.0239, -0.0018, -0.0614, -0.0286, -0.0746,  0.0012,
        -0.0736, -0.0109], grad_fn=<SelectBackward>)
target =  tensor([7, 5, 9, 0, 9, 2, 8, 1, 4, 2])
torch.max =  tensor([1, 1, 1, 1, 1, 

torch.max =  tensor([1, 1, 3, 1, 1, 1, 1, 1, 3, 3])
output =  tensor([-0.2067,  0.3213, -0.0370, -0.0380, -0.0360, -0.0525, -0.0995, -0.0070,
        -0.1026,  0.0425], grad_fn=<SelectBackward>)
target =  tensor([7, 5, 9, 0, 9, 2, 8, 1, 4, 2])
torch.max =  tensor([1, 1, 1, 1, 1, 3, 1, 1, 1, 3])
output =  tensor([-0.2356,  0.3354, -0.0428, -0.0531,  0.0293, -0.0477, -0.1103, -0.0246,
        -0.1133,  0.0601])
target =  tensor([9, 3, 5, 4, 7, 4, 9, 6, 8, 8])
torch.max =  tensor([1, 1, 3, 1, 1, 1, 1, 1, 3, 3])
output =  tensor([-0.2256,  0.3413, -0.0422, -0.0492, -0.0127, -0.0557, -0.1071, -0.0102,
        -0.1115,  0.0497])
target =  tensor([7, 5, 9, 0, 9, 2, 8, 1, 4, 2])
torch.max =  tensor([1, 1, 1, 1, 1, 3, 1, 1, 1, 3])
output =  tensor([-0.2356,  0.3354, -0.0428, -0.0531,  0.0293, -0.0477, -0.1103, -0.0246,
        -0.1133,  0.0601], grad_fn=<SelectBackward>)
target =  tensor([9, 3, 5, 4, 7, 4, 9, 6, 8, 8])
torch.max =  tensor([1, 1, 3, 1, 1, 1, 1, 1, 3, 3])
output =  tensor([-0.243

torch.max =  tensor([1, 1, 2, 2, 1, 2, 1, 1, 7, 2])
output =  tensor([-0.5844,  0.5990, -0.0621, -0.2160,  0.0384, -0.1530, -0.2010,  0.0366,
        -0.2251,  0.3248], grad_fn=<SelectBackward>)
target =  tensor([7, 5, 9, 0, 9, 2, 8, 1, 4, 2])
torch.max =  tensor([1, 1, 1, 2, 1, 2, 1, 1, 1, 2])
output =  tensor([-0.6329,  0.6559, -0.0630, -0.2125,  0.0983, -0.1253, -0.1831,  0.0274,
        -0.2255,  0.2427])
target =  tensor([9, 3, 5, 4, 7, 4, 9, 6, 8, 8])
torch.max =  tensor([1, 1, 2, 2, 1, 2, 7, 1, 7, 2])
output =  tensor([-0.6197,  0.6625, -0.0749, -0.2325,  0.0503, -0.1656, -0.2130,  0.0329,
        -0.2421,  0.3450])
target =  tensor([7, 5, 9, 0, 9, 2, 8, 1, 4, 2])
torch.max =  tensor([1, 1, 1, 2, 1, 3, 1, 1, 1, 2])
output =  tensor([-0.6329,  0.6559, -0.0630, -0.2125,  0.0983, -0.1253, -0.1831,  0.0274,
        -0.2255,  0.2427], grad_fn=<SelectBackward>)
target =  tensor([9, 3, 5, 4, 7, 4, 9, 6, 8, 8])
torch.max =  tensor([1, 1, 2, 2, 1, 2, 7, 1, 7, 2])
output =  tensor([-0.664

torch.max =  tensor([9, 9, 3, 7, 9, 7, 9, 1, 4, 3])
output =  tensor([-1.1272,  0.8268, -0.2925, -0.4294,  0.4493, -0.4073, -0.4179,  0.2003,
        -0.5339,  0.9447], grad_fn=<SelectBackward>)
target =  tensor([7, 5, 9, 0, 9, 2, 8, 1, 4, 2])
torch.max =  tensor([9, 1, 9, 3, 9, 3, 1, 1, 9, 3])
output =  tensor([-1.1340,  0.0518, -0.1861, -0.4132,  0.6957, -0.3483, -0.3803,  0.3493,
        -0.4864,  1.0666])
target =  tensor([9, 3, 5, 4, 7, 4, 9, 6, 8, 8])
torch.max =  tensor([9, 9, 3, 7, 9, 7, 9, 1, 4, 3])
output =  tensor([-1.1328,  0.7875, -0.3028, -0.4282,  0.4619, -0.4183, -0.4229,  0.2146,
        -0.5423,  0.9752])
target =  tensor([7, 5, 9, 0, 9, 2, 8, 1, 4, 2])
torch.max =  tensor([9, 1, 9, 3, 9, 3, 1, 1, 9, 3])
output =  tensor([-1.1340,  0.0518, -0.1861, -0.4132,  0.6957, -0.3483, -0.3803,  0.3493,
        -0.4864,  1.0666], grad_fn=<SelectBackward>)
target =  tensor([9, 3, 5, 4, 7, 4, 9, 6, 8, 8])
torch.max =  tensor([9, 9, 3, 7, 9, 7, 9, 1, 4, 3])
output =  tensor([-1.133

torch.max =  tensor([4, 9, 3, 7, 9, 7, 9, 1, 4, 3])
output =  tensor([-0.8490, -0.7831, -0.2681, -0.4514,  0.5847, -0.3464, -0.3305,  0.5634,
        -0.4254,  1.0983], grad_fn=<SelectBackward>)
target =  tensor([7, 5, 9, 0, 9, 2, 8, 1, 4, 2])
torch.max =  tensor([9, 1, 9, 3, 9, 2, 1, 1, 9, 3])
output =  tensor([-0.7461, -3.3654,  0.0273, -0.5008,  1.4599, -0.1587, -0.1904,  1.1589,
        -0.2304,  1.3976])
target =  tensor([9, 3, 5, 4, 7, 4, 9, 6, 8, 8])
torch.max =  tensor([4, 9, 3, 7, 9, 7, 9, 1, 9, 3])
output =  tensor([-0.8312, -0.8712, -0.2565, -0.4645,  0.5711, -0.3327, -0.3156,  0.5892,
        -0.4098,  1.0841])
target =  tensor([7, 5, 9, 0, 9, 2, 8, 1, 4, 2])
torch.max =  tensor([9, 1, 9, 3, 9, 2, 1, 1, 4, 3])
output =  tensor([-0.7461, -3.3654,  0.0273, -0.5008,  1.4599, -0.1587, -0.1904,  1.1589,
        -0.2304,  1.3976], grad_fn=<SelectBackward>)
target =  tensor([9, 3, 5, 4, 7, 4, 9, 6, 8, 8])
torch.max =  tensor([4, 9, 3, 7, 9, 7, 9, 1, 9, 3])
output =  tensor([-0.815

torch.max =  tensor([4, 9, 3, 7, 7, 7, 9, 1, 4, 0])
output =  tensor([-0.7589, -3.0150, -0.0479, -0.7147,  0.5909, -0.1030,  0.0179,  1.0755,
        -0.0977,  1.2610], grad_fn=<SelectBackward>)
target =  tensor([7, 5, 9, 0, 9, 2, 8, 1, 4, 2])
torch.max =  tensor([9, 1, 9, 3, 9, 4, 9, 1, 4, 3])
output =  tensor([-0.7502, -6.7750,  0.0998, -0.9341,  2.4171,  0.0417,  0.2270,  2.0770,
         0.0951,  1.9299])
target =  tensor([9, 3, 5, 4, 7, 4, 9, 6, 8, 8])
torch.max =  tensor([4, 9, 3, 7, 7, 7, 9, 1, 4, 0])
output =  tensor([-0.7705, -3.1486, -0.0473, -0.7221,  0.5990, -0.1027,  0.0371,  1.0998,
        -0.0805,  1.3103])
target =  tensor([7, 5, 9, 0, 9, 2, 8, 1, 4, 2])
torch.max =  tensor([9, 1, 9, 3, 9, 4, 9, 1, 4, 3])
output =  tensor([-0.7502, -6.7750,  0.0998, -0.9341,  2.4171,  0.0417,  0.2270,  2.0770,
         0.0951,  1.9299], grad_fn=<SelectBackward>)
target =  tensor([9, 3, 5, 4, 7, 4, 9, 6, 8, 8])
torch.max =  tensor([4, 9, 3, 7, 7, 7, 9, 1, 4, 0])
output =  tensor([-0.785

KeyboardInterrupt: 

In [None]:
def plot_val_accuracy(model, val_acc_history):
    plt.plot(val_acc_history)
    plt.xlabel('Number of epochs')
    plt.ylabel('Validation accuracy')
    plt.title('Validation accuracy on training set with model = ' + model.name)
    plt.show()
    
plot_val_accuracy(model, val_acc_history)

In [None]:
nb_errors_digits, nb_errors_pairs = test_model_1C(model, test_input, test_target, test_classes)

In [None]:
# Computes the number of errors between target and classes
def compute_nb_errors_data(set_, target, classes): 
    is_pair = torch.le(classes[:,0], classes[:,1]).type(torch.LongTensor)
    nb_errors = torch.abs(target-is_pair).sum().item()
    print('There are %d errors in the %s dataset' % (nb_errors, set_))
    
compute_nb_errors_data('training', train_target, train_classes)
compute_nb_errors_data('test', test_target, test_classes)

In [None]:
import csv

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def check_overwrite(filename, model, row_to_write):
    overwrite = False
    with open(filename, 'r') as readFile:
        reader = csv.reader(readFile)
        row_list = list(reader)
        for index, row in enumerate(row_list):
            if row[0] == model.name: 
                row_list[index] = row
                overwrite = True
                break
    with open(filename, 'w') as writeFile:
        writer = csv.writer(writeFile)
        writer.writerows(row_list)
    readFile.close()
    writeFile.close()
    return overwrite
    
def write_to_csv(filename, model, time_elapsed):
    nb_params = count_parameters(model)
    nb_errors_digits, nb_errors_pairs = test_model_1C(model, test_input, test_target, test_classes)
    row = [model.name, nb_params, round(time_elapsed, 2), nb_errors_digits, nb_errors_pairs]
    
    try: file = open(filename, 'r')
    except FileNotFoundError:
        csvData = [['Model', 'Number of parameters', 'Training time', 'Number of missidentified digits', 'Number of errors in prediction']]
        with open('1channel2images.csv', 'w') as csvFile:
            writer = csv.writer(csvFile)
            writer.writerows(csvData)
        csvFile.close()
        
    overwrite = check_overwrite(filename, model, row)
    if overwrite == False:    
        with open(filename, 'a') as csvFile:
            writer = csv.writer(csvFile)
            writer.writerow(row)
        csvFile.close()

write_to_csv('1channel2images.csv', model, time_elapsed)