# Visualize Best Model Evidence Results

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np

from pathlib import Path
import importlib
from src.autoks.postprocessing.summarize_group import _parse_experiment_group

## Set Default Parameters

In [None]:
# Parsing params
result_dir = Path("results")

# Visualization params
save_fig = False
output_path = 'best_model_evidence.pdf'
fig_title = None
x_label = 'evaluations'

MAX_N_COLS = 3
subplot_aspect_ratio = 1

# Parse Experiment Groups

## Format result directory
(if a custom value is given)

In [None]:
if isinstance(result_dir, str):
    result_dir = Path(result_dir)

## Get Result Files

In [None]:
paths = [result_dir / exp_group_dir for exp_group_dir in experiment_dir_names]
print(f'Created {len(paths)} paths.')

## Parse Results

In [None]:
exp_dicts_list = [_parse_experiment_group(p) for p in paths]

## Create dict of results

In [None]:
result_dict = {}

for exp_group_dict in exp_dicts_list:
    for exp_dict in exp_group_dict:
        # Get model selector.
        model_selector = exp_dict['model_selector']
        strategy_label = model_selector.__class__.__name__.split('ModelSelector')[0]
    
        # Get model search history.
        history = exp_dict["history"]
        
        # Get dataset.
        datasets_module = importlib.import_module('src.datasets')
        dataset_class_ = getattr(datasets_module, exp_dict['dataset_cls'])
        dataset_args = exp_dict.get('dataset_args', {})
        dataset = dataset_class_(**dataset_args)

        best_scores = history.stat_book_collection.stat_books['evaluations'].running_max('score')

        # Add to result dict.
        ds_key = dataset.name.lower()
        strat_key = strategy_label.lower()
        if ds_key not in result_dict:
            result_dict.update({ds_key: {strat_key: [best_scores]}})
        else:
            ds_values = result_dict[ds_key]
            if strat_key not in ds_values:
                result_dict[ds_key].update({strat_key: [best_scores]})
            else:
                result_dict[ds_key][strat_key].append(best_scores)

### Summarize `result_dict`

In [None]:
print('Created result dict.\n')
ds_keys = result_dict.keys()
for ds_key, ds_val in result_dict.items():
    n_strats = len(ds_val)
    strat_label = 'strategy' if n_strats == 1 else 'strategies'
    print(f'{ds_key} ({n_strats} {strat_label})')
    
    for strat_key, strat_val in ds_val.items():
        n_runs = len(strat_val)
        runs_label = 'run' if n_runs == 1 else 'runs'
        print(f'   {strat_key} ({n_runs} {runs_label})')

### Use result dict for visualization 
Example `best_scores_data`
```
{
    'airline': {
        'boems': [0.1, 0.2, 0.3, 0.4, 0.4],
        'cks': [[0.1, 0.1, 0.2, 0.2, 0.25]]
    },
    'mauna': {
        'cks': [[1, 2, 3, 4, 5], [1, 2, 3, 3, 5], [1, 1, 2, 4, 5]]
    }
}
```

In [None]:
best_scores_data = result_dict

In [None]:
# Parse dataset labels, strategy labels, and data.
best_score_keys = best_scores_data.keys()
strategy_keys = tuple(best_scores_data[key].keys() for key in best_score_keys)

dataset_labels = tuple(key.upper() for key in best_score_keys)
strategy_labels = tuple(tuple(key.upper() for key in keys) for keys in strategy_keys)

print(f"Dataset labels:\n   {dataset_labels}")
print()
print(f"Strategy labels:\n   {strategy_labels}")

## Format Data

### Force best score data to be 2D numpy arrays

In [None]:
for dataset_key, dataset_values in best_scores_data.items():
    for strategy_key, data in dataset_values.items():
        new_data = np.array(data)

        if new_data.ndim == 1:
            new_data = new_data[:, None].T

        assert new_data.ndim == 2
        
        best_scores_data[dataset_key][strategy_key] = new_data

