# Plots for experiments

In [2]:
%load_ext autoreload
%autoreload 2

import os
os.chdir("../")
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import data
import utils

%matplotlib inline


In [32]:
sns.set(font_scale=2.0)
np.random.seed(1)
np.set_printoptions(precision=6, suppress=True)
plot_height, legend_size = 10, 18
marker_size, line_width = 4.0, 1.5
eps_format_dict = {'0.300': '0.3',
                   '0.050': '0.05',
                   '0.025': '0.025',
                   '0.100': '0.1',
                   '0.031': '8/255'}

datasets = ['breast_cancer', 'diabetes', 'cod_rna', 'mnist_1_5', 'mnist_2_6', 'fmnist_sandal_sneaker', 'gts_100_roadworks', 'gts_30_70', 'har', 'ijcnn1']
models = ['plain', 'at_cube', 'robust_bound', 'robust_exact']
exp_folder = 'exps_diff_depth'  # exps_diff_depth
weak_learner = 'tree'
tree_depth = '4'
model_names = utils.get_model_names(datasets, models, exp_folder, weak_learner, tree_depth)

flag_plot = False
flag_latex = True
flag_n_trees_latex = True if weak_learner == 'tree' else False
flag_pruning_stats = False 
 
latex_table, latex_str = '', ''
for i, model_name in enumerate(model_names):
    dataset = model_name.split('dataset=')[1].split(' ')[0]
    model = model_name.split('model=')[1].split(' ')[0]
    eps = model_name.split('eps=')[1].split(' ')[0]
    max_depth = model_name.split('max_depth=')[1].split(' ')[0]
    print('Model (depth={}): {}'.format(max_depth, model_name))
    
    metrics_path = model_name + '.metrics'
    metrics = np.loadtxt(exp_folder + '/' + metrics_path)
    
    if metrics.shape[1] < 10:
        print('An old model encountered! Just skipping.')
        continue
    
    # needed for plots
    iters = metrics[:, 0]
    test_errs, test_adv_errs = metrics[:, 1], metrics[:, 3]
    train_errs, train_adv_errs = metrics[:, 5], metrics[:, 6]
    train_losses = metrics[:, 7]
    valid_errs, valid_adv_errs_lb, valid_adv_errs = metrics[:, 8], metrics[:, 9], metrics[:, 10]
    
    # Model selection is done
    if model == 'plain': 
        iter_to_print = np.argmin(valid_errs)
    elif model in ['at_cube', 'robust_bound', 'robust_exact']:
        # note that `da_uniform` models are mostly taken from first iterations (unless one takes them by TE)
        # iter_to_print = np.argmin((valid_errs + valid_adv_errs)/2)
        iter_to_print = np.argmin(valid_adv_errs)
    else:
        raise ValueError('wrong model name')
    
    # TODO: the last entries have to be revisited; I added a time_cert_test and removed depths/n_nodes before pruning
    # needed to print it directly or for latex table
    last_iter, n_iter_done, time_total = int(metrics[iter_to_print, 0]), len(metrics[:, 0]), metrics[-1, 12]
    test_err, test_adv_err_lb, test_adv_err, test_adv_err_ub = metrics[iter_to_print, 1:5]
    train_err, train_adv_err, train_loss = metrics[iter_to_print, 5:8]
    valid_err, valid_adv_err_lb, valid_adv_err, valid_adv_err_ub = metrics[iter_to_print, 8:12]
    time_cert = metrics[iter_to_print, 13]

    test_str = 'iter: {}/{}  [test] err {:.2%} adv_err_lb {:.2%} adv_err {:.2%}  adv_err_ub {:.2%}'.format(
        last_iter, n_iter_done, test_err, test_adv_err_lb, test_adv_err, test_adv_err_ub)
    valid_str = '[valid] err {:.2%} adv_err_lb {:.2%} adv_err {:.2%}'.format(
        valid_err, valid_adv_err_lb, valid_adv_err)
    train_str = '[train] err: {:.2%}  adv_err: {:.2%}  loss: {:.5f}'.format(
        train_err, train_adv_err, train_loss)
    pruning_str = ''
    if flag_pruning_stats:
        d_before, d_after, nodes_before, nodes_after = metrics[:iter_to_print+1, 13:17].mean(0)
        pruning_str = ' | depth {:.2f}->{:.2f} nodes {:.2f}->{:.2f}'.format(d_before, d_after, nodes_before, nodes_after)
    print('{}  |  {}  |  {} {} (cert {:.3f}s, total {:.2f} min)'.format(test_str, valid_str, train_str, pruning_str, 
                                                                        time_cert, time_total/60))
    # form the latex table
    if flag_latex:
        if model == 'plain':
            latex_str += '{} & {} &  '.format(data.dataset_names_dict[dataset], eps)
        if weak_learner == 'stump':
            latex_str += '{:.1f} & {:.1f} & {:.1f}'.format(
            test_err*100, test_adv_err*100, test_adv_err_ub*100)
        else:
            latex_str += '{:.1f} & {:.1f} & {:.1f}'.format(
            test_err*100, test_adv_err_lb*100, test_adv_err_ub*100)
        
        if flag_n_trees_latex:  # add the number of trees
            latex_str += ' & {}'.format(last_iter)
        
        # if the last column of a block
        if weak_learner == 'stump' and model == 'robust_exact' or \
            weak_learner == 'tree' and model == 'robust_bound':
            curr_row_final = utils.finalize_curr_row(latex_str, weak_learner, flag_n_trees_latex)
            latex_table += curr_row_final
            latex_str = ''  # re-initialize to an empty string
        else:
            latex_str += ' & '  
    
    if flag_plot:
        plot_name_short = '{}-{}'.format(dataset, model)
        plot_name_long = 'dataset={}-model={}-iter={}'.format(dataset, model, last_iter)
        fig, axs = plt.subplots(1, 3, figsize=(3*plot_height, plot_height)) # sharex=True, sharey=True
    
        axs[0].plot(iters, test_errs, label='test error', linestyle='solid', linewidth=line_width, marker='o', markersize=marker_size)
        axs[0].plot(iters, test_adv_errs, label='test adv error', linestyle='solid', linewidth=line_width, marker='o', markersize=marker_size)
        axs[0].plot(iters, valid_errs, label='valid error', linestyle='solid', linewidth=line_width, marker='o', markersize=marker_size)
        axs[0].plot(iters, valid_adv_errs, label='valid adv error', linestyle='solid', linewidth=line_width, marker='o', markersize=marker_size)
        axs[0].set_yticklabels(['{:.0%}'.format(x) for x in axs[0].get_yticks()])
        axs[0].set_xlabel('iteration')
        axs[0].set_ylabel('test error')
        # prec = 1 if np.round(test_adv_errs.max() - test_errs.min(), 1) != 0.0 else 3
        # y_min, y_max = test_errs.min().round(prec), test_adv_errs.max().round(prec)
        # axs[0].set_yticks(np.arange(y_min, y_max, (y_max - y_min) / 10))
        axs[0].grid(which='both', alpha=0.5, linestyle='--')
        axs[0].legend(loc='best', prop={'size': legend_size})
        axs[0].set_title(plot_name_short)
        
        axs[1].plot(iters, train_adv_errs, label='train error', linestyle='solid', linewidth=line_width, marker='o', markersize=marker_size)
        axs[1].set_yticklabels(['{:.0%}'.format(x) for x in axs[1].get_yticks()])
        axs[1].set_xlabel('iteration')
        axs[1].set_ylabel('training error')
        # prec = 1 if np.round(test_adv_errs.max() - train_adv_errs.min(), 1) != 0.0 else 3
        # y_min, y_max = train_adv_errs.min().round(prec), train_adv_errs.max().round(prec)
        # axs[1].set_yticks(np.arange(y_min, y_max, (y_max - y_min) / 10))
        axs[1].grid(which='both', alpha=0.5, linestyle='--')
        axs[1].legend(loc='best', prop={'size': legend_size})
        axs[1].set_title(plot_name_short)
        
        axs[2].plot(iters, train_losses, label='train loss', linestyle='solid', linewidth=line_width, marker='o', markersize=marker_size)
        # axs[2] = sns.lineplot(iters, train_losses, linewidth=line_width, 
        #                       marker='o', markersize=marker_size, color="black")
        axs[2].set_title(plot_name_short)
        axs[2].set_xlabel('iteration')
        axs[2].set_ylabel('training loss')
        # prec = 1 if np.round(train_losses.max() - train_losses.min(), 1) != 0.0 else 3
        # y_min, y_max = train_losses.min().round(prec), train_losses.max().round(prec)
        # axs[2].set_yticks(np.arange(y_min, y_max, (y_max - y_min) / 10))
        axs[2].grid(which='both', alpha=0.5, linestyle='--')
        axs[2].legend(loc='best', prop={'size': legend_size})
        axs[2].set_title(plot_name_short)
    
        plt.savefig('plots/{}.pdf'.format(plot_name_long), bbox_inches='tight')
    if weak_learner == 'stump' and i % 3 == 2:
        print()
    if weak_learner == 'tree' and i % 3 == 2:
        print()

