In [1]:
import pickle, os
import numpy as np
from matplotlib import pyplot as plt
import glob

In [None]:
!ls ../gcloud_data

In [None]:
!pip install torch torchvision

In [None]:
import torch

In [None]:
def get_experiments(experiment_ids, keys_to_filter, keys_to_keep, num_custom=0):
    all_exp_res = {
        "Mix&MatchCH": {},
        "Mix&MatchDP": {},
        "Mix&Match+1.0Step": {},
        "Genie": {},
        "Uniform": {},
        "Mix&MatchCH+0.1Step": {},
        "Mix&MatchDP+0.1Step": {}
    }
    for i in range(num_custom):
        all_exp_res['Only'+str(i)] = {}
    for experiment_id in experiment_ids:
        expers = glob.glob('../gcloud_data/{}/*'.format(experiment_id))
        for exper in expers:
            exper_name = os.path.basename(exper)
            pickle_files = glob.glob(exper + '/*.p')
            for file_id, pfile in enumerate(pickle_files):
                if np.any([filterkey in pfile for filterkey in keys_to_filter]):
                    continue
                elif np.any([filterkey not in pfile for filterkey in keys_to_keep]):
                    continue
                else:
                    print("Loading file:",pfile)
                with open(pfile, 'rb') as f:
                    manager = pickle.load(f)
                    for idx, res in enumerate(manager.results):
                        exper_setting = manager.experiment_settings_list[idx]
                        if "coordinate-halving__True" in pfile:
                            label = "Mix&MatchCH+0.1Step"
                        elif "coordinate-halving" in pfile:
                            label = "Mix&MatchCH"
                        elif "delaunay-partitioning__True" in pfile:
                            label = "Mix&MatchDP+0.1Step"
                        elif "delaunay" in pfile:
                            label = "Mix&MatchDP"
                        elif "alpha-star" in pfile:
                            label = "Genie"
                        elif "tree-results" in pfile:
                            label = "Mix&Match+1.0Step"
                        elif "uniform_constant" in pfile:
                            label = "Uniform"
                        elif "constant-mixture_constant" in pfile:
                            for i in range(num_custom):
                                custom_id_list = ['0']*num_custom
                                custom_id_list[i] = '1'
                                custom_id = ','.join(custom_id_list)
                                if "custom_{}".format(custom_id) in pfile:
                                    label = "Only" + str(i)
                                    break
                        else:
                            print("Skipping:",pfile)
                            continue
                        accuracies_all = []
                        for nodes in res.best_sols_all:
                            accuracies = []
                            for node in nodes:
                                acc = 0
                                with torch.no_grad():
                                    for sample_batch, label_batch in node.mf_fn.test_dl:
                                        sample_view = sample_batch.view(sample_batch.shape[0], -1)
                                        _, preds = node.final_model(sample_view).max(1)
                                        acc += (preds == label_batch).sum().item()
                                accuracies.append(acc / len(node.mf_fn.test_dataset))
                            accuracies_all.append(accuracies)
                        all_exp_res[label][experiment_id] = {
                            "actual_costs_all": res.actual_costs_all,
                            "vals_all": res.vals_all,
                            "accuracies_all": accuracies_all
                        }
    return all_exp_res

In [None]:
def plot_experiments(all_exp_res, start_idx=0, keys_to_exclude=[], stop_idx_tree=None, plot_accuracy=False):
    # fmts=['rs-','bo-','k^-','gx-','cd-','mh-']
    fmt = {
        "Uniform": {
            "fmt": "^-",
            "color": "xkcd:black"
        },
        "Mix&Match+1.0Step": {
            "fmt":"o-",
            "color": "xkcd:sky blue"
        },
        "Genie": {
            "fmt": "s-",
            "color": "xkcd:coral"
        },
        "Mix&MatchCH": {
            "fmt": "x-",
            "color": "xkcd:lavender"
        },
        "Mix&MatchCH+0.1Step": {
            "fmt":"x-",
            "color": "xkcd:olive"
        },
        "Mix&MatchDP": {
            "fmt": "d-",
            "color": "xkcd:plum"
        },
        "Mix&MatchDP+0.1Step": {
            "fmt":"d-",
            "color": "xkcd:sienna"
        }
    }
    for i in range(num_custom):
        fmt['Only'+str(i)] = {
            "fmt": ".-",
            "color": "C{}".format(i+1)
        } 

    for label, info in all_exp_res.items():
        if info == {}: continue
        if np.any([key in label for key in keys_to_exclude]): continue
        print(label)
        costs_all = None
        vals_all = None
        accuracies_all = None
        fmt_l = fmt[label]
        for exp_id, data in info.items():
            costs_all = np.hstack((costs_all, np.array(data['actual_costs_all']))) if costs_all is not None else np.array(data['actual_costs_all'])
            print(costs_all.shape)
            vals_all = np.hstack((vals_all, np.array(data['vals_all']))) if vals_all is not None else np.array(data['vals_all'])
            print(vals_all.shape)
            accuracies_all = np.hstack((accuracies_all, np.array(data['accuracies_all']))) if accuracies_all is not None else np.array(data['accuracies_all'])
            print(accuracies_all.shape)
            
        if label == "Mix&Match+1.0Step":
            costs_all *= 2
        print(np.average(costs_all, axis=1))
        avg_costs = np.average(costs_all, axis=1)
        std_costs = np.std(costs_all, axis=1)
        if plot_accuracy:
            avg_vals = np.average(accuracies_all, axis=1)
            std_vals = np.std(accuracies_all, axis=1)
        else:
            avg_vals = np.average(vals_all, axis=1)
            std_vals = np.std(vals_all, axis=1)
        if "Mix&Match+1.0Step" in label:
            plt.errorbar(avg_costs[start_idx:stop_idx_tree], avg_vals[start_idx:stop_idx_tree], xerr=std_costs[start_idx:stop_idx_tree], yerr=std_vals[start_idx:stop_idx_tree], color=fmt_l['color'], fmt=fmt_l['fmt'], label=label)
        else:
            plt.errorbar(avg_costs[start_idx:], avg_vals[start_idx:], xerr=std_costs[start_idx:], yerr=std_vals[start_idx:], color=fmt_l['color'], fmt=fmt_l['fmt'], label=label)
    plt.xlabel("SGD Iteration budget")
    if plot_accuracy:
        plt.ylabel("Test accuracy")
    else:
        plt.ylabel("Test error")
    plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
    plt.show()

In [None]:
!ls ../gcloud_data

In [None]:
experiment_names = ['<>']
num_custom = 4

all_exp_res = get_experiments(experiment_names,['ACTUAL_MIXTURES'],[""], num_custom=num_custom)


In [None]:
start_idx = 0
keys_to_exclude=[]
plot_experiments(all_exp_res, start_idx=start_idx, keys_to_exclude=keys_to_exclude)

