# Visualize Trained Agents

In [None]:
from orbax.checkpoint import PyTreeCheckpointer

import _pickle as pickle
import jax
import json
import logging
import matplotlib.pyplot as plt
import numpy as np
import os
import timeit

from jaxl.buffers import get_buffer
from jaxl.constants import *
from jaxl.models import (
    get_model,
    get_policy,
    policy_output_dim,
)
from jaxl.models.policies import MultitaskPolicy
from jaxl.envs import get_environment
from jaxl.envs.rollouts import EvaluationRollout
from jaxl.utils import set_seed, parse_dict

In [None]:
run_seed = None
set_seed(run_seed)

In [None]:
rl_trained_path = "/Users/chanb/research/personal/jaxl/jaxl/logs/inverted_double_pendulum/ppo/06-07-23_13_30_37-77f2ee17-4bad-45aa-9446-6dbe1964dbbe"

num_episodes = 10
env_seed = 9999
buffer_size = 0

In [None]:
rl_config_path = os.path.join(rl_trained_path, "config.json")
with open(rl_config_path, "r") as f:
    rl_config_dict = json.load(f)
    rl_config_dict["learner_config"]["env_config"]["env_kwargs"][
        "render_mode"
    ] = "rgb_array"
    rl_config = parse_dict(rl_config_dict)

In [None]:
from gymnasium.experimental.wrappers import RecordVideoV0

In [None]:
h_state_dim = (1,)
if hasattr(rl_config.model_config, "h_state_dim"):
    h_state_dim = rl_config.model_config.h_state_dim

env = get_environment(rl_config.learner_config.env_config)
env = RecordVideoV0(env, "videos")

input_dim = env.observation_space.shape
output_dim = policy_output_dim(env.action_space.shape, rl_config.learner_config)
model = get_model(input_dim, output_dim, rl_config.model_config.policy)
policy = get_policy(model, rl_config.learner_config)

rl_model_path = os.path.join(rl_trained_path, "termination_model")
checkpointer = PyTreeCheckpointer()
model_dict = checkpointer.restore(rl_model_path)
rl_policy_params = model_dict[CONST_MODEL][CONST_POLICY]
with open(os.path.join(rl_model_path, "learner_dict.pkl"), "rb") as f:
    learner_dict = pickle.load(f)
    rl_obs_rms = learner_dict[CONST_OBS_RMS]

rl_rollout = EvaluationRollout(env, seed=env_seed)
rl_rollout.rollout(rl_policy_params, policy, rl_obs_rms, num_episodes, None)