In [9]:
#%% ========== ========== ========== ========== ========== ========== ========== ==========
import matplotlib.pyplot as plt
import os
import pickle
import numpy as np
import pandas as pd
import seaborn as sns
import time
import sys
import copy
import platform
import json

In [28]:
#%%
import torch
from sklearn.linear_model import LogisticRegression
from evaluate_DC import distribution_calibration
from pipeline.free_lunch_util import normalize_l2, score_ens, get_task, \
                                     checkAacc, NC_o, redata, relabel

In [29]:
#%%
# ---- data loading
dataset = 'CUB' # 'miniImagenet' 'CUB' 'tieredImagenet' 'MultiDigitMNIST'
n_shot = 5 # 1 5
n_ways = 5
n_queries = 15
n_runs = 1000
n_lsamples = n_ways * n_shot
n_usamples = n_ways * n_queries
n_samples = n_lsamples + n_usamples
print('{}: {} way {} shot {} query'.format(dataset, n_ways, n_shot, n_queries))

import FSLTask
cfg = {'shot': n_shot, 'ways': n_ways, 'queries': n_queries}
feat_file_name = FSLTask.loadDataSet(dataset, 
                                     dgr='0', 
                                     mth='s2m2',
                                     img_size='84',
                                     flip=False,
                                     nv='novel')
FSLTask.setRandomStates(cfg)
ndatas, labset, idxset = FSLTask.GenerateRunSet(end=n_runs, cfg=cfg, 
                                                re_true_lab=True)
ndatas = ndatas.permute(0, 2, 1, 3).reshape(n_runs, n_samples, -1)
labels = torch.arange(n_ways).view(1, 1, n_ways).expand(n_runs, n_shot + n_queries, 5).clone().view(n_runs, n_samples)

print('- all task data:', ndatas.shape)
print('- all task label:', labels.shape)

JS = './filelists/'+dataset+'/novel.json' # backup/
with open(JS,"r") as f:
    JSlist = json.load(f)
super_lab = np.array(JSlist['label_names'])


CUB: 5 way 5 shot 15 query
>>> reading ./checkpoints/CUB/WideResNet28_10_S2M2_R/last/novel_features_s84_r0_f0.plk
all labels: [  3   3   3 ... 199 199 199]
min/max numbers per class over 50 classes: 48/60

Total of 50 classes, 60 elements each, with dimension 640

labels in order: [  3   7  11  15  19  23  27  31  35  39  43  47  51  55  59  63  67  71
  75  79  83  87  91  95  99 103 107 111 115 119 123 127 131 135 139 143
 147 151 155 159 163 167 171 175 179 183 187 191 195 199]
reloading random states from file....
generating task from 0 to 1000
- all task data: torch.Size([1000, 100, 640])
- all task label: torch.Size([1000, 100])


In [30]:
#%%
i_run       = 999
dataSlabelS = get_task(i_run, ndatas, labels, n_lsamples)
support_data, support_label, query_data, query_label = dataSlabelS

be = 0.5
l2 = True
acc_nc, nn_pre, dist = NC_o(support_data, support_label, query_data, query_label,
                            n_ways, n_shot, beta=be, l2=l2, use_mean=True, re_dist=True,
                            merge_query=0)
print(acc_nc[0])

0.8666666666666667


In [10]:
#%% ========== ========== ========== ========== ========== ========== ========== ==========

# ---------- ----------
#
# large-scale ensemble.
#
# ---------- ----------

from pipeline.free_lunch_util import Run

In [13]:
#%%

# -------------------
# 
# no distance-pattern
# 
# -------------------

dataset   = 'CUB' # miniImagenet CUB tieredImageNet
n_shot    = 5 # 1 5
n_ways    = 5
img_size  = 84
dgr       = '0'
flip      = False
transform = 'beta' # beta log
be        = 0.5
bias      = 0.0

if dataset == 'miniImagenet':
    n2w   = {1:7.0,2:5.5,3:4.5,4:4.0,5:4.0,6:3.5,7:3.5,8:3.0,9:3.0,10:3.0}
    idx2v = {1:0.75,2:0.08}
    lgbse = 0.02
elif dataset == 'CUB':
    n2w   = {1:7.0,2:6.0,3:3.5,4:3.0,5:3.0,6:3.0,7:3.0,8:2.0,9:2.5,10:2.0}
    idx2v = {1:1.0,2:1.0}
    lgbse = 0.5 if n_shot == 1 else 0.02
elif dataset == 'tieredImageNet':
    n2w   = {1:4.0,2:4.0,3:5.0,4:6.0,5:6.0,6:6.0,7:7.0,8:7.0,9:7.0,10:7.0}
    idx2v = {1:0.75,2:0.30}
    lgbse = 0.1 if n_shot == 1 else 0.02

pre = 'dgr_val' # dgr or dgr_val

In [14]:
#%%
run = Run(dataset, n_shot, n_ways, img_size, dgr, flip, transform, be, bias,
          k=0, w=0, pre=pre) # k=9, w=4
run.check()

> loading: ./dgr_val/5way5shot/CUB/s84_r0_f0_b0.50_l0.plk
90.37±0.26
90.37±0.26


In [15]:
#%% now do complete loading

