In [1]:
# -*- coding: utf-8 -*-
"""
Created on Wed Apr 28 09:11:04 2021

@author: aurel
"""
import torch
import dlc_practical_prologue as prologue

from torch import optim
from torch.nn import functional as F
from torch import nn
from torch.autograd import Variable


######################################################################
def compute_nb_errors(model, data_input, data_target, mini_batch_size):

    nb_data_errors = 0

    for b in range(0, data_input.size(0), mini_batch_size):
        _, _, result = model(data_input.narrow(0, b, mini_batch_size))
        _, predicted_classes = torch.max(result, 1)
        for k in range(mini_batch_size):
            if data_target[b + k] != predicted_classes[k]:
                nb_data_errors = nb_data_errors + 1

    return nb_data_errors

In [2]:
def compute_nb_errors_comparison(model, data_input, data_target, data_classes, mini_batch_size):
    
    nb_data_errors = 0
    nb_errors_comparison = 0
    nb_errors_digit1 = 0
    nb_errors_digit2 = 0
    
    for b in range(0, data_input.size(0), mini_batch_size):
        digit1, digit2, result = model(data_input.narrow(0, b, mini_batch_size))
        _, predicted_classes = torch.max(result, 1)
        _, predicted_classes_digit1 = torch.max(digit1, 1)
        _, predicted_classes_digit2 = torch.max(digit2, 1)
        for k in range(mini_batch_size):
            if data_target[b + k] != predicted_classes[k]:
                nb_data_errors = nb_data_errors + 1
            if data_classes[b + k, 0] != predicted_classes_digit1[k]:
                nb_errors_digit1 += 1
            if data_classes[b + k, 1] != predicted_classes_digit2[k]:
                nb_errors_digit2 += 1
            if ((predicted_classes_digit1[k] < predicted_classes_digit2[k]) and predicted_classes[k] == 0):
                nb_errors_comparison += 1
            if ((predicted_classes_digit1[k] > predicted_classes_digit2[k]) and predicted_classes[k] == 1):
                nb_errors_comparison += 1
            
    return nb_data_errors, nb_errors_comparison, nb_errors_digit1, nb_errors_digit2

In [3]:
def accuracy_based_on_imgs(model, data_input, data_target):
    digit1, digit2, result = model(data_input)

    _, predictions1 = torch.max(digit1.data, 1)
    _, predictions2 = torch.max(digit2.data, 1)
    
    predictions = (predictions1 <= predictions2).long()
    well_predicted_count = (predictions == data_target).sum().item()

    return 1 - well_predicted_count / data_input.size(0)

In [4]:
def accuracy_based_on_result(model, data_input, data_target):
    digit1, digit2, result = model(data_input)
    
    _, predictions = torch.max(result.data, 1)
    well_predicted_count = (predictions == test_target_).sum().item()
    
    return 1 - well_predicted_count / total

In [5]:
######################################################################
def train_model_decay(model, train_input, train_target, train_classes, nb_epochs, mini_batch_size):
    criterion = nn.CrossEntropyLoss()
    eta0 = 1e-1
    decay = 1

    for e in range(nb_epochs):
        eta = (1 / (1 + decay*nb_epochs)) * eta0
        optimizer = optim.SGD(model.parameters(), lr = eta)
        
        for b in range(0, train_input.size(0), mini_batch_size):
            digit1, digit2, result = model(train_input.narrow(0, b, mini_batch_size))
            
            loss_result = criterion(result, train_target.narrow(0, b, mini_batch_size))
            loss_digit1 = criterion(digit1, train_classes[:,0].narrow(0, b, mini_batch_size))
            loss_digit2 = criterion(digit2, train_classes[:,1].narrow(0, b, mini_batch_size))
            loss = loss_result + loss_digit1 + loss_digit2
            
            model.zero_grad()
            loss.backward()
            optimizer.step()

In [6]:
######################################################################
def train_model(model, train_input, train_target, train_classes, nb_epochs, mini_batch_size):
    criterion = nn.CrossEntropyLoss()
    eta = 1e-3
    optimizer = optim.Adam(model.parameters(), lr = eta)
    
    for e in range(nb_epochs):    
        
        for b in range(0, train_input.size(0), mini_batch_size):
            digit1, digit2, result = model(train_input.narrow(0, b, mini_batch_size))
            
            loss_result = criterion(result, train_target.narrow(0, b, mini_batch_size))
            loss_digit1 = criterion(digit1, train_classes[:,0].narrow(0, b, mini_batch_size))
            loss_digit2 = criterion(digit2, train_classes[:,1].narrow(0, b, mini_batch_size))
            loss = loss_result + loss_digit1 + loss_digit2
            
            model.zero_grad()
            loss.backward()
            optimizer.step()

In [7]:
######################################################################            
# def eval_Model(model, mini_batch_size, nb_epochs):


