# Evaluate trained policies at checkpoints and choose the best one

In [None]:
from pprint import pprint

import _pickle as pickle
import matplotlib.pyplot as plt
import notebook_utils
import numpy as np
import os

from jaxl.constants import *
from jaxl.utils import set_seed

In [None]:
run_seed = None
set_seed(run_seed)

In [None]:
runs_dir = "/Users/chanb/research/personal/jaxl/data/expert_models/pendulum_cont/runs"
save_path = "pendulum_cont.pkl"

num_episodes = 10
env_seed = 9999
record_video = False

In [None]:
results = {}

run_i = 0
for run_path, _, filenames in os.walk(runs_dir):
    for filename in filenames:
        if filename != "config.json":
            continue
        run_i += 1
        if run_i % 10 == 0:
            print(f"Processed {run_i} variants")
        agent_path = run_path
        variant = os.path.basename(os.path.dirname(agent_path)).split("-")[1]
        results[variant] = notebook_utils.get_episodic_returns_per_checkpoint(
            agent_path,
            variant,
            env_seed,
            num_episodes,
            None,
            record_video,
            False,
        )

if save_path:
    with open(save_path, "wb") as f:
        pickle.dump(results, f)

In [None]:
fig, ax = plt.subplots(1, figsize=(10, 5))

model_ids = list(episodic_returns_per_variant.keys())
means = []
stds = []
for val in episodic_returns_per_variant.values():
    means.append(np.mean(val))
    stds.append(np.std(val))
means = np.array(means)
stds = np.array(stds)

sort_idxes = np.argsort(model_ids)
model_ids = np.array(model_ids)
ax.plot(model_ids[sort_idxes], means[sort_idxes], marker="x")
ax.fill_between(
    model_ids[sort_idxes],
    means[sort_idxes] + stds[sort_idxes],
    means[sort_idxes] - stds[sort_idxes],
    alpha=0.1,
)
ax.set_title(f"Returns Across {num_episodes} Episodes")
ax.set_xlabel("Iteration")
ax.set_ylabel("Return")
ax.legend()
fig.show()