In [1]:
%%javascript
IPython.OutputArea.auto_scroll_threshold = 9999;

<IPython.core.display.Javascript object>

In [2]:
from matplotlib import pyplot
import numpy as np
import pandas
from matplotlib import pyplot as plt
from scipy.stats import rankdata

import sys
sys.path.append('../../')
from bayesian_benchmarks.database_utils import Database
from bayesian_benchmarks.data import classification_datasets, _ALL_REGRESSION_DATATSETS, _ALL_CLASSIFICATION_DATATSETS
ALL_DATATSETS = {}
ALL_DATATSETS.update(_ALL_REGRESSION_DATATSETS)
ALL_DATATSETS.update(_ALL_CLASSIFICATION_DATATSETS)

from IPython.display import display, HTML



/homes/hrs13/Documents/github/bayesian_benchmarks/bayesian_benchmarks


In [3]:
def rankarray(A):
    ranks = []
    for a in A:
        ranks.append(rankdata(a))
    return np.array(ranks)


def read_regression_classification(fs, models_names, datasets, task):
    if task == 'classification':
        fields = ['dataset', 'N', 'D', 'K'] + [m[1] for m in models_names]
    else:
        fields = ['dataset', 'N', 'D'] + [m[1] for m in models_names]

    results = {}
    for f in fs:
        results[f] = {'table':{f:[] for f in fields}, 'vals':[]}

    with Database('../results/results.db') as db:

        for dataset in datasets:
            for f in fs:
                results[f]['table']['dataset'].append(dataset[:10])
                results[f]['table']['N'].append(ALL_DATATSETS[dataset].N)
                results[f]['table']['D'].append(ALL_DATATSETS[dataset].D)
                if task == 'classification':
                    results[f]['table']['K'].append(ALL_DATATSETS[dataset].K)

            row = {f:[] for f in fs}
            for model, name in models_names:
                res = db.read(task, fs, {'model':model, 
                                         'dataset':dataset})
                    
                if len(res) == 0:
                    for f in fs:
                        results[f]['table'][name].append('')
                        row[f].append(np.nan)
                else:
                    print('{} {} {}'.format(model, dataset, len(res)))
                    for i, f in enumerate(fs):
                        L = [float(l[i]) for l in res]
                        m = np.average(L)
                        std = np.std(L) if len(L) > 1 else np.nan
                        if m < 1000 and m > -1000:
                            r = '{:.3f}({:.3f})'.format(m, std)
                            row[f].append(m)
                        else:
                            r = 'nan'
                            row[f].append(np.nan)

                        results[f]['table'][name].append(r)

            #             stderr = np.std(L)/float(len(L))**0.5
            #             r = '{:.3f} ({:.3f})'.format(m, stderr)
            for f in fs:   
                results[f]['vals'].append(row[f])


    for f in fs:
        if 'unnormalized' not in f:
            vals = np.array(results[f]['vals'])

            avgs = np.nanmean(vals, 0)
            meds = np.nanmedian(vals, 0)
            rks = np.nanmean(rankarray(vals), 0)

            for s, n in [[avgs, 'avg'], [meds, 'median'], [rks, 'avg rank']]:
                results[f]['table']['dataset'].append(n)
                results[f]['table']['N'].append('')
                results[f]['table']['D'].append('')
                if task == 'classification':
                    results[f]['table']['K'].append('')
                for ss, name in zip(s, [m[1] for m in models_names]):
                    results[f]['table'][name].append('{:.3f}'.format(ss))
    
    return results, fields


In [4]:
models_names = [['linear', 'lin'],
                ['variationally_sparse_gp', 'SVGP'],
                ['variationally_sparse_gp_minibatch', 'SVGP_mb'],
                ['deep_gp_doubly_stochastic','DGP'],
                ['svm', 'svm'],
                ['knn', 'knn'],
#                 ['naive_bayes', 'nb'],
#                 ['decision_tree', 'dt'],
#                 ['random_forest', 'rf'],
                ['gradient_boosting_machine', 'gbm'],
                ['adaboost', 'ab'],
                ['mlp', 'mlp'],
                ]

fs = 'test_loglik', 'test_rmse', 'test_loglik_unnormalized', 'test_rmse_unnormalized'

results, fields = read_regression_classification(fs, models_names, _ALL_REGRESSION_DATATSETS, 'regression')