img_sizes = [84, 90, 100, 110, 120, 130, 140, 150]
dgrs      = ['0', '90', '180', '270']
flips     = [False, True]
transs    = [('beta', 0.5, 0.0),    ('beta', 0.75,0.0), ('beta', 1.0, 0.0),
             ('log',  0.5, lgbse),  ('log',  0.5, 2*lgbse),
             ('log',  0.5, 3*lgbse),('log',  0.5, 4*lgbse), 
             ('log',  0.5, 5*lgbse)]


In [16]:
#%% read all
runSet = []

idx = 1
for img_s in img_sizes:
    for dg in dgrs:
        for fl in flips:
            for tran in transs:
                print(idx)
                run = Run(dataset, n_shot, n_ways,
                          img_s, dg, fl, tran[0], tran[1], tran[2],
                          check=True, pre=pre)
                runSet.append(run)                
                idx += 1


1
> loading: ./dgr_val/5way5shot/CUB/s84_r0_f0_b0.50_l0.plk
90.37±0.26
2
> loading: ./dgr_val/5way5shot/CUB/s84_r0_f0_b0.75_l0.plk
90.53±0.25
3
> loading: ./dgr_val/5way5shot/CUB/s84_r0_f0_b1.00_l0.plk
89.75±0.26
4
> loading: ./dgr_val/5way5shot/CUB/s84_r0_f0_b0_l0.02.plk
90.24±0.26
5
> loading: ./dgr_val/5way5shot/CUB/s84_r0_f0_b0_l0.04.plk
90.46±0.25
6
> loading: ./dgr_val/5way5shot/CUB/s84_r0_f0_b0_l0.06.plk
90.51±0.25
7
> loading: ./dgr_val/5way5shot/CUB/s84_r0_f0_b0_l0.08.plk
90.55±0.25
8
> loading: ./dgr_val/5way5shot/CUB/s84_r0_f0_b0_l0.10.plk
90.55±0.25
9
> loading: ./dgr_val/5way5shot/CUB/s84_r0_f1_b0.50_l0.plk
90.37±0.26
10
> loading: ./dgr_val/5way5shot/CUB/s84_r0_f1_b0.75_l0.plk
90.53±0.25
11
> loading: ./dgr_val/5way5shot/CUB/s84_r0_f1_b1.00_l0.plk
89.75±0.26
12
> loading: ./dgr_val/5way5shot/CUB/s84_r0_f1_b0_l0.02.plk
90.24±0.26
13
> loading: ./dgr_val/5way5shot/CUB/s84_r0_f1_b0_l0.04.plk
90.46±0.25
14
> loading: ./dgr_val/5way5shot/CUB/s84_r0_f1_b0_l0.06.plk
90.51±0.25
1

84.82±0.31
119
> loading: ./dgr_val/5way5shot/CUB/s90_r270_f0_b0_l0.08.plk
84.80±0.31
120
> loading: ./dgr_val/5way5shot/CUB/s90_r270_f0_b0_l0.10.plk
84.81±0.31
121
> loading: ./dgr_val/5way5shot/CUB/s90_r270_f1_b0.50_l0.plk
84.33±0.33
122
> loading: ./dgr_val/5way5shot/CUB/s90_r270_f1_b0.75_l0.plk
84.60±0.32
123
> loading: ./dgr_val/5way5shot/CUB/s90_r270_f1_b1.00_l0.plk
83.81±0.32
124
> loading: ./dgr_val/5way5shot/CUB/s90_r270_f1_b0_l0.02.plk
84.23±0.33
125
> loading: ./dgr_val/5way5shot/CUB/s90_r270_f1_b0_l0.04.plk
84.52±0.32
126
> loading: ./dgr_val/5way5shot/CUB/s90_r270_f1_b0_l0.06.plk
84.62±0.32
127
> loading: ./dgr_val/5way5shot/CUB/s90_r270_f1_b0_l0.08.plk
84.63±0.32
128
> loading: ./dgr_val/5way5shot/CUB/s90_r270_f1_b0_l0.10.plk
84.61±0.32
129
> loading: ./dgr_val/5way5shot/CUB/s100_r0_f0_b0.50_l0.plk
90.18±0.26
130
> loading: ./dgr_val/5way5shot/CUB/s100_r0_f0_b0.75_l0.plk
90.72±0.25
131
> loading: ./dgr_val/5way5shot/CUB/s100_r0_f0_b1.00_l0.plk
90.20±0.25
132
> loading: ./

