# Compare RL-Trained Agents and BC-Trained Agents

In [None]:
from orbax.checkpoint import PyTreeCheckpointer

import _pickle as pickle
import jax
import json
import logging
import matplotlib.pyplot as plt
import numpy as np
import os
import timeit

from jaxl.buffers import get_buffer
from jaxl.constants import *
from jaxl.models import (
    get_model,
    get_policy,
    policy_output_dim,
)
from jaxl.models.policies import MultitaskPolicy
from jaxl.envs import get_environment
from jaxl.envs.rollouts import EvaluationRollout
from jaxl.utils import set_seed, parse_dict

In [None]:
run_seed = None
set_seed(run_seed)

In [None]:
rl_trained_path = "/home/chanb/Documents/research/personal/jaxl/data/inverted_pendulum/runs/0/gravity_-9.059259217791961/06-04-23_02_23_53-420709bb-6e78-4587-b383-91d4bcdd0de4"
bc_trained_path = "/home/chanb/Documents/research/personal/jaxl/jaxl/logs/inverted_pendulum/bc/06-06-23_16_08_30-d90fd198-0564-49b0-b97b-7e5aa996f55e"
mtbcs_trained_path = ""

# rl_trained_path = "/mnt/HDD/research/mtil/inverted_pendulum/expert_models/gravity/runs/0/gravity_-10.788581730190907/06-03-23_22_40_06-e516828d-4848-4fd1-bce8-18d3f297d2c5"
# bc_trained_path = "/mnt/HDD/research/mtil/inverted_pendulum/test_bc/gravity-num_tasks_analysis/runs/0/06-06-23_09_22_09-a6cccfa8-1bb2-473e-b0e1-990623e2866b"
# mtbcs_trained_path = "/mnt/HDD/research/mtil/inverted_pendulum/test_mtbc/gravity-num_tasks_analysis/runs/0"

num_episodes = 10
buffer_size = 100000
env_seed = 9999

In [None]:
rl_config_path = os.path.join(rl_trained_path, "config.json")
with open(rl_config_path, "r") as f:
    rl_config_dict = json.load(f)
    rl_config_dict["learner_config"]["buffer_config"]["buffer_size"] = buffer_size
    rl_config_dict["learner_config"]["buffer_config"]["buffer_type"] = CONST_DEFAULT
    rl_config = parse_dict(rl_config_dict)

bc_config_path = os.path.join(bc_trained_path, "config.json")
with open(bc_config_path, "r") as f:
    bc_config_dict = json.load(f)
    bc_config_dict["learner_config"]["policy_distribution"] = CONST_DETERMINISTIC
    bc_config = parse_dict(bc_config_dict)

In [None]:
h_state_dim = (1,)
if hasattr(rl_config.model_config, "h_state_dim"):
    h_state_dim = rl_config.model_config.h_state_dim

env = get_environment(rl_config.learner_config.env_config)
rl_buffer = get_buffer(
    rl_config.learner_config.buffer_config,
    rl_config.learner_config.seeds.buffer_seed,
    env,
    h_state_dim,
)

input_dim = rl_buffer.input_dim
output_dim = policy_output_dim(rl_buffer.output_dim, rl_config.learner_config)
model = get_model(input_dim, output_dim, rl_config.model_config)
policy = get_policy(model, rl_config.learner_config)

rl_model_path = os.path.join(rl_trained_path, "termination_model")
checkpointer = PyTreeCheckpointer()
model_dict = checkpointer.restore(rl_model_path)
rl_policy_params = model_dict[CONST_MODEL][CONST_POLICY]
with open(os.path.join(rl_model_path, "learner_dict.pkl"), "rb") as f:
    learner_dict = pickle.load(f)
    rl_obs_rms = learner_dict[CONST_OBS_RMS]

rl_rollout = EvaluationRollout(env, seed=env_seed)
rl_rollout.rollout(rl_policy_params, policy, rl_obs_rms, num_episodes, rl_buffer)

In [None]:
h_state_dim = (1,)
if hasattr(rl_config.model_config, "h_state_dim"):
    h_state_dim = bc_config.model_config.h_state_dim

env = get_environment(rl_config.learner_config.env_config)
bc_buffer = get_buffer(
    rl_config.learner_config.buffer_config,
    rl_config.learner_config.seeds.buffer_seed,
    env,
    h_state_dim,
)

input_dim = bc_buffer.input_dim
output_dim = policy_output_dim(bc_buffer.output_dim, bc_config.learner_config)
model = get_model(input_dim, output_dim, bc_config.model_config)
policy = get_policy(model, bc_config.learner_config)