linear energy 10
variationally_sparse_gp energy 10
deep_gp_doubly_stochastic energy 10
svm energy 10
knn energy 10
gradient_boosting_machine energy 10
adaboost energy 10
mlp energy 10
linear yacht 10
variationally_sparse_gp yacht 10
deep_gp_doubly_stochastic yacht 10
svm yacht 10
knn yacht 10
gradient_boosting_machine yacht 10
adaboost yacht 10
mlp yacht 10
linear boston 10
variationally_sparse_gp boston 11
deep_gp_doubly_stochastic boston 10
svm boston 10
knn boston 10
gradient_boosting_machine boston 10
adaboost boston 10
mlp boston 10
linear winered 10
variationally_sparse_gp winered 10
deep_gp_doubly_stochastic winered 10
svm winered 10
knn winered 10
gradient_boosting_machine winered 10
adaboost winered 10
mlp winered 10
linear protein 10
variationally_sparse_gp protein 10
deep_gp_doubly_stochastic protein 10
svm protein 10
knn protein 10
gradient_boosting_machine protein 10
adaboost protein 10
mlp protein 10
linear power 10
variationally_sparse_gp power 10
deep_gp_doubly_stochast

  r = func(a, **kwargs)


In [5]:
print('normalised test loglikelihood')
display(HTML(pandas.DataFrame(results['test_loglik']['table'], columns=fields).to_html(index=False)))
# print(pandas.DataFrame(results['test_loglik']['table'], columns=fields).to_latex())

print('unnormalized test loglikelihood')
display(HTML(pandas.DataFrame(results['test_loglik_unnormalized']['table'], columns=fields).to_html(index=False)))


print('normalised test rmse')
display(HTML(pandas.DataFrame(results['test_rmse']['table'], columns=fields).to_html(index=False)))

print('normalised test rmse')
display(HTML(pandas.DataFrame(results['test_rmse_unnormalized']['table'], columns=fields).to_html(index=False)))



normalised test loglikelihood


dataset,N,D,lin,SVGP,SVGP_mb,DGP,svm,knn,gbm,ab,mlp
energy,768.0,8.0,-0.220(0.114),1.227(0.042),,1.590(0.106),0.038(0.159),-0.022(0.253),1.603(0.154),0.237(0.034),0.121(0.153)
yacht,308.0,6.0,-0.929(0.083),1.943(0.141),,2.259(0.134),-0.614(0.286),-1.155(0.320),-0.589(2.235),0.799(0.351),-0.090(0.318)
boston,506.0,13.0,-0.644(0.066),-0.180(0.064),,-0.190(0.052),-0.157(0.083),-0.467(0.134),-0.642(0.243),-0.368(0.099),-0.228(0.111)
winered,1599.0,11.0,-1.208(0.060),-1.163(0.065),,-1.168(0.064),-1.174(0.095),-1.280(0.121),-1.206(0.100),-1.176(0.082),-1.189(0.098)
protein,45730.0,9.0,-1.257(0.005),-1.090(0.008),,-1.035(0.009),-1.150(0.011),-1.013(0.018),-1.156(0.009),-1.359(0.014),-1.092(0.014)
power,9568.0,4.0,-0.098(0.031),0.038(0.036),,0.047(0.037),0.034(0.038),0.046(0.056),0.066(0.042),-0.309(0.037),0.002(0.034)
concrete,1030.0,8.0,-0.953(0.052),-0.353(0.054),,-0.359(0.055),-0.558(0.107),-0.828(0.101),-0.440(0.162),-0.678(0.086),-0.433(0.122)
winewhite,4898.0,12.0,-1.254(0.039),-1.167(0.027),,-1.164(0.025),-1.161(0.031),-1.230(0.040),-1.161(0.031),-1.228(0.031),-1.170(0.036)
naval,11934.0,12.0,-0.489(0.017),5.135(0.015),,2.527(0.224),0.120(0.028),0.740(0.109),-0.088(0.030),-1.297(0.019),1.080(0.154)
avg,,,-0.784,0.488,,0.279,-0.514,-0.579,-0.401,-0.598,-0.333


unnormalized test loglikelihood