89.79±0.26
237
> loading: ./dgr_val/5way5shot/CUB/s110_r180_f1_b0_l0.04.plk
90.24±0.26
238
> loading: ./dgr_val/5way5shot/CUB/s110_r180_f1_b0_l0.06.plk
90.41±0.26
239
> loading: ./dgr_val/5way5shot/CUB/s110_r180_f1_b0_l0.08.plk
90.50±0.26
240
> loading: ./dgr_val/5way5shot/CUB/s110_r180_f1_b0_l0.10.plk
90.53±0.26
241
> loading: ./dgr_val/5way5shot/CUB/s110_r270_f0_b0.50_l0.plk
84.39±0.32
242
> loading: ./dgr_val/5way5shot/CUB/s110_r270_f0_b0.75_l0.plk
84.80±0.32
243
> loading: ./dgr_val/5way5shot/CUB/s110_r270_f0_b1.00_l0.plk
84.13±0.32
244
> loading: ./dgr_val/5way5shot/CUB/s110_r270_f0_b0_l0.02.plk
84.30±0.32
245
> loading: ./dgr_val/5way5shot/CUB/s110_r270_f0_b0_l0.04.plk
84.70±0.32
246
> loading: ./dgr_val/5way5shot/CUB/s110_r270_f0_b0_l0.06.plk
84.82±0.32
247
> loading: ./dgr_val/5way5shot/CUB/s110_r270_f0_b0_l0.08.plk
84.84±0.32
248
> loading: ./dgr_val/5way5shot/CUB/s110_r270_f0_b0_l0.10.plk
84.84±0.32
249
> loading: ./dgr_val/5way5shot/CUB/s110_r270_f1_b0.50_l0.plk
84.13±0.33
2

84.20±0.32
351
> loading: ./dgr_val/5way5shot/CUB/s130_r90_f1_b0_l0.08.plk
84.25±0.32
352
> loading: ./dgr_val/5way5shot/CUB/s130_r90_f1_b0_l0.10.plk
84.25±0.32
353
> loading: ./dgr_val/5way5shot/CUB/s130_r180_f0_b0.50_l0.plk
84.22±0.33
354
> loading: ./dgr_val/5way5shot/CUB/s130_r180_f0_b0.75_l0.plk
84.59±0.33
355
> loading: ./dgr_val/5way5shot/CUB/s130_r180_f0_b1.00_l0.plk
83.97±0.34
356
> loading: ./dgr_val/5way5shot/CUB/s130_r180_f0_b0_l0.02.plk
84.26±0.33
357
> loading: ./dgr_val/5way5shot/CUB/s130_r180_f0_b0_l0.04.plk
84.60±0.33
358
> loading: ./dgr_val/5way5shot/CUB/s130_r180_f0_b0_l0.06.plk
84.68±0.33
359
> loading: ./dgr_val/5way5shot/CUB/s130_r180_f0_b0_l0.08.plk
84.67±0.33
360
> loading: ./dgr_val/5way5shot/CUB/s130_r180_f0_b0_l0.10.plk
84.66±0.33
361
> loading: ./dgr_val/5way5shot/CUB/s130_r180_f1_b0.50_l0.plk
88.80±0.28
362
> loading: ./dgr_val/5way5shot/CUB/s130_r180_f1_b0.75_l0.plk
89.52±0.27
363
> loading: ./dgr_val/5way5shot/CUB/s130_r180_f1_b1.00_l0.plk
89.22±0.27
364

87.42±0.28
460
> loading: ./dgr_val/5way5shot/CUB/s150_r0_f1_b0_l0.02.plk
87.01±0.30
461
> loading: ./dgr_val/5way5shot/CUB/s150_r0_f1_b0_l0.04.plk
87.50±0.29
462
> loading: ./dgr_val/5way5shot/CUB/s150_r0_f1_b0_l0.06.plk
87.70±0.29
463
> loading: ./dgr_val/5way5shot/CUB/s150_r0_f1_b0_l0.08.plk
87.84±0.28
464
> loading: ./dgr_val/5way5shot/CUB/s150_r0_f1_b0_l0.10.plk
87.87±0.28
465
> loading: ./dgr_val/5way5shot/CUB/s150_r90_f0_b0.50_l0.plk
82.44±0.34
466
> loading: ./dgr_val/5way5shot/CUB/s150_r90_f0_b0.75_l0.plk
82.77±0.33
467
> loading: ./dgr_val/5way5shot/CUB/s150_r90_f0_b1.00_l0.plk
81.90±0.34
468
> loading: ./dgr_val/5way5shot/CUB/s150_r90_f0_b0_l0.02.plk
82.51±0.34
469
> loading: ./dgr_val/5way5shot/CUB/s150_r90_f0_b0_l0.04.plk
82.85±0.33
470
> loading: ./dgr_val/5way5shot/CUB/s150_r90_f0_b0_l0.06.plk
82.91±0.33
471
> loading: ./dgr_val/5way5shot/CUB/s150_r90_f0_b0_l0.08.plk
82.91±0.33
472
> loading: ./dgr_val/5way5shot/CUB/s150_r90_f0_b0_l0.10.plk
82.89±0.33
473
> loading: ./dg

In [17]:
#%% check accs
m_s = []
m_w = []
for ru in runSet:
    print(ru.mean, end=' ')
    m_s.append(ru.mean)
    m_w.append(ru.width)

mi_, ma_ = min(m_s), max(m_s)
print('\nlen, min, max', len(m_s), mi_, ma_)
print('~~~~~ All mean ~~~~~')
print('%4.2f±%4.2f'%(np.mean(m_s), np.mean(m_w)))

