In [1]:
import os
import dill as pickle
import numpy as np
from matplotlib import pyplot as plt
import glob
from sklearn.metrics import precision_recall_fscore_support as prfs

In [None]:
!ls ../gcloud_data

In [None]:
!pip install torch torchvision

In [None]:
import torch

In [None]:
def get_experiments(experiment_ids, keys_to_filter, keys_to_keep, num_custom=0):
    all_exp_res = {
        "Mix&MatchCH": {},
        "Mix&MatchDP": {},
        "Mix&Match+1.0Step": {},
        "Genie": {},
        "Uniform": {},
        "IW-Uniform": {},
        "IW-ERM": {},
        "MMD": {},
        "Mix&MatchCH+0.1Step": {},
        "Mix&MatchDP+0.1Step": {},
        "Validation": {}
    }
    for i in range(num_custom):
        all_exp_res['Only'+str(i)] = {}
    for experiment_id in experiment_ids:
        expers = glob.glob('../gcloud_data/{}/*'.format(experiment_id))
        for exper in expers:
            exper_name = os.path.basename(exper)
            pickle_files = glob.glob(exper + '/*.p')
            for file_id, pfile in enumerate(pickle_files):
                if np.any([filterkey in pfile for filterkey in keys_to_filter]):
                    continue
                elif np.any([filterkey not in pfile for filterkey in keys_to_keep]):
                    continue
                else:
                    print("Loading file:",pfile)
                with open(pfile, 'rb') as f:
                    manager = pickle.load(f)
                    for idx, res in enumerate(manager.results):
                        if "coordinate-halving__True" in pfile:
                            label = "Mix&MatchCH+0.1Step"
                        elif "coordinate-halving" in pfile:
                            label = "Mix&MatchCH"
                        elif "delaunay-partitioning__True" in pfile:
                            label = "Mix&MatchDP+0.1Step"
                        elif "delaunay" in pfile:
                            label = "Mix&MatchDP"
                        elif "alpha-star" in pfile:
                            label = "Genie"
                        elif "validation" in pfile:
                            label = "Validation"
                        elif "tree-results" in pfile:
                            label = "Mix&Match+1.0Step"
                        elif "importance-weighted-uniform_constant" in pfile:
                            label = "IW-Uniform"
                        elif "uniform_constant" in pfile:
                            label = "Uniform"
                        elif "importance-weighted-erm_constant" in pfile:
                            label = "IW-ERM"
                        elif "mmd_constant" in pfile:
                            label = "MMD"
                        elif "custom" in pfile:
                            for i in range(num_custom):
                                custom_id_list = ['0.0']*num_custom
                                custom_id_list[i] = '1.0'
                                custom_id = ','.join(custom_id_list)
                                if "custom_{}".format(custom_id) in pfile:
                                    label = "Only" + str(i)
                                    break
                        else:
                            print("Skipping:",pfile)
                            continue
                        validation_mf_fn_results_all = []
                        for nodes in res.best_sols_all:
                            validation_mf_fn_results = []
                            for node in nodes:
                                # Compute val
                                validation_mf_fn_results.append(node.validation_mf_fn_results)
                                
                            validation_mf_fn_results_all.append(validation_mf_fn_results)
                        all_exp_res[label][experiment_id] = {
                            "train_data": manager.experiment_configurer.dataset_config.data.train,
                            "actual_costs_all": res.actual_costs_all,
                            "test_mf_fn_results_all": res.mf_fn_results_all,
                            "validation_mf_fn_results_all": validation_mf_fn_results_all,
                            "recorders_all": res.recorders_all[0],
                        }
    return all_exp_res
    