dataset,N,D,lin,SVGP,SVGP_mb,DGP,svm,knn,gbm,ab,mlp
energy,768,8,-2.531(0.114),-1.084(0.042),,-0.721(0.106),-2.273(0.159),-2.333(0.253),-0.708(0.154),-2.074(0.034),-2.190(0.153)
yacht,308,6,-3.646(0.083),-0.774(0.141),,-0.458(0.134),-3.331(0.286),-3.872(0.320),-3.306(2.235),-1.918(0.351),-2.807(0.318)
boston,506,13,-2.862(0.066),-2.398(0.064),,-2.408(0.052),-2.375(0.083),-2.685(0.134),-2.860(0.243),-2.586(0.099),-2.446(0.111)
winered,1599,11,-0.994(0.060),-0.949(0.065),,-0.954(0.064),-0.960(0.095),-1.066(0.121),-0.991(0.100),-0.962(0.082),-0.975(0.098)
protein,45730,9,-3.068(0.005),-2.901(0.008),,-2.846(0.009),-2.962(0.011),-2.824(0.018),-2.967(0.009),-3.170(0.014),-2.903(0.014)
power,9568,4,-2.935(0.031),-2.799(0.036),,-2.790(0.037),-2.803(0.038),-2.791(0.056),-2.771(0.042),-3.146(0.037),-2.835(0.034)
concrete,1030,8,-3.768(0.052),-3.168(0.054),,-3.175(0.055),-3.373(0.107),-3.643(0.101),-3.256(0.162),-3.494(0.086),-3.248(0.122)
winewhite,4898,12,-1.133(0.039),-1.045(0.027),,-1.043(0.025),-1.039(0.031),-1.108(0.040),-1.040(0.031),-1.106(0.031),-1.049(0.036)
naval,11934,12,3.730(0.017),9.353(0.015),,6.746(0.224),4.338(0.028),4.958(0.109),4.130(0.030),2.921(0.019),5.298(0.154)


normalised test rmse


dataset,N,D,lin,SVGP,SVGP_mb,DGP,svm,knn,gbm,ab,mlp
energy,768.0,8.0,0.300(0.034),0.064(0.005),,0.048(0.006),0.227(0.027),0.219(0.029),0.047(0.005),0.190(0.006),0.210(0.033)
yacht,308.0,6.0,0.608(0.048),0.043(0.014),,0.031(0.013),0.419(0.092),0.670(0.138),0.044(0.014),0.103(0.023),0.244(0.051)
boston,506.0,13.0,0.444(0.044),0.282(0.029),,0.286(0.025),0.267(0.037),0.380(0.059),0.283(0.020),0.344(0.028),0.298(0.026)
winered,1599.0,11.0,0.808(0.046),0.777(0.051),,0.780(0.051),0.768(0.055),0.825(0.063),0.761(0.046),0.778(0.054),0.775(0.056)
protein,45730.0,9.0,0.850(0.004),0.719(0.007),,0.681(0.007),0.764(0.008),0.623(0.007),0.768(0.007),0.941(0.013),0.721(0.010)
power,9568.0,4.0,0.267(0.008),0.232(0.008),,0.230(0.009),0.234(0.009),0.219(0.008),0.226(0.008),0.329(0.012),0.241(0.008)
concrete,1030.0,8.0,0.626(0.031),0.357(0.024),,0.357(0.022),0.405(0.030),0.519(0.031),0.323(0.026),0.468(0.030),0.347(0.027)
winewhite,4898.0,12.0,0.847(0.033),0.778(0.024),,0.776(0.022),0.768(0.021),0.788(0.020),0.768(0.021),0.825(0.025),0.778(0.026)
naval,11934.0,12.0,0.394(0.007),0.001(0.000),,0.017(0.007),0.215(0.006),0.104(0.006),0.263(0.007),0.885(0.017),0.083(0.014)
avg,,,0.572,0.362,,0.356,0.452,0.483,0.387,0.540,0.411


normalised test rmse


dataset,N,D,lin,SVGP,SVGP_mb,DGP,svm,knn,gbm,ab,mlp
energy,768,8,3.029(0.342),0.650(0.049),,0.486(0.060),2.290(0.274),2.204(0.295),0.477(0.051),1.921(0.063),2.115(0.329)
yacht,308,6,9.208(0.721),0.655(0.217),,0.476(0.198),6.339(1.385),10.136(2.090),0.673(0.209),1.553(0.346),3.687(0.767)
boston,506,13,4.076(0.406),2.588(0.262),,2.625(0.232),2.455(0.336),3.487(0.538),2.598(0.183),3.162(0.255),2.735(0.241)
winered,1599,11,0.652(0.037),0.627(0.041),,0.630(0.041),0.620(0.044),0.666(0.051),0.615(0.037),0.628(0.044),0.626(0.045)
protein,45730,9,5.201(0.027),4.401(0.040),,4.169(0.041),4.676(0.049),3.812(0.041),4.701(0.041),5.758(0.078),4.410(0.062)
power,9568,4,4.553(0.142),3.967(0.145),,3.927(0.156),3.987(0.146),3.731(0.129),3.856(0.145),5.612(0.205),4.115(0.139)
concrete,1030,8,10.458(0.518),5.965(0.393),,5.955(0.366),6.767(0.495),8.672(0.521),5.390(0.427),7.813(0.508),5.787(0.451)
winewhite,4898,12,0.750(0.029),0.689(0.021),,0.687(0.020),0.680(0.018),0.698(0.018),0.680(0.018),0.731(0.022),0.689(0.023)
naval,11934,12,0.006(0.000),0.000(0.000),,0.000(0.000),0.003(0.000),0.002(0.000),0.004(0.000),0.013(0.000),0.001(0.000)