90.37133333333333 90.528 89.748 90.23533333333334 90.46466666666667 90.50933333333334 90.554 90.55066666666666 90.37133333333333 90.528 89.748 90.23533333333334 90.46466666666667 90.50933333333334 90.554 90.55066666666666 84.83733333333333 84.90933333333334 84.03333333333335 84.73466666666666 84.982 85.01 84.976 84.97533333333334 84.69533333333334 84.74866666666665 83.92066666666668 84.57 84.80733333333333 84.82 84.77533333333335 84.76266666666666 85.03533333333334 85.368 84.742 84.984 85.26 85.342 85.35666666666667 85.374 90.39 90.54466666666666 89.81 90.26066666666665 90.49866666666665 90.54266666666666 90.576 90.544 84.26533333333333 84.50066666666666 83.672 84.17133333333334 84.45866666666667 84.50933333333333 84.53066666666668 84.51533333333334 84.308 84.412 83.54733333333334 84.17933333333335 84.45133333333334 84.47533333333334 84.472 84.42933333333335 90.23666666666666 90.59266666666666 89.95533333333333 90.11333333333334 90.434 90.552 90.6 90.61 90.23666666666666 90.59266666666

In [18]:
with open('tables/acc_val_cub.plk', 'wb') as handle:
    pickle.dump(m_s, handle, protocol=pickle.HIGHEST_PROTOCOL)


In [19]:
pre = 'dgr' # dgr or dgr_val

run = Run(dataset, n_shot, n_ways, img_size, dgr, flip, transform, be, bias,
          k=0, w=0, pre=pre) # k=9, w=4
run.check()

> loading: ./dgr/5way5shot/CUB/s84_r0_f0_b0.50_l0.plk
91.31±0.23
91.31±0.23


In [20]:
#%% read all
runSet = []

idx = 1
for img_s in img_sizes:
    for dg in dgrs:
        for fl in flips:
            for tran in transs:
                print(idx)
                run = Run(dataset, n_shot, n_ways,
                          img_s, dg, fl, tran[0], tran[1], tran[2],
                          check=True, pre=pre)
                runSet.append(run)                
                idx += 1


1
> loading: ./dgr/5way5shot/CUB/s84_r0_f0_b0.50_l0.plk
91.31±0.23
2
> loading: ./dgr/5way5shot/CUB/s84_r0_f0_b0.75_l0.plk
91.55±0.22
3
> loading: ./dgr/5way5shot/CUB/s84_r0_f0_b1.00_l0.plk
90.91±0.23
4
> loading: ./dgr/5way5shot/CUB/s84_r0_f0_b0_l0.02.plk
91.20±0.23
5
> loading: ./dgr/5way5shot/CUB/s84_r0_f0_b0_l0.04.plk
91.43±0.22
6
> loading: ./dgr/5way5shot/CUB/s84_r0_f0_b0_l0.06.plk
91.49±0.22
7
> loading: ./dgr/5way5shot/CUB/s84_r0_f0_b0_l0.08.plk
91.52±0.22
8
> loading: ./dgr/5way5shot/CUB/s84_r0_f0_b0_l0.10.plk
91.53±0.22
9
> loading: ./dgr/5way5shot/CUB/s84_r0_f1_b0.50_l0.plk
91.31±0.23
10
> loading: ./dgr/5way5shot/CUB/s84_r0_f1_b0.75_l0.plk
91.55±0.22
11
> loading: ./dgr/5way5shot/CUB/s84_r0_f1_b1.00_l0.plk
90.91±0.23
12
> loading: ./dgr/5way5shot/CUB/s84_r0_f1_b0_l0.02.plk
91.20±0.23
13
> loading: ./dgr/5way5shot/CUB/s84_r0_f1_b0_l0.04.plk
91.43±0.22
14
> loading: ./dgr/5way5shot/CUB/s84_r0_f1_b0_l0.06.plk
91.49±0.22
15
> loading: ./dgr/5way5shot/CUB/s84_r0_f1_b0_l0.08.plk


91.79±0.22
135
> loading: ./dgr/5way5shot/CUB/s100_r0_f0_b0_l0.08.plk
91.84±0.22
136
> loading: ./dgr/5way5shot/CUB/s100_r0_f0_b0_l0.10.plk
91.87±0.22
137
> loading: ./dgr/5way5shot/CUB/s100_r0_f1_b0.50_l0.plk
91.32±0.23
138
> loading: ./dgr/5way5shot/CUB/s100_r0_f1_b0.75_l0.plk
91.90±0.22
139
> loading: ./dgr/5way5shot/CUB/s100_r0_f1_b1.00_l0.plk
91.48±0.22
140
> loading: ./dgr/5way5shot/CUB/s100_r0_f1_b0_l0.02.plk
91.20±0.23
141
> loading: ./dgr/5way5shot/CUB/s100_r0_f1_b0_l0.04.plk
91.63±0.22
142
> loading: ./dgr/5way5shot/CUB/s100_r0_f1_b0_l0.06.plk
91.79±0.22
143
> loading: ./dgr/5way5shot/CUB/s100_r0_f1_b0_l0.08.plk
91.84±0.22
144
> loading: ./dgr/5way5shot/CUB/s100_r0_f1_b0_l0.10.plk
91.87±0.22
145
> loading: ./dgr/5way5shot/CUB/s100_r90_f0_b0.50_l0.plk
86.65±0.27
146
> loading: ./dgr/5way5shot/CUB/s100_r90_f0_b0.75_l0.plk
87.08±0.27
147
> loading: ./dgr/5way5shot/CUB/s100_r90_f0_b1.00_l0.plk
86.50±0.28
148
> loading: ./dgr/5way5shot/CUB/s100_r90_f0_b0_l0.02.plk
86.55±0.27
149
>