In [None]:
def _get_costs_and_vals(costs_all, mf_fn_results_all, alt_metric_to_include, label):
    if label == "Mix&Match+1.0Step":
        costs_all *= 2
    avg_costs = np.average(costs_all, axis=1)
    std_costs = np.std(costs_all, axis=1)
    
    num_classes = len(mf_fn_results_all[0,0].precision) if alt_metric_to_include else 0
    if alt_metric_to_include == "precision":
        f = lambda i: np.vectorize(lambda mf_fn_result: mf_fn_result.precision[i])
    elif alt_metric_to_include == "recall":
        f = lambda i: np.vectorize(lambda mf_fn_result: mf_fn_result.recall[i])
    elif alt_metric_to_include == "F":
        f = lambda i: np.vectorize(lambda mf_fn_result: mf_fn_result.f1[i])
    elif alt_metric_to_include == "support":
        f = lambda i: np.vectorize(lambda mf_fn_result: mf_fn_result.support[i])
    elif alt_metric_to_include == "auc_roc_ovo":
        f = np.vectorize(lambda mf_fn_result: mf_fn_result.auc_roc_ovo)
    elif alt_metric_to_include == "auc_roc_ovr":
        f = np.vectorize(lambda mf_fn_result: mf_fn_result.auc_roc_ovr)
    else:
        f = np.vectorize(lambda mf_fn_result: mf_fn_result.error)
        
    if alt_metric_to_include and "auc_roc" not in alt_metric_to_include:
        vals_all = np.stack([f(i)(mf_fn_results_all) for i in range(num_classes)], axis=-1)
    else:
        vals_all = f(mf_fn_results_all)
    avg_vals = np.average(vals_all, axis=1)
    std_vals = np.std(vals_all, axis=1)
    
    return avg_costs, std_costs, avg_vals, std_vals


In [None]:
def _get_costs_and_vals_from_recorders(recorders_all, alt_metric_to_include, show_validation, label):
    costs_all = None
    mf_fn_results_all = None
    for recorder in recorders_all:
        costs_all = np.append(costs_all, np.array(recorder.results.costs)[:,None], axis=1) if costs_all is not None else np.array(recorder.results.costs)[:,None]
        mf_fn_results = recorder.results.val_mf_fn_results if show_validation else recorder.results.test_mf_fn_results
        mf_fn_results = [res[0] for res in mf_fn_results]
        mf_fn_results_all = np.append(mf_fn_results_all, np.array(mf_fn_results)[:,None], axis=1) if mf_fn_results_all is not None else np.array(mf_fn_results)[:,None]
        
    return _get_costs_and_vals(costs_all=costs_all,
                               mf_fn_results_all=mf_fn_results_all,
                               alt_metric_to_include=alt_metric_to_include, 
                               label=label)

In [None]:
def get_formats(num_custom):
    fmt = {
        "Uniform": {
            "fmt": "^-",
            "color": "xkcd:black"
        },
        "IW-Uniform": {
            "fmt": "|-",
            "color": "xkcd:electric green"
        },
        "IW-ERM": {
            "fmt": "_-",
            "color": "xkcd:very light green"
        },
        "MMD": {
            "fmt": "p-",
            "color": "xkcd:powder pink"
        },
        "Mix&Match+1.0Step": {
            "fmt":"o-",
            "color": "xkcd:sky blue"
        },
        "Genie": {
            "fmt": "s-",
            "color": "xkcd:coral"
        },
        "Validation": {
            "fmt": "P-",
            "color": "xkcd:violet"
        },
        "Mix&MatchCH": {
            "fmt": "x-",
            "color": "xkcd:lavender"
        },
        "Mix&MatchCH+0.1Step": {
            "fmt":"x-",
            "color": "xkcd:olive"
        },
        "Mix&MatchDP": {
            "fmt": "d-",
            "color": "xkcd:plum"
        },
        "Mix&MatchDP+0.1Step": {
            "fmt":"d-",
            "color": "xkcd:sienna"
        }
    }
    for i in range(num_custom):
        fmt['Only'+str(i)] = {
            "fmt": ".-",
            "color": "C{}".format(i+1)
        } 
    return fmt



