In [1]:
import pickle, os
import numpy as np
from matplotlib import pyplot as plt
import glob

In [None]:
!ls ../gcloud_data

In [None]:
!pip install torch torchvision

In [None]:
import torch

In [None]:
def get_experiments(experiment_ids, keys_to_filter, keys_to_keep, num_custom=0, record_alt_metrics=False, num_target_classes=0):
    all_exp_res = {
        "Mix&MatchCH": {},
        "Mix&MatchDP": {},
        "Mix&Match+1.0Step": {},
        "Genie": {},
        "Uniform": {},
        "Mix&MatchCH+0.1Step": {},
        "Mix&MatchDP+0.1Step": {},
        "Validation": {}
    }
    for i in range(num_custom):
        all_exp_res['Only'+str(i)] = {}
    for experiment_id in experiment_ids:
        expers = glob.glob('../gcloud_data/{}/*'.format(experiment_id))
        for exper in expers:
            exper_name = os.path.basename(exper)
            pickle_files = glob.glob(exper + '/*.p')
            for file_id, pfile in enumerate(pickle_files):
                if np.any([filterkey in pfile for filterkey in keys_to_filter]):
                    continue
                elif np.any([filterkey not in pfile for filterkey in keys_to_keep]):
                    continue
                else:
                    print("Loading file:",pfile)
                with open(pfile, 'rb') as f:
                    manager = pickle.load(f)
                    for idx, res in enumerate(manager.results):
                        exper_setting = manager.experiment_settings_list[idx]
                        if "coordinate-halving__True" in pfile:
                            label = "Mix&MatchCH+0.1Step"
                        elif "coordinate-halving" in pfile:
                            label = "Mix&MatchCH"
                        elif "delaunay-partitioning__True" in pfile:
                            label = "Mix&MatchDP+0.1Step"
                        elif "delaunay" in pfile:
                            label = "Mix&MatchDP"
                        elif "alpha-star" in pfile:
                            label = "Genie"
                        elif "validation" in pfile:
                            label = "Validation"
                        elif "tree-results" in pfile:
                            label = "Mix&Match+1.0Step"
                        elif "uniform_constant" in pfile:
                            label = "Uniform"
                        elif "constant-mixture_constant" in pfile:
                            for i in range(num_custom):
                                custom_id_list = ['0']*num_custom
                                custom_id_list[i] = '1'
                                custom_id = ','.join(custom_id_list)
                                if "custom_{}".format(custom_id) in pfile:
                                    label = "Only" + str(i)
                                    break
                        else:
                            print("Skipping:",pfile)
                            continue
                        conf_mats_all = []
                        validation_conf_mats_all = []
                        validation_vals_all = []
                        for nodes in res.best_sols_all:
                            conf_mats = []
                            validation_conf_mats = []
                            validation_vals = []
                            for node in nodes:
                                # Compute val
                                validation_vals.append(node.value.item())
                                
                                # Compute conf mat
                                if record_alt_metrics:
                                    conf_mat = np.zeros((num_target_classes, num_target_classes))
                                    validation_conf_mat = np.zeros((num_target_classes, num_target_classes))
                                    with torch.no_grad():
                                        for sample_batch, label_batch in node.mf_fn.test_dl:
                                            sample_view = sample_batch.view(sample_batch.shape[0], -1)
                                            _, preds = node.final_model(sample_view).max(1)
                                            for cl_i in range(num_target_classes):
                                                for cl_j in range(num_target_classes):
                                                    conf_mat[cl_i,cl_j] += ((preds == cl_i) & (label_batch == cl_j)).sum().item()
                                        for sample_batch, label_batch in node.mf_fn.validation_dl:
                                            sample_view = sample_batch.view(sample_batch.shape[0], -1)
                                            _, preds = node.final_model(sample_view).max(1)
                                            for cl_i in range(num_target_classes):
                                                for cl_j in range(num_target_classes):
                                                    validation_conf_mat[cl_i,cl_j] += ((preds == cl_i) & (label_batch == cl_j)).sum().item()
                                    conf_mats.append(conf_mat)
                                    validation_conf_mats.append(validation_conf_mat)
                            validation_vals_all.append(validation_vals)
                            if record_alt_metrics:
                                conf_mats_all.append(conf_mats)
                                validation_conf_mats_all.append(validation_conf_mats)
                        all_exp_res[label][experiment_id] = {
                            "actual_costs_all": res.actual_costs_all,
                            "vals_all": res.vals_all,
                            "validation_vals_all": validation_vals_all,
                            "conf_mats_all": conf_mats_all,
                            "validation_conf_mats_all": validation_conf_mats_all
                        }
    return all_exp_res