86.20±0.28
254
> loading: ./dgr/5way5shot/CUB/s110_r270_f1_b0_l0.06.plk
86.34±0.28
255
> loading: ./dgr/5way5shot/CUB/s110_r270_f1_b0_l0.08.plk
86.40±0.28
256
> loading: ./dgr/5way5shot/CUB/s110_r270_f1_b0_l0.10.plk
86.42±0.28
257
> loading: ./dgr/5way5shot/CUB/s120_r0_f0_b0.50_l0.plk
90.53±0.24
258
> loading: ./dgr/5way5shot/CUB/s120_r0_f0_b0.75_l0.plk
91.28±0.23
259
> loading: ./dgr/5way5shot/CUB/s120_r0_f0_b1.00_l0.plk
90.89±0.23
260
> loading: ./dgr/5way5shot/CUB/s120_r0_f0_b0_l0.02.plk
90.41±0.24
261
> loading: ./dgr/5way5shot/CUB/s120_r0_f0_b0_l0.04.plk
90.91±0.23
262
> loading: ./dgr/5way5shot/CUB/s120_r0_f0_b0_l0.06.plk
91.14±0.23
263
> loading: ./dgr/5way5shot/CUB/s120_r0_f0_b0_l0.08.plk
91.26±0.23
264
> loading: ./dgr/5way5shot/CUB/s120_r0_f0_b0_l0.10.plk
91.31±0.23
265
> loading: ./dgr/5way5shot/CUB/s120_r0_f1_b0.50_l0.plk
90.53±0.24
266
> loading: ./dgr/5way5shot/CUB/s120_r0_f1_b0.75_l0.plk
91.28±0.23
267
> loading: ./dgr/5way5shot/CUB/s120_r0_f1_b1.00_l0.plk
90.89±0.23
268

85.11±0.29
382
> loading: ./dgr/5way5shot/CUB/s130_r270_f1_b0_l0.06.plk
85.26±0.29
383
> loading: ./dgr/5way5shot/CUB/s130_r270_f1_b0_l0.08.plk
85.33±0.29
384
> loading: ./dgr/5way5shot/CUB/s130_r270_f1_b0_l0.10.plk
85.35±0.29
385
> loading: ./dgr/5way5shot/CUB/s140_r0_f0_b0.50_l0.plk
89.13±0.26
386
> loading: ./dgr/5way5shot/CUB/s140_r0_f0_b0.75_l0.plk
89.91±0.25
387
> loading: ./dgr/5way5shot/CUB/s140_r0_f0_b1.00_l0.plk
89.54±0.25
388
> loading: ./dgr/5way5shot/CUB/s140_r0_f0_b0_l0.02.plk
89.02±0.26
389
> loading: ./dgr/5way5shot/CUB/s140_r0_f0_b0_l0.04.plk
89.62±0.25
390
> loading: ./dgr/5way5shot/CUB/s140_r0_f0_b0_l0.06.plk
89.86±0.25
391
> loading: ./dgr/5way5shot/CUB/s140_r0_f0_b0_l0.08.plk
89.96±0.25
392
> loading: ./dgr/5way5shot/CUB/s140_r0_f0_b0_l0.10.plk
89.99±0.25
393
> loading: ./dgr/5way5shot/CUB/s140_r0_f1_b0.50_l0.plk
89.13±0.26
394
> loading: ./dgr/5way5shot/CUB/s140_r0_f1_b0.75_l0.plk
89.91±0.25
395
> loading: ./dgr/5way5shot/CUB/s140_r0_f1_b1.00_l0.plk
89.54±0.25
396

83.50±0.31
511
> loading: ./dgr/5way5shot/CUB/s150_r270_f1_b0_l0.08.plk
83.56±0.31
512
> loading: ./dgr/5way5shot/CUB/s150_r270_f1_b0_l0.10.plk
83.56±0.31


In [21]:
#%% check accs
m_s = []
m_w = []
for ru in runSet:
    print(ru.mean, end=' ')
    m_s.append(ru.mean)
    m_w.append(ru.width)

mi_, ma_ = min(m_s), max(m_s)
print('\nlen, min, max', len(m_s), mi_, ma_)
print('~~~~~ All mean ~~~~~')
print('%4.2f±%4.2f'%(np.mean(m_s), np.mean(m_w)))

91.30933333333334 91.55133333333333 90.91 91.20266666666666 91.428 91.494 91.51866666666668 91.52933333333333 91.30933333333334 91.55133333333333 90.91 91.20266666666666 91.428 91.494 91.51866666666668 91.52933333333333 86.39266666666666 86.58666666666667 85.92666666666668 86.24866666666667 86.544 86.618 86.63333333333334 86.6 86.43733333333334 86.79933333333332 86.14266666666666 86.352 86.656 86.77133333333333 86.79866666666668 86.81533333333333 86.698 87.11266666666666 86.57533333333333 86.56266666666666 86.94333333333333 87.05066666666666 87.078 87.07 91.336 91.46933333333334 90.91466666666666 91.24866666666667 91.456 91.508 91.524 91.51 86.05533333333332 86.44533333333334 85.78666666666668 85.91466666666666 86.24866666666665 86.34733333333332 86.40866666666668 86.42933333333333 85.98133333333334 86.338 85.70466666666667 85.79333333333334 86.112 86.202 86.24466666666667 86.25733333333334 91.31133333333334 91.686 91.106 91.228 91.544 91.64266666666666 91.68266666666666 91.666 91.3113