In [6]:
fs = ['test_loglik', 'test_acc']
results, fields = read_regression_classification(fs, models_names, classification_datasets, 'classification')


linear abalone 20
variationally_sparse_gp abalone 10
deep_gp_doubly_stochastic abalone 10
svm abalone 10
knn abalone 10
gradient_boosting_machine abalone 10
adaboost abalone 10
mlp abalone 10
linear acute-inflammation 20
variationally_sparse_gp acute-inflammation 10
deep_gp_doubly_stochastic acute-inflammation 10
svm acute-inflammation 10
knn acute-inflammation 10
gradient_boosting_machine acute-inflammation 10
adaboost acute-inflammation 10
mlp acute-inflammation 10
linear acute-nephritis 20
variationally_sparse_gp acute-nephritis 10
deep_gp_doubly_stochastic acute-nephritis 10
svm acute-nephritis 10
knn acute-nephritis 10
gradient_boosting_machine acute-nephritis 10
adaboost acute-nephritis 10
mlp acute-nephritis 10
linear adult 20
svm adult 10
knn adult 10
gradient_boosting_machine adult 10
adaboost adult 10
mlp adult 10
linear annealing 20
variationally_sparse_gp annealing 10
deep_gp_doubly_stochastic annealing 10
svm annealing 10
knn annealing 10
gradient_boosting_machine annealin

adaboost heart-hungarian 10
mlp heart-hungarian 10
linear heart-switzerland 20
variationally_sparse_gp heart-switzerland 10
deep_gp_doubly_stochastic heart-switzerland 10
svm heart-switzerland 10
knn heart-switzerland 10
gradient_boosting_machine heart-switzerland 10
adaboost heart-switzerland 10
mlp heart-switzerland 10
linear heart-va 20
variationally_sparse_gp heart-va 10
deep_gp_doubly_stochastic heart-va 10
svm heart-va 10
knn heart-va 10
gradient_boosting_machine heart-va 10
adaboost heart-va 10
mlp heart-va 10
linear hepatitis 20
variationally_sparse_gp hepatitis 10
deep_gp_doubly_stochastic hepatitis 10
svm hepatitis 10
knn hepatitis 10
gradient_boosting_machine hepatitis 10
adaboost hepatitis 10
mlp hepatitis 10
linear hill-valley 20
variationally_sparse_gp hill-valley 10
deep_gp_doubly_stochastic hill-valley 4
svm hill-valley 10
knn hill-valley 10
gradient_boosting_machine hill-valley 10
adaboost hill-valley 10
mlp hill-valley 10
linear horse-colic 20
variationally_sparse_gp 

knn pittsburg-bridges-T-OR-D 10
gradient_boosting_machine pittsburg-bridges-T-OR-D 10
adaboost pittsburg-bridges-T-OR-D 10
mlp pittsburg-bridges-T-OR-D 10
linear pittsburg-bridges-TYPE 20
deep_gp_doubly_stochastic pittsburg-bridges-TYPE 10
svm pittsburg-bridges-TYPE 10
knn pittsburg-bridges-TYPE 10
gradient_boosting_machine pittsburg-bridges-TYPE 10
adaboost pittsburg-bridges-TYPE 10
mlp pittsburg-bridges-TYPE 10
linear planning 20
variationally_sparse_gp planning 10
deep_gp_doubly_stochastic planning 10
svm planning 10
knn planning 10
gradient_boosting_machine planning 9
adaboost planning 10
mlp planning 10
linear plant-margin 20
variationally_sparse_gp plant-margin 8
svm plant-margin 10
knn plant-margin 10
gradient_boosting_machine plant-margin 10
adaboost plant-margin 10
mlp plant-margin 10
linear plant-shape 20
variationally_sparse_gp plant-shape 8
deep_gp_doubly_stochastic plant-shape 2
svm plant-shape 10
knn plant-shape 10
gradient_boosting_machine plant-shape 10
adaboost plant-s