if flag_latex:
    # Global post-processing of the latex table
    latex_table = latex_table.replace('100.0', '100')  # to save some width in the table 
    for eps_orig in eps_format_dict:
        latex_table = latex_table.replace(eps_orig, eps_format_dict[eps_orig])
    
    print()
    print('Latex table:')
    print(latex_table)


Model (depth=4): 2019-08-11 14:28:04 dataset=breast_cancer weak_learner=tree model=robust_bound n_train=-1 n_trials_coord=10 eps=0.300 max_depth=4 lr=0.01
iter: 46/150  [test] err 0.73% adv_err_lb 6.57% adv_err 6.57%  adv_err_ub 6.57%  |  [valid] err 1.83% adv_err_lb 8.26% adv_err 8.26%  |  [train] err: 7.32%  adv_err: 13.73%  loss: 0.77129  (cert 2.356s, total 26.63 min)
Model (depth=4): 2019-08-11 14:28:04 dataset=diabetes weak_learner=tree model=robust_bound n_train=-1 n_trials_coord=8 eps=0.050 max_depth=4 lr=0.2
iter: 9/150  [test] err 27.27% adv_err_lb 35.71% adv_err 35.71%  adv_err_ub 35.71%  |  [valid] err 22.13% adv_err_lb 30.33% adv_err 30.33%  |  [train] err: 21.75%  adv_err: 28.46%  loss: 0.87517  (cert 3.165s, total 26.82 min)
Model (depth=4): 2019-08-11 14:28:08 dataset=cod_rna weak_learner=tree model=robust_bound n_train=-1 n_trials_coord=8 eps=0.025 max_depth=4 lr=0.2
iter: 36/150  [test] err 6.91% adv_err_lb 21.26% adv_err 21.37%  adv_err_ub 21.37%  |  [valid] err 7.95