In [1]:
def plot_experiments(all_exp_res, start_idx=0, num_custom=0, keys_to_exclude=[], stop_idx_tree=None, alt_metric_to_plot=None, target_class_idx=None, plot_validation=False):
    # fmts=['rs-','bo-','k^-','gx-','cd-','mh-']
    fmt = {
        "Uniform": {
            "fmt": "^-",
            "color": "xkcd:black"
        },
        "Mix&Match+1.0Step": {
            "fmt":"o-",
            "color": "xkcd:sky blue"
        },
        "Genie": {
            "fmt": "s-",
            "color": "xkcd:coral"
        },
        "Validation": {
            "fmt": "P-",
            "color": "xkcd:violet"
        },
        "Mix&MatchCH": {
            "fmt": "x-",
            "color": "xkcd:lavender"
        },
        "Mix&MatchCH+0.1Step": {
            "fmt":"x-",
            "color": "xkcd:olive"
        },
        "Mix&MatchDP": {
            "fmt": "d-",
            "color": "xkcd:plum"
        },
        "Mix&MatchDP+0.1Step": {
            "fmt":"d-",
            "color": "xkcd:sienna"
        }
    }
    for i in range(num_custom):
        fmt['Only'+str(i)] = {
            "fmt": ".-",
            "color": "C{}".format(i+1)
        } 

    for label, info in all_exp_res.items():
        if info == {}: continue
        if np.any([key in label for key in keys_to_exclude]): continue
        # print(label)
        costs_all = None
        vals_all = None
        alt_metric_all = None
        fmt_l = fmt[label]
        for exp_id, data in info.items():
            costs_all = np.hstack((costs_all, np.array(data['actual_costs_all']))) if costs_all is not None else np.array(data['actual_costs_all'])
            # print(costs_all.shape)
            if plot_validation:
                vals_all = np.hstack((vals_all, np.array(data['validation_vals_all']))) if vals_all is not None else np.array(data['validation_vals_all'])
            else:
                vals_all = np.hstack((vals_all, np.array(data['vals_all']))) if vals_all is not None else np.array(data['vals_all'])
            # print(vals_all.shape)
            if alt_metric_to_plot:
                if plot_validation:
                    curr_conf_mats_all = np.array(data['validation_conf_mats_all'])
                else:
                    curr_conf_mats_all = np.array(data['conf_mats_all'])
                if "F" in alt_metric_to_plot:
                    # Compute precision
                    summ = np.sum(curr_conf_mats_all, axis=3)
                    diag = np.diagonal(curr_conf_mats_all, axis1=2, axis2=3)
                    precision = np.divide(diag, summ)
                    precision[np.isnan(precision)] = 0

                    # Compute recall
                    summ = np.sum(curr_conf_mats_all, axis=3)
                    diag = np.diagonal(curr_conf_mats_all, axis1=2, axis2=3)
                    recall = np.divide(diag, summ)
                    recall[np.isnan(recall)] = 0
                    
                    # Combine
                    curr_metric = 2 * np.divide(np.multiply(precision, recall), np.add(precision, recall))
                    curr_metric[np.isnan(curr_metric)] = 0
                    
                    if "avg" in alt_metric_to_plot:
                        curr_metric = np.average(curr_metric, axis=2)
                    else:
                        curr_metric = curr_metric[:,:,target_class_idx]
                elif "precision" in alt_metric_to_plot:
                    summ = np.sum(curr_conf_mats_all, axis=3)
                    diag = np.diagonal(curr_conf_mats_all, axis1=2, axis2=3)
                    curr_metric = np.divide(diag, summ)
                    curr_metric[np.isnan(curr_metric)] = 0
                    if "avg" in alt_metric_to_plot:
                        curr_metric = np.average(curr_metric, axis=2)
                    else:
                        curr_metric = curr_metric[:,:,target_class_idx]
                elif "recall" in alt_metric_to_plot:
                    summ = np.sum(curr_conf_mats_all, axis=2)
                    diag = np.diagonal(curr_conf_mats_all, axis1=2, axis2=3)
                    curr_metric = np.divide(diag, summ)
                    curr_metric[np.isnan(curr_metric)] = 0
                    if "avg" in alt_metric_to_plot:
                        curr_metric = np.average(curr_metric, axis=2)
                    else:
                        curr_metric = curr_metric[:,:,target_class_idx]
                else:
                    assert False
                alt_metric_all = np.hstack((alt_metric_all, curr_metric)) if alt_metric_all is not None else curr_metric

        if label == "Mix&Match+1.0Step":
            costs_all *= 2
        # print(np.average(costs_all, axis=1))
        avg_costs = np.average(costs_all, axis=1)
        std_costs = np.std(costs_all, axis=1)
        if alt_metric_to_plot:
            avg_vals = np.average(alt_metric_all, axis=1)
            std_vals = np.std(alt_metric_all, axis=1)
        else:
            avg_vals = np.average(vals_all, axis=1)
            std_vals = np.std(vals_all, axis=1)
        if "Mix&Match+1.0Step" in label:
            plt.errorbar(avg_costs[start_idx:stop_idx_tree], avg_vals[start_idx:stop_idx_tree], xerr=std_costs[start_idx:stop_idx_tree], yerr=std_vals[start_idx:stop_idx_tree], color=fmt_l['color'], fmt=fmt_l['fmt'], label=label)
        else:
            plt.errorbar(avg_costs[start_idx:], avg_vals[start_idx:], xerr=std_costs[start_idx:], yerr=std_vals[start_idx:], color=fmt_l['color'], fmt=fmt_l['fmt'], label=label)
    plt.xlabel("SGD Iteration budget")
    plt_type = "Validation" if plot_validation else "Test" 
    if alt_metric_to_plot:
        plt.ylabel(plt_type + " " + alt_metric_to_plot)
    else:
        plt.ylabel(plt_type + " " + "error")
    plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
    plt.show()

