## Data Loading

In [1]:
import h5py
import numpy as np
import pickle

data = h5py.File('/home/sunji/ANN/nytimes_256_angular/nytimes-256-angular.hdf5', 'r')
data_train = np.array(data['train'])
data_test = np.array(data['test'])
with open('/home/sunji/ANN/nytimes_256_angular/clusters_nytimes_256_angular.pkl', 'rb') as f:
    clusters = pickle.load(f)
with open('/home/sunji/ANN/nytimes_256_angular/ground_truth_nytimes_256_angular_0_4_0_5.pkl', 'rb') as f:
    ground_truth_total = pickle.load(f)

In [2]:
cluster_size = len(clusters)

In [3]:
for idx, c in enumerate(clusters):
    print (idx, len(c))

0 3440
1 3874
2 3050
3 3451
4 1692
5 3706
6 3331
7 3519
8 1112
9 3575
10 1166
11 1228
12 1735
13 3264
14 3195
15 2372
16 1023
17 1430
18 1482
19 5186
20 4602
21 4707
22 2990
23 2835
24 1600
25 3676
26 2222
27 2137
28 2717
29 4326
30 1547
31 2505
32 2130
33 3988
34 2460
35 2340
36 4093
37 3962
38 3492
39 1952
40 3878
41 2474
42 2141
43 2741
44 2642
45 2193
46 2937
47 4289
48 3080
49 3005
50 2962
51 2052
52 1157
53 740
54 4492
55 3340
56 5026
57 2348
58 3400
59 2421
60 5009
61 3233
62 2780
63 5000
64 1347
65 2566
66 1374
67 1656
68 4120
69 1899
70 3087
71 2751
72 2892
73 3397
74 3156
75 4429
76 2045
77 2559
78 2069
79 1844
80 1222
81 3854
82 4271
83 1979
84 2675
85 2500
86 3875
87 2067
88 3661
89 3115
90 4743
91 1664
92 3927
93 5654
94 2604
95 5129
96 1377
97 3917
98 2495
99 1628


In [4]:
ground_truth_total_level = [[[] for _ in range(10000)] for _ in range(cluster_size)]
for clus in range(cluster_size):
    for t in ground_truth_total[clus]:
        ground_truth_total_level[t[0]][t[1]].append(t)

In [5]:
clusters[7][0]

