# Evaluate trained policies at checkpoints and choose the best one

In [None]:
from pprint import pprint

import _pickle as pickle
import math
import matplotlib
%matplotlib widget
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import notebook_utils
import numpy as np
import os

from jaxl.constants import *
from jaxl.utils import set_seed

In [None]:
run_seed = None
set_seed(run_seed)

In [None]:
runs_dir = "/Users/chanb/research/personal/jaxl/data/expert_models/pendulum_disc/runs"
save_path = "pendulum_disc.pkl"

num_episodes = 10
env_seed = 9999
record_video = False

In [None]:
results = {}

if os.path.isfile(save_path):
    results = pickle.load(open(save_path, "rb"))
else:
    run_i = 0
    for run_path, _, filenames in os.walk(runs_dir):
        for filename in filenames:
            if filename != "config.json":
                continue
            run_i += 1
            if run_i % 10 == 0:
                print(f"Processed {run_i} variants")
            agent_path = run_path
            variant = os.path.basename(os.path.dirname(agent_path)).split("-")[1]
            results[variant] = notebook_utils.get_episodic_returns_per_checkpoint(
                agent_path,
                variant,
                env_seed,
                num_episodes,
                None,
                record_video,
                False,
            )

    if save_path:
        with open(save_path, "wb") as f:
            pickle.dump(results, f)

In [None]:
sorted_variants = np.sort([int(variant) for variant in results.keys()])

In [None]:
modified_attributes = {}
best_returns = []

for variant_val in results.values():
    for attr, attr_val in variant_val[CONST_ENV_CONFIG]["modified_attributes"].items():
        modified_attributes.setdefault(attr, [])
        modified_attributes[attr].append(attr_val)
    best_returns.append(np.max(np.mean(list(variant_val[CONST_EPISODIC_RETURNS].values()), axis=-1)))


In [None]:
n_rows = math.ceil(len(modified_attributes) / 3)
fig, axes = plt.subplots(n_rows, 3, figsize=(17.5, (5 * n_rows)))

for attr_i, (attr, attr_vals) in enumerate(modified_attributes.items()):
    row_i = attr_i // 3
    col_i = attr_i % 3
    sort_idxes = np.argsort(attr_vals)
    if n_rows == 1:
        axes[col_i].plot(np.array(attr_vals)[sort_idxes], np.array(best_returns)[sort_idxes])
        axes[col_i].set_title(f"Attribute: {attr}")
    else:
        axes[row_i, col_i].plot(np.array(attr_vals)[sort_idxes], np.array(best_returns)[sort_idxes])
        axes[row_i, col_i].set_title(f"Attribute: {attr}")

fig.supylabel("Best Expected Returns")
fig.tight_layout()

In [None]:
len(modified_attributes)

In [None]:
if len(modified_attributes) == 3:
    (x_label, y_label, z_label) = list(modified_attributes.keys())

    fig = plt.figure()
    ax = fig.add_subplot(projection='3d')
    c_map = plt.get_cmap("jet")
    c_norm = matplotlib.colors.Normalize(vmin=min(best_returns), vmax=max(best_returns))
    
    ax.scatter(*list(modified_attributes.values()), c=best_returns, cmap=c_map)
    ax.set_xlabel(x_label)
    ax.set_ylabel(y_label)
    ax.set_zlabel(z_label)
    fig.colorbar(cm.ScalarMappable(norm=c_norm, cmap=c_map), ax=ax, location="top")

In [None]:
for variant in sorted_variants:
    episodic_returns = results[f"{variant}"][CONST_EPISODIC_RETURNS]
    fig, ax = plt.subplots(1, figsize=(10, 5))

    model_ids = list(episodic_returns.keys())
    means = []
    stds = []
    for val in episodic_returns.values():
        means.append(np.mean(val))
        stds.append(np.std(val))
    means = np.array(means)
    stds = np.array(stds)

    sort_idxes = np.argsort(model_ids)
    model_ids = np.array(model_ids)
    ax.plot(model_ids[sort_idxes], means[sort_idxes], marker="x")
    ax.fill_between(
        model_ids[sort_idxes],
        means[sort_idxes] + stds[sort_idxes],
        means[sort_idxes] - stds[sort_idxes],
        alpha=0.1,
    )
    ax.set_title(f"Returns Across {num_episodes} Episodes - Variant: {variant}")
    ax.set_xlabel("Iteration")
    ax.set_ylabel("Return")
    fig.show()