mlp yeast 10
linear zoo 20
deep_gp_doubly_stochastic zoo 10
svm zoo 10
knn zoo 10
gradient_boosting_machine zoo 10
adaboost zoo 10
mlp zoo 10


  r = func(a, **kwargs)


In [7]:
print('test loglikelihood')
display(HTML(pandas.DataFrame(results['test_loglik']['table'], columns=fields).to_html(index=False)))
# print(pandas.DataFrame(results['test_loglik']['table'], columns=fields).to_latex())

print('test accuracy')
display(HTML(pandas.DataFrame(results['test_acc']['table'], columns=fields).to_html(index=False)))



test loglikelihood


dataset,N,D,K,lin,SVGP,SVGP_mb,DGP,svm,knn,gbm,ab,mlp
abalone,4177.0,9.0,3.0,-0.760(0.039),-2.150(0.181),,-1.825(0.170),-0.746(0.040),-2.564(0.324),-0.654(0.030),-1.053(0.007),-0.708(0.038)
acute-infl,120.0,7.0,2.0,-0.049(0.015),-0.008(0.002),,-0.063(0.008),-0.018(0.001),-0.000(0.000),-0.000(0.000),-0.025(0.063),-0.030(0.008)
acute-neph,120.0,7.0,2.0,-0.031(0.008),-0.007(0.002),,-0.051(0.006),-0.019(0.001),-0.000(0.000),-0.000(0.000),-0.085(0.256),-0.017(0.004)
adult,48842.0,15.0,2.0,-0.342(0.005),,,,-0.359(0.005),-1.146(0.048),-0.290(0.004),-0.665(0.000),-0.315(0.005)
annealing,898.0,32.0,5.0,-0.365(0.061),-0.743(0.358),,-1.051(0.333),-0.342(0.086),-0.769(0.394),-0.104(0.062),-1.219(0.021),-0.279(0.052)
arrhythmia,452.0,263.0,13.0,-1.343(0.298),-1.162(0.285),,-1.270(0.225),-1.071(0.143),-5.885(1.417),-1.331(0.344),-2.568(0.360),-1.333(0.287)
audiology-,196.0,60.0,18.0,-1.072(0.229),-0.914(0.338),,-1.140(0.311),-1.453(0.216),-4.130(1.691),-0.979(0.378),-2.904(0.326),-0.867(0.215)
balance-sc,625.0,5.0,3.0,-0.363(0.107),-0.039(0.046),,-0.018(0.014),-0.220(0.074),-2.087(0.818),-0.397(0.119),-0.994(0.009),-0.139(0.050)
balloons,16.0,5.0,2.0,-0.708(0.355),,,-0.615(0.119),-0.617(0.139),-0.576(0.161),-1.630(2.162),-2.676(2.902),-0.652(0.463)
bank,4521.0,17.0,2.0,-0.271(0.028),-0.252(0.024),,-0.253(0.024),-0.286(0.029),-1.143(0.224),-0.235(0.020),-0.646(0.002),-0.281(0.036)


test accuracy