In [8]:
######################################################################   
class ConvNoWS(nn.Module):
    def __init__(self):
        super(ConvNoWS, self).__init__()
        
        #Input channels = 1, output channels = 32
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2))
        
        #Input channels = 32, output channels = 64
        self.layer2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=2, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2))

        
        # Formula to get out_put size (in_size - kernel_size + 2*(padding)) / stride) + 1
        # first layer (14-5+2*2)/1 +1 = 14/2 = 7
        # second layer (7 -4 +2*2)/1 +1 = 8/2 = 4
        # 4 * 4 * 64 input features, 1000 output features
        self.fc1 = nn.Linear(4 * 4 * 64, 1000)
        
        # 1000 input features, 2 output features
        self.fc2 = nn.Linear(1000, 10)

        #Comparison of the two digits
        self.layer_comp = nn.Sequential(
            nn.Linear(20, 200),
            nn.ReLU(),
            nn.Linear(200, 1000),
            nn.ReLU(),
            nn.Linear(1000, 2))
        
    def forward(self, x):
        
        first_digit = x[:,[0]]
        second_digit = x[:,[1]]

        first_digit = self.layer1(first_digit)
        second_digit = self.layer1(second_digit)
        
        first_digit = self.layer2(first_digit)
        second_digit = self.layer2(second_digit)
    
        first_digit = F.relu(self.fc1(first_digit.view(-1, 4 * 4 * 64)))
        second_digit = F.relu(self.fc1(second_digit.view(-1, 4 * 4 * 64)))
        
        first_digit = self.fc2(first_digit)
        second_digit = self.fc2(second_digit)
        
        result = torch.cat((first_digit, second_digit), dim=1, out=None)
        result = self.layer_comp(result)
        
        return first_digit, second_digit, result

In [9]:
    
######################################################################   
    
train_input, train_target, train_classes,_, _, _ \
    = prologue.generate_pair_sets(1000)
    

# train_input, train_target, train_classes \
#     = Variable(train_input), Variable(train_target), Variable(train_classes)
# test_input, test_target, test_classes \
#     = Variable(test_input), Variable(test_target), Variable(test_classes)

def get_tests(n):
    M = []
    for k in range (0, n):
        L = []
        _, _, _, test_input, test_target, test_classes =  prologue.generate_pair_sets(1000)
        L.append(test_input)
        L.append(test_target)
        L.append(test_classes)
        M.append(L)
    return M


model = ConvNoWS()
nb_epochs = 25
mini_batch_size = 100

train_model(model, train_input, train_target, train_classes, nb_epochs, mini_batch_size)
L = get_tests(10)


nb_train_errors = compute_nb_errors(model, train_input, train_target, mini_batch_size)
#accuracy_based_on_imgs = accuracy_based_on_imgs(model, train_input, train_target)
#accuracy_based_on_result = accuracy_based_on_result(model, train_input, train_target)

print('train error ConvNoWS {:0.2f}%{:d}/{:d}'.format((100 * nb_train_errors) / train_input.size(0),
                                        nb_train_errors, train_input.size(0)))
#print('train accuracy_based_on_imgs ConvNoWS {:0.2f}%{:d}/{:d}'.format((100 * accuracy_based_on_imgs) / train_input.size(0),
#                                        accuracy_based_on_imgs, train_input.size(0)))
#print('train accuracy_based_on_result ConvNoWS {:0.2f}%{:d}/{:d}'.format((100 * accuracy_based_on_result) / train_input.size(0),
#                                        accuracy_based_on_result, train_input.size(0)))

nb_moy_test_error = 0
#average_test_error_basedOnImages = 0
#average_test_error_basedOnResults = 0

for k in range (0, len(L)):
    nb_test_errors = compute_nb_errors(model, L[k][0], L[k][1], mini_batch_size)
    #test_error_basedOnImages = accuracy_based_on_imgs(model, L[k][0], L[k][1])
    #test_error_basedOnResults = accuracy_based_on_result(model, L[k][0], L[k][1])
    
    nb_moy_test_error += nb_test_errors
    #average_test_error_basedOnImages += test_error_basedOnImages
    #average_test_error_basedOnResults += test_error_basedOnResults

    print('test error ConvNoWS {:0.2f}% {:d}/{:d}'.format((100 * nb_test_errors) / L[k][0].size(0),
                                                nb_test_errors, L[k][0].size(0)))
    #print('train accuracy_based_on_imgs ConvNoWS {:0.2f}%{:d}/{:d}'.format((100 * test_error_basedOnImages) / L[k][0].size(0),
    #                                    test_error_basedOnImages, L[k][0].size(0)))
    #print('train accuracy_based_on_result ConvNoWS {:0.2f}%{:d}/{:d}'.format((100 * test_error_basedOnResults) / L[k][0].size(0),
    #                                    test_error_basedOnResults, L[k][0].size(0)))
print('Average test error ConvNoWS {:0.2f}% {:0.1f}/{:d}'.format((100*nb_moy_test_error/10) / L[0][0].size(0),nb_moy_test_error/10, L[0][0].size(0) ))




