In [1]:
import pickle
import glob, os
import fnmatch
import numpy
import matplotlib
from matplotlib.backends.backend_pdf import PdfPages
from matplotlib import pyplot as plt

In [2]:
def average_trials(min_results):
    mean_data={}
    for s in min_results.keys():
        mean_data[s]=[]
        min_iters=min([len(min_results[s][seed]) for seed in min_results[s].keys()])
        for i in range(min_iters):
            avg_r=numpy.nanmean([min_results[s][seed][i] for seed in min_results[s].keys()],axis=0)
            max_r=numpy.nanmax([min_results[s][seed][i] for seed in min_results[s].keys()],axis=0)
            max_r[1:]=max_r[1:]-avg_r[1:]
            min_r=numpy.nanmin([min_results[s][seed][i] for seed in min_results[s].keys()],axis=0)
            min_r[1:]=avg_r[1:]-min_r[1:]
            mean_data[s].append([min_r,avg_r,max_r])
    return mean_data

In [10]:
colors={'hyperband_constant':'y','smac':'r','hyperopt':'b','random':'g','spearmint':'c','random_2x':'g','smac_early':'r'}
seeds={'hyperband_constant':3,'random':3, 'smac':100,'hyperopt': 100,'spearmint':100}
labels={'hyperband_constant':'hyperband','smac':'smac','hyperopt':'TPE','random':'random','spearmint':'spearmint','random_2x':'random_2x','smac_early':'smac_early'}
def results_plot(mean_data,index,error_bar):
    matplotlib.rcParams.update({'font.size': 16})
    line_width=2
    edge_width=2
    shift = 0.3
    mean_data.pop('random_2x')
    #for s in ['random_2x','smac','smac_early','hyperopt','hyperband','spearmint']:
    for s in [s for s in mean_data.keys() if s!='hyperband']:
        if index==2:
            axis_label='Average Test Error'
        else:
            axis_label='Average Val Error'
        if s == 'hyperband' or s=='hyperband_constant':
            if error_bar:
                plt.errorbar([i[0][0] for i in mean_data[s]],[i[1][index] for i in mean_data[s]],color=colors[s],
                              yerr=[[i[0][index] for i in mean_data[s]],[i[2][index] for i in mean_data[s]]],
                         label=labels[s],linewidth=line_width,marker='x',markeredgewidth=edge_width,elinewidth=0.5,capthick=0.5)
            else:
                plt.plot([i[0][0] for i in mean_data[s]],[i[1][index] for i in mean_data[s]],color=colors[s],                          
                     label=labels[s],linewidth=line_width,marker='x',markeredgewidth=edge_width)

        elif s=='random_2x':
            if error_bar:
                plt.errorbar([row[0][0] for row in mean_data[s]],[row[1][index] for row in mean_data[s]],color=colors[s],errorevery=8,
                          yerr=[[row[0][index] for row in mean_data[s]],
                                [row[2][index] for row in mean_data[s]]],label=labels[s],linewidth=line_width,marker='+',markeredgewidth=edge_width,elinewidth=0.5,capthick=0.5,markevery=4)
            else:
                plt.plot([row[0][0] for row in mean_data[s]],[row[1][index] for row in mean_data[s]],color=colors[s],
                          label=labels[s],linewidth=line_width,marker='+',markeredgewidth=edge_width,markevery=4)
        elif s=='smac_early':
            if error_bar:
                plt.errorbar([i[0][0] for i in mean_data[s]],[i[1][index] for i in mean_data[s]],color=colors[s],
                              yerr=[[i[0][index] for i in mean_data[s]],[i[2][index] for i in mean_data[s]]],
                         label=labels[s],linewidth=line_width,marker='^',markeredgewidth=edge_width,elinewidth=0.5,capthick=0.5,markeredgecolor='r',markevery=4,errorevery=4)
            else:
                plt.plot([i[0][0] for i in mean_data[s]],[i[1][index] for i in mean_data[s]],color=colors[s],                          
                     label=labels[s],linewidth=line_width,marker='^',markeredgewidth=edge_width,markeredgecolor='r',markevery=4)
        else:

            if error_bar:
                plt.errorbar(numpy.array([row[0][0]+shift for row in mean_data[s]]),[row[1][index] for row in mean_data[s]],color=colors[s],errorevery=4,
                          yerr=[[row[0][index] for row in mean_data[s]],
                                [row[2][index] for row in mean_data[s]]],label=labels[s],linewidth=line_width,elinewidth=0.5,capthick=0.5)
            else:
                plt.plot(numpy.array([row[0][0] for row in mean_data[s]]),[row[1][index] for row in mean_data[s]],color=colors[s],
                          label=labels[s],linewidth=line_width)
            shift+=0.3
    
    plt.xlim([0,50])
    #plt.ylim([0.175,0.325])
    plt.xlabel('Multiple of Max Iter Used')
    plt.ylabel(axis_label)