bc_model_path = os.path.join(bc_trained_path, "termination_model")
checkpointer = PyTreeCheckpointer()
model_dict = checkpointer.restore(bc_model_path)
bc_policy_params = model_dict[CONST_MODEL][CONST_POLICY]
with open(os.path.join(bc_model_path, "learner_dict.pkl"), "rb") as f:
    learner_dict = pickle.load(f)
    bc_obs_rms = learner_dict[CONST_OBS_RMS]

bc_rollout = EvaluationRollout(env, seed=env_seed)
bc_rollout.rollout(bc_policy_params, policy, bc_obs_rms, num_episodes, bc_buffer)

In [None]:
if mtbcs_trained_path:
    mtbc_episodic_returns_per_task = {}

    for num_tasks_dir in os.listdir(mtbcs_trained_path):
        mtbc_trained_path = None
        dir_to_walk = os.path.join(mtbcs_trained_path, num_tasks_dir)

        for run_path, _, filenames in os.walk(dir_to_walk):
            for filename in filenames:
                if filename != "config.json":
                    continue
                mtbc_trained_path = run_path

        mtbc_config_path = os.path.join(mtbc_trained_path, "config.json")
        with open(mtbc_config_path, "r") as f:
            mtbc_config_dict = json.load(f)
            mtbc_config_dict["learner_config"][
                "policy_distribution"
            ] = CONST_DETERMINISTIC
            mtbc_config_dict["model_config"]["predictor"]["vmap_all"] = False
            mtbc_config = parse_dict(mtbc_config_dict)

        mtbc_episodic_returns = []
        mtbc_buffers = []

        h_state_dim = (1,)
        if hasattr(rl_config.model_config, "h_state_dim"):
            h_state_dim = mtbc_config.model_config.h_state_dim
        num_tasks = len(mtbc_config.learner_config.buffer_configs)

        for task_i in range(num_tasks):
            env = get_environment(rl_config.learner_config.env_config)
            mtbc_buffers.append(
                get_buffer(
                    rl_config.learner_config.buffer_config,
                    rl_config.learner_config.seeds.buffer_seed,
                    env,
                    h_state_dim,
                )
            )

            if task_i == 0:
                input_dim = mtbc_buffers[-1].input_dim
                output_dim = policy_output_dim(
                    mtbc_buffers[-1].output_dim, mtbc_config.learner_config
                )

                model = get_model(input_dim, output_dim, mtbc_config.model_config)
                policy = MultitaskPolicy(
                    get_policy(model, mtbc_config.learner_config), model, num_tasks
                )

                mtbc_model_path = os.path.join(mtbc_trained_path, "termination_model")
                checkpointer = PyTreeCheckpointer()
                model_dict = checkpointer.restore(mtbc_model_path)
                mtbc_policy_params = model_dict[CONST_MODEL][CONST_POLICY]
                with open(os.path.join(mtbc_model_path, "learner_dict.pkl"), "rb") as f:
                    learner_dict = pickle.load(f)
                    mtbc_obs_rms = learner_dict[CONST_OBS_RMS]

            mtbc_rollout = EvaluationRollout(env, seed=env_seed)
            mtbc_rollout.rollout(
                mtbc_policy_params, policy, mtbc_obs_rms, num_episodes, mtbc_buffers[-1]
            )
            mtbc_episodic_returns.append(mtbc_rollout.episodic_returns)
        mtbc_episodic_returns_per_task[num_tasks_dir] = mtbc_episodic_returns

In [None]:
plt.plot(
    1 + np.arange(len(rl_rollout.episodic_returns)),
    rl_rollout.episodic_returns,
    label="RL Expert",
)
plt.plot(
    1 + np.arange(len(bc_rollout.episodic_returns)),
    bc_rollout.episodic_returns,
    label="BC",
)

if mtbcs_trained_path:
    for num_tasks, episodic_returns in mtbc_episodic_returns_per_task.items():
        for task_i in range(len(episodic_returns)):
            plt.plot(
                1 + np.arange(len(mtbc_episodic_returns[task_i])),
                mtbc_episodic_returns[task_i],
                label=f"MTBC - {num_tasks}",
                alpha=0.3,
                linestyle="--",
            )
plt.title("Comparison in Returns")
plt.xlabel("Episode")
plt.ylabel("Return")
plt.legend()
plt.show()

In [None]:
from pprint import pprint

for key, val in mtbc_episodic_returns_per_task.items():
    print(key, val)