In [None]:
def plot_experiments(all_exp_res, start_idx=0, num_custom=0, keys_to_exclude=[], stop_idx_tree=None, alt_metric_to_plot=None, target_class_idx=None, plot_validation=False, label_mapping={}, recorded_or_final_results='recorded'):
    # fmts=['rs-','bo-','k^-','gx-','cd-','mh-']
    fmt = get_formats(num_custom=num_custom)

    print("Algorithm & SDG Iteration Budget & Average Error $\pm$ $1$ std dev")
    for label, info in all_exp_res.items():
        if info == {}: continue
        if np.any([key in label for key in keys_to_exclude]): continue
        # print(label)
        costs_all = None
        mf_fn_results_all = None
        recorders_all = None
        alt_metric_all = None
        fmt_l = fmt[label]
        for exp_id, data in info.items():
            costs_all = np.hstack((costs_all, np.array(data['actual_costs_all']))) if costs_all is not None else np.array(data['actual_costs_all'])
            recorders_all = np.hstack((recorders_all, np.array([data['recorders_all']]))) if recorders_all is not None else np.array(data['recorders_all'])
            # print(costs_all.shape)
            mf_fn_key = "{}_mf_fn_results_all".format("validation" if plot_validation else "test")
            mf_fn_results_all = np.concatenate((mf_fn_results_all, np.array(data[mf_fn_key])), axis=1) if mf_fn_results_all is not None else np.array(data[mf_fn_key])

        # print(np.average(costs_all, axis=1))
        if recorded_or_final_results == 'recorded':
            avg_costs, std_costs, avg_vals, std_vals = _get_costs_and_vals_from_recorders(recorders_all, alt_metric_to_plot, plot_validation, label)
        elif recorded_or_final_results == 'final':
            avg_costs, std_costs, avg_vals, std_vals = _get_costs_and_vals(costs_all, mf_fn_results_all, alt_metric_to_plot, label)
        else:
            print('Recorded_or_final_results argument {} is invalid. Please set either recorded or final.')
            assert False
        if alt_metric_to_plot and "auc" not in alt_metric_to_plot:
            avg_vals = avg_vals[:,target_class_idx]
            std_vals = std_vals[:,target_class_idx]
        
        display_label = label if label not in label_mapping.keys() else label_mapping[label]
        
        if "Mix&Match+1.0Step" in label:
            plt.errorbar(avg_costs[start_idx:stop_idx_tree], avg_vals[start_idx:stop_idx_tree], xerr=std_costs[start_idx:stop_idx_tree], yerr=std_vals[start_idx:stop_idx_tree], color=fmt_l['color'], fmt=fmt_l['fmt'], label=display_label)
        else:
            plt.errorbar(avg_costs[start_idx:], avg_vals[start_idx:], xerr=std_costs[start_idx:], yerr=std_vals[start_idx:], color=fmt_l['color'], fmt=fmt_l['fmt'], label=display_label)
    plt.xlabel("SGD Iteration budget")
    plt_type = "Validation" if plot_validation else "Test" 
    if alt_metric_to_plot:
        plt.ylabel(plt_type + " " + alt_metric_to_plot)
    else:
        plt.ylabel(plt_type + " " + "error")
    plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
    plt.show()

In [1]:
def print_table(all_exp_res, keys_to_exclude=[], num_classes=0, label_mapping={}, idx_to_display=0, include_cost=True, alt_metric_to_include="", caption="CAPTION HERE", table_label="DERP", show_validation=False):
    print("\\begin{table}[h!]")
    print("\\centering")
    print("\\begin{tabular}{ "\
          + "c|{}{}".format("c|" if include_cost else "",
                                "|".join(["c" for i in range(num_classes)]) if alt_metric_to_include and "auc_roc" not in alt_metric_to_include else "c")\
          + " }")
    print("\\hline")
    print("Algorithm & {}{} \\\\".format(
        "Average Cost & " if include_cost else "",
        "Average {}".format(alt_metric_to_include.replace('_','\_') if alt_metric_to_include else "Error")
    ))
    print("\\hline")
    for label, info in all_exp_res.items():
        if info == {}: continue
        if np.any([key in label for key in keys_to_exclude]): continue
        costs_all = None
        mf_fn_results_all = None
        for exp_id, data in info.items():
            # Get costs
            costs_all = np.hstack((costs_all, np.array(data['actual_costs_all']))) if costs_all is not None else np.array(data['actual_costs_all'])
            # Get error
            mf_res_key = "{}_mf_fn_results_all".format("validation" if show_validation else "test")
            mf_fn_results_all = np.hstack((mf_fn_results_all, np.array(data[mf_res_key]))) if mf_fn_results_all is not None else np.array(data[mf_res_key])
            
        # Compute the costs
        avg_costs, _, avg_vals, std_vals = _get_costs_and_vals(costs_all, mf_fn_results_all, alt_metric_to_include, label)
        
        display_label = label if label not in label_mapping.keys() else label_mapping[label]
        
        # Algo & Avg Err +/- std & Cl1 Pr +/- std, Re +/- std & Cl2 Pr +/- std, Re +/- std & ...
        if alt_metric_to_include != "" and "auc_roc" not in alt_metric_to_include:
            table_row_vals = " & ".join(["${:0.2f} \pm {:0.2f}$".format(avg_vals[idx_to_display, cl], std_vals[idx_to_display, cl]) for cl in range(num_classes)])
        else:
            table_row_vals = "${:0.4f} \pm {:0.4f}$".format(avg_vals[idx_to_display], std_vals[idx_to_display])
        print("{} & {} \\\\".format(display_label.replace("&","\\&"),
                                          # "${:0.0f}$".format(avg_costs[idx_to_display]) if include_cost else "",
                                          table_row_vals))
                                          
    print("\\hline")
    print("\\end{tabular}")
    print("\\caption{" + caption + "}")
    print("\\label{tab:" + table_label + "}")
    print("\\end{table}")
            