dataset,N,D,K,lin,SVGP,SVGP_mb,DGP,svm,knn,gbm,ab,mlp
abalone,4177.0,9.0,3.0,0.636(0.029),0.664(0.024),,0.669(0.026),0.661(0.024),0.633(0.029),0.696(0.022),0.688(0.024),0.668(0.028)
acute-infl,120.0,7.0,2.0,1.000(0.000),1.000(0.000),,1.000(0.000),1.000(0.000),1.000(0.000),1.000(0.000),0.983(0.050),1.000(0.000)
acute-neph,120.0,7.0,2.0,1.000(0.000),1.000(0.000),,1.000(0.000),1.000(0.000),1.000(0.000),1.000(0.000),0.992(0.025),1.000(0.000)
adult,48842.0,15.0,2.0,0.843(0.005),,,,0.849(0.004),0.830(0.003),0.868(0.004),0.859(0.005),0.854(0.004)
annealing,898.0,32.0,5.0,0.848(0.041),0.892(0.046),,0.862(0.043),0.876(0.044),0.880(0.041),0.967(0.017),0.870(0.036),0.890(0.047)
arrhythmia,452.0,263.0,13.0,0.707(0.092),0.761(0.056),,0.761(0.048),0.674(0.051),0.615(0.048),0.754(0.061),0.613(0.070),0.678(0.056)
audiology-,196.0,60.0,18.0,0.795(0.096),0.755(0.069),,0.730(0.081),0.630(0.105),0.535(0.081),0.810(0.073),0.265(0.125),0.705(0.072)
balance-sc,625.0,5.0,3.0,0.863(0.048),0.986(0.019),,0.997(0.006),0.916(0.036),0.835(0.062),0.870(0.049),0.908(0.042),0.970(0.026)
balloons,16.0,5.0,2.0,0.500(0.316),,,0.600(0.300),0.600(0.300),0.600(0.300),0.650(0.450),0.400(0.300),0.550(0.350)
bank,4521.0,17.0,2.0,0.892(0.018),0.892(0.015),,0.891(0.017),0.892(0.013),0.890(0.014),0.900(0.011),0.892(0.015),0.894(0.013)


In [8]:
# fields = ['dataset', 'N', 'D']

                
# colours = ['C{}'.format(i) for i in range(10)]

# fields = fields + [m[1] for m in models_names]
# results = {f:[] for f in fields}


# for dataset in regression_datasets:
    
#     fig, axs = plt.subplots(1, 2, figsize=(10, 5))

#     results['dataset'].append(dataset)
#     results['N'].append(ALL_REGRESSION_DATATSETS[dataset].N)
#     results['D'].append(ALL_REGRESSION_DATATSETS[dataset].D)

#     for (model, name), c in zip(models_names, colours):
#         with Database('../results/results.db') as db:
#             d = {'model':model, 'dataset':dataset}

#             res = db.read('active_learning_continuous', ['total_loglik', 'total_rmse'], d) 
#         if len(res)>0:
#             test_ll = res[0][0]
#             test_acc = res[0][1]

#             axs[0].plot(test_ll, label=model, color=c)r
#             axs[1].plot(test_acc, label=model, color=c)
#     axs[0].set_ylim(-10, 10)
#     plt.title('{} {} {}'.format(dataset,
#                                    ALL_REGRESSION_DATATSETS[dataset].N,
#                                    ALL_REGRESSION_DATATSETS[dataset].D))
#     plt.legend()
#     plt.show()


In [9]:

# fields = ['dataset', 'N', 'D', 'K']

# models_names = [['linear', 'lin'],
#                 ['variationally_sparse_gp', 'SVGP'],
#                 ['deep_gp_doubly_stochastic','DGP'],
#                 ['svm', 'svm'],
#                 ['knn', 'knn'],
#                 ['naive_bayes', 'nb'],
#                 ['decision_tree', 'dt'],
#                 ['random_forest', 'rf'],
#                 ['gradient_boosting_machine', 'gbm'],
#                 ['adaboost', 'ab'],
#                 ['mlp', 'mlp'],
#                 ]
                
# colours = ['C{}'.format(i) for i in range(10)]

# fields = fields + [m[1] for m in models_names]
# results = {f:[] for f in fields}


# for dataset in classification_datasets[:4]:  # don't show them all...
    
#     fig, axs = plt.subplots(1, 2, figsize=(10, 5))

#     results['dataset'].append(dataset)
#     results['N'].append(ALL_CLASSIFICATION_DATATSETS[dataset].N)
#     results['D'].append(ALL_CLASSIFICATION_DATATSETS[dataset].D)
#     results['K'].append(ALL_CLASSIFICATION_DATATSETS[dataset].K)

#     for (model, name), c in zip(models_names, colours):
#         with Database('../results/results.db') as db:
#             d = {'model':model, 'dataset':dataset}

#             res = db.read('active_learning_discrete', ['test_loglik', 'total_acc'], d) 
#         if len(res)>0:
#             test_ll = res[0][0]
#             test_acc = res[0][1]

#             axs[0].plot(test_ll, label=model, color=c)
#             axs[1].plot(test_acc, label=model, color=c)

#     plt.title('{} {} {} {}'.format(dataset,
#                                    ALL_CLASSIFICATION_DATATSETS[dataset].N,
#                                    ALL_CLASSIFICATION_DATATSETS[dataset].D,
#                                    ALL_CLASSIFICATION_DATATSETS[dataset].K))
#     plt.legend()
#     plt.show()