In [1]:
import pickle, os
import numpy as np
from matplotlib import pyplot as plt
import glob

In [None]:
!ls ../gcloud_data/experiment_running

In [None]:
!pip install torch torchvision

In [3]:
def display_experiment(experiment_id, keys_to_filter, keys_to_keep, start_idx=0):
    expers = glob.glob('../gcloud_data/experiment_running/{}/*'.format(experiment_id))
    fmts=['rs-','bo-','k^-','gx-','cd-','mh-']
    counter=0
    for exper in expers:
        exper_name = os.path.basename(exper)
        pickle_files = glob.glob(exper + '/*.p')
        for file_id, pfile in enumerate(pickle_files):
            if np.any([filterkey in pfile for filterkey in keys_to_filter]):
                continue
            elif np.any([filterkey not in pfile for filterkey in keys_to_keep]):
                continue
            else:
                print("Loading file:",pfile)
            with open(pfile, 'rb') as f:
                print(pfile)
                manager = pickle.load(f)
                for idx, res in enumerate(manager.results):
                    exper_setting = manager.experiment_settings_list[idx]
                    avg_costs = np.average(res.actual_costs_all, axis=1)
                    std_costs = np.std(res.actual_costs_all, axis=1)
                    if "coordinate-halving" in pfile:
                        label = "TreeCH"
                    elif "delaunay" in pfile:
                        label = "TreeDP"
                    elif "alpha-star" in pfile:
                        label = "Genie"
                    elif "tree-results" in pfile:
                        label = "Tree results"
                    elif "uniform_constant_1" in pfile:
                        label = "Uniform1x"
                    elif "uniform_constant_2" in pfile:
                        label = "Uniform2x"
                    else:
                        assert False
                    plt.errorbar(avg_costs[start_idx:], res.vals_avg[start_idx:], xerr=std_costs[start_idx:], yerr=res.vals_std[start_idx:], fmt=fmts[counter], label=label)
                    counter+=1
    plt.xlabel("SGD Iteration budget")
    plt.ylabel("Test error")
    plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
    plt.show()

In [None]:
experiment_name = "<>"
display_experiment(experiment_name,['ACTUAL_MIXTURES'],[""])

In [None]:
def get_experiments(experiment_ids, keys_to_filter, keys_to_keep, start_idx=0):
    all_exp_res = {
        "TreeCH": {},
        "TreeDP": {},
        "Tree results": {},
        "Uniform1x": {},
        "Uniform2x": {}
    }
    for experiment_id in experiment_ids:
        expers = glob.glob('../gcloud_data/{}/*'.format(experiment_id))
        fmts=['rs-','bo-','k^-','gx-','cd-','mh-']
        counter=0
        for exper in expers:
            exper_name = os.path.basename(exper)
            pickle_files = glob.glob(exper + '/*.p')
            for file_id, pfile in enumerate(pickle_files):
                if np.any([filterkey in pfile for filterkey in keys_to_filter]):
                    continue
                elif np.any([filterkey not in pfile for filterkey in keys_to_keep]):
                    continue
                else:
                    print("Loading file:",pfile)
                with open(pfile, 'rb') as f:
                    print(pfile)
                    manager = pickle.load(f)
                    for idx, res in enumerate(manager.results):
                        exper_setting = manager.experiment_settings_list[idx]
                        if "coordinate-halving" in pfile:
                            label = "TreeCH"
                        elif "delaunay" in pfile:
                            label = "TreeDP"
                        elif "alpha-star" in pfile:
                            label = "Genie"
                        elif "tree-results" in pfile:
                            label = "Tree results"
                        elif "uniform_constant_1" in pfile:
                            label = "Uniform1x"
                        elif "uniform_constant_2" in pfile:
                            label = "Uniform2x"
                        else:
                            assert False
                        all_exp_res[label][experiment_id] = {
                            "actual_costs_all": res.actual_costs_all,
                            "vals_all": res.vals_all,
                            "fmt": fmts[counter]
                        }
                        counter+=1
    return all_exp_res

In [None]:
experiment_names = ['<>']

all_exp_res = get_experiments(experiment_names,['ACTUAL_MIXTURES'],[""])

In [None]:
start_idx=0
for label, info in all_exp_res.items():
    print(label)
    costs_all = None
    vals_all = None
    fmt = None
    for exp_id, data in info.items():
        fmt = data['fmt']
        costs_all = np.hstack((costs_all, np.array(data['actual_costs_all']))) if costs_all is not None else np.array(data['actual_costs_all'])
        print(costs_all.shape)
        vals_all = np.hstack((vals_all, np.array(data['vals_all']))) if vals_all is not None else np.array(data['vals_all'])
        print(vals_all.shape)
        
    print(np.average(costs_all, axis=1))
    avg_costs = np.average(costs_all, axis=1)
    std_costs = np.std(costs_all, axis=1)
    avg_vals = np.average(vals_all, axis=1)
    std_vals = np.std(vals_all, axis=1)
    plt.errorbar(avg_costs[start_idx:], avg_vals[start_idx:], xerr=std_costs[start_idx:], yerr=std_vals[start_idx:], fmt=fmt, label=label)
plt.xlabel("SGD Iteration budget")
plt.ylabel("Test error")
plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
plt.show()