In [None]:
!ls ../gcloud_data

In [None]:
num_custom = 4
num_target_classes = 4
label_mapping = {
    "Only0": "OnlyUS",
    "Only1": "OnlyFrance",
    "Only2": "OnlyItaly",
    "Only3": "OnlySpain",
}
# num_custom = 4
# num_target_classes = 2
# label_mapping = {
#     "Only0": "Only117878",
#     "Only1": "Only117941",
#     "Only2": "Only117945",
#     "Only3": "Only117920",
# }
# num_custom = 3
# num_target_classes = 4
# label_mapping = {
#     "Only0": "OnlyFrance",
#     "Only1": "OnlyItaly",
#     "Only2": "OnlySpain",
# }
# num_custom = 3
# num_target_classes = 4
# label_mapping = {
#     "Only0": "OnlyFL",
#     "Only1": "OnlyCT",
#     "Only2": "OnlyOH",
# }
# num_custom = 3
# num_target_classes = 2
# label_mapping = {
#     "Only0": "Only0.1",
#     "Only1": "Only0.2",
#     "Only2": "Only0.7",
# }

In [None]:
experiment_names = ['<>']
record_alt_metrics = True
# record_alt_metrics = False

all_exp_res = get_experiments(experiment_names,['ACTUAL_MIXTURES','PRETRAINED_UNIFORM_BASELINE','INDIVIDUAL_SRC_BASELINE'],[""], num_custom=num_custom)


In [None]:
start_idx = 0
# num_custom=4
keys_to_exclude=[]
# num_target_classes = 4
stop_idx_tree = 7
plot_validation = True
recorded_or_final_results='recorded'
plot_experiments(all_exp_res, start_idx=start_idx, num_custom=num_custom, keys_to_exclude=keys_to_exclude, alt_metric_to_plot="auc_roc_ovo", target_class_idx=0, stop_idx_tree=stop_idx_tree, plot_validation=plot_validation, label_mapping=label_mapping, recorded_or_final_results=recorded_or_final_results)
plot_experiments(all_exp_res, start_idx=start_idx, num_custom=num_custom, keys_to_exclude=keys_to_exclude, alt_metric_to_plot="auc_roc_ovr", target_class_idx=0, stop_idx_tree=stop_idx_tree, plot_validation=plot_validation, label_mapping=label_mapping, recorded_or_final_results=recorded_or_final_results)
for i in range(num_target_classes):
    plot_experiments(all_exp_res, start_idx=start_idx, num_custom=num_custom, keys_to_exclude=keys_to_exclude, alt_metric_to_plot="precision", target_class_idx=i, stop_idx_tree=stop_idx_tree, plot_validation=plot_validation, label_mapping=label_mapping, recorded_or_final_results=recorded_or_final_results)
    plot_experiments(all_exp_res, start_idx=start_idx, num_custom=num_custom, keys_to_exclude=keys_to_exclude, alt_metric_to_plot="recall", target_class_idx=i, stop_idx_tree=stop_idx_tree, plot_validation=plot_validation, label_mapping=label_mapping, recorded_or_final_results=recorded_or_final_results)
    plot_experiments(all_exp_res, start_idx=start_idx, num_custom=num_custom, keys_to_exclude=keys_to_exclude, alt_metric_to_plot="F", target_class_idx=i, stop_idx_tree=stop_idx_tree, plot_validation=plot_validation, label_mapping=label_mapping, recorded_or_final_results=recorded_or_final_results)
    plot_experiments(all_exp_res, start_idx=start_idx, num_custom=num_custom, keys_to_exclude=keys_to_exclude, alt_metric_to_plot="support", target_class_idx=i, stop_idx_tree=stop_idx_tree, plot_validation=plot_validation, label_mapping=label_mapping, recorded_or_final_results=recorded_or_final_results)