## Define Plotting Functions


In [None]:
def plot_dataset_results(best_scores_list, ax, labels, title, legend=False):
    title_kwargs = {
        'size': 'large',
        'weight': 'book',
    }
    
    has_data = False
    for best_scores, label in zip(best_scores_list, labels):
        if best_scores.size > 0:
            has_data = True
            plot_mean_pm_std(best_scores, ax, label=label)
    
    ax.set_title(title, **title_kwargs)

    if legend and has_data:
        ax.legend(fontsize='large')

def plot_mean_pm_std(data, ax=None, plot_confidence=True, **kwargs):
    ax = ax or plt.gca()
    x = np.arange(data.shape[1])
    mu = np.mean(data, axis=0)
    
    if plot_confidence:
        std = np.std(data, axis=0)
        confidence = (mu - std, mu + std)
        ax.fill_between(x, confidence[0], confidence[1], alpha=0.3)
    
    ax.margins(x=0)
    return ax.plot(x, mu, lw=4, **kwargs)

def hide_upper_ax_lines(ax):
    # Hide the right and top spines
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)

    # Only show ticks on the left and bottom spines
    ax.yaxis.set_ticks_position('left')
    ax.xaxis.set_ticks_position('bottom')

## Create Visualization

### Define Global Styles

In [None]:
plt.style.use('seaborn-paper')
plt.rcParams['font.family'] = "serif"

### Infer subplot ordering

In [None]:
# figure out plot alignment using MAX_N_COLS
n_subplots = len(best_scores_data)

# assume MAX_N_COLS = 3
n_cols = min(MAX_N_COLS, n_subplots)
n_rows = int(np.ceil(n_subplots / MAX_N_COLS))

print(f'Going to create a plot with {n_subplots} subplots by allocating {n_rows} rows x {n_cols} columns.')

### Infer Figure Size

In [None]:
fig_subplot_width = 4.
fig_subplot_height = fig_subplot_width / subplot_aspect_ratio

fig_w = fig_subplot_width * n_cols
fig_h = fig_subplot_height * n_rows

figsize = (fig_w, fig_h)

print(f"Figure size = {figsize}")

### Now, Create Plot

In [None]:
fig, axes =  plt.subplots(nrows=n_rows, ncols=n_cols, figsize=figsize)

if n_rows == 1 and n_cols == 1:
    axes = np.array([axes])

i = 0
for key, dataset_label, strategy_label in zip(best_score_keys, dataset_labels, strategy_labels):
    row = i // n_rows#i // (n_subplots - 1)
    col = i - row * n_rows#i % (n_subplots - 1)

    best_scores_list = best_scores_data[key].values()

    if n_rows == 1:
        axis = axes[i]
    else:
        axis = axes[row, col]

    plot_dataset_results(best_scores_list, axis, strategy_label, dataset_label, legend=True)
    
    i += 1

# Format subplots.
for ax in axes.reshape(-1): 
    hide_upper_ax_lines(ax)
    ax.set_xlabel(x_label, size='large', weight= 'book')
    ax.locator_params(nbins=5, axis='y')
    ax.locator_params(nbins=4, axis='x')
    
    # Set the tick labels font.
    for label in (ax.get_xticklabels() + ax.get_yticklabels()):
        label.set_fontname('Arial')
        label.set_fontsize(12)
        
# Hide unused subplots.
for i in range(n_subplots, int(n_rows * n_cols)):
    row = i // n_rows#i // (n_subplots - 1)
    col = i - row * n_rows#i % (n_subplots - 1)
    
    if n_rows == 1:
        axis = axes[i]
    else:
        axis = axes[row, col]
        
    axis.axis('off')

plt.subplots_adjust(hspace = 0.4)

if fig_title:
    plt.suptitle(fig_title)

if save_fig:
    print(f'Saving figure to {output_path}')
    plt.savefig(output_path)

plt.show()