In [22]:
with open('tables/acc_val_cub.plk', 'rb') as handle:
    m_s_v = pickle.load(handle)

print('val set -- len, min, max', len(m_s_v), min(m_s_v), max(m_s_v))


val set -- len, min, max 512 81.288 90.89666666666668


In [34]:
#%%
from numpy import linalg as LA
from itertools import combinations
import random

def norm_dist(dist_): 
    dist = copy.deepcopy(dist_)
    dist = dist/(np.sum(dist, axis=1)[:,None])
    # dist = softmax(dist, axis=1)
    return dist

def dist_ens(distsS,norm=False):
    distsS = np.stack(distsS)
    if norm:
        distsS = np.stack([norm_dist(dis) for dis in distsS]) # norm?
    # distsS = 1/distsS
    distsS = -(distsS)
    distsE = np.mean(distsS, axis=0)    
    return distsE

def full_ens(disSub, query_label):
    E_l = len(disSub)
    accsS = []
    for i in range(1, E_l+1):
        print('> use:', i)
        accs = []
        comb = [list(c) for c in combinations(np.arange(E_l), i)]
        for j in range(len(comb)):
            distsE = dist_ens([disSub[k] for k in comb[j]])
            acc, wid = checkAacc(distsE, query_label)
            accs.append(acc)            
        accs = np.array(accs)
        accsS.append(accs)
    return accsS

def random_ens(disSub, query_label, Get=10):
    E_l = len(disSub)
    accsS = []    
    Ns = Get
    for i in range(1, E_l+1):
        print('> use:', i)
        accs = []
        for j in range(Ns):
            selec = random.sample(list(range(E_l)), i)
            distsE = dist_ens([disSub[k] for k in selec])
            acc, wid = checkAacc(distsE, query_label)
            accs.append(acc)            
        accs = np.array(accs)
        accsS.append(accs)
    return accsS        

def guided_ens(disSub, query_label, m_s_v):
    m_s_v = np.array(m_s_v)
    val_rank = np.argsort(m_s_v)[::-1]
    E_l = len(disSub)
    accsS = [] 
    for i in range(1, E_l+1):
        print('> use:', i)
        selec = val_rank[:i]
        distsE = dist_ens([disSub[k] for k in selec])
        acc, wid = checkAacc(distsE, query_label)
        accsS.append(acc)
    return accsS   

def checkAacc(distsE, query_label):
    N = 2000
    min_id = np.argmax(distsE,axis=-1)
    l_long = np.array(query_label.tolist()*N) # Note: use 2000
    TF     = min_id==l_long
    TF     = TF.reshape((N, -1))
    score  = np.mean(TF, axis=1) * 100
    mean, width = score_ens(score)
    print('%4.2f±%4.2f'%(mean, width))
    return mean, width


In [35]:
disSub = [k.dists for k in runSet] # runSet or runSub
print(len(disSub))

# naive ensemble
distsE = dist_ens(disSub)
full_m, full_v = checkAacc(distsE, query_label)

512
91.70±0.22


In [36]:
#%% random ensemble
accsR = random_ens(disSub, query_label, Get=5) # 5 or 2

> use: 1
91.43±0.22
86.28±0.28
85.83±0.29
86.62±0.28
85.45±0.28
> use: 2
85.07±0.30
91.15±0.23
87.27±0.27
87.97±0.26
89.26±0.25
> use: 3
91.22±0.23
88.85±0.26
87.25±0.28
88.58±0.26
88.91±0.25
> use: 4
92.16±0.22
91.34±0.23
91.77±0.22
89.90±0.25
91.79±0.22
> use: 5
92.25±0.22
88.34±0.26
89.05±0.25
92.02±0.22
91.01±0.23
> use: 6
89.87±0.24
89.87±0.24
91.23±0.23
92.25±0.22
89.24±0.25
> use: 7
92.45±0.21
91.53±0.23
91.60±0.22
90.36±0.24
91.44±0.22
> use: 8
91.68±0.22
89.73±0.25
91.84±0.22
90.23±0.24
90.60±0.24
> use: 9
91.99±0.22
90.21±0.24
90.59±0.24
91.61±0.22
90.29±0.24
> use: 10
91.22±0.23
90.61±0.24
89.99±0.24
92.40±0.21
91.81±0.22
> use: 11
92.23±0.21
91.47±0.23
92.21±0.21
89.37±0.25
89.43±0.25
> use: 12
91.00±0.23
91.67±0.22
91.37±0.22
91.54±0.22
92.16±0.22
> use: 13
91.89±0.22
90.97±0.23
91.66±0.22
90.83±0.23
92.21±0.22
> use: 14
90.80±0.23
92.48±0.22
91.70±0.22
91.67±0.22
91.65±0.22
> use: 15
92.60±0.21
89.22±0.25
90.49±0.24
91.85±0.22
90.81±0.23
> use: 16
91.20±0.23
89.50±0.25
91