In [None]:
start_idx = 0
# num_custom=4
keys_to_exclude=[]
# num_target_classes = 4
stop_idx_tree = 7
plot_validation = False
recorded_or_final_results='recorded'
plot_experiments(all_exp_res, start_idx=start_idx, num_custom=num_custom, keys_to_exclude=keys_to_exclude, alt_metric_to_plot="auc_roc_ovo", target_class_idx=0, stop_idx_tree=stop_idx_tree, plot_validation=plot_validation, label_mapping=label_mapping, recorded_or_final_results=recorded_or_final_results)
plot_experiments(all_exp_res, start_idx=start_idx, num_custom=num_custom, keys_to_exclude=keys_to_exclude, alt_metric_to_plot="auc_roc_ovr", target_class_idx=0, stop_idx_tree=stop_idx_tree, plot_validation=plot_validation, label_mapping=label_mapping, recorded_or_final_results=recorded_or_final_results)
for i in range(num_target_classes):
    plot_experiments(all_exp_res, start_idx=start_idx, num_custom=num_custom, keys_to_exclude=keys_to_exclude, alt_metric_to_plot="precision", target_class_idx=i, stop_idx_tree=stop_idx_tree, plot_validation=plot_validation, label_mapping=label_mapping, recorded_or_final_results=recorded_or_final_results)
    plot_experiments(all_exp_res, start_idx=start_idx, num_custom=num_custom, keys_to_exclude=keys_to_exclude, alt_metric_to_plot="recall", target_class_idx=i, stop_idx_tree=stop_idx_tree, plot_validation=plot_validation, label_mapping=label_mapping, recorded_or_final_results=recorded_or_final_results)
    plot_experiments(all_exp_res, start_idx=start_idx, num_custom=num_custom, keys_to_exclude=keys_to_exclude, alt_metric_to_plot="F", target_class_idx=i, stop_idx_tree=stop_idx_tree, plot_validation=plot_validation, label_mapping=label_mapping, recorded_or_final_results=recorded_or_final_results)
    plot_experiments(all_exp_res, start_idx=start_idx, num_custom=num_custom, keys_to_exclude=keys_to_exclude, alt_metric_to_plot="support", target_class_idx=i, stop_idx_tree=stop_idx_tree, plot_validation=plot_validation, label_mapping=label_mapping, recorded_or_final_results=recorded_or_final_results)


In [None]:
start_idx = 0
# num_custom=4
keys_to_exclude=[]
alt_metric_to_plot=""
target_class_idx = 0
stop_idx_tree = 7
plot_validation = False
recorded_or_final_results='recorded'
plot_experiments(all_exp_res, start_idx=start_idx, num_custom=num_custom, keys_to_exclude=keys_to_exclude, alt_metric_to_plot=alt_metric_to_plot, target_class_idx=target_class_idx, stop_idx_tree=stop_idx_tree, plot_validation=plot_validation, label_mapping=label_mapping, recorded_or_final_results=recorded_or_final_results)

In [None]:
experiment_name = '<>'
nbins = 20
title=''
training_data = all_exp_res['IW-Uniform'][experiment_name]['train_data']
iw_column = training_data['importance_weights']
iw_column.plot.hist(cumulative=True, bins=nbins,alpha=0.5)
plt.title(title)
plt.show()

In [None]:
caption="CAPTION"
lab="DERP"
include_cost=False
alt_metric_to_include=""
# alt_metric_to_include="precision"
# alt_metric_to_include="recall"
# alt_metric_to_include="f1"
# alt_metric_to_include="support"
# alt_metric_to_include="auc_roc_ovo"
print_vals_at_idx = -1
show_validation=False
print_table(all_exp_res, keys_to_exclude=keys_to_exclude, num_classes=num_target_classes, label_mapping=label_mapping, idx_to_display=print_vals_at_idx, alt_metric_to_include=alt_metric_to_include, include_cost=include_cost, caption=caption, table_label=lab, show_validation=show_validation)
