In [None]:
import ipywidgets as widgets

# Avoid non-compliant Type 3 fonts
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42  # pylint: disable=wrong-import-position

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from ipywidgets import interact
from IPython.display import display
from tqdm.notebook import tqdm

import utils

In [None]:
pd.set_option('display.max_rows', None)
plt.rcParams['figure.figsize'] = (12, 8)

In [None]:
logs_dir = utils.get_logs_dir()
eval_dir = utils.get_eval_dir()
env_names = ['small_empty', 'small_divider', 'large_empty', 'large_doors', 'large_tunnels', 'large_rooms']
step_size = 100

In [None]:
# Load all runs
cfgs = [utils.load_config(str(x / 'config.yml')) for x in tqdm(sorted(logs_dir.iterdir())) if x.is_dir()]

In [None]:
def extend_curves(curves, min_len=None):
    if len(curves) == 0:
        return curves
    max_length = max(len(curve) for curve in curves)
    if min_len is not None:
        max_length = max(max_length, min_len)
    for i, curve in enumerate(curves):
        curves[i] = np.pad(curve, (0, max_length - len(curve)), 'edge')
    return curves

In [None]:
def get_curve_for_run(cfg):
    eval_path = eval_dir / '{}.npy'.format(cfg.run_name)
    data = np.load(eval_path, allow_pickle=True)
    curves = []
    for episode in data:
        cubes = np.asarray([step['cubes'] for step in episode])
        simulation_steps = np.array([step['simulation_steps'] for step in episode])
        x = np.arange(0, simulation_steps[-1] + step_size, step_size)
        xp, fp = simulation_steps, cubes
        curves.append(np.floor(np.interp(x, xp, fp, left=0)))
    return np.mean(extend_curves(curves), axis=0)

In [None]:
def get_all_curves():
    all_curves = {}
    for cfg in tqdm(cfgs):
        if cfg.experiment_name not in all_curves:
            all_curves[cfg.experiment_name] = []
        all_curves[cfg.experiment_name].append(get_curve_for_run(cfg))
    return all_curves

In [None]:
all_curves = get_all_curves()

In [None]:
def get_all_cutoffs():
    all_cutoffs = {}
    for cfg in tqdm(cfgs):
        robot_config_str = cfg.experiment_name.split('-')[0]
        if robot_config_str not in all_cutoffs:
            all_cutoffs[robot_config_str] = {}
        if cfg.env_name not in all_cutoffs[robot_config_str]:
            all_cutoffs[robot_config_str][cfg.env_name] = float('inf')

        # Find the time at which the last cube was successfully foraged
        y_mean = np.mean(extend_curves(all_curves[cfg.experiment_name]), axis=0)
        this_len = np.searchsorted(y_mean > y_mean[-1] - 0.5, True)  # Subtract 0.5 since interpolated curves are continuous
        all_cutoffs[robot_config_str][cfg.env_name] = min(all_cutoffs[robot_config_str][cfg.env_name], this_len)
    return all_cutoffs

In [None]:
all_cutoffs = get_all_cutoffs()

In [None]:
def get_all_results():
    all_results = {}
    for cfg in tqdm(cfgs):
        robot_config_str = cfg.experiment_name.split('-')[0]
        if robot_config_str not in all_results:
            all_results[robot_config_str] = {}
        cutoff = all_cutoffs[robot_config_str][cfg.env_name]
        curves = extend_curves(all_curves[cfg.experiment_name], min_len=(cutoff + 1))
        cubes = [curve[cutoff] for curve in curves]
        all_results[robot_config_str][cfg.experiment_name] = '{:.2f} ± {:.2f}'.format(np.mean(cubes), np.std(cubes))
    return all_results

In [None]:
all_results = get_all_results()

In [None]:
def show_table():
    def f(robot_config_str):
        data = {'performance': all_results[robot_config_str]}
        display(pd.DataFrame(data))

    robot_config_str_dropdown = widgets.RadioButtons(options=sorted(all_results.keys()))
    interact(f, robot_config_str=robot_config_str_dropdown)

In [None]:
# Note: These metrics measure relative performance, see show_curves() for plots of absolute performance
show_table()

In [None]:
def show_curves():
    def plot_curves(experiment_names, env_name, fontsize=20):
        for experiment_name in experiment_names:
            # Plot cutoff
            robot_config_str = experiment_name.split('-')[0]
            best_len = all_cutoffs[robot_config_str][env_name]
            plt.axvline(best_len * step_size, linewidth=1, c='r')

            # Plot curve
            curves = extend_curves(all_curves[experiment_name])
            x = np.arange(0, (len(curves[0]) - 0.5) * step_size, step_size)
            y_mean, y_std = np.mean(curves, axis=0), np.std(curves, axis=0)
            label = '{} ({})'.format(experiment_name, all_results[robot_config_str][experiment_name])
            plt.plot(x, y_mean, label=label)
            plt.fill_between(x, y_mean - y_std, y_mean + y_std, alpha=0.2)

        num_cubes = 20 if env_name.startswith('large') else 10
        plt.xlim(0)
        plt.ylim(0, num_cubes)
        plt.xticks(fontsize=fontsize - 2)
        plt.yticks(range(0, num_cubes + 1, 2), fontsize=fontsize - 2)
        plt.xlabel('Simulation Steps', fontsize=fontsize)
        plt.ylabel('Num Objects', fontsize=fontsize)
        plt.legend(fontsize=fontsize - 2)

    def f(env_name, experiment_names):
        if len(experiment_names) == 0:
            return
        plot_curves(experiment_names, env_name)
        plt.show()

    env_name_radio = widgets.RadioButtons(options=env_names)
    experiment_names_select = widgets.SelectMultiple(layout=widgets.Layout(width='60%', height='150px'))
    def update_experiment_names_options(*_):
        matching_experiment_names = []
        for experiment_name in sorted(all_curves):
            if env_name_radio.value in experiment_name:
                matching_experiment_names.append(experiment_name)
        experiment_names_select.options = matching_experiment_names
        experiment_names_select.rows = len(matching_experiment_names)
        experiment_names_select.value = ()
    env_name_radio.observe(update_experiment_names_options)
    interact(f, env_name=env_name_radio, experiment_names=experiment_names_select)

In [None]:
# Note: The vertical red line is used to measure relative performance, the curves show absolute performance
show_curves()