SyntaxError: invalid syntax (<ipython-input-1-3ae97401a000>, line 18)

In [None]:
!ls ../gcloud_data

In [None]:
experiment_names = ['<>']
num_custom = 4
num_target_classes = 4
record_alt_metrics = True

all_exp_res = get_experiments(experiment_names,['ACTUAL_MIXTURES'],[""], num_custom=num_custom, record_alt_metrics=record_alt_metrics, num_target_classes=num_target_classes)


In [None]:
start_idx = 0
num_custom=4
keys_to_exclude=[]
num_target_classes = 4
stop_idx_tree = 7
plot_validation = False
for i in range(num_target_classes):
    plot_experiments(all_exp_res, start_idx=start_idx, num_custom=num_custom, keys_to_exclude=keys_to_exclude, alt_metric_to_plot="precision", target_class_idx=i, stop_idx_tree=stop_idx_tree, plot_validation=plot_validation)
    plot_experiments(all_exp_res, start_idx=start_idx, num_custom=num_custom, keys_to_exclude=keys_to_exclude, alt_metric_to_plot="recall", target_class_idx=i, stop_idx_tree=stop_idx_tree, plot_validation=plot_validation)
    plot_experiments(all_exp_res, start_idx=start_idx, num_custom=num_custom, keys_to_exclude=keys_to_exclude, alt_metric_to_plot="F", target_class_idx=i, stop_idx_tree=stop_idx_tree, plot_validation=plot_validation)


In [None]:
start_idx = 0
num_custom=4
keys_to_exclude=[]
alt_metric_to_plot=""
target_class_idx = 0
stop_idx_tree = 7
plot_validation = False
plot_experiments(all_exp_res, start_idx=start_idx, num_custom=num_custom, keys_to_exclude=keys_to_exclude, alt_metric_to_plot=alt_metric_to_plot, target_class_idx=target_class_idx, stop_idx_tree=stop_idx_tree, plot_validation=plot_validation)
