In [5]:
##################
## Load Imports ##
##################

import os
import pickle
import numpy as np
import matplotlib.pyplot as plt

from ipywidgets import interact, interactive, interactive_output, HBox, VBox, fixed, interact_manual
import ipywidgets as widgets

%matplotlib inline
%load_ext autoreload
%autoreload 2

from experiments import folder_name, load_metric_history

def plot_init(fig_size, xlabel = "Epoch", ylabel = "\log_base", log_base = None, title = ""):
    fig = plt.figure(figsize = fig_size)
    plt.rc('text', usetex = True)
    plt.grid(True,which="both",color='0.75')
    plt.title(title)
    plt.xlabel(xlabel)
    if ylabel == "\log_base":
        if log_base == None:
            plt.ylabel(r'$\log$loss')
        else:
            plt.ylabel(r'$\log_{' + str(log_base) + '}$loss')
    else:
        plt.ylabel(ylabel)
    return fig

def plot_metrics(axis, keys, colors, linestyles, results_folder, experiment_name, data_set, model_params, train_params, optim_params, log_base=np.exp(1), *args, **kwargs):
    metrics = load_metric_history(experiment_name = experiment_name,
                                  data_set = data_set,
                                  model_params = model_params,
                                  train_params = train_params,
                                  optim_params = optim_params,
                                  results_folder = results_folder)
    num_evals = len(metrics[keys[0]])
    epoch = (1+np.arange(num_evals)) * train_params['num_epochs'] / num_evals
    for i, key in enumerate(keys):
        axis.plot(epoch, metrics[key]/np.log(log_base), color = colors[i], linestyle = linestyles[i], *args, **kwargs)
        

np.random.seed(111)

######################
## Interactive Plot ##
######################

def plot_results(hu, pp, bs, mc, yt, yl, yu):
    
    hidden_sizes = eval(hu)
    prec = eval(pp)
    batch_size = eval(bs)
    num_epochs = 2 * batch_size
    train_mc_samples = eval(mc)
        
    
    ####################
    ## Set parameters ##
    ####################

    # Folder for storing results
    results_folder = "./results/"

    # Data set
    data_set = "mnist"
    log_base = 10

    # Model parameters
    model_params = {'hidden_sizes': hidden_sizes,
                    'act_func': 'relu',
                    'prior_prec': prec}

    # Training parameters
    train_params = {'num_epochs': num_epochs,
                    'batch_size': batch_size,
                    'train_mc_samples': train_mc_samples,
                    'eval_mc_samples': 10,
                    'seed': 123}

    # Optimization parameters
    optim_params = {'learning_rate': 0.001,
                    'betas': (0.9,0.999),
                    'prec_init': prec}

    ##################
    ## Plot logloss ##
    ##################

    fig = plot_init(fig_size = (12, 8), log_base = log_base)
    plot_metrics(plt, 
                 keys = ['test_pred_logloss'], 
                 log_base = log_base, 
                 colors = ['k'], 
                 linestyles = ['-'], 
                 results_folder = results_folder, 
                 experiment_name = "bbb_mlp_class", 
                 data_set = data_set, 
                 model_params = model_params,
                 train_params = train_params, 
                 optim_params = optim_params)
    plot_metrics(plt, 
                 keys = ['test_pred_logloss'], 
                 log_base = log_base, 
                 colors = ['r'], 
                 linestyles = ['-'], 
                 results_folder = results_folder, 
                 experiment_name = "vadam_mlp_class", 
                 data_set = data_set, 
                 model_params = model_params,
                 train_params = train_params, 
                 optim_params = optim_params)
    if yt == "True":
        plt.ylim(yl, yu)
    plt.legend(["BBVI", "Vadam"], loc='upper center')
    
hu = widgets.RadioButtons(
    options=['[400]', '[400,400]'],
    description='Hidden Units:',
    disabled=False
)

pp = widgets.SelectionSlider(
    options=['1e-2', '2e-2', '5e-2', '1e-1', '2e-1', '5e-1', '1e0', '2e0', '5e0', '1e1', '2e1', '5e1', '1e2', '2e2', '5e2'],
    value='1e0',
    description='Prior Prec.:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True
)

bs = widgets.RadioButtons(
    options=['1', '10', '100'],
    description='Batch Size:',
    disabled=False
)

mc = widgets.SelectionSlider(
    options=['1', '10'],
    value='10',
    description='MC Samples:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True
)

yt = widgets.RadioButtons(
    options=['True', 'False'],
    description='Fix ylim:',
    disabled=False
)

yl = widgets.FloatText(
    value=0.0,
    description='y_min:',
    disabled=False
)

yu = widgets.FloatText(
    value=1.0,
    description='y_max:',
    disabled=False
)

from IPython.display import display

out = interactive_output(plot_results, {"hu":hu, "pp":pp, "bs":bs, "mc":mc, "yt":yt, "yl":yl, "yu":yu})
hbox1 = HBox([yt, yl, yu])
vbox1 = VBox([hu])
vbox2 = VBox([bs])
vbox3 = VBox([pp, mc])
hbox2 = HBox([vbox1, vbox2, vbox3])

ui = VBox([hbox1, hbox2])

display(ui, out)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


VBox(children=(HBox(children=(RadioButtons(description='Fix ylim:', options=('True', 'False'), value='True'), …

Output()