tensor([[  9.7611,  -8.2174],
        [  3.9286,  -2.7851],
        [  7.5686,  -5.6129],
        [  7.3327,  -6.8025],
        [ -2.7232,   3.3621],
        [  4.7019,  -3.4338],
        [ -3.5787,   3.8880],
        [  8.4273,  -7.5260],
        [ -3.7779,   4.1848],
        [ -2.1525,   3.7909],
        [  7.3586,  -7.0583],
        [  3.2879,  -2.9451],
        [  3.0323,  -3.3121],
        [  4.1939,  -4.8284],
        [  6.0814,  -6.6718],
        [  3.6979,  -3.5124],
        [  6.4941,  -5.2498],
        [ -9.5554,  10.3321],
        [  5.2434,  -5.3159],
        [ -2.1026,   3.9984],
        [ -3.8130,   4.5506],
        [ -6.8594,   6.8949],
        [ -2.6751,   4.0219],
        [-14.3277,  15.4553],
        [  8.2685,  -7.2293],
        [ -6.4187,   8.1849],
        [ -2.7003,   3.3563],
        [  3.2651,  -3.4209],
        [ -5.4164,   6.4462],
        [  7.7199,  -6.4057],
        [ -4.5889,   6.0134],
        [ -3.1581,   4.9282],
        [ -5.8527,   6.4452],
        [ 

tensor([[  5.9018,  -4.4314],
        [ 14.4067, -13.6201],
        [ 12.8635, -12.7037],
        [  9.1732, -10.3537],
        [  8.3261,  -7.4192],
        [ -5.6639,   6.0948],
        [-14.9826,  16.3520],
        [  4.8014,  -4.7746],
        [  3.1224,  -3.4564],
        [ -7.3816,  10.2282],
        [  5.9423,  -5.0555],
        [ -3.7284,   4.6889],
        [ 10.3799,  -9.1817],
        [  5.6668,  -6.1782],
        [ -2.8208,   3.8496],
        [  5.5004,  -4.1412],
        [  4.0475,  -4.4419],
        [  8.6743,  -7.8059],
        [-14.3959,  15.6720],
        [-11.3376,  13.5774],
        [  5.6886,  -6.1198],
        [ -4.9156,   5.0478],
        [  4.6550,  -4.4016],
        [  5.6838,  -5.5773],
        [-13.0142,  14.4578],
        [-10.1002,  12.2919],
        [ -6.6431,   7.5355],
        [-11.9558,  14.2380],
        [  4.0469,  -3.5030],
        [ -4.7015,   5.9461],
        [  6.7434,  -6.6932],
        [  3.9122,  -4.5729],
        [ -9.8103,  11.5424],
        [ 

tensor([[  6.3910,  -7.2937],
        [ -8.9722,  11.8530],
        [-10.3979,  11.2196],
        [ -7.0620,   8.3632],
        [ -3.5140,   4.3030],
        [ -4.2998,   4.1849],
        [  4.9183,  -5.0284],
        [ 11.7087, -10.3179],
        [ -3.3960,   4.3839],
        [  4.1878,  -4.0498],
        [  4.5592,  -2.8840],
        [  3.7188,  -3.9263],
        [-11.1210,  11.0294],
        [ -5.1281,   5.5993],
        [  4.1894,  -3.5948],
        [  3.2324,  -3.2198],
        [-13.1560,  15.9570],
        [ 12.5475, -11.9892],
        [  4.0018,  -2.9361],
        [  4.3804,  -2.1595],
        [ -4.5030,   5.0570],
        [  6.2223,  -6.7043],
        [-10.1675,  11.3942],
        [-10.5701,  11.6988],
        [ -6.0431,   6.5171],
        [ -6.8418,   7.7336],
        [ 11.9190, -11.4081],
        [ -9.5072,  10.3551],
        [  4.8655,  -6.0053],
        [ -6.2617,   6.5845],
        [ -3.9626,   4.6621],
        [  7.1044,  -7.5897],
        [-14.5401,  16.6656],
        [ 

tensor([[ 8.8278e+00, -7.9998e+00],
        [-3.4112e+00,  4.3861e+00],
        [ 3.3671e+00, -6.9022e-01],
        [-6.8740e-01,  2.1743e+00],
        [ 2.4264e+00, -2.0889e+00],
        [-3.5585e+00,  4.2709e+00],
        [ 4.4702e+00, -4.3386e+00],
        [ 1.0795e+01, -1.0869e+01],
        [ 1.3888e+00, -1.4509e+00],
        [-9.9166e+00,  1.2214e+01],
        [-8.1986e+00,  8.7462e+00],
        [ 3.6665e+00, -4.3468e+00],
        [ 2.9487e+00, -2.6984e+00],
        [ 5.9210e+00, -6.1330e+00],
        [ 4.1850e+00, -4.1365e+00],
        [-1.2718e+01,  1.3147e+01],
        [ 4.9252e+00, -3.7557e+00],
        [ 2.4022e-01,  6.8043e-01],
        [-1.0895e+01,  1.2429e+01],
        [ 4.1463e+00, -4.4949e+00],
        [ 7.4949e+00, -7.4051e+00],
        [-2.2204e-01,  9.4126e-01],
        [-9.4640e-02,  8.5112e-01],
        [ 3.7999e+00, -3.4153e+00],
        [-4.7605e+00,  4.5534e+00],
        [-8.9017e+00,  1.0062e+01],
        [-1.0115e+01,  1.1695e+01],
        [-6.4105e+00,  8.103

tensor([[ -6.2897,   7.1701],
        [ 10.3059, -10.0714],
        [ -4.9855,   6.3246],
        [  5.6804,  -5.6446],
        [ -0.1078,   1.4155],
        [-10.6866,  11.8491],
        [  4.7561,  -4.7452],
        [ -1.2371,   1.4755],
        [  5.1616,  -2.8166],
        [  1.0778,  -1.0547],
        [  4.8832,  -4.4219],
        [ -9.7568,  10.2259],
        [  1.2736,   0.1889],
        [  1.7474,  -0.4080],
        [ -6.4538,   6.2935],
        [  4.4903,  -4.0963],
        [ -2.0725,   2.8029],
        [  9.2837,  -9.4142],
        [ -9.6015,  10.4397],
        [  5.8552,  -5.5207],
        [  5.1659,  -5.9793],
        [  7.2916,  -6.9431],
        [  0.4873,  -1.0524],
        [ -7.1126,   7.4295],
        [  3.6623,  -2.3648],
        [ -6.4444,   8.1442],
        [  5.9870,  -5.8260],
        [  3.6242,  -3.4296],
        [  1.0508,  -0.3290],
        [ -8.7773,   8.7616],
        [  0.8918,  -0.3673],
        [ -6.2533,   7.0349],
        [  3.1276,  -0.3488],
        [ 

tensor([[ -5.2343,   8.1398],
        [  5.9046,  -5.2527],
        [  8.4501,  -7.8363],
        [ -1.2130,   3.0254],
        [  6.1447,  -5.1598],
        [ 11.5335, -11.8706],
        [  4.7703,  -3.1492],
        [  9.5106,  -9.4238],
        [ -3.9162,   4.1947],
        [ -6.0535,   6.9151],
        [ -6.9209,   8.2234],
        [ -6.2709,   8.7784],
        [ -7.9914,   8.6284],
        [  5.4188,  -5.0662],
        [  8.9694,  -9.9256],
        [ -1.2425,   3.6469],
        [ -3.6479,   3.1095],
        [  7.5436,  -6.8282],
        [  4.3792,  -3.7198],
        [ -6.9687,   8.1403],
        [ -8.8846,  10.7039],
        [ -8.9246,   9.8214],
        [ -4.2061,   5.1954],
        [ -8.6503,  10.5772],
        [ -4.8835,   5.5641],
        [ -5.1018,   5.4584],
        [ -5.1482,   5.6748],
        [  5.3380,  -3.7004],
        [-13.5847,  15.8521],
        [  9.5605,  -9.4820],
        [-14.3716,  15.4995],
        [  4.6371,  -5.6552],
        [ -2.0203,   2.0070],
        [ 

tensor([[ -4.4793,   6.1457],
        [ -8.0440,   8.4635],
        [  0.4812,   0.5574],
        [  6.8696,  -6.1986],
        [-11.6483,  13.2410],
        [ -1.9818,   3.4486],
        [  3.5000,  -3.2817],
        [ -1.8005,   1.3226],
        [  2.9337,  -0.1157],
        [ -3.0610,   5.1066],
        [ -0.3676,   1.2233],
        [ -0.8825,   1.6821],
        [ -4.3835,   8.4967],
        [ 12.8312, -11.4304],
        [ -0.9519,   1.3688],
        [  2.2649,  -2.2084],
        [ -2.2483,   2.2564],
        [  2.8652,  -2.9406],
        [ -2.6854,   3.8055],
        [  5.6497,  -5.7634],
        [  4.7094,  -5.6136],
        [ -6.9010,   8.4087],
        [-10.1186,  11.5329],
        [  6.3524,  -4.9687],
        [ -3.2688,   4.1497],
        [ 10.6871, -11.8737],
        [  0.1553,   0.9694],
        [ -9.9182,  11.4829],
        [  5.2381,  -5.3093],
        [  2.1940,  -1.6687],
        [ -5.7667,   5.7078],
        [ -1.2497,   3.7146],
        [ -7.8065,   8.2572],
        [ 

tensor([[ -0.2874,   0.7165],
        [  8.0792,  -7.2211],
        [  4.0366,  -4.9415],
        [ -8.2189,   7.9411],
        [  0.4144,   1.0729],
        [ -9.4428,  10.4097],
        [ -7.3556,   7.3640],
        [ -1.8460,   3.2823],
        [  4.4463,  -4.3118],
        [  6.0416,  -5.6330],
        [ -4.2420,   7.1435],
        [-12.1903,  15.6328],
        [  8.3804,  -8.6238],
        [ -5.3669,   5.5907],
        [ -5.1707,   5.7227],
        [ -5.2727,   6.3764],
        [ -6.4143,   7.8141],
        [ -6.0947,   6.2644],
        [ -3.0908,   3.6715],
        [ -4.5408,   4.3641],
        [  9.5513,  -8.3412],
        [-16.9527,  19.8446],
        [ -7.3489,   8.0960],
        [ -0.5346,   1.0363],
        [  7.3117,  -6.2350],
        [-10.9851,  13.6937],
        [  1.9776,  -2.8462],
        [  4.8613,  -4.1873],
        [ 13.5075, -12.4528],
        [ 10.1258, -10.0796],
        [ -1.5976,   2.1756],
        [ -6.4873,   6.7665],
        [ -8.5237,   9.4920],
        [ 

tensor([[ -5.1639,   5.5426],
        [ -0.3564,   1.1315],
        [ -3.9649,   4.4328],
        [ -7.2274,   7.2890],
        [ -7.4184,   7.9672],
        [-11.0551,  12.6828],
        [ -0.6809,   3.7525],
        [ -4.7581,   5.8618],
        [-12.8219,  13.8631],
        [  5.5963,  -3.3220],
        [ -5.8630,   7.3787],
        [  3.3184,  -0.9763],
        [  5.8728,  -4.6252],
        [  2.5665,  -1.4496],
        [  6.0528,  -4.9715],
        [  7.4950,  -6.3190],
        [ -4.4138,   5.0322],
        [  0.3340,   0.4032],
        [  4.9703,  -4.9217],
        [  8.2554,  -8.7851],
        [ -5.5809,   6.8144],
        [-12.9401,  16.5599],
        [-10.1426,  13.0307],
        [ -8.7067,  11.6099],
        [ -4.7741,   4.2104],
        [ -4.2151,   4.9292],
        [ -7.4191,   8.4802],
        [ -1.3445,   2.8456],
        [ -0.3037,   0.4545],
        [  9.2336,  -9.8689],
        [ 13.0037, -11.6992],
        [  8.8255,  -8.2993],
        [  6.6092,  -6.2526],
        [ 

tensor([[  7.9526,  -7.2894],
        [ -4.9188,   5.6608],
        [ -3.4397,   3.1928],
        [ -6.8089,   8.4079],
        [ -4.4697,   4.8229],
        [ -0.5669,   2.1777],
        [ -8.7917,  10.6859],
        [ -8.3397,   9.1325],
        [-13.6111,  15.5365],
        [-10.4073,  11.6965],
        [  6.3239,  -4.9657],
        [ -0.7736,   1.9287],
        [ -5.5560,   6.5064],
        [  5.0701,  -6.1574],
        [  1.6658,  -0.4051],
        [ -1.3226,   0.7353],
        [ -7.6519,   9.2369],
        [  0.5470,   0.1312],
        [  2.0633,  -1.4070],
        [ -6.5028,  10.1548],
        [ -8.1480,   9.3200],
        [ -6.2702,   7.4348],
        [  7.8568,  -6.4975],
        [ -8.2767,  10.0697],
        [ -2.3649,   3.9648],
        [  6.1207,  -6.5989],
        [  0.5265,  -0.2976],
        [ 11.3137, -10.8940],
        [ -0.5259,   1.1510],
        [ -0.9590,   1.1052],
        [-10.8143,  12.1818],
        [ -4.9612,   6.2044],
        [  0.1930,   0.9775],
        [ 

tensor([[-1.0043e+01,  1.3655e+01],
        [-5.5247e+00,  7.8707e+00],
        [ 3.1248e+00, -3.0957e+00],
        [ 4.6819e+00, -2.6384e+00],
        [ 5.9087e+00, -5.5806e+00],
        [-1.6092e+00,  2.0588e+00],
        [-9.0885e+00,  9.7090e+00],
        [-8.0523e+00,  8.7519e+00],
        [-1.2767e+00,  3.1905e+00],
        [ 2.2079e+00, -3.3200e+00],
        [ 7.4451e+00, -7.9706e+00],
        [-5.2020e+00,  6.3042e+00],
        [ 1.0395e+01, -1.0241e+01],
        [-7.6959e+00,  8.7452e+00],
        [ 7.7538e+00, -6.3109e+00],
        [ 7.5144e+00, -7.2052e+00],
        [ 5.5900e+00, -3.9035e+00],
        [ 9.8004e+00, -9.3248e+00],
        [ 4.7396e+00, -5.2719e+00],
        [-1.1323e+01,  1.1541e+01],
        [ 1.7220e+00, -1.1755e+00],
        [ 7.8265e+00, -6.6503e+00],
        [ 6.3997e+00, -7.6708e+00],
        [ 9.0068e+00, -9.1230e+00],
        [-9.5933e+00,  1.2993e+01],
        [-1.2408e+01,  1.3143e+01],
        [ 2.5280e+00, -2.0584e+00],
        [ 4.0100e+00, -3.681

tensor([[ -2.0929,   2.4964],
        [-13.7512,  15.7719],
        [  0.2471,   0.1003],
        [  0.5028,   2.0012],
        [ -6.5540,   7.4648],
        [ -2.3587,   2.8595],
        [  3.6298,  -3.0577],
        [ -4.1518,   4.4120],
        [-10.6031,  10.5943],
        [  3.0989,  -1.7242],
        [ -8.7852,   9.2569],
        [-10.7716,  12.0965],
        [ -0.6040,   1.5826],
        [  5.5830,  -5.6704],
        [ -7.3510,   8.2389],
        [ -8.5420,   8.5513],
        [ -7.3189,   8.3571],
        [-11.9044,  14.3038],
        [ -2.4091,   3.0167],
        [  4.2985,  -5.2782],
        [ -2.1816,   4.3848],
        [ -2.4636,   2.4040],
        [ -5.7121,   7.1446],
        [ 10.5534,  -9.6248],
        [ -7.7260,   7.8931],
        [  8.9025,  -7.9000],
        [ -4.3961,   5.6369],
        [ -2.6425,   3.1041],
        [-10.4463,  13.8019],
        [  6.6272,  -7.3942],
        [ 15.0593, -13.3299],
        [ -6.0763,   6.6212],
        [ -2.2088,   3.0110],
        [ 

tensor([[  5.8583,  -6.9118],
        [ -5.3754,   7.1437],
        [ -2.4986,   2.1173],
        [ -4.5035,   6.1009],
        [ -6.3361,   6.3870],
        [ -0.7613,   1.0722],
        [ -8.1104,  10.7142],
        [  5.7370,  -6.4551],
        [  5.3272,  -5.3304],
        [ -3.5271,   3.7333],
        [ -5.4910,   5.5552],
        [  0.5306,  -0.3827],
        [  6.6216,  -6.5835],
        [  1.4809,  -2.2251],
        [  1.3742,   0.4867],
        [  5.4155,  -5.7525],
        [-10.4635,  11.6464],
        [ -6.2765,   6.8811],
        [  2.3531,  -0.5478],
        [ -6.4548,   6.8507],
        [  4.5869,  -4.4620],
        [  9.7739, -11.6524],
        [ -6.2641,   7.5319],
        [ -4.2814,   6.1609],
        [ -7.6290,  10.1964],
        [ -3.8441,   4.5489],
        [  5.1082,  -5.7035],
        [  7.8073,  -6.5024],
        [ -3.5716,   5.6627],
        [  4.0435,  -4.7717],
        [  5.5861,  -6.0125],
        [ -9.4051,   9.9879],
        [ -6.2901,   7.6504],
        [ 

tensor([[  0.5054,  -0.4276],
        [-10.5147,  11.9156],
        [ -5.5846,   7.5039],
        [  3.7480,  -3.7775],
        [ 14.5310, -13.9893],
        [  6.6732,  -3.9992],
        [  0.1047,   1.0722],
        [ -1.7592,   3.4318],
        [  9.2035,  -8.5510],
        [  7.7141,  -8.0297],
        [  3.6177,  -3.1976],
        [  2.2583,  -1.2227],
        [ -4.4694,   4.1686],
        [ 13.9270, -13.3360],
        [ -9.6874,  10.7130],
        [  5.1264,  -3.8744],
        [ -1.2775,   2.6235],
        [  0.9336,   0.4441],
        [  3.6674,  -3.3810],
        [ -9.7356,  10.2728],
        [  2.4458,  -2.2532],
        [ 12.0768, -12.0073],
        [  3.0084,  -2.6995],
        [ -7.5205,   8.2027],
        [  6.4296,  -6.0309],
        [ 14.2935, -14.2606],
        [  8.8559,  -8.0602],
        [  2.1439,  -2.5848],
        [  5.5478,  -6.3058],
        [ -4.7841,   6.3928],
        [  1.5183,  -1.5770],
        [ -0.8383,   1.5277],
        [-11.3087,  11.8979],
        [ 

tensor([[  7.3906,  -6.7957],
        [ -2.4518,   4.7982],
        [  4.5238,  -3.5557],
        [  9.2144,  -8.6975],
        [  0.3909,  -0.2472],
        [  0.0491,   0.1808],
        [-10.9235,  13.0008],
        [  2.4356,  -0.2485],
        [  3.5135,  -3.4219],
        [  9.0467,  -9.9090],
        [-11.3844,  12.0260],
        [ -2.9149,   3.3714],
        [ -8.6589,  10.7689],
        [  8.9871,  -8.0935],
        [ -1.7861,   2.7601],
        [  9.3691,  -9.6979],
        [ 11.1749,  -8.6015],
        [ -0.0783,   0.5928],
        [-15.7839,  17.8869],
        [  9.0605,  -7.9671],
        [  6.1199,  -5.5621],
        [ -2.6187,   2.2828],
        [ -0.0212,  -0.5943],
        [ -4.4550,   3.9624],
        [  6.7307,  -5.6790],
        [ 15.8199, -15.0654],
        [  4.9246,  -5.1925],
        [  5.3745,  -5.3678],
        [  6.3936,  -7.1966],
        [  0.3153,  -0.1189],
        [  7.2439,  -8.5455],
        [  3.1225,  -4.5653],
        [  2.1247,  -0.3933],
        [ 

tensor([[  3.6395,  -4.2423],
        [  4.3866,  -4.8265],
        [-14.2288,  17.1370],
        [ -6.4740,   8.2352],
        [ -2.1508,   3.9328],
        [ -2.7690,   3.6633],
        [ 10.1588,  -9.7760],
        [ 10.9680, -10.2587],
        [ -4.6157,   5.6323],
        [-11.4350,  12.7250],
        [  8.7595,  -8.5851],
        [ -4.1116,   4.8434],
        [ -5.3198,   6.8150],
        [ -9.9056,  10.8756],
        [ -8.3162,   8.6583],
        [ -7.2620,   7.6531],
        [  6.9778,  -6.8156],
        [ -6.0970,   7.1788],
        [ 13.8552, -12.2874],
        [ -5.4684,   6.0486],
        [-13.0156,  15.0538],
        [  2.6639,  -3.5880],
        [-14.2877,  16.3728],
        [ -3.0330,   3.5647],
        [ -9.6828,  10.0920],
        [ -6.1886,   6.5517],
        [  1.1226,  -0.3443],
        [ -7.9359,   9.9344],
        [  5.8385,  -6.0899],
        [-10.1967,  11.5872],
        [ -2.8588,   4.3503],
        [ 12.7041, -12.1366],
        [  4.3033,  -2.4092],
        [ 

tensor([[ 11.9686,  -9.0234],
        [ -7.1168,   7.7327],
        [ -0.8693,   1.2122],
        [ -3.2024,   4.1022],
        [  2.4562,  -2.0766],
        [-10.3871,  11.3435],
        [  5.2353,  -5.8153],
        [ -5.2681,   5.5537],
        [ -3.6891,   4.9867],
        [ -4.5528,   5.2519],
        [-10.0298,  10.6900],
        [  7.2337,  -7.4776],
        [  4.9895,  -5.5458],
        [ -6.9837,   7.9325],
        [ -1.4686,   2.8469],
        [  1.4125,  -0.7182],
        [  1.4283,  -1.2375],
        [  5.9894,  -5.2983],
        [  0.2993,   0.3558],
        [  6.6472,  -6.0789],
        [  2.8798,  -2.4795],
        [  0.3867,   1.6497],
        [ 14.1387, -13.6113],
        [  6.5291,  -4.5874],
        [ -4.5884,   4.8558],
        [ -1.9474,   2.5810],
        [  6.4906,  -5.5678],
        [  6.3102,  -4.9713],
        [  9.8641,  -9.3620],
        [  4.8966,  -6.0128],
        [-11.2453,  13.2814],
        [ -7.7675,   7.7552],
        [-11.6635,  12.1508],
        [ 

tensor([[ 9.4770e+00, -1.0673e+01],
        [-3.4016e+00,  3.8080e+00],
        [-3.4399e+00,  3.9573e+00],
        [-1.0010e+01,  1.2654e+01],
        [ 8.9114e-01, -5.5723e-01],
        [-8.6262e+00,  9.6160e+00],
        [-1.4879e+01,  1.7354e+01],
        [-9.9631e-01,  1.1850e+00],
        [-8.7891e+00,  1.0488e+01],
        [-1.0771e+01,  1.3037e+01],
        [ 1.1234e+00,  1.0852e+00],
        [-6.8051e-01,  3.3366e+00],
        [ 1.6547e+00, -8.0455e-01],
        [ 3.7345e-01,  8.9091e-02],
        [-5.5302e+00,  6.8076e+00],
        [-1.1174e+01,  1.3006e+01],
        [-1.0319e+00,  2.6872e+00],
        [ 6.8452e+00, -6.4918e+00],
        [-7.5616e-01,  2.4025e+00],
        [-1.6137e+01,  1.8825e+01],
        [ 3.9861e+00, -4.8940e+00],
        [ 9.6858e-01, -8.6775e-01],
        [ 3.5046e-01, -4.1510e-01],
        [-4.6070e+00,  5.4766e+00],
        [ 1.2364e+01, -1.1040e+01],
        [-8.5867e+00,  1.0561e+01],
        [-8.6004e+00,  8.7192e+00],
        [ 1.5511e+01, -1.471

test error ConvNoWS 10.50% 105/1000
tensor([[ -2.4118,   3.3100],
        [ -6.9946,   7.5490],
        [  2.7008,  -3.2792],
        [  9.6759,  -9.1100],
        [  3.3489,  -2.5181],
        [  5.5873,  -5.4679],
        [ -9.9556,  11.3014],
        [ 14.3974, -12.8267],
        [  5.8632,  -6.7123],
        [  6.0194,  -5.3962],
        [-12.2999,  13.6407],
        [  7.3482,  -6.9895],
        [  6.3344,  -5.3685],
        [  0.7806,   1.1283],
        [-16.8543,  18.9991],
        [ -6.2812,   7.7364],
        [ -5.1249,   9.1912],
        [  7.5574,  -6.8709],
        [ -9.3643,   9.9176],
        [  2.1715,  -0.8648],
        [  5.7783,  -6.1433],
        [ -8.4955,   9.4940],
        [ -2.0528,   3.0603],
        [ -6.6024,   5.5243],
        [ -6.3185,   7.5877],
        [ -2.9501,   4.2344],
        [ -4.4250,   5.4136],
        [  6.5283,  -7.5543],
        [ -9.2370,  11.7521],
        [  4.3452,  -4.0020],
        [  2.1273,  -2.5393],
        [ -1.1262,   1.8402],
    

tensor([[  7.4272,  -6.5750],
        [-13.0407,  14.2984],
        [  3.3608,  -3.5373],
        [  6.6401,  -7.3305],
        [  7.1437,  -6.8216],
        [  8.2010,  -6.8716],
        [ -3.5703,   4.5591],
        [  5.7165,  -5.4569],
        [  1.0556,  -1.2428],
        [  3.7836,  -1.6624],
        [-13.3159,  16.0436],
        [ -2.7626,   3.2842],
        [  0.6624,   2.1916],
        [  5.3586,  -4.8593],
        [  5.6811,  -4.9449],
        [  1.7900,  -1.5375],
        [  4.4604,  -5.1466],
        [ -7.6711,   8.4138],
        [  0.3067,   0.5320],
        [ 15.6580, -14.7753],
        [  2.5010,  -1.9688],
        [  3.9592,  -4.4305],
        [ -4.3696,   5.2603],
        [  8.3101,  -7.7881],
        [  6.2928,  -5.3378],
        [ -4.0520,   4.9852],
        [ -3.9593,   4.6408],
        [  4.1905,  -4.3989],
        [  4.7711,  -4.1197],
        [ -3.7677,   4.0259],
        [-11.0484,  13.2703],
        [ -6.5339,   7.7308],
        [  5.2039,  -4.4508],
        [-

tensor([[ 6.9959e+00, -5.8806e+00],
        [-3.3053e+00,  4.9027e+00],
        [ 4.1800e+00, -2.6464e+00],
        [ 1.7567e+00, -9.0583e-01],
        [ 4.0312e+00, -3.3144e+00],
        [ 7.4144e+00, -7.8256e+00],
        [-4.9486e+00,  6.2892e+00],
        [ 6.9211e+00, -7.7124e+00],
        [ 8.8124e+00, -8.8650e+00],
        [-9.2412e-01,  1.6065e+00],
        [-3.8961e+00,  3.7212e+00],
        [ 1.1587e+01, -1.0983e+01],
        [ 1.2199e+01, -1.1391e+01],
        [-1.0627e+01,  1.1667e+01],
        [-6.6763e+00,  7.2782e+00],
        [-1.9181e+00,  3.1088e+00],
        [ 5.5312e+00, -4.7274e+00],
        [ 1.1245e+01, -1.2021e+01],
        [-4.1900e+00,  3.9857e+00],
        [ 5.7881e+00, -5.9350e+00],
        [-1.1064e+01,  1.2286e+01],
        [ 1.3939e+00,  6.5009e-01],
        [-1.1141e-01,  1.0901e+00],
        [-3.8309e+00,  4.2113e+00],
        [-9.9000e+00,  1.1679e+01],
        [-7.4706e+00,  6.6401e+00],
        [ 1.0834e+01, -1.0576e+01],
        [ 1.1528e+00, -1.190

tensor([[  5.0121,  -3.6420],
        [-14.6461,  17.3984],
        [-15.5754,  17.4544],
        [ 11.1072, -10.0877],
        [  7.3001,  -8.0849],
        [ -2.7270,   3.6626],
        [ -7.9800,   9.9032],
        [ -8.8440,  10.3913],
        [ -4.8016,   4.7398],
        [ -3.7858,   4.5245],
        [  5.4018,  -3.4978],
        [  0.9678,  -0.8964],
        [  2.4077,  -2.6464],
        [ -4.7707,   5.4205],
        [  5.4233,  -5.4187],
        [ -5.6398,   6.3939],
        [ -6.7067,   7.6343],
        [  3.9856,  -4.7729],
        [  1.1727,   0.0244],
        [ 10.8352,  -9.5655],
        [ -3.3657,   5.3635],
        [  7.7905,  -8.2401],
        [  3.0067,  -3.3694],
        [ -8.8480,   9.4702],
        [ -2.1989,   3.7904],
        [ -1.5021,   1.4528],
        [  1.9216,  -2.0239],
        [ -0.3974,   3.9463],
        [-10.6889,  12.0880],
        [ -3.3388,   4.4171],
        [ -1.6887,   1.9374],
        [ -4.0997,   5.4687],
        [  0.8993,  -0.7795],
        [ 

tensor([[  6.1981,  -6.9396],
        [-10.9543,  12.5810],
        [  6.3625,  -4.6293],
        [ -0.5950,   0.3697],
        [ -0.4761,   2.6783],
        [ -2.9851,   2.8365],
        [ -1.3738,   3.0868],
        [ -3.7033,   4.5009],
        [ -9.9619,  11.0460],
        [  3.4784,  -1.1449],
        [  8.9027,  -9.1555],
        [ -3.7786,   4.1544],
        [  1.8241,   0.5643],
        [  5.6508,  -6.1139],
        [  5.3983,  -4.3203],
        [ -3.5480,   4.5348],
        [ -4.2841,   4.7356],
        [ -7.2740,   7.4623],
        [ -5.3520,   6.0265],
        [-13.6157,  15.4899],
        [-10.1455,  11.9742],
        [  4.3861,  -3.8825],
        [ -7.8056,   8.0675],
        [ -8.6544,   9.2538],
        [ -2.2631,   3.0898],
        [  3.9297,  -3.5012],
        [ 11.4051, -11.0247],
        [ -8.8971,   9.9861],
        [  3.0034,  -0.2837],
        [ -5.5686,   7.2362],
        [ -7.5217,   7.9364],
        [  0.1772,  -0.0304],
        [ 10.4533,  -9.3698],
        [ 

tensor([[ 11.6201, -10.6608],
        [  6.5438,  -6.2306],
        [ -8.4675,   8.9331],
        [ 10.9841, -10.4605],
        [  0.5849,   1.0424],
        [  3.3333,  -1.1241],
        [-15.4252,  17.7166],
        [-11.7479,  13.2742],
        [ -1.2054,   2.5959],
        [ -7.5106,   8.3000],
        [-11.9619,  14.4904],
        [ -9.1586,   9.1606],
        [ -4.4485,   5.4827],
        [ -8.3113,  10.2548],
        [ -6.0574,   8.3505],
        [  2.3125,  -1.3175],
        [ -1.4881,   2.3947],
        [  0.6966,   0.1563],
        [ 13.9546, -13.1838],
        [  9.5571,  -9.2970],
        [-10.9324,  13.2926],
        [  6.8428,  -8.0267],
        [  1.6225,  -0.8652],
        [  1.2359,  -0.2503],
        [ -6.4854,   8.5387],
        [-12.3821,  13.1495],
        [  6.5958,  -4.7482],
        [-13.2539,  16.5707],
        [ -0.4334,   1.1506],
        [  6.3257,  -6.3941],
        [  6.4691,  -6.0643],
        [  7.8220,  -8.0020],
        [ -9.7030,  10.8681],
        [ 

tensor([[ -4.6267,   7.2969],
        [ -3.0855,   3.6719],
        [  7.4035,  -7.0140],
        [ -8.8981,  11.2740],
        [ -2.7188,   3.6948],
        [ -6.5186,   8.2023],
        [ -1.6535,   2.2801],
        [ -2.4110,   3.1914],
        [  7.5669,  -9.2424],
        [  1.8971,  -2.3708],
        [  5.1476,  -4.9459],
        [  5.9190,  -4.6839],
        [ -6.4190,   6.9652],
        [  0.0962,   0.9786],
        [  6.9275,  -4.5480],
        [  6.8243,  -7.1684],
        [  6.7858,  -6.5735],
        [-12.9518,  13.6907],
        [  1.5904,  -1.2446],
        [ -6.1938,   6.8084],
        [  5.0432,  -3.0817],
        [  3.3126,  -3.3754],
        [  7.7376,  -6.9708],
        [  0.6096,   1.4760],
        [ 12.0832, -10.6994],
        [ -3.4378,   2.8634],
        [ -0.5277,   1.0990],
        [ -5.8178,   6.4955],
        [-10.8698,  11.8267],
        [ -4.4925,   5.3583],
        [ -9.3112,   8.6661],
        [ -3.4467,   4.3864],
        [-10.3086,  11.4582],
        [ 

tensor([[  3.1930,  -2.9645],
        [  8.8971,  -6.5549],
        [ -3.2628,   4.6504],
        [  4.1878,  -3.4880],
        [  0.5875,   1.6615],
        [  9.8623,  -9.0356],
        [ -3.2562,   4.4469],
        [ -6.8435,   8.9514],
        [  2.4624,  -1.2131],
        [  3.5034,  -1.9528],
        [  0.4425,  -0.2573],
        [  4.8169,  -5.0825],
        [  9.7715,  -9.1753],
        [  8.7161,  -8.7689],
        [ -1.5198,   1.2365],
        [ -1.1787,   3.0410],
        [ -7.6701,  10.2916],
        [-16.2414,  17.6097],
        [  1.3795,  -1.2523],
        [ -5.4480,   5.2778],
        [  5.0914,  -5.5866],
        [ -9.8675,  11.9677],
        [ -7.2641,   9.0741],
        [  6.9565,  -5.4752],
        [  0.2708,   0.2178],
        [  4.0658,  -3.5966],
        [ 11.9937, -12.2314],
        [ -2.4351,   2.8596],
        [ 12.2771, -12.1375],
        [ -5.1116,   5.4483],
        [ -8.3904,   9.2342],
        [ -0.3017,  -0.3298],
        [  2.3736,  -1.4644],
        [ 