# Visualize Trained Agents

In [None]:
from gymnasium.experimental.wrappers import RecordVideoV0
from orbax.checkpoint import PyTreeCheckpointer

import _pickle as pickle
import jax
import json
import logging
import matplotlib.pyplot as plt
import numpy as np
import os
import timeit

from jaxl.buffers import get_buffer
from jaxl.constants import *
from jaxl.models import (
    get_model,
    get_policy,
    policy_output_dim,
)
from jaxl.models.policies import MultitaskPolicy
from jaxl.envs import get_environment
from jaxl.envs.rollouts import EvaluationRollout
from jaxl.utils import set_seed, parse_dict

In [None]:
run_seed = None
set_seed(run_seed)

In [None]:
def get_trained_parameters(rl_trained_path):
    rl_config_path = os.path.join(rl_trained_path, "config.json")
    with open(rl_config_path, "r") as f:
        rl_config_dict = json.load(f)
    return rl_config_dict["learner_config"]["env_config"]["env_kwargs"]

In [None]:
rl_trained_path = "/mnt/HDD/research/mtil/inverted_double_pendulum/expert_models/gravity/runs/0/gravity_-8.249612491943623/06-09-23_15_21_56-296b3f54-5c33-43f3-97dd-3b7eb184bc99"
trained_env_parameters = get_trained_parameters(rl_trained_path)

num_episodes = 100
env_seed = 9999
buffer_size = 0
record_video = False
gravities = [
    trained_env_parameters["gravity"],
    -0.01,
    -1.0,
    -5.0,
    -9.81,
    -13.0,
    -15.0,
    -30.0,
]

In [None]:
def get_config(rl_trained_path, gravity):
    rl_config_path = os.path.join(rl_trained_path, "config.json")
    with open(rl_config_path, "r") as f:
        rl_config_dict = json.load(f)
        rl_config_dict["learner_config"]["env_config"]["env_kwargs"][
            "render_mode"
        ] = "rgb_array"
        if gravity != "default":
            rl_config_dict["learner_config"]["env_config"]["env_kwargs"][
                "gravity"
            ] = gravity
        rl_config = parse_dict(rl_config_dict)
    return rl_config

In [None]:
episodic_returns_per_variant = {}

for gravity in gravities:
    rl_config = get_config(rl_trained_path, gravity)
    h_state_dim = (1,)
    if hasattr(rl_config.model_config, "h_state_dim"):
        h_state_dim = rl_config.model_config.h_state_dim

    env = get_environment(rl_config.learner_config.env_config)

    if record_video:
        env = RecordVideoV0(env, f"gravity_{gravity}-videos")
        # env = RecordVideoV0(env, f"gravity_{gravity}-videos", episode_trigger=lambda x: True)

    input_dim = env.observation_space.shape
    output_dim = policy_output_dim(env.action_space.shape, rl_config.learner_config)
    model = get_model(input_dim, output_dim, rl_config.model_config.policy)
    policy = get_policy(model, rl_config.learner_config)

    rl_model_path = os.path.join(rl_trained_path, "termination_model")
    checkpointer = PyTreeCheckpointer()
    model_dict = checkpointer.restore(rl_model_path)
    rl_policy_params = model_dict[CONST_MODEL][CONST_POLICY]
    with open(os.path.join(rl_model_path, "learner_dict.pkl"), "rb") as f:
        learner_dict = pickle.load(f)
        rl_obs_rms = learner_dict[CONST_OBS_RMS]

    rl_rollout = EvaluationRollout(env, seed=env_seed)
    rl_rollout.rollout(rl_policy_params, policy, rl_obs_rms, num_episodes, None)

    episodic_returns_per_variant[gravity] = rl_rollout.episodic_returns

In [None]:
fig, ax = plt.subplots(1, figsize=(10, 5))
ax.boxplot(episodic_returns_per_variant.values())
ax.set_xticks(
    range(1, len(episodic_returns_per_variant) + 1),
    [val if val else "default" for val in episodic_returns_per_variant.keys()],
)
ax.set_title("Returns Across 100 Episodes")
ax.set_xlabel("Gravity")
ax.set_ylabel("Return")
fig.show()

In [None]:
from pprint import pprint

pprint(
    [
        (key, np.mean(val), np.std(val))
        for key, val in episodic_returns_per_variant.items()
    ]
)

In [None]:
fig, ax = plt.subplots(1, figsize=(10, 5))

means = []
stds = []
for val in episodic_returns_per_variant.values():
    means.append(np.mean(val))
    stds.append(np.std(val))
means = np.array(means)
stds = np.array(stds)

sort_idxes = np.argsort(gravities)
gravities = np.array(gravities)
ax.plot(gravities[sort_idxes], means[sort_idxes], marker="x")
ax.fill_between(
    gravities[sort_idxes],
    means[sort_idxes] + stds[sort_idxes],
    means[sort_idxes] - stds[sort_idxes],
    alpha=0.1,
)
ax.set_title("Returns Across 100 Episodes")
ax.set_xlabel("Gravity")
ax.set_ylabel("Return")
fig.show()