# Evaluate trained policies at checkpoints

In [None]:
from gymnasium.experimental.wrappers import RecordVideoV0
from pprint import pprint

import matplotlib.pyplot as plt
import notebook_utils
import numpy as np
import os

from jaxl.constants import *
from jaxl.envs import get_environment
from jaxl.envs.rollouts import EvaluationRollout
from jaxl.utils import set_seed

In [None]:
run_seed = None
set_seed(run_seed)

In [None]:
# runs_dir = "/Users/chanb/research/personal/jaxl/data/hyperparam_search/hopper_cont/search_expert/runs"
# runs_dir = "/Users/chanb/research/personal/jaxl/data/hyperparam_search/hopper_disc/search_expert/runs"
# runs_dir = "/Users/chanb/research/personal/jaxl/data/hyperparam_search/pendulum_disc/search_expert/runs/"
runs_dir = runs_dir = "/Users/chanb/research/personal/jaxl/data/hyperparam_search/cheetah_disc/search_expert/runs/1"

num_episodes = 10
env_seed = 9999
record_video = True

In [None]:
episodic_returns_per_variant = {}

for run_path, _, filenames in os.walk(runs_dir):
    for filename in filenames:
        if filename != "config.json" or not os.path.isdir(os.path.join(run_path, "termination_model")):
            continue
        agent_path = run_path

        agent_config, aux = notebook_utils.get_config(agent_path)
        env = get_environment(agent_config.learner_config.env_config)

        variant = os.path.basename(os.path.dirname(agent_path))

        _, policy = notebook_utils.get_agent(env, agent_config, aux)
        agent_policy_params, agent_obs_rms = notebook_utils.restore_agent_state(
            os.path.join(agent_path, "termination_model")
        )

        model_id = f"{variant}-final"
        if record_video:
            env = RecordVideoV0(
                env, f"videos/model_id_{model_id}-videos", disable_logger=True
            )

        agent_rollout = EvaluationRollout(env, seed=env_seed)
        agent_rollout.rollout(
            agent_policy_params, policy, agent_obs_rms, num_episodes, None, False
        )

        episodic_returns_per_variant[model_id] = agent_rollout.episodic_returns
        print(env.get_config()["modified_attributes"])

In [None]:
pprint(
    [
        (key, np.mean(val), np.std(val))
        for key, val in episodic_returns_per_variant.items()
    ]
)

In [None]:
fig, ax = plt.subplots(1, figsize=(10, 5))

model_ids = [int(key.split("-")[1]) for key in episodic_returns_per_variant.keys()]
means = []
stds = []
for val in episodic_returns_per_variant.values():
    means.append(np.mean(val))
    stds.append(np.std(val))
means = np.array(means)
stds = np.array(stds)

sort_idxes = np.argsort(model_ids)
model_ids = np.array(model_ids)
ax.plot(model_ids[sort_idxes], means[sort_idxes], marker="x")
ax.fill_between(
    model_ids[sort_idxes],
    means[sort_idxes] + stds[sort_idxes],
    means[sort_idxes] - stds[sort_idxes],
    alpha=0.1,
)
ax.set_title(f"Returns Across {num_episodes} Episodes")
ax.set_xlabel("Variant")
ax.set_ylabel("Return")
ax.legend()
fig.show()