array([ 0.03958282, -0.04249519, -0.02680158,  0.05558047,  0.02956207,
        0.01495276,  0.11445832,  0.02025234,  0.05141935,  0.12908261,
       -0.0020322 , -0.09796277, -0.11000596, -0.04206981, -0.04520715,
        0.04104386,  0.17379147,  0.01629554,  0.08113959, -0.0837682 ,
        0.04311689, -0.01379663, -0.02449157, -0.00272997, -0.08257152,
        0.0285256 , -0.02959241, -0.01321774,  0.01224518,  0.04638979,
        0.04583569, -0.02181366,  0.10432127,  0.021705  ,  0.11971048,
        0.00679282, -0.05672008, -0.07022616,  0.06480139,  0.05908615,
       -0.0466856 , -0.02422996, -0.04829268, -0.08923909,  0.02421452,
       -0.03067837,  0.01344449,  0.06449297, -0.009079  , -0.01631479,
        0.09215888, -0.004335  ,  0.03206396,  0.10120524,  0.12466195,
        0.11014991, -0.05343785, -0.01973541, -0.03676889, -0.01938055,
        0.03820704, -0.07167533, -0.10006417, -0.05048365, -0.02969268,
       -0.12271964, -0.03789967,  0.10456178,  0.01343348,  0.03

In [6]:
centroids = []
for cluster in clusters:
    centroids.append(np.mean(cluster, 0))

## Prepare Inputs

In [8]:
import numpy as np
from numpy import dot
from numpy.linalg import norm
from scipy import spatial
import random
from multiprocessing import Pool

def angular_dist(x1, x2=None, eps=1e-8):
    cosine_sim = 1 - spatial.distance.cosine(x1, x2)
    distance = np.arccos(cosine_sim) / 3.14159267
    return distance 

def euclidean_dist_normalized(x1, x2=None, eps=1e-8):
    if x2 is None:
        return 1.0
    return np.sqrt(((x1 - x2) ** 2).mean())

def prepare_for_cluster(cluster_id):
    slot = 0.002
    batch_size = 128
    min_card = 1e10
    max_card = 0
    train_queries = []
    train_distances = []
    train_thresholds = []
    train_targets = []
    for query_id in range(8000):
        cardinality = 0
        for threshold_id, threshold in enumerate(np.arange(0.4, 0.5, slot)):
            cardinality += ground_truth_total_level[cluster_id][query_id][threshold_id][-1]
            if cardinality > 0 or random.random() < 0.4:
                train_queries.append(data_test[query_id])
                train_distances.append([angular_dist(data_test[query_id], centroids[cluster_id])])
                train_thresholds.append([threshold+slot])
                train_targets.append([cardinality])

    test_queries = []
    test_distances = []
    test_thresholds = []
    test_targets = []
    for query_id in range(8000,10000):
        cardinality = 0
        for threshold_id, threshold in enumerate(np.arange(0.4, 0.5, slot)):
            cardinality += ground_truth_total_level[cluster_id][query_id][threshold_id][-1]
            if cardinality > 0 or random.random() < 0.4:
                test_queries.append(data_test[query_id])
                test_distances.append([angular_dist(data_test[query_id], centroids[cluster_id])])
                test_thresholds.append([threshold+slot])
                test_targets.append([cardinality])
    
    train_loader = torch.utils.data.DataLoader(
        torch.utils.data.TensorDataset(torch.FloatTensor(train_queries), torch.FloatTensor(train_distances), torch.FloatTensor(train_thresholds), torch.FloatTensor(train_targets)), batch_size=batch_size, shuffle=True)
    test_loader = torch.utils.data.DataLoader(
        torch.utils.data.TensorDataset(torch.FloatTensor(test_queries), torch.FloatTensor(test_distances), torch.FloatTensor(test_thresholds), torch.FloatTensor(test_targets)), batch_size=batch_size, shuffle=True)
    
    return train_loader, test_loader, min_card, max_card
        

In [70]:
 np.arccos(-0.2)

1.7721542475852274

In [10]:
import torch
import torch.utils.data

train_loaders = []
test_loaders = []
min_cards = []
max_cards = []
for cluster_id in range(cluster_size):
    print (cluster_id)
    train, test, min_card, max_card = prepare_for_cluster(cluster_id)
    train_loaders.append(train)
    test_loaders.append(test)
    min_cards.append(min_card)
    max_cards.append(max_card)


0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99


In [None]:
def normalize(labels, mini, maxi):
    return (torch.log(labels) - mini) / (maxi - mini)

def unnormalize(labels, mini, maxi):
    return torch.exp(labels * (maxi - mini) + mini)


## BaseLine Local Model

In [11]:
from __future__ import print_function
import argparse
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

queries_dimension = 960
hidden_num = 128

class Test_Model(nn.Module):
    
    def __init__(self):
        super(Test_Model, self).__init__()
        self.threshold1 = nn.Linear(1, hidden_num)
        self.threshold2 = nn.Linear(hidden_num, 1)

        self.cnn_layer1 = nn.Sequential(
            nn.Conv1d(1, 8, kernel_size=3, stride=1, padding=2), 
            nn.BatchNorm1d(8),
            nn.LeakyReLU(),
            nn.AvgPool1d(kernel_size=3, stride=3))
        
        self.cnn_layer2 = nn.Sequential(
            nn.Conv1d(8, 16, kernel_size=3, stride=1, padding=2), 
            nn.BatchNorm1d(16),
            nn.LeakyReLU(),
            nn.AvgPool1d(kernel_size=3, stride=3))
        
        self.out1 = nn.Linear(368, hidden_num)
        self.out2 = nn.Linear(hidden_num, 1)
        
    def forward(self, query, threshold):
        
        query = query.unsqueeze(2).permute(0,2,1)
#         print (query.shape)
        threshold = F.relu(self.threshold1(threshold))
        threshold = self.threshold2(threshold)
#         print (threshold.shape)
        query = self.cnn_layer1(query)
        query = self.cnn_layer2(query)
        query = query.view(query.shape[0], -1)
#         print (query.shape)
        query = self.out1(query)
        
        output = self.out2(F.relu(query+threshold))
        
        return output

def begin_test_model():
#     models = []
#     errors = []
#     for idx in range(100):
    idx = 0
    print ('idx: {}'.format(idx))
    train = train_loaders[idx]
    test = test_loaders[idx]
    mini = min_cards[idx]
    maxi = max_cards[idx]
    episode = 5
    queries_dimension = 960
    model = Test_Model()
    opt = optim.Adam(model.parameters(), lr=0.01)
    error = test_model(model, opt, train, test, episode)
    models.append(model)
    errors.append(error)

def test_model(model, opt, train, test, episode):
    print ('size: {}'.format(len(train)))
    test_errors = []
    for e in range(episode):
        model.train()
        for batch_idx, (queries, _, thresholds, targets) in enumerate(train):
    #         print (torch.cat((queries, thresholds), 1)[0])
            queries = Variable(queries)
            thresholds = Variable(thresholds)
            targets = Variable(targets)
    #         print (targets)
            opt.zero_grad()
            
            estimates = model(queries, thresholds)
            
            loss = l1_loss(estimates, targets)
            loss.backward()
            opt.step()
#             if batch_idx % 100 == 0:
#                 print('Training: Iteration {0}, Batch {1}, Loss {2}'.format(e, batch_idx, loss.item()))
#                 print(model.cnn_layer1[0].weight)
            for p in model.parameters():
                p.data.clamp_(-2, 2)
        model.eval()
        test_loss = 0.0
        mse_error = 0.0
        q_mean = 0.0
        q_max = 0.0
        for batch_idx, (queries, _, thresholds, targets) in enumerate(test):
            queries = Variable(queries)
            thresholds = Variable(thresholds)
            targets = Variable(targets)
            
            estimates = model(queries, thresholds)
            
            loss = l1_loss(estimates, targets)
            mse, qer_mean, qer_max = print_loss(estimates, targets)
            test_loss += loss.item()
            mse_error += mse.item()
            q_mean += qer_mean
            if qer_max > q_max:
                q_max = qer_max
        test_loss /= len(test)
        mse_error /= len(test)
        q_mean /= len(test)
        test_errors.append(q_mean)
        print ('Testing: Iteration {0}, Loss {1}, MSE_error {2}, Q_error_mean {3}, Q_error_max {4}'.format(e, test_loss, mse_error, q_mean, q_max))
    return np.mean(test_errors[-3:])

In [None]:
begin_test_model()

## Stacked Model Definition

In [66]:
from __future__ import print_function
import argparse
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

# queries_dimension = 200
# hidden_num = 128

class Threshold_Model(nn.Module):
    
    def __init__(self):
        super(Threshold_Model, self).__init__()
        self.fc1 = nn.Linear(1, hidden_num)
        self.fc2 = nn.Linear(hidden_num, 1)
    
    def forward(self, threshold):
        t1 = F.relu(self.fc1(threshold))
        t2 = self.fc2(t1)
        return t2

class Distance_Model(nn.Module):
    
    def __init__(self):
        super(Distance_Model, self).__init__()
        self.fc1 = nn.Linear(1, hidden_num)
        self.fc2 = nn.Linear(hidden_num, 1)
    
    def forward(self, distance):
        t1 = F.relu(self.fc1(distance))
        t2 = self.fc2(t1)
        return t2


class CNN_Model(nn.Module):
    
    def __init__(self, in_channel, out_channel, kernel_size, stride, padding, pool_type, pool_size):
        super(CNN_Model, self).__init__()
        if pool_type == 0:
            pool_layer = nn.MaxPool1d(kernel_size=pool_size, stride=pool_size)
        elif pool_type == 1:
            pool_layer = nn.AvgPool1d(kernel_size=pool_size, stride=pool_size)
        else:
            print ('CNN_Model Init Error, invalid pool_type {}'.format(pool_type))
            return
        self.layer = nn.Sequential(
            nn.Conv1d(in_channel, out_channel, kernel_size=kernel_size, stride=stride, padding=padding), 
            nn.BatchNorm1d(out_channel),
            nn.ReLU(),
            pool_layer)
        
    def forward(self, inputs):
        hid = self.layer(inputs)
#         print (hid.shape)
#         hid = F.relu(self.n3(hid))
#         hid = F.relu(self.n4(hid))
#         hid = self.norm2(hid)
#         print (hid.shape)
#         out2 = self.fc(hid.view(out1.shape[0], -1))
        return hid

class Output_Model(nn.Module):
    
    def __init__(self, inputs_dim):
        super(Output_Model, self).__init__()
        self.fc1 = nn.Linear(inputs_dim + 2, hidden_num)
        self.fc2 = nn.Linear(hidden_num, 1)
        
    def forward(self, queries, threshold, distance):
        out1 = F.relu(self.fc1(torch.cat((queries, distance, threshold), dim=1)))
        out2 = out1
#         print ('out2: {0}, threshold: {1}'.format(out2.shape, threshold.shape))
        out3 = self.fc2(out2)
        return out3

class Model(nn.Module):
    
    def __init__(self):
        super(Model, self).__init__()
        self.nn1 = nn.Linear(queries_dimension+1, hidden_num)
        self.n1 = nn.Linear(hidden_num, hidden_num)
        self.n2 = nn.Linear(hidden_num, hidden_num)
#         self.n3 = nn.Linear(hidden_num, hidden_num)
#         self.n4 = nn.Linear(hidden_num, hidden_num)
        self.nn2 = nn.Linear(hidden_num, 1)

    def forward(self, queries, threshold):
        out1 = F.relu(self.nn1(torch.cat([queries, threshold],1)))
        hid = out1
        hid = F.relu(self.n1(hid))
        hid = F.relu(self.n2(hid))
#         hid = F.relu(self.n3(hid))
#         hid = F.relu(self.n4(hid))
#         hid = self.norm2(hid)
        out2 = self.nn2(hid)
        return out2

def loss_fn(estimates, targets, mini, maxi):
    est = unnormalize(estimates, mini, maxi)
    print (torch.cat((est, targets), 1))
    return F.mse_loss(est, targets)

def l1_loss(estimates, targets, eps=1e-5):
    estimates = torch.exp(estimates)
    qerror = 0.0
    targets += 1.0
    for i in range(estimates.shape[0]):
        if estimates[i] > targets[i]:
            qerror += ((estimates[i] / (targets[i])))
        else:
            qerror += (((targets[i]) / estimates[i]))
    return qerror / estimates.shape[0]

def mse_loss(estimates, targets, eps=1e-5):
#     print (torch.cat((estimates, targets), 1))
    return F.mse_loss(estimates, torch.log(targets))

def qerror_loss(preds, targets, mini, maxi):
    qerror = []
    preds = unnormal1ize_label(preds, mini, maxi)
#     print (torch.cat((preds, targets), 1))
    for i in range(len(targets)):
        if (preds[i] > targets[i]).cpu().data.numpy()[0]:
            qerror.append(preds[i]/targets[i])
        else:
            qerror.append(targets[i]/(preds[i] + 0.1))
    return torch.mean(torch.cat(qerror) ** 2)

def print_loss(estimates, targets):
    esti = torch.exp(estimates)
#     print (torch.cat((estimates, esti, targets), 1))
    qerror = []
    targets += 1
    for i in range(esti.shape[0]):
        if esti[i] > targets[i]:
            qerror.append((esti[i] / (targets[i])).item())
        else:
            qerror.append(((targets[i]) / esti[i]).item())
    
    return F.mse_loss(esti, targets), np.mean(qerror), np.max(qerror)

## Hyper-parameters Selection Methods

In [67]:
from random import sample

def repair_specific_local_model(errors, next_cnn_parameterss, next_cnn_modelss, next_output_models, threshold_models, cluster_id):
    clus = cluster_id
    print ('Begin Cluster: {}'.format(clus))
    idx = clus
    train = train_loaders[idx]
    test = test_loaders[idx]
    mini = min_cards[idx]
    maxi = max_cards[idx]
    prev_best_error = 100000.0
    cnn_parameters, cnn_models = [], []
    episode = 5
    queries_dimension = 256
    threshold_model = Threshold_Model()
    error, next_cnn_parameters, next_cnn_models,next_output_model = select_best_layer(prev_best_error, cnn_parameters, cnn_models, threshold_model, train, test, episode, queries_dimension)
    saved_error, saved_next_cnn_parameters, saved_next_cnn_models,saved_next_output_model = error, next_cnn_parameters, next_cnn_models,next_output_model
    while error is not None:
        saved_error, saved_next_cnn_parameters, saved_next_cnn_models,saved_next_output_model = error, next_cnn_parameters, next_cnn_models,next_output_model
        print ('Cluster: {}, Error: {}, CNN Layer Num: {}, Added CNN Layer: {}'.format(clus, error, len(next_cnn_parameters), next_cnn_parameters[-1]))
        error, next_cnn_parameters, next_cnn_models,next_output_model = select_best_layer(error, next_cnn_parameters, next_cnn_models, threshold_model, train, test, episode, queries_dimension)
    errors[clus] = saved_error
    next_cnn_parameterss[clus] = saved_next_cnn_parameters
    next_cnn_modelss[clus] = saved_next_cnn_models
    next_output_models[clus] = saved_next_output_model
    threshold_models[clus] = threshold_model

def construct_model():
    errors = []
    next_cnn_parameterss = []
    next_cnn_modelss = []
    next_output_models = []
    threshold_models = []
    distance_models = []
    for clus in range(cluster_size):
        print ('Begin Cluster: {}'.format(clus))
        idx = clus
        train = train_loaders[idx]
        test = test_loaders[idx]
        mini = min_cards[idx]
        maxi = max_cards[idx]
        prev_best_error = 100000.0
        cnn_parameters, cnn_models = [], []
        episode = 5
        queries_dimension = 256
        threshold_model = Threshold_Model()
        distance_model = Distance_Model()
        error, next_cnn_parameters, next_cnn_models,next_output_model = select_best_layer(prev_best_error, cnn_parameters, cnn_models, threshold_model, distance_model, train, test, episode, queries_dimension)
        saved_error, saved_next_cnn_parameters, saved_next_cnn_models,saved_next_output_model = error, next_cnn_parameters, next_cnn_models,next_output_model
        while error is not None:
            saved_error, saved_next_cnn_parameters, saved_next_cnn_models,saved_next_output_model = error, next_cnn_parameters, next_cnn_models,next_output_model
            print ('Cluster: {}, Error: {}, CNN Layer Num: {}, Added CNN Layer: {}'.format(clus, error, len(next_cnn_parameters), next_cnn_parameters[-1]))
            error, next_cnn_parameters, next_cnn_models,next_output_model = select_best_layer(error, next_cnn_parameters, next_cnn_models, threshold_model, train, test, episode, queries_dimension)
        errors.append(saved_error)
        next_cnn_parameterss.append(saved_next_cnn_parameters)
        next_cnn_modelss.append(saved_next_cnn_models)
        next_output_models.append(saved_next_output_model)
        threshold_models.append(threshold_model)
        distance_models.append(distance_model)
    return errors, next_cnn_parameterss, next_cnn_modelss, next_output_models, threshold_models

class TunableParameters():
    
    def __init__(self, out_channel, kernel_size, stride, padding, pool_size, pool_type):
        self.out_channel = out_channel
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.pool_size = pool_size
        self.pool_type = pool_type
    
    def __repr__(self):
        return str(self.out_channel) +' '+ str(self.kernel_size) +' '+ str(self.stride) +' '+ str(self.padding) +' '+ str(self.pool_size) +' '+ str(self.pool_type)
 
    def __str__(self):
        return str(self.out_channel) +' '+ str(self.kernel_size) +' '+ str(self.stride) +' '+ str(self.padding) +' '+ str(self.pool_size) +' '+ str(self.pool_type)

def select_best_layer(prev_best_error, cnn_parameters, cnn_models, threshold_model, distance_model, train, test, episode, queries_dimension):
    print ('Input Model Size: {}'.format(len(cnn_parameters)))
    if len(cnn_parameters) > 0:
        in_channel = cnn_parameters[-1].out_channel
    else:
        in_channel = 1
    in_size = queries_dimension
    for para in cnn_parameters:
        in_size = int((int((in_size - para.kernel_size + 2*(para.padding)) / para.stride) + 1) / para.pool_size)
        print(para.kernel_size, para.padding, para.stride, para.pool_size, in_size)
    
    if in_size < 10 or len(cnn_parameters) > 5:
        return None, None, None, None
    
    current_paras = []
    current_paras.append(TunableParameters(8,10,1,3,10,0))
    current_paras.append(TunableParameters(4,5,3,2,5,0))
#     current_paras.append(TunableParameters(4,3,1,0,3,0))
    current_paras.append(TunableParameters(2,2,1,0,2,0))
    
#     for out_channel in [2,4,8]:
#         for kernel_size in [2,4,8]:
#             for stride in range(1, min(4, kernel_size)):
#                 for padding in [0,2]:
#                     for pool_size in [kernel_size,]:
#                         for pool_type in [0,]:
#                             current_paras.append(TunableParameters(out_channel,kernel_size,stride,padding,pool_size,pool_type))
    print ('Group of parameters: {}'.format(len(current_paras)))
    next_cnn_models = []
    next_cnn_parameters = []
    next_output_model = None
#     current_paras = sample(current_paras, 2)
    for para in current_paras:
        print (para)
        in_size_local = int((int((in_size - para.kernel_size + 2*(para.padding)) / para.stride) + 1) / para.pool_size)
        if in_size_local < 10:
            continue
        print (in_size_local, para.out_channel)
        output_model = Output_Model(in_size_local * para.out_channel)
        added_cnn_layer = CNN_Model(in_channel, para.out_channel, para.kernel_size, para.stride, para.padding, para.pool_type, para.pool_size)
        paras = [{"params": model.parameters()} for model in cnn_models]
        paras.append({"params": added_cnn_layer.parameters()})
        paras.append({"params": threshold_model.parameters()})
        paras.append({"params": distance_model.parameters()})
        paras.append({"params": output_model.parameters()})
        opt = optim.Adam(paras, lr=0.001)
        new_cnn_models = []
        for model in cnn_models:
            new_cnn_models.append(model)
        new_cnn_models.append(added_cnn_layer)
        error = train_and_test(distance_model, new_cnn_models, threshold_model, output_model, opt, train, test, episode)
        if error < prev_best_error - 0.1:
            prev_best_error = error
            new_cnn_parameters = []
            for para_old in cnn_parameters:
                new_cnn_parameters.append(para_old)
            next_output_model = output_model
            new_cnn_parameters.append(para)
            print ('Update layer: {}'.format(len(new_cnn_parameters)))
            next_cnn_parameters = new_cnn_parameters
            next_cnn_models = new_cnn_models
    if len(next_cnn_models) == 0:
        return None, None, None, None
    return prev_best_error, next_cnn_parameters, next_cnn_models, next_output_model
        

def only_test(distance_model, cnn_models, threshold_model, output_model, test):
    for model in cnn_models:
        model.eval()
    threshold_model.eval()
    output_model.eval()
    q_errors = []
    for batch_idx, (queries, distances, thresholds, targets) in enumerate(test):
        queries = Variable(queries)
        thresholds = Variable(thresholds)
        distances = Variable(distances)
        targets = Variable(targets)

        queries = queries.unsqueeze(2).permute(0,2,1)
        for model in cnn_models:
            queries = model(queries)
        threshold = threshold_model(thresholds)
#         distances = distance_model(distances)
        queries = queries.view(queries.shape[0], -1)
        estimates = output_model(queries, threshold, distances)

        loss = l1_loss(estimates, targets)
        
        esti = torch.exp(estimates)
        for i in range(esti.shape[0]):
            if esti[i] > targets[i] + 0.1:
                q_errors.append((esti[i] / (targets[i] + 0.1)).item())
            else:
                q_errors.append(((targets[i] + 0.1) / esti[i]).item())
    mean = np.mean(q_errors)
    percent90 = np.percentile(q_errors, 90)
    percent95 = np.percentile(q_errors, 95)
    percent99 = np.percentile(q_errors, 99)
    median = np.median(q_errors)
    maxi = np.max(q_errors)
    print ('Testing: Mean Error {}, Median Error {}, 90 Percent {}, 95 Percent {}, 99 Percent {}, Max Percent {}'
           .format(mean, median, percent90, percent95, percent99, maxi))
    
    
def train_and_test(distance_model, cnn_models, threshold_model, output_model, opt, train, test, episode):
    print ('size: {}'.format(len(train)))
    test_errors = []
    for e in range(episode):
        for model in cnn_models:
            model.train()
        distance_model.train()
        threshold_model.train()
        output_model.train()
        for batch_idx, (queries, distances, thresholds, targets) in enumerate(train):
    #         print (torch.cat((queries, thresholds), 1)[0])
            queries = Variable(queries)
            thresholds = Variable(thresholds)
            distances = Variable(distances)
            targets = Variable(targets)
    #         print (targets)
            opt.zero_grad()
            queries = queries.unsqueeze(2).permute(0,2,1)
            for model in cnn_models:
                queries = model(queries)
            threshold = threshold_model(thresholds)
#             distances = distance_model(distances)
            queries = queries.view(queries.shape[0], -1) 
            estimates = output_model(queries, threshold, distances)
            print (torch.cat((distances, torch.exp(estimates), targets), dim=1))
            
            loss = l1_loss(estimates, targets)
            loss.backward()
            opt.step()
#             for p in distance_model.parameters():
#                 p.data.clamp_(-1, 1)
#             for p in threshold_model.parameters():
#                 p.data.clamp_(-1, 1)
            next(threshold_model.fc1.parameters()).data.clamp_(0)
            next(threshold_model.fc2.parameters()).data.clamp_(0)
            next(output_model.fc2.parameters()).data.clamp_(0)
#             if batch_idx % 100 == 0:
#                 print('Training: Iteration {0}, Batch {1}, Loss {2}'.format(e, batch_idx, loss.item()))
#                 print(cnn_models[0].layer[0].weight.grad)
        for model in cnn_models:
            model.eval()
        distance_model.eval()
        threshold_model.eval()
        output_model.eval()
        test_loss = 0.0
        mse_error = 0.0
        q_mean = 0.0
        q_max = 0.0
        for batch_idx, (queries, distances, thresholds, targets) in enumerate(test):
            queries = Variable(queries)
            thresholds = Variable(thresholds)
            distances = Variable(distances)
            targets = Variable(targets)
            
            queries = queries.unsqueeze(2).permute(0,2,1)
            for model in cnn_models:
                queries = model(queries)
            threshold = threshold_model(thresholds)
#             distances = distance_model(distances)
            queries = queries.view(queries.shape[0], -1)
            estimates = output_model(queries, threshold, distances)
            
            loss = l1_loss(estimates, targets)
            mse, qer_mean, qer_max = print_loss(estimates, targets)
            test_loss += loss.item()
            mse_error += mse.item()
            q_mean += qer_mean
            if qer_max > q_max:
                q_max = qer_max
        test_loss /= len(test)
        mse_error /= len(test)
        q_mean /= len(test)
        test_errors.append(q_mean)
        print ('Testing: Iteration {0}, Loss {1}, MSE_error {2}, Q_error_mean {3}, Q_error_max {4}'.format(e, test_loss, mse_error, q_mean, q_max))
    return np.mean(test_errors[-3:])

## Select Model Hyper-parameters

In [71]:
errors, next_cnn_parameterss, next_cnn_modelss, next_output_models, threshold_models = construct_model()

Begin Cluster: 0
Input Model Size: 0
Group of parameters: 3
8 10 1 3 10 0
25 8
size: 1423
tensor([[4.4898e-01, 1.2956e+00, 0.0000e+00],
        [4.0354e-01, 1.2562e+00, 3.4400e+03],
        [4.7352e-01, 1.3042e+00, 0.0000e+00],
        [4.6407e-01, 1.2699e+00, 3.4400e+03],
        [4.7708e-01, 1.2864e+00, 3.4400e+03],
        [4.8106e-01, 1.2567e+00, 0.0000e+00],
        [4.6336e-01, 1.1915e+00, 0.0000e+00],
        [4.4429e-01, 1.2954e+00, 0.0000e+00],
        [4.7985e-01, 1.2094e+00, 0.0000e+00],
        [4.4060e-01, 1.2689e+00, 0.0000e+00],
        [4.9596e-01, 1.2683e+00, 0.0000e+00],
        [4.6224e-01, 1.2322e+00, 0.0000e+00],
        [5.0762e-01, 1.4368e+00, 0.0000e+00],
        [4.8462e-01, 1.2645e+00, 0.0000e+00],
        [4.6952e-01, 1.4650e+00, 0.0000e+00],
        [4.4266e-01, 1.2598e+00, 0.0000e+00],
        [4.9733e-01, 1.2684e+00, 0.0000e+00],
        [5.0699e-01, 1.1885e+00, 0.0000e+00],
        [4.9052e-01, 1.2743e+00, 0.0000e+00],
        [5.1660e-01, 1.1591e+00, 0.0

tensor([[4.5496e-01,        nan, 0.0000e+00],
        [4.3404e-01,        nan, 0.0000e+00],
        [5.0792e-01,        nan, 0.0000e+00],
        [4.7890e-01,        nan, 0.0000e+00],
        [4.6875e-01,        nan, 0.0000e+00],
        [4.9721e-01,        nan, 0.0000e+00],
        [5.0330e-01,        nan, 0.0000e+00],
        [4.4970e-01,        nan, 3.4400e+03],
        [4.3095e-01,        nan, 0.0000e+00],
        [4.5672e-01,        nan, 3.4400e+03],
        [5.1488e-01,        nan, 3.4400e+03],
        [4.5794e-01,        nan, 3.4400e+03],
        [4.7360e-01,        nan, 0.0000e+00],
        [4.4711e-01,        nan, 0.0000e+00],
        [4.2908e-01,        nan, 3.4400e+03],
        [4.5914e-01,        nan, 0.0000e+00],
        [4.8612e-01,        nan, 0.0000e+00],
        [4.5436e-01,        nan, 0.0000e+00],
        [4.8941e-01,        nan, 3.4400e+03],
        [4.4513e-01,        nan, 0.0000e+00],
        [4.5784e-01,        nan, 3.4400e+03],
        [4.9904e-01,        nan, 0

tensor([[4.5303e-01,        nan, 0.0000e+00],
        [4.9222e-01,        nan, 3.4400e+03],
        [5.1612e-01,        nan, 3.4400e+03],
        [4.8111e-01,        nan, 0.0000e+00],
        [5.2929e-01,        nan, 0.0000e+00],
        [4.8217e-01,        nan, 3.4400e+03],
        [4.4354e-01,        nan, 0.0000e+00],
        [4.8618e-01,        nan, 0.0000e+00],
        [5.1895e-01,        nan, 0.0000e+00],
        [4.5319e-01,        nan, 0.0000e+00],
        [5.3456e-01,        nan, 0.0000e+00],
        [4.8200e-01,        nan, 0.0000e+00],
        [4.6266e-01,        nan, 0.0000e+00],
        [4.6733e-01,        nan, 0.0000e+00],
        [5.1987e-01,        nan, 3.4400e+03],
        [5.5098e-01,        nan, 0.0000e+00],
        [4.5873e-01,        nan, 0.0000e+00],
        [4.5564e-01,        nan, 0.0000e+00],
        [4.9899e-01,        nan, 0.0000e+00],
        [4.9579e-01,        nan, 0.0000e+00],
        [4.3289e-01,        nan, 0.0000e+00],
        [4.4937e-01,        nan, 0

tensor([[4.6327e-01,        nan, 3.4400e+03],
        [5.0630e-01,        nan, 0.0000e+00],
        [5.0876e-01,        nan, 0.0000e+00],
        [5.0413e-01,        nan, 0.0000e+00],
        [4.8784e-01,        nan, 0.0000e+00],
        [4.6667e-01,        nan, 0.0000e+00],
        [4.9188e-01,        nan, 0.0000e+00],
        [4.9964e-01,        nan, 0.0000e+00],
        [5.0329e-01,        nan, 0.0000e+00],
        [4.7991e-01,        nan, 0.0000e+00],
        [4.5288e-01,        nan, 0.0000e+00],
        [4.9351e-01,        nan, 0.0000e+00],
        [4.8970e-01,        nan, 0.0000e+00],
        [4.5732e-01,        nan, 0.0000e+00],
        [4.6010e-01,        nan, 0.0000e+00],
        [5.0609e-01,        nan, 0.0000e+00],
        [4.7118e-01,        nan, 0.0000e+00],
        [4.8025e-01,        nan, 0.0000e+00],
        [4.7125e-01,        nan, 3.4400e+03],
        [4.4236e-01,        nan, 0.0000e+00],
        [5.0240e-01,        nan, 3.4400e+03],
        [5.1347e-01,        nan, 0

tensor([[4.8799e-01,        nan, 0.0000e+00],
        [4.7525e-01,        nan, 0.0000e+00],
        [4.7776e-01,        nan, 0.0000e+00],
        [5.1907e-01,        nan, 3.4400e+03],
        [4.4988e-01,        nan, 0.0000e+00],
        [4.9277e-01,        nan, 0.0000e+00],
        [4.5250e-01,        nan, 0.0000e+00],
        [4.3896e-01,        nan, 0.0000e+00],
        [5.3101e-01,        nan, 0.0000e+00],
        [4.6547e-01,        nan, 0.0000e+00],
        [4.7823e-01,        nan, 0.0000e+00],
        [4.3372e-01,        nan, 3.4400e+03],
        [4.7836e-01,        nan, 3.4400e+03],
        [4.7089e-01,        nan, 0.0000e+00],
        [4.8296e-01,        nan, 0.0000e+00],
        [4.4325e-01,        nan, 0.0000e+00],
        [5.3535e-01,        nan, 0.0000e+00],
        [4.7712e-01,        nan, 0.0000e+00],
        [4.7656e-01,        nan, 3.4400e+03],
        [4.6196e-01,        nan, 0.0000e+00],
        [4.4369e-01,        nan, 3.4400e+03],
        [5.1241e-01,        nan, 0

tensor([[4.7240e-01,        nan, 3.4400e+03],
        [4.9610e-01,        nan, 0.0000e+00],
        [5.0272e-01,        nan, 0.0000e+00],
        [4.9477e-01,        nan, 0.0000e+00],
        [4.8691e-01,        nan, 0.0000e+00],
        [5.3312e-01,        nan, 0.0000e+00],
        [4.7790e-01,        nan, 0.0000e+00],
        [5.4188e-01,        nan, 0.0000e+00],
        [5.5697e-01,        nan, 0.0000e+00],
        [5.4383e-01,        nan, 0.0000e+00],
        [4.5384e-01,        nan, 0.0000e+00],
        [5.0408e-01,        nan, 0.0000e+00],
        [5.1689e-01,        nan, 0.0000e+00],
        [4.8404e-01,        nan, 0.0000e+00],
        [4.7913e-01,        nan, 0.0000e+00],
        [4.8941e-01,        nan, 3.4400e+03],
        [5.0236e-01,        nan, 0.0000e+00],
        [4.8063e-01,        nan, 3.4400e+03],
        [4.8045e-01,        nan, 0.0000e+00],
        [4.8414e-01,        nan, 0.0000e+00],
        [4.9735e-01,        nan, 0.0000e+00],
        [4.4175e-01,        nan, 0

tensor([[4.7041e-01,        nan, 0.0000e+00],
        [4.8602e-01,        nan, 0.0000e+00],
        [4.7419e-01,        nan, 0.0000e+00],
        [5.0916e-01,        nan, 0.0000e+00],
        [4.7526e-01,        nan, 0.0000e+00],
        [4.9018e-01,        nan, 3.4400e+03],
        [4.8842e-01,        nan, 0.0000e+00],
        [5.1454e-01,        nan, 0.0000e+00],
        [4.8736e-01,        nan, 0.0000e+00],
        [5.1733e-01,        nan, 0.0000e+00],
        [4.9754e-01,        nan, 0.0000e+00],
        [4.9173e-01,        nan, 3.4400e+03],
        [4.8402e-01,        nan, 0.0000e+00],
        [5.1488e-01,        nan, 0.0000e+00],
        [4.9027e-01,        nan, 0.0000e+00],
        [5.1499e-01,        nan, 0.0000e+00],
        [4.3894e-01,        nan, 0.0000e+00],
        [4.8551e-01,        nan, 0.0000e+00],
        [4.3607e-01,        nan, 0.0000e+00],
        [5.0100e-01,        nan, 3.4400e+03],
        [5.1489e-01,        nan, 3.4400e+03],
        [4.9422e-01,        nan, 0

tensor([[4.3065e-01,        nan, 3.4400e+03],
        [5.0528e-01,        nan, 0.0000e+00],
        [4.6306e-01,        nan, 0.0000e+00],
        [4.7005e-01,        nan, 0.0000e+00],
        [4.3853e-01,        nan, 0.0000e+00],
        [4.7338e-01,        nan, 0.0000e+00],
        [4.9908e-01,        nan, 0.0000e+00],
        [       nan,        nan, 0.0000e+00],
        [4.9406e-01,        nan, 0.0000e+00],
        [5.0943e-01,        nan, 0.0000e+00],
        [4.1651e-01,        nan, 0.0000e+00],
        [4.7015e-01,        nan, 0.0000e+00],
        [4.6562e-01,        nan, 3.4400e+03],
        [4.9737e-01,        nan, 0.0000e+00],
        [4.7152e-01,        nan, 0.0000e+00],
        [4.5749e-01,        nan, 0.0000e+00],
        [4.9763e-01,        nan, 0.0000e+00],
        [4.7008e-01,        nan, 0.0000e+00],
        [4.8201e-01,        nan, 0.0000e+00],
        [4.4741e-01,        nan, 3.4400e+03],
        [4.6097e-01,        nan, 0.0000e+00],
        [4.6734e-01,        nan, 3

tensor([[4.7720e-01,        nan, 0.0000e+00],
        [4.4778e-01,        nan, 0.0000e+00],
        [5.0447e-01,        nan, 0.0000e+00],
        [5.2365e-01,        nan, 0.0000e+00],
        [4.4539e-01,        nan, 3.4400e+03],
        [4.6273e-01,        nan, 0.0000e+00],
        [4.7455e-01,        nan, 3.4400e+03],
        [5.0577e-01,        nan, 0.0000e+00],
        [4.6508e-01,        nan, 0.0000e+00],
        [4.7180e-01,        nan, 0.0000e+00],
        [4.5898e-01,        nan, 0.0000e+00],
        [4.5885e-01,        nan, 0.0000e+00],
        [4.7022e-01,        nan, 0.0000e+00],
        [5.1496e-01,        nan, 0.0000e+00],
        [4.9142e-01,        nan, 0.0000e+00],
        [4.6677e-01,        nan, 3.4400e+03],
        [4.3993e-01,        nan, 0.0000e+00],
        [4.3773e-01,        nan, 0.0000e+00],
        [4.8618e-01,        nan, 0.0000e+00],
        [4.4593e-01,        nan, 0.0000e+00],
        [4.6738e-01,        nan, 0.0000e+00],
        [4.9599e-01,        nan, 0

tensor([[4.7436e-01,        nan, 0.0000e+00],
        [4.7320e-01,        nan, 0.0000e+00],
        [4.7275e-01,        nan, 3.4400e+03],
        [4.9193e-01,        nan, 0.0000e+00],
        [4.4724e-01,        nan, 0.0000e+00],
        [4.6605e-01,        nan, 3.4400e+03],
        [5.0285e-01,        nan, 0.0000e+00],
        [4.9730e-01,        nan, 3.4400e+03],
        [4.7659e-01,        nan, 3.4400e+03],
        [5.0755e-01,        nan, 0.0000e+00],
        [4.8063e-01,        nan, 0.0000e+00],
        [4.9383e-01,        nan, 0.0000e+00],
        [4.9619e-01,        nan, 0.0000e+00],
        [4.7521e-01,        nan, 0.0000e+00],
        [4.5453e-01,        nan, 0.0000e+00],
        [4.1383e-01,        nan, 3.4400e+03],
        [5.1469e-01,        nan, 0.0000e+00],
        [5.1127e-01,        nan, 0.0000e+00],
        [4.6602e-01,        nan, 0.0000e+00],
        [4.6421e-01,        nan, 0.0000e+00],
        [4.7899e-01,        nan, 0.0000e+00],
        [4.9221e-01,        nan, 0

tensor([[4.7699e-01,        nan, 0.0000e+00],
        [4.5125e-01,        nan, 0.0000e+00],
        [4.6656e-01,        nan, 0.0000e+00],
        [4.8199e-01,        nan, 0.0000e+00],
        [4.3733e-01,        nan, 3.4400e+03],
        [4.5560e-01,        nan, 0.0000e+00],
        [4.8201e-01,        nan, 0.0000e+00],
        [4.9354e-01,        nan, 0.0000e+00],
        [4.9377e-01,        nan, 0.0000e+00],
        [5.0286e-01,        nan, 0.0000e+00],
        [4.8530e-01,        nan, 3.4400e+03],
        [4.7167e-01,        nan, 0.0000e+00],
        [4.8684e-01,        nan, 3.4400e+03],
        [4.6648e-01,        nan, 3.4400e+03],
        [4.6882e-01,        nan, 0.0000e+00],
        [4.5694e-01,        nan, 0.0000e+00],
        [4.8391e-01,        nan, 0.0000e+00],
        [4.9390e-01,        nan, 0.0000e+00],
        [4.1247e-01,        nan, 0.0000e+00],
        [4.3770e-01,        nan, 0.0000e+00],
        [4.7869e-01,        nan, 0.0000e+00],
        [4.8830e-01,        nan, 0

tensor([[4.8211e-01,        nan, 0.0000e+00],
        [4.7411e-01,        nan, 3.4400e+03],
        [4.2752e-01,        nan, 0.0000e+00],
        [4.3885e-01,        nan, 3.4400e+03],
        [4.7932e-01,        nan, 0.0000e+00],
        [4.2410e-01,        nan, 3.4400e+03],
        [4.7828e-01,        nan, 0.0000e+00],
        [5.3535e-01,        nan, 3.4400e+03],
        [4.2551e-01,        nan, 0.0000e+00],
        [4.6855e-01,        nan, 0.0000e+00],
        [4.5923e-01,        nan, 0.0000e+00],
        [4.4929e-01,        nan, 0.0000e+00],
        [4.9420e-01,        nan, 3.4400e+03],
        [4.5622e-01,        nan, 0.0000e+00],
        [4.1821e-01,        nan, 0.0000e+00],
        [4.9575e-01,        nan, 0.0000e+00],
        [4.7988e-01,        nan, 0.0000e+00],
        [4.6335e-01,        nan, 0.0000e+00],
        [5.3014e-01,        nan, 0.0000e+00],
        [5.0974e-01,        nan, 3.4400e+03],
        [4.5348e-01,        nan, 3.4400e+03],
        [4.8616e-01,        nan, 0

tensor([[4.5335e-01,        nan, 0.0000e+00],
        [4.4908e-01,        nan, 0.0000e+00],
        [4.6577e-01,        nan, 3.4400e+03],
        [5.0148e-01,        nan, 0.0000e+00],
        [4.9156e-01,        nan, 0.0000e+00],
        [4.8248e-01,        nan, 0.0000e+00],
        [4.4419e-01,        nan, 3.4400e+03],
        [5.0358e-01,        nan, 0.0000e+00],
        [4.7844e-01,        nan, 0.0000e+00],
        [4.7521e-01,        nan, 0.0000e+00],
        [4.8561e-01,        nan, 0.0000e+00],
        [4.8559e-01,        nan, 0.0000e+00],
        [4.7835e-01,        nan, 3.4400e+03],
        [4.8745e-01,        nan, 0.0000e+00],
        [5.1633e-01,        nan, 0.0000e+00],
        [4.5987e-01,        nan, 3.4400e+03],
        [4.5689e-01,        nan, 0.0000e+00],
        [4.7773e-01,        nan, 0.0000e+00],
        [5.7117e-01,        nan, 0.0000e+00],
        [5.0116e-01,        nan, 0.0000e+00],
        [4.9173e-01,        nan, 3.4400e+03],
        [5.0271e-01,        nan, 0

tensor([[4.5619e-01,        nan, 0.0000e+00],
        [4.9364e-01,        nan, 0.0000e+00],
        [4.6340e-01,        nan, 0.0000e+00],
        [5.2030e-01,        nan, 0.0000e+00],
        [5.0114e-01,        nan, 3.4400e+03],
        [4.9150e-01,        nan, 3.4400e+03],
        [4.4598e-01,        nan, 3.4400e+03],
        [4.5556e-01,        nan, 3.4400e+03],
        [4.6984e-01,        nan, 0.0000e+00],
        [4.1501e-01,        nan, 3.4400e+03],
        [4.3929e-01,        nan, 0.0000e+00],
        [4.1651e-01,        nan, 0.0000e+00],
        [4.4424e-01,        nan, 3.4400e+03],
        [4.7986e-01,        nan, 3.4400e+03],
        [4.8124e-01,        nan, 0.0000e+00],
        [4.9454e-01,        nan, 0.0000e+00],
        [4.6939e-01,        nan, 0.0000e+00],
        [4.7634e-01,        nan, 0.0000e+00],
        [4.8649e-01,        nan, 0.0000e+00],
        [4.7903e-01,        nan, 0.0000e+00],
        [4.5241e-01,        nan, 0.0000e+00],
        [4.7788e-01,        nan, 0

tensor([[4.8851e-01,        nan, 0.0000e+00],
        [4.8269e-01,        nan, 0.0000e+00],
        [4.3843e-01,        nan, 0.0000e+00],
        [5.1347e-01,        nan, 0.0000e+00],
        [4.5683e-01,        nan, 0.0000e+00],
        [4.9524e-01,        nan, 0.0000e+00],
        [4.6534e-01,        nan, 0.0000e+00],
        [4.7664e-01,        nan, 0.0000e+00],
        [4.9220e-01,        nan, 3.4400e+03],
        [4.9115e-01,        nan, 0.0000e+00],
        [4.8850e-01,        nan, 0.0000e+00],
        [4.7856e-01,        nan, 3.4400e+03],
        [4.5905e-01,        nan, 0.0000e+00],
        [5.0640e-01,        nan, 0.0000e+00],
        [4.7031e-01,        nan, 0.0000e+00],
        [4.8106e-01,        nan, 0.0000e+00],
        [4.9342e-01,        nan, 3.4400e+03],
        [4.7597e-01,        nan, 0.0000e+00],
        [5.3570e-01,        nan, 0.0000e+00],
        [4.3198e-01,        nan, 0.0000e+00],
        [5.3667e-01,        nan, 0.0000e+00],
        [4.7007e-01,        nan, 0

tensor([[5.2137e-01,        nan, 0.0000e+00],
        [4.7250e-01,        nan, 0.0000e+00],
        [4.8579e-01,        nan, 0.0000e+00],
        [4.5234e-01,        nan, 0.0000e+00],
        [4.5248e-01,        nan, 0.0000e+00],
        [4.2700e-01,        nan, 0.0000e+00],
        [4.7250e-01,        nan, 0.0000e+00],
        [4.0890e-01,        nan, 0.0000e+00],
        [4.7619e-01,        nan, 3.4400e+03],
        [4.7771e-01,        nan, 0.0000e+00],
        [4.6721e-01,        nan, 0.0000e+00],
        [4.7861e-01,        nan, 0.0000e+00],
        [4.7634e-01,        nan, 0.0000e+00],
        [4.4669e-01,        nan, 0.0000e+00],
        [5.0524e-01,        nan, 0.0000e+00],
        [4.1470e-01,        nan, 0.0000e+00],
        [4.7503e-01,        nan, 0.0000e+00],
        [4.6741e-01,        nan, 3.4400e+03],
        [4.5221e-01,        nan, 3.4400e+03],
        [4.8236e-01,        nan, 0.0000e+00],
        [4.6348e-01,        nan, 3.4400e+03],
        [4.6430e-01,        nan, 0

tensor([[4.8935e-01,        nan, 0.0000e+00],
        [4.6520e-01,        nan, 3.4400e+03],
        [5.0384e-01,        nan, 3.4400e+03],
        [5.1273e-01,        nan, 3.4400e+03],
        [4.7539e-01,        nan, 3.4400e+03],
        [4.5265e-01,        nan, 3.4400e+03],
        [5.1390e-01,        nan, 3.4400e+03],
        [4.6517e-01,        nan, 0.0000e+00],
        [4.9604e-01,        nan, 3.4400e+03],
        [4.8072e-01,        nan, 0.0000e+00],
        [5.2379e-01,        nan, 0.0000e+00],
        [4.5317e-01,        nan, 0.0000e+00],
        [4.6952e-01,        nan, 0.0000e+00],
        [4.5608e-01,        nan, 0.0000e+00],
        [4.9311e-01,        nan, 0.0000e+00],
        [4.6885e-01,        nan, 3.4400e+03],
        [4.6693e-01,        nan, 0.0000e+00],
        [4.6401e-01,        nan, 0.0000e+00],
        [4.9628e-01,        nan, 0.0000e+00],
        [4.8634e-01,        nan, 0.0000e+00],
        [4.7133e-01,        nan, 0.0000e+00],
        [5.2960e-01,        nan, 3

tensor([[4.3362e-01,        nan, 3.4400e+03],
        [4.7499e-01,        nan, 0.0000e+00],
        [4.6516e-01,        nan, 0.0000e+00],
        [5.3720e-01,        nan, 0.0000e+00],
        [4.3031e-01,        nan, 0.0000e+00],
        [5.3808e-01,        nan, 0.0000e+00],
        [4.5840e-01,        nan, 0.0000e+00],
        [4.4709e-01,        nan, 0.0000e+00],
        [5.2635e-01,        nan, 0.0000e+00],
        [4.8423e-01,        nan, 0.0000e+00],
        [4.9142e-01,        nan, 0.0000e+00],
        [4.7423e-01,        nan, 3.4400e+03],
        [4.5600e-01,        nan, 3.4400e+03],
        [4.6099e-01,        nan, 0.0000e+00],
        [4.8703e-01,        nan, 0.0000e+00],
        [5.3216e-01,        nan, 0.0000e+00],
        [4.7969e-01,        nan, 0.0000e+00],
        [4.7953e-01,        nan, 0.0000e+00],
        [5.1706e-01,        nan, 0.0000e+00],
        [4.4207e-01,        nan, 0.0000e+00],
        [5.5415e-01,        nan, 3.4400e+03],
        [4.6443e-01,        nan, 0

tensor([[4.5254e-01,        nan, 0.0000e+00],
        [4.4420e-01,        nan, 0.0000e+00],
        [5.0046e-01,        nan, 0.0000e+00],
        [4.8736e-01,        nan, 0.0000e+00],
        [4.8716e-01,        nan, 0.0000e+00],
        [4.8807e-01,        nan, 0.0000e+00],
        [4.6546e-01,        nan, 3.4400e+03],
        [4.6136e-01,        nan, 0.0000e+00],
        [5.0007e-01,        nan, 0.0000e+00],
        [4.8674e-01,        nan, 0.0000e+00],
        [5.3300e-01,        nan, 0.0000e+00],
        [4.7275e-01,        nan, 3.4400e+03],
        [4.4495e-01,        nan, 0.0000e+00],
        [5.0577e-01,        nan, 0.0000e+00],
        [4.7702e-01,        nan, 0.0000e+00],
        [5.3162e-01,        nan, 3.4400e+03],
        [4.7129e-01,        nan, 0.0000e+00],
        [4.7828e-01,        nan, 0.0000e+00],
        [4.4721e-01,        nan, 0.0000e+00],
        [5.0061e-01,        nan, 0.0000e+00],
        [4.8559e-01,        nan, 0.0000e+00],
        [4.4101e-01,        nan, 0

tensor([[5.4696e-01,        nan, 0.0000e+00],
        [4.7883e-01,        nan, 0.0000e+00],
        [4.5814e-01,        nan, 0.0000e+00],
        [4.6326e-01,        nan, 0.0000e+00],
        [5.0562e-01,        nan, 0.0000e+00],
        [4.5476e-01,        nan, 0.0000e+00],
        [4.4089e-01,        nan, 0.0000e+00],
        [4.7708e-01,        nan, 0.0000e+00],
        [5.4460e-01,        nan, 0.0000e+00],
        [4.5613e-01,        nan, 0.0000e+00],
        [5.0820e-01,        nan, 0.0000e+00],
        [4.5845e-01,        nan, 0.0000e+00],
        [4.7340e-01,        nan, 0.0000e+00],
        [4.6682e-01,        nan, 3.4400e+03],
        [4.6502e-01,        nan, 0.0000e+00],
        [4.9184e-01,        nan, 0.0000e+00],
        [5.1262e-01,        nan, 0.0000e+00],
        [5.2836e-01,        nan, 0.0000e+00],
        [4.7914e-01,        nan, 0.0000e+00],
        [4.9429e-01,        nan, 3.4400e+03],
        [4.5975e-01,        nan, 0.0000e+00],
        [4.9163e-01,        nan, 0

tensor([[4.7724e-01,        nan, 0.0000e+00],
        [4.5501e-01,        nan, 0.0000e+00],
        [4.7352e-01,        nan, 0.0000e+00],
        [5.0351e-01,        nan, 0.0000e+00],
        [5.1418e-01,        nan, 3.4400e+03],
        [4.9844e-01,        nan, 0.0000e+00],
        [4.7064e-01,        nan, 0.0000e+00],
        [4.8276e-01,        nan, 0.0000e+00],
        [4.5476e-01,        nan, 0.0000e+00],
        [4.7580e-01,        nan, 0.0000e+00],
        [4.5398e-01,        nan, 0.0000e+00],
        [5.1332e-01,        nan, 0.0000e+00],
        [4.4914e-01,        nan, 0.0000e+00],
        [5.0381e-01,        nan, 0.0000e+00],
        [5.2411e-01,        nan, 0.0000e+00],
        [4.7172e-01,        nan, 0.0000e+00],
        [4.7825e-01,        nan, 0.0000e+00],
        [4.8845e-01,        nan, 0.0000e+00],
        [4.9762e-01,        nan, 3.4400e+03],
        [4.6274e-01,        nan, 0.0000e+00],
        [4.8446e-01,        nan, 0.0000e+00],
        [5.0329e-01,        nan, 0

tensor([[4.2940e-01,        nan, 0.0000e+00],
        [4.6483e-01,        nan, 0.0000e+00],
        [4.6188e-01,        nan, 0.0000e+00],
        [4.5394e-01,        nan, 0.0000e+00],
        [4.6441e-01,        nan, 0.0000e+00],
        [5.2596e-01,        nan, 0.0000e+00],
        [4.8576e-01,        nan, 0.0000e+00],
        [5.1060e-01,        nan, 0.0000e+00],
        [4.5216e-01,        nan, 3.4400e+03],
        [4.8681e-01,        nan, 0.0000e+00],
        [4.7116e-01,        nan, 3.4400e+03],
        [4.6673e-01,        nan, 0.0000e+00],
        [4.7228e-01,        nan, 0.0000e+00],
        [4.5353e-01,        nan, 0.0000e+00],
        [4.8763e-01,        nan, 0.0000e+00],
        [4.8476e-01,        nan, 0.0000e+00],
        [4.8412e-01,        nan, 0.0000e+00],
        [4.6648e-01,        nan, 3.4400e+03],
        [4.6077e-01,        nan, 0.0000e+00],
        [4.7572e-01,        nan, 0.0000e+00],
        [4.5450e-01,        nan, 0.0000e+00],
        [4.7219e-01,        nan, 3

tensor([[4.8625e-01,        nan, 0.0000e+00],
        [5.2975e-01,        nan, 3.4400e+03],
        [4.6168e-01,        nan, 3.4400e+03],
        [4.7985e-01,        nan, 0.0000e+00],
        [4.3243e-01,        nan, 3.4400e+03],
        [4.9150e-01,        nan, 0.0000e+00],
        [4.8548e-01,        nan, 3.4400e+03],
        [4.3860e-01,        nan, 0.0000e+00],
        [4.1106e-01,        nan, 0.0000e+00],
        [4.9941e-01,        nan, 0.0000e+00],
        [4.5382e-01,        nan, 0.0000e+00],
        [4.6986e-01,        nan, 0.0000e+00],
        [4.3495e-01,        nan, 3.4400e+03],
        [5.0965e-01,        nan, 0.0000e+00],
        [4.5391e-01,        nan, 0.0000e+00],
        [5.3012e-01,        nan, 0.0000e+00],
        [5.1772e-01,        nan, 0.0000e+00],
        [4.7702e-01,        nan, 0.0000e+00],
        [5.1281e-01,        nan, 0.0000e+00],
        [4.4883e-01,        nan, 0.0000e+00],
        [4.9195e-01,        nan, 3.4400e+03],
        [4.6604e-01,        nan, 0

tensor([[4.9621e-01,        nan, 0.0000e+00],
        [4.1782e-01,        nan, 0.0000e+00],
        [5.2006e-01,        nan, 0.0000e+00],
        [4.8680e-01,        nan, 0.0000e+00],
        [5.1770e-01,        nan, 0.0000e+00],
        [4.7580e-01,        nan, 0.0000e+00],
        [4.6394e-01,        nan, 0.0000e+00],
        [4.6543e-01,        nan, 0.0000e+00],
        [4.8382e-01,        nan, 0.0000e+00],
        [5.0184e-01,        nan, 0.0000e+00],
        [4.8368e-01,        nan, 0.0000e+00],
        [4.3825e-01,        nan, 0.0000e+00],
        [4.8205e-01,        nan, 0.0000e+00],
        [4.6019e-01,        nan, 3.4400e+03],
        [5.1578e-01,        nan, 3.4400e+03],
        [4.8060e-01,        nan, 0.0000e+00],
        [5.0168e-01,        nan, 0.0000e+00],
        [4.7353e-01,        nan, 0.0000e+00],
        [4.6491e-01,        nan, 3.4400e+03],
        [4.9305e-01,        nan, 0.0000e+00],
        [4.9431e-01,        nan, 0.0000e+00],
        [4.8344e-01,        nan, 0

tensor([[4.4545e-01,        nan, 3.4400e+03],
        [4.5907e-01,        nan, 3.4400e+03],
        [4.5905e-01,        nan, 0.0000e+00],
        [4.7522e-01,        nan, 0.0000e+00],
        [5.6484e-01,        nan, 0.0000e+00],
        [4.5958e-01,        nan, 3.4400e+03],
        [4.5223e-01,        nan, 0.0000e+00],
        [4.6656e-01,        nan, 0.0000e+00],
        [4.8210e-01,        nan, 0.0000e+00],
        [5.1793e-01,        nan, 3.4400e+03],
        [4.5325e-01,        nan, 0.0000e+00],
        [4.6503e-01,        nan, 0.0000e+00],
        [4.5641e-01,        nan, 0.0000e+00],
        [4.9909e-01,        nan, 0.0000e+00],
        [4.6168e-01,        nan, 3.4400e+03],
        [4.4837e-01,        nan, 0.0000e+00],
        [5.1080e-01,        nan, 0.0000e+00],
        [5.0147e-01,        nan, 0.0000e+00],
        [4.2944e-01,        nan, 0.0000e+00],
        [5.0844e-01,        nan, 3.4400e+03],
        [4.6752e-01,        nan, 0.0000e+00],
        [4.3753e-01,        nan, 0

tensor([[4.3156e-01,        nan, 0.0000e+00],
        [4.5768e-01,        nan, 0.0000e+00],
        [4.8384e-01,        nan, 0.0000e+00],
        [5.0772e-01,        nan, 0.0000e+00],
        [4.8529e-01,        nan, 0.0000e+00],
        [4.9306e-01,        nan, 0.0000e+00],
        [4.7419e-01,        nan, 0.0000e+00],
        [5.1826e-01,        nan, 0.0000e+00],
        [4.5954e-01,        nan, 0.0000e+00],
        [4.7480e-01,        nan, 0.0000e+00],
        [4.5995e-01,        nan, 0.0000e+00],
        [5.1440e-01,        nan, 3.4400e+03],
        [5.3127e-01,        nan, 0.0000e+00],
        [4.8074e-01,        nan, 0.0000e+00],
        [4.5830e-01,        nan, 0.0000e+00],
        [4.5435e-01,        nan, 0.0000e+00],
        [4.7432e-01,        nan, 0.0000e+00],
        [4.8962e-01,        nan, 0.0000e+00],
        [4.3607e-01,        nan, 0.0000e+00],
        [4.5919e-01,        nan, 0.0000e+00],
        [5.2187e-01,        nan, 0.0000e+00],
        [4.8124e-01,        nan, 0

tensor([[4.8242e-01,        nan, 0.0000e+00],
        [4.8970e-01,        nan, 3.4400e+03],
        [5.0022e-01,        nan, 0.0000e+00],
        [4.4887e-01,        nan, 0.0000e+00],
        [4.8443e-01,        nan, 0.0000e+00],
        [4.6692e-01,        nan, 0.0000e+00],
        [5.0580e-01,        nan, 0.0000e+00],
        [4.6223e-01,        nan, 0.0000e+00],
        [5.1504e-01,        nan, 0.0000e+00],
        [5.0916e-01,        nan, 0.0000e+00],
        [4.6587e-01,        nan, 0.0000e+00],
        [4.5287e-01,        nan, 0.0000e+00],
        [4.2768e-01,        nan, 3.4400e+03],
        [4.9721e-01,        nan, 0.0000e+00],
        [4.7870e-01,        nan, 0.0000e+00],
        [4.6604e-01,        nan, 0.0000e+00],
        [4.8911e-01,        nan, 3.4400e+03],
        [4.4545e-01,        nan, 0.0000e+00],
        [5.0220e-01,        nan, 0.0000e+00],
        [4.6696e-01,        nan, 0.0000e+00],
        [4.4334e-01,        nan, 0.0000e+00],
        [4.3030e-01,        nan, 0

tensor([[4.7467e-01,        nan, 0.0000e+00],
        [4.8087e-01,        nan, 0.0000e+00],
        [4.8542e-01,        nan, 0.0000e+00],
        [4.8239e-01,        nan, 0.0000e+00],
        [4.6273e-01,        nan, 0.0000e+00],
        [4.6813e-01,        nan, 0.0000e+00],
        [4.5453e-01,        nan, 0.0000e+00],
        [4.3564e-01,        nan, 0.0000e+00],
        [4.6298e-01,        nan, 0.0000e+00],
        [4.2641e-01,        nan, 0.0000e+00],
        [4.7849e-01,        nan, 0.0000e+00],
        [5.0619e-01,        nan, 0.0000e+00],
        [4.1651e-01,        nan, 0.0000e+00],
        [4.8854e-01,        nan, 0.0000e+00],
        [5.1790e-01,        nan, 0.0000e+00],
        [4.6910e-01,        nan, 3.4400e+03],
        [4.9171e-01,        nan, 0.0000e+00],
        [4.5866e-01,        nan, 0.0000e+00],
        [4.5071e-01,        nan, 0.0000e+00],
        [4.9252e-01,        nan, 0.0000e+00],
        [4.7938e-01,        nan, 0.0000e+00],
        [4.5462e-01,        nan, 0

tensor([[5.0524e-01,        nan, 0.0000e+00],
        [4.7912e-01,        nan, 0.0000e+00],
        [4.7147e-01,        nan, 0.0000e+00],
        [4.7521e-01,        nan, 0.0000e+00],
        [4.6750e-01,        nan, 0.0000e+00],
        [4.5902e-01,        nan, 0.0000e+00],
        [5.0378e-01,        nan, 0.0000e+00],
        [4.5501e-01,        nan, 0.0000e+00],
        [4.7863e-01,        nan, 0.0000e+00],
        [4.6826e-01,        nan, 0.0000e+00],
        [5.0962e-01,        nan, 0.0000e+00],
        [4.6786e-01,        nan, 0.0000e+00],
        [4.4625e-01,        nan, 0.0000e+00],
        [4.8290e-01,        nan, 3.4400e+03],
        [5.0388e-01,        nan, 0.0000e+00],
        [4.9353e-01,        nan, 0.0000e+00],
        [4.5004e-01,        nan, 0.0000e+00],
        [4.4159e-01,        nan, 0.0000e+00],
        [4.7652e-01,        nan, 0.0000e+00],
        [4.5735e-01,        nan, 0.0000e+00],
        [4.6911e-01,        nan, 0.0000e+00],
        [4.8864e-01,        nan, 3

tensor([[4.5033e-01,        nan, 3.4400e+03],
        [4.6183e-01,        nan, 0.0000e+00],
        [4.9524e-01,        nan, 0.0000e+00],
        [4.7521e-01,        nan, 0.0000e+00],
        [4.7445e-01,        nan, 3.4400e+03],
        [4.7924e-01,        nan, 3.4400e+03],
        [5.2411e-01,        nan, 0.0000e+00],
        [4.9307e-01,        nan, 0.0000e+00],
        [4.4399e-01,        nan, 3.4400e+03],
        [4.8425e-01,        nan, 0.0000e+00],
        [4.5091e-01,        nan, 3.4400e+03],
        [4.4386e-01,        nan, 0.0000e+00],
        [5.2269e-01,        nan, 0.0000e+00],
        [4.5923e-01,        nan, 0.0000e+00],
        [4.7349e-01,        nan, 0.0000e+00],
        [4.9427e-01,        nan, 0.0000e+00],
        [4.7768e-01,        nan, 0.0000e+00],
        [5.3014e-01,        nan, 0.0000e+00],
        [4.8210e-01,        nan, 0.0000e+00],
        [4.9535e-01,        nan, 0.0000e+00],
        [4.6953e-01,        nan, 0.0000e+00],
        [4.8961e-01,        nan, 0

tensor([[4.7557e-01,        nan, 0.0000e+00],
        [5.0056e-01,        nan, 0.0000e+00],
        [4.7155e-01,        nan, 0.0000e+00],
        [4.6921e-01,        nan, 0.0000e+00],
        [4.6881e-01,        nan, 0.0000e+00],
        [4.6836e-01,        nan, 0.0000e+00],
        [4.6119e-01,        nan, 0.0000e+00],
        [4.9285e-01,        nan, 0.0000e+00],
        [4.7572e-01,        nan, 0.0000e+00],
        [5.1365e-01,        nan, 3.4400e+03],
        [5.0381e-01,        nan, 0.0000e+00],
        [4.8797e-01,        nan, 0.0000e+00],
        [4.7299e-01,        nan, 0.0000e+00],
        [5.0388e-01,        nan, 0.0000e+00],
        [5.0398e-01,        nan, 0.0000e+00],
        [4.6793e-01,        nan, 0.0000e+00],
        [4.6200e-01,        nan, 0.0000e+00],
        [5.0708e-01,        nan, 0.0000e+00],
        [5.2577e-01,        nan, 0.0000e+00],
        [5.0748e-01,        nan, 0.0000e+00],
        [5.0212e-01,        nan, 0.0000e+00],
        [5.0005e-01,        nan, 0

tensor([[4.7595e-01,        nan, 0.0000e+00],
        [4.8911e-01,        nan, 0.0000e+00],
        [4.9266e-01,        nan, 3.4400e+03],
        [4.8361e-01,        nan, 0.0000e+00],
        [4.9932e-01,        nan, 0.0000e+00],
        [4.5768e-01,        nan, 3.4400e+03],
        [5.2714e-01,        nan, 0.0000e+00],
        [4.3254e-01,        nan, 3.4400e+03],
        [4.8197e-01,        nan, 3.4400e+03],
        [4.9097e-01,        nan, 0.0000e+00],
        [4.4269e-01,        nan, 3.4400e+03],
        [4.6907e-01,        nan, 0.0000e+00],
        [4.2983e-01,        nan, 0.0000e+00],
        [4.5844e-01,        nan, 0.0000e+00],
        [4.6952e-01,        nan, 0.0000e+00],
        [4.8298e-01,        nan, 0.0000e+00],
        [4.7625e-01,        nan, 3.4400e+03],
        [4.4920e-01,        nan, 0.0000e+00],
        [5.6033e-01,        nan, 0.0000e+00],
        [4.7635e-01,        nan, 0.0000e+00],
        [4.6422e-01,        nan, 0.0000e+00],
        [4.3395e-01,        nan, 0

tensor([[5.3213e-01,        nan, 0.0000e+00],
        [4.7432e-01,        nan, 0.0000e+00],
        [5.1674e-01,        nan, 0.0000e+00],
        [4.5054e-01,        nan, 0.0000e+00],
        [4.6618e-01,        nan, 3.4400e+03],
        [5.0803e-01,        nan, 3.4400e+03],
        [4.8878e-01,        nan, 0.0000e+00],
        [5.0079e-01,        nan, 0.0000e+00],
        [4.7319e-01,        nan, 0.0000e+00],
        [4.9370e-01,        nan, 0.0000e+00],
        [4.9682e-01,        nan, 3.4400e+03],
        [4.5384e-01,        nan, 3.4400e+03],
        [4.1247e-01,        nan, 0.0000e+00],
        [5.0057e-01,        nan, 0.0000e+00],
        [4.9140e-01,        nan, 0.0000e+00],
        [4.5286e-01,        nan, 0.0000e+00],
        [4.9586e-01,        nan, 0.0000e+00],
        [4.6796e-01,        nan, 3.4400e+03],
        [4.7978e-01,        nan, 0.0000e+00],
        [5.1814e-01,        nan, 3.4400e+03],
        [4.5479e-01,        nan, 0.0000e+00],
        [4.4279e-01,        nan, 0

tensor([[4.7255e-01,        nan, 0.0000e+00],
        [4.4892e-01,        nan, 0.0000e+00],
        [4.4144e-01,        nan, 0.0000e+00],
        [4.9234e-01,        nan, 0.0000e+00],
        [4.8527e-01,        nan, 0.0000e+00],
        [4.5245e-01,        nan, 0.0000e+00],
        [4.7720e-01,        nan, 0.0000e+00],
        [4.6997e-01,        nan, 3.4400e+03],
        [4.7443e-01,        nan, 0.0000e+00],
        [4.8408e-01,        nan, 0.0000e+00],
        [4.4946e-01,        nan, 0.0000e+00],
        [4.7713e-01,        nan, 0.0000e+00],
        [4.8146e-01,        nan, 0.0000e+00],
        [4.8345e-01,        nan, 0.0000e+00],
        [4.8665e-01,        nan, 0.0000e+00],
        [5.3577e-01,        nan, 0.0000e+00],
        [5.0181e-01,        nan, 0.0000e+00],
        [4.9808e-01,        nan, 0.0000e+00],
        [5.2444e-01,        nan, 0.0000e+00],
        [4.8579e-01,        nan, 0.0000e+00],
        [5.1525e-01,        nan, 0.0000e+00],
        [5.5351e-01,        nan, 0

tensor([[5.0753e-01,        nan, 0.0000e+00],
        [4.8956e-01,        nan, 0.0000e+00],
        [4.6969e-01,        nan, 0.0000e+00],
        [4.8276e-01,        nan, 0.0000e+00],
        [4.4936e-01,        nan, 0.0000e+00],
        [4.8706e-01,        nan, 0.0000e+00],
        [5.1456e-01,        nan, 0.0000e+00],
        [4.7191e-01,        nan, 0.0000e+00],
        [4.7458e-01,        nan, 0.0000e+00],
        [5.1843e-01,        nan, 0.0000e+00],
        [4.7735e-01,        nan, 0.0000e+00],
        [5.2906e-01,        nan, 0.0000e+00],
        [4.6494e-01,        nan, 3.4400e+03],
        [5.0694e-01,        nan, 0.0000e+00],
        [5.3862e-01,        nan, 0.0000e+00],
        [4.6536e-01,        nan, 0.0000e+00],
        [5.1567e-01,        nan, 3.4400e+03],
        [5.0064e-01,        nan, 0.0000e+00],
        [4.8522e-01,        nan, 0.0000e+00],
        [4.6024e-01,        nan, 3.4400e+03],
        [4.9908e-01,        nan, 0.0000e+00],
        [4.3759e-01,        nan, 0

KeyboardInterrupt: 

In [27]:
repair_specific_local_model(errors, next_cnn_parameterss, next_cnn_modelss, next_output_models, threshold_models, 19)

Begin Cluster: 19
Input Model Size: 0
Group of parameters: 3
8 10 1 3 10 0
78 8
size: 970
Testing: Iteration 0, Loss 4.496146740516027, MSE_error 78755.40997721354, Q_error_mean 4.496146721127854, Q_error_max 82.85245513916016


KeyboardInterrupt: 

## Final Training

In [37]:
threshold_models = [Threshold_Model() for _ in range(100)]
for idx in range(100):
    print (idx)
    test = test_loaders[idx]
    train = train_loaders[idx]
    paras = [{"params": model.parameters()} for model in next_cnn_modelss[idx]]
    paras.append({"params": threshold_models[idx].parameters()})
    paras.append({"params": next_output_models[idx].parameters()})
    opt = optim.Adam(paras, lr=0.001)
    episode = 5
    train_and_test(next_cnn_modelss[idx], threshold_models[idx], next_output_models[idx], opt, train, test, episode)
    only_test(next_cnn_modelss[idx], threshold_models[idx], next_output_models[idx], test)

0
size: 1131
Testing: Iteration 0, Loss 4.475145844802314, MSE_error 90524.63895685054, Q_error_mean 4.475145870479033, Q_error_max 77.77468872070312
Testing: Iteration 1, Loss 3.8405503522458875, MSE_error 467586.64640791813, Q_error_mean 3.840550355745948, Q_error_max 73.310546875
Testing: Iteration 2, Loss 3.7021278079294224, MSE_error 450876.8805604982, Q_error_mean 3.702127858905608, Q_error_max 54.42537307739258
Testing: Iteration 3, Loss 3.690968884267841, MSE_error 1058717.502780249, Q_error_mean 3.6909688751230694, Q_error_max 62.4102783203125
Testing: Iteration 4, Loss 3.608779061307262, MSE_error 472581.2518905694, Q_error_mean 3.6087791034206806, Q_error_max 49.42038345336914
Testing: Mean Error 3.608209824788999, Median Error 2.72385835647583, 90 Percent 6.31022539138794, 95 Percent 9.353417968749994, 99 Percent 18.789539432525636, Max Percent 49.42038345336914
1
size: 1034
Testing: Iteration 0, Loss 4.649038359005019, MSE_error 158538.25171935328, Q_error_mean 4.649038231

Testing: Iteration 1, Loss 4.22055240352778, MSE_error 18716.91684549356, Q_error_mean 4.22055237036547, Q_error_max 44.34416961669922
Testing: Iteration 2, Loss 2.5998087666065395, MSE_error 12324.254220594152, Q_error_mean 2.59980873358115, Q_error_max 32.569149017333984
Testing: Iteration 3, Loss 2.2436658379346004, MSE_error 74374.28091889083, Q_error_mean 2.2436658364688427, Q_error_max 27.04608726501465
Testing: Iteration 4, Loss 2.182767124646723, MSE_error 125227.45295315853, Q_error_mean 2.1827671488490403, Q_error_max 23.38239288330078
Testing: Mean Error 2.182151433046747, Median Error 1.607520341873169, 90 Percent 3.936448574066164, 95 Percent 5.764556646347042, 99 Percent 9.272054481506352, Max Percent 23.38239288330078
10
size: 990
Testing: Iteration 0, Loss 5.125562102205841, MSE_error 79752.49077998482, Q_error_mean 5.125562078603323, Q_error_max 113.33219909667969
Testing: Iteration 1, Loss 4.460556915414478, MSE_error 107657.94341472672, Q_error_mean 4.460556937217888

Testing: Iteration 3, Loss 5.483208049790131, MSE_error 809501.6715481172, Q_error_mean 5.483207995040418, Q_error_max 142.56082153320312
Testing: Iteration 4, Loss 5.3589942555048475, MSE_error 2011061.5249738493, Q_error_mean 5.358994330720027, Q_error_max 154.2837677001953
Testing: Mean Error 5.358341019181995, Median Error 2.86792528629303, 90 Percent 8.108980464935302, 95 Percent 17.40299119949337, 99 Percent 55.784951019287114, Max Percent 154.2837677001953
19
size: 970
Testing: Iteration 0, Loss 5.276794320344925, MSE_error 90053.27141927084, Q_error_mean 5.276794288812808, Q_error_max 92.52643585205078
Testing: Iteration 1, Loss 3.942901588479678, MSE_error 99338.36484375, Q_error_mean 3.9429015817983517, Q_error_max 67.9149398803711
Testing: Iteration 2, Loss 3.622026730577151, MSE_error 225775.68536783854, Q_error_mean 3.622026759683036, Q_error_max 54.99528503417969
Testing: Iteration 3, Loss 3.546451868613561, MSE_error 285214.56559244794, Q_error_mean 3.5464518595655794, Q

Testing: Mean Error 6.385575846503595, Median Error 3.98740816116333, 90 Percent 11.251104736328134, 95 Percent 22.48454627990722, 99 Percent 51.18293159484858, Max Percent 88.0138168334961
28
size: 879
Testing: Iteration 0, Loss 4.573026110773129, MSE_error 34719.3490663537, Q_error_mean 4.573026092092691, Q_error_max 91.14608764648438
Testing: Iteration 1, Loss 3.6894224609494746, MSE_error 81396.94127049467, Q_error_mean 3.689422494780228, Q_error_max 77.85680389404297
Testing: Iteration 2, Loss 3.3252994094728887, MSE_error 139516.3047050168, Q_error_mean 3.325299444176802, Q_error_max 69.22696685791016
Testing: Iteration 3, Loss 3.1674517804731703, MSE_error 230693.54509704316, Q_error_mean 3.1674517800719366, Q_error_max 65.60953521728516
Testing: Iteration 4, Loss 3.0735336461943894, MSE_error 344455.2591350196, Q_error_mean 3.0735336141613954, Q_error_max 60.792545318603516
Testing: Mean Error 3.0745670112491297, Median Error 2.065812587738037, 90 Percent 6.155717086791995, 95 

Testing: Iteration 0, Loss 4.091165912535883, MSE_error 396803.9603830645, Q_error_mean 4.091165946069298, Q_error_max 140.535400390625
Testing: Iteration 1, Loss 3.9034468250889933, MSE_error 1205426.2531754032, Q_error_mean 3.9034468403652762, Q_error_max 98.1204605102539
Testing: Iteration 2, Loss 3.8841965083153016, MSE_error 657294.1019153226, Q_error_mean 3.884196497972765, Q_error_max 75.95145416259766
Testing: Iteration 3, Loss 3.886181701383283, MSE_error 1346157.6415322581, Q_error_mean 3.886181720720141, Q_error_max 88.1377944946289
Testing: Iteration 4, Loss 4.0018166326707405, MSE_error 469511.2227318548, Q_error_mean 4.001816635210787, Q_error_max 55.82062530517578
Testing: Mean Error 4.002333121314786, Median Error 3.144629955291748, 90 Percent 6.696932983398437, 95 Percent 9.383298492431598, 99 Percent 23.607375259399447, Max Percent 55.82062530517578
38
size: 1358
Testing: Iteration 0, Loss 3.4828744868496404, MSE_error 614091.9007325667, Q_error_mean 3.482874489305448

Testing: Iteration 1, Loss 4.32334310937636, MSE_error 24870.74341062036, Q_error_mean 4.323343156287515, Q_error_max 91.3995132446289
Testing: Iteration 2, Loss 3.7911304223655473, MSE_error 22091.774708965037, Q_error_mean 3.791130418153447, Q_error_max 75.62140655517578
Testing: Iteration 3, Loss 3.378408627934975, MSE_error 23343.930610883355, Q_error_mean 3.378408640352934, Q_error_max 66.12960815429688
Testing: Iteration 4, Loss 3.1309783942628613, MSE_error 38015.006574876235, Q_error_mean 3.1309783285092796, Q_error_max 58.830322265625
Testing: Mean Error 3.1309292239984474, Median Error 2.003519058227539, 90 Percent 6.1241876602172844, 95 Percent 9.391020965576162, 99 Percent 19.65645385742182, Max Percent 58.830322265625
47
size: 1041
Testing: Iteration 0, Loss 6.862230358197707, MSE_error 128930.60443919574, Q_error_mean 6.8622303996724225, Q_error_max 109.17193603515625
Testing: Iteration 1, Loss 5.471337554990783, MSE_error 192093.96245155038, Q_error_mean 5.47133765539865

Testing: Iteration 2, Loss 3.8006854779528876, MSE_error 193359.82450387772, Q_error_mean 3.800685480790747, Q_error_max 57.87752151489258
Testing: Iteration 3, Loss 3.684641328171222, MSE_error 331686.5082116788, Q_error_mean 3.6846413339734414, Q_error_max 60.39252471923828
Testing: Iteration 4, Loss 3.6323338564294967, MSE_error 355979.68481979927, Q_error_mean 3.6323338453934277, Q_error_max 56.43357849121094
Testing: Mean Error 3.63221642482851, Median Error 2.6155203580856323, 90 Percent 5.907627820968628, 95 Percent 9.283514976501465, 99 Percent 24.703142166137695, Max Percent 56.43357849121094
56
size: 1031
Testing: Iteration 0, Loss 4.19073320089602, MSE_error 265130.3939338235, Q_error_mean 4.190733176291427, Q_error_max 102.66426086425781
Testing: Iteration 1, Loss 4.069810059491326, MSE_error 405796.21875, Q_error_mean 4.069810090086478, Q_error_max 67.03651428222656
Testing: Iteration 2, Loss 4.2812121812035056, MSE_error 1845110.8337009803, Q_error_mean 4.281212232992638,

Testing: Iteration 3, Loss 3.342221773103519, MSE_error 383520.36223700497, Q_error_mean 3.342221779203297, Q_error_max 51.60331344604492
Testing: Iteration 4, Loss 3.506566577225235, MSE_error 857118.644234736, Q_error_mean 3.506566567500746, Q_error_max 71.49742889404297
Testing: Mean Error 3.506615615824254, Median Error 2.017192840576172, 90 Percent 6.072141456604008, 95 Percent 11.338158226013178, 99 Percent 31.612789993286142, Max Percent 71.49742889404297
65
size: 1370
Testing: Iteration 0, Loss 4.330545342912422, MSE_error 614396.9692082112, Q_error_mean 4.330545261809247, Q_error_max 269.57177734375
Testing: Iteration 1, Loss 4.123457194073809, MSE_error 1164386.310896261, Q_error_mean 4.12345721272921, Q_error_max 182.55905151367188
Testing: Iteration 2, Loss 4.176407470731092, MSE_error 834954.2006964809, Q_error_mean 4.17640752027199, Q_error_max 143.0494842529297
Testing: Iteration 3, Loss 4.041696094697522, MSE_error 1221498.8451246335, Q_error_mean 4.041696078806044, Q_e

Testing: Iteration 4, Loss 2.9526649944117813, MSE_error 180138.5049564549, Q_error_mean 2.952665015908773, Q_error_max 33.67550277709961
Testing: Mean Error 2.9528755310397923, Median Error 2.2280309200286865, 90 Percent 4.890291595458983, 95 Percent 7.235526847839354, 99 Percent 15.65044670104981, Max Percent 33.67550277709961
74
size: 1109
Testing: Iteration 0, Loss 4.711936951547429, MSE_error 118975.84287250906, Q_error_mean 4.71193687170609, Q_error_max 121.19062805175781
Testing: Iteration 1, Loss 4.3805105470228884, MSE_error 293331.63994565216, Q_error_mean 4.380510541226418, Q_error_max 115.99091339111328
Testing: Iteration 2, Loss 4.305288567059282, MSE_error 368837.90344769025, Q_error_mean 4.305288519650913, Q_error_max 113.01488494873047
Testing: Iteration 3, Loss 4.262809546097465, MSE_error 732912.5146059783, Q_error_mean 4.2628095255165865, Q_error_max 135.56251525878906
Testing: Iteration 4, Loss 4.199365541554879, MSE_error 554923.4172327898, Q_error_mean 4.199365524

Testing: Mean Error 3.789489862648798, Median Error 2.409791111946106, 90 Percent 6.17481384277344, 95 Percent 9.538255882263183, 99 Percent 33.86518314361571, Max Percent 72.33265686035156
83
size: 1374
Testing: Iteration 0, Loss 3.5491730187994994, MSE_error 258425.46978097508, Q_error_mean 3.5491730821103165, Q_error_max 116.48943328857422
Testing: Iteration 1, Loss 3.4482110791192375, MSE_error 410395.76236024563, Q_error_mean 3.448211083961569, Q_error_max 87.40058135986328
Testing: Iteration 2, Loss 3.410115006032927, MSE_error 332898.28908541054, Q_error_mean 3.4101149980306538, Q_error_max 54.8708610534668
Testing: Iteration 3, Loss 3.5349352919111503, MSE_error 913516.9986284093, Q_error_mean 3.5349353337906235, Q_error_max 71.03258514404297
Testing: Iteration 4, Loss 3.431066697643649, MSE_error 304973.6149880865, Q_error_mean 3.431066718277497, Q_error_max 51.65279769897461
Testing: Mean Error 3.4204707866386177, Median Error 2.5590708255767822, 90 Percent 5.668932151794436,

Testing: Iteration 0, Loss 4.023472505784684, MSE_error 32315.104161600193, Q_error_mean 4.023472499240059, Q_error_max 89.41334533691406
Testing: Iteration 1, Loss 3.4327632610899927, MSE_error 64391.59071771644, Q_error_mean 3.4327633249417584, Q_error_max 86.98357391357422
Testing: Iteration 2, Loss 3.256614314906792, MSE_error 100019.26681055447, Q_error_mean 3.2566142774147164, Q_error_max 67.57281494140625
Testing: Iteration 3, Loss 3.2082114683514904, MSE_error 117935.11654152481, Q_error_mean 3.2082114459200244, Q_error_max 66.70587921142578
Testing: Iteration 4, Loss 3.2088809486493064, MSE_error 109376.15117339494, Q_error_mean 3.2088809561216296, Q_error_max 55.702457427978516
Testing: Mean Error 3.1924290574798797, Median Error 2.279960036277771, 90 Percent 4.844652891159058, 95 Percent 7.611490249633779, 99 Percent 22.702716503143318, Max Percent 55.702457427978516
93
size: 1045
Testing: Iteration 0, Loss 6.425294947900367, MSE_error 111470.37490950772, Q_error_mean 6.4252

## Save The Model

In [38]:
for idx in range(100):
    states = {}
    for idd, cnn_model in enumerate(next_cnn_modelss[idx]):
        states['cnn_model_state_dict_' + str(idd)] = cnn_model.state_dict()
    states['threshold_model_state_dict'] = threshold_models[idx].state_dict()
    states['output_model_state_dict'] = next_output_models[idx].state_dict()
    torch.save(states, '/home/sunji/ANN/fashion_mnist_784_euclidean/saved_models/local_fashion_mnist_784_euclidean_cluster_' + str(idx) + '.model')

In [39]:
import pickle
with open('/home/sunji/ANN/fashion_mnist_784_euclidean/saved_models/cnn_hyper_parameters.hyperpara', 'w') as handle:
    for idx in range(100):
        handle.write(';'.join(str(x) for x in next_cnn_parameterss[idx]))
        handle.write('\n')

In [24]:
next_cnn_parameterss = hyper_parameterss
next_cnn_modelss = cnn_modelss
next_output_models = output_models
errors = [0.0 for _ in range(100)]

In [25]:
len(errors)

100

# Load Model

In [31]:
import pickle
hyper_parameterss = []
with open('/home/sunji/ANN/fashion_mnist_784_euclidean/saved_models/cnn_hyper_parameters.hyperpara', 'r') as handle:
    for paras in handle.readlines():
        hyper_parameters = []
        for para in paras.split(';'):
            para = para.split(' ')
            hyper_parameters.append(TunableParameters(int(para[0]), int(para[1]), int(para[2]),
                                                      int(para[3]), int(para[4]), int(para[5])))
        hyper_parameterss.append(hyper_parameters)

In [32]:
cnn_modelss = []
threshold_models = []
output_models = []
for idx in range(100):
    states = torch.load('/home/sunji/ANN/fashion_mnist_784_euclidean/saved_models/local_fashion_mnist_784_euclidean_cluster_' + str(idx) + '.model')
    hyper_para = hyper_parameterss[idx]
    cnn_models = []
    weights = [None for _ in range(len(hyper_para))]
    for key, value in states.items():
        if key != 'threshold_model_state_dict' and key != 'output_model_state_dict':
#             print (key)
            layer_id = int(key.split('_')[-1])
#             print (layer_id)
            weights[layer_id] = value
    in_channel = 1
    in_size = queries_dimension
    for weight_idx, weight in enumerate(weights):
        hyper = hyper_para[weight_idx]
        cnn_model = CNN_Model(in_channel, hyper.out_channel, hyper.kernel_size,
                              hyper.stride, hyper.padding, hyper.pool_type, hyper.pool_size)
        in_size = int((int((in_size - hyper.kernel_size + 2*(hyper.padding)) / hyper.stride) + 1) / hyper.pool_size)
        in_channel = hyper.out_channel
        cnn_model.load_state_dict(weight)
        cnn_model.eval()
        cnn_models.append(cnn_model)
    cnn_modelss.append(cnn_models)
        
    threshold_model_state_dict = states['threshold_model_state_dict']
    threshold_model = Threshold_Model()
    threshold_model.load_state_dict(threshold_model_state_dict)
    threshold_model.eval()
    threshold_models.append(threshold_model)
    
    output_model_state_dict = states['output_model_state_dict']
    output_model = Output_Model(in_size * in_channel)
    output_model.load_state_dict(output_model_state_dict)
    output_model.eval()
    output_models.append(output_model)

## Model Test

In [None]:
model = Model()
model.load_state_dict(torch.load('MDN_model'))
model.eval()