# Evalute how robust the trained policy is in different environment variation

In [None]:
from gymnasium.experimental.wrappers import RecordVideoV0
from orbax.checkpoint import PyTreeCheckpointer
from pprint import pprint

import _pickle as pickle
import jax
import json
import logging
import matplotlib.pyplot as plt
import numpy as np
import os
import timeit

from jaxl.buffers import get_buffer
from jaxl.constants import *
from jaxl.models import (
    get_model,
    get_policy,
    policy_output_dim,
)
from jaxl.models.policies import MultitaskPolicy
from jaxl.envs import get_environment
from jaxl.envs.rollouts import EvaluationRollout
from jaxl.utils import set_seed, parse_dict, set_dict_value, get_dict_value

In [None]:
run_seed = None
set_seed(run_seed)

In [None]:
def get_env(agent_path):
    agent_config_path = os.path.join(agent_path, "config.json")
    with open(agent_config_path, "r") as f:
        agent_config_dict = json.load(f)
    return agent_config_dict["learner_config"]["env_config"]

In [None]:
agent_to_load_env_path = "/Users/chanb/research/personal/jaxl/jaxl/logs/hopper-default/ppo-discrete_control/07-10-23_07_58_53-7d0e3ac1-196d-43a6-bd5e-17ec441700e1"
trained_env_parameters = get_env(agent_to_load_env_path)

agent_path = "/Users/chanb/research/personal/jaxl/jaxl/logs/hopper-default/ppo-discrete_control/07-10-23_07_58_53-7d0e3ac1-196d-43a6-bd5e-17ec441700e1"

num_episodes = 10
env_seed = 9999
buffer_size = 0
record_video = True

In [None]:
def get_config(agent_path, seed, use_default=False):
    agent_config_path = os.path.join(agent_path, "config.json")
    with open(agent_config_path, "r") as f:
        agent_config_dict = json.load(f)
        agent_config_dict["learner_config"]["env_config"] = trained_env_parameters
        agent_config_dict["learner_config"]["env_config"]["env_kwargs"][
            "render_mode"
        ] = "rgb_array"
        agent_config_dict["learner_config"]["env_config"]["env_kwargs"]["seed"] = seed
        agent_config_dict["learner_config"]["env_config"]["env_kwargs"][
            "use_default"
        ] = use_default
        if "policy_distribution" not in agent_config_dict["learner_config"]:
            agent_config_dict["learner_config"][
                "policy_distribution"
            ] = CONST_DETERMINISTIC
        set_dict_value(agent_config_dict, "vmap_all", False)
        (multitask, num_models) = get_dict_value(agent_config_dict, "num_models")
        agent_config = parse_dict(agent_config_dict)
    return agent_config, {
        "multitask": multitask,
        "num_models": num_models,
    }

In [None]:
episodic_returns_per_variant = {}
env_configs = {}
checkpointer = PyTreeCheckpointer()

default_seed = trained_env_parameters["env_kwargs"]["seed"]

seeds = [default_seed, *np.random.randint(0, 2**32 - 1, 10)]
for env_i, seed in enumerate(seeds):
    agent_config, aux = get_config(agent_path, seed, use_default=(env_i == 0))
    env = get_environment(agent_config.learner_config.env_config)
    env_configs[np.log2(seed)] = (
        env.get_config() if hasattr(env, "get_config") else None
    )

    if record_video:
        env = RecordVideoV0(env, f"env_seed_{seed}-videos")

    if env_i == 0:
        input_dim = env.observation_space.shape
        output_dim = policy_output_dim(env.act_dim, agent_config.learner_config)
        model = get_model(
            input_dim,
            output_dim,
            getattr(agent_config.model_config, "policy", agent_config.model_config),
        )
        policy = get_policy(model, agent_config.learner_config)
        if aux["multitask"]:
            policy = MultitaskPolicy(policy, model, aux["num_models"])

        agent_model_path = os.path.join(agent_path, "termination_model")
        model_dict = checkpointer.restore(agent_model_path)
        agent_policy_params = model_dict[CONST_MODEL][CONST_POLICY]
        with open(os.path.join(agent_model_path, "learner_dict.pkl"), "rb") as f:
            learner_dict = pickle.load(f)
            agent_obs_rms = learner_dict[CONST_OBS_RMS]

    agent_rollout = EvaluationRollout(env, seed=env_seed)
    agent_rollout.rollout(
        agent_policy_params, policy, agent_obs_rms, num_episodes, None
    )

    episodic_returns_per_variant[np.log2(seed)] = agent_rollout.episodic_returns

In [None]:
pprint({seed: env_configs[seed]["modified_attributes"] for seed in env_configs})

In [None]:
fig, ax = plt.subplots(1, figsize=(10, 5))
ax.boxplot(episodic_returns_per_variant.values())
ax.set_xticks(
    range(1, len(episodic_returns_per_variant) + 1),
    ["{:.2f}".format(val) for val in episodic_returns_per_variant.keys()],
)
ax.set_title(f"Returns Across {num_episodes} Episodes")
ax.set_xlabel("Seed in Log-scale")
ax.set_ylabel("Return")
fig.show()

In [None]:
pprint(
    [
        (key, np.mean(val), np.std(val))
        for key, val in episodic_returns_per_variant.items()
    ]
)

In [None]:
fig, ax = plt.subplots(1, figsize=(10, 5))

means = []
stds = []
for val in episodic_returns_per_variant.values():
    means.append(np.mean(val))
    stds.append(np.std(val))
means = np.array(means)
stds = np.array(stds)

sort_idxes = np.argsort(seeds)
seeds = np.log2(np.array(seeds))
ax.plot(seeds[sort_idxes], means[sort_idxes], marker="x")
ax.fill_between(
    seeds[sort_idxes],
    means[sort_idxes] + stds[sort_idxes],
    means[sort_idxes] - stds[sort_idxes],
    alpha=0.1,
)
ax.set_title(f"Returns Across {num_episodes} Episodes")
ax.axvline(
    seeds[0],
    label="trained parameter: {:.2f}".format(seeds[0]),
    linestyle="--",
    linewidth=1,
)
ax.set_xlabel("Seed in Log-scale")
ax.set_ylabel("Return")
ax.legend()
fig.show()

In [None]:
print(seeds)
print(means)