In [12]:
cifar_data =pickle.load(open('./cifar10/cifar10_data.pkl','r'))
mrbi_data =pickle.load(open('./mrbi/mrbi_data.pkl','r'))
svhn_data =pickle.load(open('./svhn/svhn_data.pkl','r'))
svhn_data['spearmint'][8700].append([55,0.044499999999999998,0.040231000000000003])
svhn_data['spearmint'][8700].append([56,0.044499999999999998,0.040231000000000003])
cifar_mean=average_trials(cifar_data)
mrbi_mean=average_trials(mrbi_data)
svhn_mean=average_trials(svhn_data)
pdf=PdfPages('./error_avg_constant.pdf')
plt.figure(figsize=(6.5,5.5))
results_plot(cifar_mean, 2,0)
#plt.legend(ncol=2,columnspacing=0.2,fancybox=True, framealpha=0.75)
plt.ylim(0.175,0.325)
#plt.title('CIFAR-10 Test Error')
pdf.savefig()
plt.close()
plt.figure(figsize=(6.5,5.5))
results_plot(mrbi_mean, 2,0)
plt.legend(ncol=2,columnspacing=0.2,fancybox=True, framealpha=0.75)
plt.ylim(0.22,0.37)
#plt.title('MRBI Test Error')
pdf.savefig()
plt.close()
plt.figure(figsize=(6.5,5.5))
results_plot(svhn_mean, 2,0)
#plt.legend(ncol=2,columnspacing=0.2,fancybox=True, framealpha=0.75)

plt.ylim(0.025,0.175)
#plt.title('SVHN Test Error')
pdf.savefig()
plt.close()
pdf.close()



In [2]:
cifar_data =pickle.load(open('./cifar10/cifar_data.pkl','r'))

In [3]:
cifar_data

{'hyperband': {300: [[4, 0.20879999876022337, 0.21110000550746921],
   [8, 0.20260000109672549, 0.20330000281333926],
   [12, 0.20260000109672549, 0.20330000281333926],
   [16, 0.20260000109672549, 0.20330000281333926],
   [24, 0.19360000133514399, 0.19110000371932978],
   [32, 0.19360000133514399, 0.19110000371932978],
   [40, 0.19360000133514399, 0.19110000371932978],
   [48, 0.19360000133514399, 0.19110000371932978],
   [56, 0.19360000133514399, 0.19110000371932978]],
  400: [[4, 0.24529999792575841, 0.24640000164508824],
   [8, 0.24529999792575841, 0.24640000164508824],
   [12, 0.24529999792575841, 0.24640000164508824],
   [16, 0.24529999792575841, 0.24640000164508824],
   [24, 0.21819999754428865, 0.214500002861023],
   [32, 0.20849999845027922, 0.21220000445842746],
   [40, 0.20849999845027922, 0.21220000445842746],
   [48, 0.20849999845027922, 0.21220000445842746],
   [56, 0.20820000112056736, 0.2039000058174133]],
  500: [[4, 0.21640000462532039, 0.22050000429153438],
   [8, 0.