# Visualising trials from Fast AutoAugment

In this notebook, I aim to understand how Fast AutoAugment found the optimal data augmentation policies for CIFAR-10.

In the bayesian optimization loop I designed, at each round, the evaluation function receives a randomly selected sub-policy that has two operations with their probabilities and magnitudes. The function saves the sub-policy and its validation loss to Trials. I'm going to delve into the Trials and visualize the relationship between validation error (that we are aiming to minimize) and operations.

To focus on an operation at a time, I de-coupled 2 operations with their joint validation loss. This might weaken the interpretation but it shows the general trend. 

In [None]:
import os
import pickle
from collections import defaultdict
from search_fastautoaugment import decipher_trial
import pandas as pd
import numpy as np

import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
cv_folds = 5
search_width = 2
search_results_folder = 'fastautoaugment'

In [None]:
agg_list = []

for k_idx in range(cv_folds):

    byT_error = []
    byT_policies = []

    for t_idx in range(search_width):
        trials = pickle.load(
            open(os.path.join(search_results_folder, "k{}_t{}_trials.pkl".format(k_idx, t_idx)), "rb"))
        val_error_list, sub_policy_list = decipher_trial(trials)

        byT_error.extend(val_error_list)
        byT_policies.extend(sub_policy_list)
        
    for error, policy in zip(byT_error, byT_policies):
        op_list = list(policy.keys())
        for op_name in op_list:
            entry = defaultdict()
            entry['op_name'] = op_name
            entry['op_prob'], entry['op_value'] = policy[op_name]
            entry['error'] = error

            agg_list.append(entry)

In [None]:
agg_df = pd.DataFrame(agg_list)

In [None]:
agg_df.head()

In [None]:
g = agg_df.groupby('op_name')

fig, ax = plt.subplots(8, 4, figsize=(12, 24))

for idx, (label, data) in enumerate(g):
    nrow = idx // 2
    isEven = int((idx+1) % 2 == 0)
    
    minimum_record = data.iloc[np.argmin(np.array(data.error))]
    
    ax[nrow, isEven*2].scatter(data.op_prob, data.error, alpha=0.5)
    ax[nrow, isEven*2].scatter(minimum_record.op_prob, minimum_record.error, color='red')
    if not data.op_value.iloc[0] == 'None':
        ax[nrow, isEven*2+1].scatter(data.op_value, data.error, alpha=0.5)
        ax[nrow, isEven*2+1].scatter(minimum_record.op_value, minimum_record.error, color='red')
    ax[nrow, isEven*2].set_title(label + " prob")
    ax[nrow, isEven*2+1].set_title(label + " value")
    
    ax[nrow, isEven*2].set_ylabel("error")
    
#     ax[nrow, isEven*2].set_ylim(0.1, 0.2)
#     ax[nrow, isEven*2+1].set_ylim(0.1, 0.2)
plt.tight_layout()    
plt.show()