91.49±0.22
> use: 127
91.27±0.23
91.72±0.22
91.65±0.22
91.37±0.23
91.58±0.22
> use: 128
91.48±0.22
91.60±0.22
91.88±0.22
91.49±0.22
91.83±0.22
> use: 129
91.53±0.22
91.55±0.22
91.73±0.22
91.90±0.22
91.72±0.22
> use: 130
91.64±0.22
91.49±0.23
91.09±0.23
91.54±0.22
91.80±0.22
> use: 131
91.99±0.22
91.55±0.22
91.85±0.22
91.73±0.22
91.24±0.23
> use: 132
91.67±0.22
91.64±0.22
91.67±0.22
91.73±0.22
91.58±0.22
> use: 133
91.65±0.22
91.66±0.22
91.19±0.23
91.18±0.23
91.78±0.22
> use: 134
91.69±0.22
91.88±0.22
91.63±0.22
91.75±0.22
91.55±0.22
> use: 135
91.61±0.22
91.68±0.22
91.33±0.23
91.40±0.23
91.57±0.22
> use: 136
91.95±0.22
91.82±0.22
91.64±0.22
91.71±0.22
91.48±0.23
> use: 137
91.60±0.22
91.37±0.23
91.80±0.22
91.70±0.22
91.66±0.22
> use: 138
91.69±0.22
91.86±0.22
91.76±0.22
91.37±0.23
91.67±0.22
> use: 139
91.64±0.22
91.56±0.22
91.83±0.22
91.81±0.22
91.96±0.22
> use: 140
91.59±0.22
91.53±0.22
91.48±0.23
91.57±0.22
91.37±0.23
> use: 141
91.68±0.22
91.45±0.23
91.68±0.22
91.44±0.23
91.74±0.22

91.44±0.23
91.53±0.22
91.90±0.22
91.45±0.23
91.76±0.22
> use: 252
91.86±0.22
91.55±0.22
91.54±0.22
91.66±0.22
91.66±0.22
> use: 253
91.89±0.22
91.77±0.22
91.85±0.22
91.66±0.22
91.66±0.22
> use: 254
91.83±0.22
91.77±0.22
91.59±0.22
91.53±0.22
91.84±0.22
> use: 255
91.84±0.22
91.72±0.22
91.59±0.22
91.65±0.22
91.95±0.22
> use: 256
91.54±0.22
91.61±0.22
91.69±0.22
91.81±0.22
91.57±0.22
> use: 257
91.65±0.22
91.78±0.22
91.81±0.22
91.49±0.22
91.64±0.22
> use: 258
91.60±0.22
91.71±0.22
91.63±0.22
91.61±0.22
91.77±0.22
> use: 259
91.57±0.22
91.74±0.22
91.62±0.22
91.68±0.22
91.82±0.22
> use: 260
91.57±0.22
91.81±0.22
91.60±0.22
91.66±0.22
91.73±0.22
> use: 261
91.73±0.22
91.76±0.22
91.52±0.22
91.84±0.22
91.68±0.22
> use: 262
91.72±0.22
91.74±0.22
91.62±0.22
91.60±0.22
91.82±0.22
> use: 263
91.70±0.22
91.82±0.22
91.85±0.22
91.81±0.22
91.66±0.22
> use: 264
91.51±0.22
91.75±0.22
91.49±0.22
91.60±0.22
91.72±0.22
> use: 265
91.82±0.22
91.62±0.22
91.56±0.22
91.84±0.22
91.64±0.22
> use: 266
91.75±0.22

91.71±0.22
91.68±0.22
91.67±0.22
91.76±0.22
> use: 376
91.80±0.22
91.73±0.22
91.82±0.22
91.77±0.22
91.76±0.22
> use: 377
91.65±0.22
91.80±0.22
91.72±0.22
91.72±0.22
91.65±0.22
> use: 378
91.63±0.22
91.61±0.22
91.76±0.22
91.68±0.22
91.85±0.22
> use: 379
91.57±0.22
91.83±0.22
91.71±0.22
91.78±0.22
91.70±0.22
> use: 380
91.62±0.22
91.66±0.22
91.66±0.22
91.71±0.22
91.77±0.22
> use: 381
91.72±0.22
91.67±0.22
91.66±0.22
91.79±0.22
91.67±0.22
> use: 382
91.67±0.22
91.60±0.22
91.75±0.22
91.75±0.22
91.77±0.22
> use: 383
91.69±0.22
91.80±0.22
91.86±0.22
91.77±0.22
91.86±0.22
> use: 384
91.70±0.22
91.65±0.22
91.64±0.22
91.68±0.22
91.69±0.22
> use: 385
91.69±0.22
91.60±0.22
91.76±0.22
91.62±0.22
91.67±0.22
> use: 386
91.64±0.22
91.67±0.22
91.74±0.22
91.76±0.22
91.71±0.22
> use: 387
91.70±0.22
91.84±0.22
91.67±0.22
91.80±0.22
91.66±0.22
> use: 388
91.77±0.22
91.76±0.22
91.84±0.22
91.74±0.22
91.65±0.22
> use: 389
91.77±0.22
91.68±0.22
91.67±0.22
91.79±0.22
91.63±0.22
> use: 390
91.71±0.22
91.82±0.22

