In [None]:
import sys
from pathlib import Path
sys.path.append(str(Path.cwd().parent))

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from copy import deepcopy
import numpy as np
from collections import defaultdict
from src.rl_utils import get_project_folder
sns.set_theme(context="paper", style="whitegrid",
              palette="colorblind", font_scale=1.3)

In [None]:
%run eval_focus.py --mode blur

In [None]:
def plot_ready_data(errors):
    new_errors = deepcopy(errors)
    for key, val in errors.items():
        new_errors[key] = np.concatenate(new_errors[key]).tolist()

    lens = []
    for i, val in enumerate(range(0, 110, 10)):
        lens = lens + [val] * [len(value) for value in errors["agent"]][i]

    plot_dict = defaultdict(list)
    title_dict = {"agent": "Black box (PPO)", "pwnet": "PW-Net", "moe": "Ours",
                  "viper": "VIPER with Concepts", "sdt": "SDT"}

    for key, val in new_errors.items():
        plot_dict["Error"] = plot_dict["Error"] + val
        plot_dict["Method"] = plot_dict["Method"] + \
            [title_dict[key]] * len(val)
        plot_dict["Percentile"] = plot_dict["Percentile"] + lens

    return pd.DataFrame(plot_dict)


dfs = {}

for env_id in ["CarRacing-v2", "PongNoFrameskip-v4", "BreakoutNoFrameskip-v4", "MsPacmanNoFrameskip-v4"]:
    dfs[env_id] = plot_ready_data(INFO[env_id])

In [None]:
fig, ax = plt.subplots(layout="constrained", figsize=(3 * 4, 4), ncols=4)

title = {"CarRacing-v2": "Car Racing", "PongNoFrameskip-v4": "Pong",
         "BreakoutNoFrameskip-v4": "Breakout", "MsPacmanNoFrameskip-v4": "MsPacman"}

for i, env_id in enumerate(list(title.keys())):
    sns.lineplot(dfs[env_id], x="Percentile", y="Error", hue="Method", ax=ax[i],
                 legend=True)
    ax[i].get_legend().remove()
    ax[i].set_xlabel("")
    if i > 0:
        ax[i].set_ylabel("")
    ax[i].set_title(title[env_id])

fig.text(0.5, -0.04, 'Percentile', ha='center')
handles, labels = ax[0].get_legend_handles_labels()
fig.legend(handles, labels, loc=(0.068, 0.64))

save_folder = get_project_folder() / "experiment-data/input-fidelity/"
if not save_folder.exists():
    save_folder.mkdir(parents=True)
fig.savefig(save_folder / f"{INFO['mode']}__input-fidelity.pdf",
            bbox_inches='tight', pad_inches=0)