91.68±0.22
91.69±0.22
91.68±0.22
> use: 500
91.73±0.22
91.73±0.22
91.70±0.22
91.68±0.22
91.69±0.22
> use: 501
91.70±0.22
91.71±0.22
91.71±0.22
91.70±0.22
91.70±0.22
> use: 502
91.71±0.22
91.70±0.22
91.72±0.22
91.69±0.22
91.69±0.22
> use: 503
91.68±0.22
91.72±0.22
91.70±0.22
91.69±0.22
91.71±0.22
> use: 504
91.69±0.22
91.65±0.22
91.74±0.22
91.72±0.22
91.69±0.22
> use: 505
91.70±0.22
91.67±0.22
91.69±0.22
91.69±0.22
91.68±0.22
> use: 506
91.69±0.22
91.72±0.22
91.66±0.22
91.69±0.22
91.69±0.22
> use: 507
91.70±0.22
91.69±0.22
91.70±0.22
91.70±0.22
91.69±0.22
> use: 508
91.68±0.22
91.70±0.22
91.70±0.22
91.68±0.22
91.69±0.22
> use: 509
91.70±0.22
91.70±0.22
91.71±0.22
91.73±0.22
91.70±0.22
> use: 510
91.70±0.22
91.70±0.22
91.68±0.22
91.71±0.22
91.68±0.22
> use: 511
91.70±0.22
91.69±0.22
91.70±0.22
91.69±0.22
91.70±0.22
> use: 512
91.70±0.22
91.70±0.22
91.70±0.22
91.70±0.22
91.70±0.22


In [37]:
#%% guided ensemble
accsG = guided_ens(disSub, query_label, m_s_v)

> use: 1
91.81±0.22
> use: 2
91.80±0.22
> use: 3
91.78±0.22
> use: 4
91.77±0.22
> use: 5
91.92±0.22
> use: 6
92.08±0.22
> use: 7
92.35±0.21
> use: 8
92.45±0.21
> use: 9
92.48±0.21
> use: 10
92.48±0.21
> use: 11
92.54±0.21
> use: 12
92.57±0.21
> use: 13
92.51±0.21
> use: 14
92.56±0.21
> use: 15
92.57±0.21
> use: 16
92.57±0.21
> use: 17
92.56±0.21
> use: 18
92.56±0.21
> use: 19
92.55±0.21
> use: 20
92.55±0.21
> use: 21
92.56±0.21
> use: 22
92.56±0.21
> use: 23
92.56±0.21
> use: 24
92.58±0.21
> use: 25
92.58±0.21
> use: 26
92.59±0.21
> use: 27
92.57±0.21
> use: 28
92.55±0.21
> use: 29
92.56±0.21
> use: 30
92.55±0.21
> use: 31
92.56±0.21
> use: 32
92.56±0.21
> use: 33
92.58±0.21
> use: 34
92.60±0.21
> use: 35
92.60±0.21
> use: 36
92.60±0.21
> use: 37
92.60±0.21
> use: 38
92.61±0.21
> use: 39
92.60±0.21
> use: 40
92.58±0.21
> use: 41
92.60±0.21
> use: 42
92.62±0.21
> use: 43
92.60±0.21
> use: 44
92.57±0.21
> use: 45
92.57±0.21
> use: 46
92.57±0.21
> use: 47
92.55±0.21
> use: 48
92.54±0.21
>

92.25±0.21
> use: 379
92.24±0.21
> use: 380
92.24±0.21
> use: 381
92.24±0.21
> use: 382
92.23±0.21
> use: 383
92.23±0.21
> use: 384
92.23±0.21
> use: 385
92.23±0.21
> use: 386
92.23±0.21
> use: 387
92.23±0.21
> use: 388
92.23±0.21
> use: 389
92.23±0.21
> use: 390
92.22±0.21
> use: 391
92.23±0.21
> use: 392
92.23±0.21
> use: 393
92.23±0.21
> use: 394
92.21±0.21
> use: 395
92.20±0.21
> use: 396
92.20±0.21
> use: 397
92.20±0.21
> use: 398
92.20±0.21
> use: 399
92.20±0.21
> use: 400
92.19±0.22
> use: 401
92.18±0.22
> use: 402
92.18±0.22
> use: 403
92.18±0.22
> use: 404
92.18±0.22
> use: 405
92.17±0.22
> use: 406
92.16±0.22
> use: 407
92.17±0.22
> use: 408
92.16±0.22
> use: 409
92.15±0.22
> use: 410
92.15±0.22
> use: 411
92.15±0.22
> use: 412
92.14±0.22
> use: 413
92.13±0.22
> use: 414
92.13±0.22
> use: 415
92.12±0.22
> use: 416
92.12±0.22
> use: 417
92.12±0.22
> use: 418
92.11±0.22
> use: 419
92.11±0.22
> use: 420
92.10±0.22
> use: 421
92.10±0.22
> use: 422
92.10±0.22
> use: 423
92.09±0.22

In [42]:
print('now we can achieve an accuracy of', max(accsG))

now we can achieve an accuracy of 92.61666666666666
