# Compare RL-Trained Agents and BC-Trained Agents

In [None]:
from gymnasium.experimental.wrappers import RecordVideoV0
from orbax.checkpoint import PyTreeCheckpointer

import _pickle as pickle
import jax
import json
import logging
import matplotlib.pyplot as plt
import numpy as np
import os
import timeit

from jaxl.buffers import get_buffer
from jaxl.constants import *
from jaxl.models import (
    get_model,
    get_policy,
    policy_output_dim,
)
from jaxl.models.policies import MultitaskPolicy
from jaxl.envs import get_environment
from jaxl.envs.rollouts import EvaluationRollout
from jaxl.utils import set_seed, parse_dict

In [None]:
run_seed = None
set_seed(run_seed)

In [None]:
# rl_trained_path = "/mnt/HDD/research/mtil/inverted_double_pendulum/expert_models/gravity/runs/0/gravity_-8.249612491943623/06-09-23_15_21_56-296b3f54-5c33-43f3-97dd-3b7eb184bc99"
# bc_trained_path = "/mnt/HDD/research/mtil/inverted_double_pendulum/test_bc/gravity-num_tasks_analysis/runs/0/06-13-23_16_37_46-41025fa6-ec46-4be9-ad76-f0efda797283"
# mtbcs_trained_path = "/mnt/HDD/research/mtil/inverted_double_pendulum/test_mtbc/gravity-num_tasks_analysis/runs/0"

# Mac
rl_trained_path = "/Users/chanb/research/personal/jaxl/data/inverted_double_pendulum/expert_models/gravity/runs/0/gravity_-8.249612491943623/06-09-23_15_21_56-296b3f54-5c33-43f3-97dd-3b7eb184bc99"
bc_trained_path = "/Users/chanb/research/personal/jaxl/data/inverted_double_pendulum/test_bc/gravity-num_tasks_analysis/runs/0/06-13-23_16_37_46-41025fa6-ec46-4be9-ad76-f0efda797283"
mtbcs_trained_path = "/Users/chanb/research/personal/jaxl/data/inverted_double_pendulum/test_mtbc/gravity-num_tasks_analysis/runs/0"

num_episodes = 100
buffer_size = 0
env_seed = 9999
record_video = False

In [None]:
rl_config_path = os.path.join(rl_trained_path, "config.json")
with open(rl_config_path, "r") as f:
    rl_config_dict = json.load(f)
    rl_config_dict["learner_config"]["buffer_config"]["buffer_size"] = buffer_size
    rl_config_dict["learner_config"]["buffer_config"]["buffer_type"] = CONST_DEFAULT
    rl_config_dict["learner_config"]["env_config"]["env_kwargs"][
        "render_mode"
    ] = "rgb_array"
    rl_config = parse_dict(rl_config_dict)

bc_config_path = os.path.join(bc_trained_path, "config.json")
with open(bc_config_path, "r") as f:
    bc_config_dict = json.load(f)
    bc_config_dict["learner_config"]["policy_distribution"] = CONST_DETERMINISTIC
    bc_config = parse_dict(bc_config_dict)

In [None]:
h_state_dim = (1,)
if hasattr(rl_config.model_config, "h_state_dim"):
    h_state_dim = rl_config.model_config.h_state_dim

env = get_environment(rl_config.learner_config.env_config)
if record_video:
    env = RecordVideoV0(env, f"rl_expert-videos")

rl_buffer = get_buffer(
    rl_config.learner_config.buffer_config,
    rl_config.learner_config.seeds.buffer_seed,
    env,
    h_state_dim,
)

input_dim = rl_buffer.input_dim
output_dim = policy_output_dim(rl_buffer.output_dim, rl_config.learner_config)
model = get_model(
    input_dim,
    output_dim,
    getattr(rl_config.model_config, "policy", rl_config.model_config),
)
policy = get_policy(model, rl_buffer.output_dim, rl_config.learner_config)

rl_model_path = os.path.join(rl_trained_path, "termination_model")
checkpointer = PyTreeCheckpointer()
model_dict = checkpointer.restore(rl_model_path)
rl_policy_params = model_dict[CONST_MODEL][CONST_POLICY]
with open(os.path.join(rl_model_path, "learner_dict.pkl"), "rb") as f:
    learner_dict = pickle.load(f)
    rl_obs_rms = learner_dict[CONST_OBS_RMS]

rl_rollout = EvaluationRollout(env, seed=env_seed)
rl_rollout.rollout(
    rl_policy_params,
    policy,
    rl_obs_rms,
    num_episodes,
    rl_buffer if buffer_size else None,
)

In [None]:
h_state_dim = (1,)
if hasattr(rl_config.model_config, "h_state_dim"):
    h_state_dim = bc_config.model_config.h_state_dim

env = get_environment(rl_config.learner_config.env_config)
if record_video:
    env = RecordVideoV0(env, f"bc-videos")
bc_buffer = get_buffer(
    rl_config.learner_config.buffer_config,
    rl_config.learner_config.seeds.buffer_seed,
    env,
    h_state_dim,
)

input_dim = bc_buffer.input_dim
output_dim = policy_output_dim(bc_buffer.output_dim, bc_config.learner_config)
model = get_model(input_dim, output_dim, bc_config.model_config)
policy = get_policy(model, rl_buffer.output_dim, bc_config.learner_config)

bc_model_path = os.path.join(bc_trained_path, "termination_model")
checkpointer = PyTreeCheckpointer()
model_dict = checkpointer.restore(bc_model_path)
bc_policy_params = model_dict[CONST_MODEL][CONST_POLICY]
with open(os.path.join(bc_model_path, "learner_dict.pkl"), "rb") as f:
    learner_dict = pickle.load(f)
    bc_obs_rms = learner_dict[CONST_OBS_RMS]

bc_rollout = EvaluationRollout(env, seed=env_seed)
bc_rollout.rollout(
    bc_policy_params,
    policy,
    bc_obs_rms,
    num_episodes,
    bc_buffer if buffer_size else None,
)

In [None]:
if mtbcs_trained_path:
    mtbc_episodic_returns_per_task = {}

    for num_tasks_dir in os.listdir(mtbcs_trained_path):
        mtbc_trained_path = None
        dir_to_walk = os.path.join(mtbcs_trained_path, num_tasks_dir)

        for run_path, _, filenames in os.walk(dir_to_walk):
            for filename in filenames:
                if filename != "config.json":
                    continue
                mtbc_trained_path = run_path

        mtbc_config_path = os.path.join(mtbc_trained_path, "config.json")
        with open(mtbc_config_path, "r") as f:
            mtbc_config_dict = json.load(f)
            mtbc_config_dict["learner_config"][
                "policy_distribution"
            ] = CONST_DETERMINISTIC
            mtbc_config_dict["model_config"]["predictor"]["vmap_all"] = False
            mtbc_config = parse_dict(mtbc_config_dict)

        mtbc_episodic_returns = []
        mtbc_buffers = []

        h_state_dim = (1,)
        if hasattr(rl_config.model_config, "h_state_dim"):
            h_state_dim = mtbc_config.model_config.h_state_dim
        num_tasks = len(mtbc_config.learner_config.buffer_configs)

        for task_i in range(num_tasks):
            env = get_environment(rl_config.learner_config.env_config)
            if record_video:
                env = RecordVideoV0(env, f"mtbc-{num_tasks_dir}-videos")
            mtbc_buffers.append(
                get_buffer(
                    rl_config.learner_config.buffer_config,
                    rl_config.learner_config.seeds.buffer_seed,
                    env,
                    h_state_dim,
                )
            )

            if task_i == 0:
                input_dim = mtbc_buffers[-1].input_dim
                output_dim = policy_output_dim(
                    mtbc_buffers[-1].output_dim, mtbc_config.learner_config
                )

                model = get_model(input_dim, output_dim, mtbc_config.model_config)
                policy = MultitaskPolicy(
                    get_policy(model, mtbc_config.learner_config), model, num_tasks
                )

                mtbc_model_path = os.path.join(mtbc_trained_path, "termination_model")
                checkpointer = PyTreeCheckpointer()
                model_dict = checkpointer.restore(mtbc_model_path)
                mtbc_policy_params = model_dict[CONST_MODEL][CONST_POLICY]
                with open(os.path.join(mtbc_model_path, "learner_dict.pkl"), "rb") as f:
                    learner_dict = pickle.load(f)
                    mtbc_obs_rms = learner_dict[CONST_OBS_RMS]

            mtbc_rollout = EvaluationRollout(env, seed=env_seed)
            mtbc_rollout.rollout(
                mtbc_policy_params,
                policy,
                mtbc_obs_rms,
                num_episodes,
                mtbc_buffers[-1] if buffer_size else None,
            )
            mtbc_episodic_returns.append(mtbc_rollout.episodic_returns)
        mtbc_episodic_returns_per_task[num_tasks_dir] = mtbc_episodic_returns

In [None]:
plt.plot(
    1 + np.arange(len(rl_rollout.episodic_returns)),
    rl_rollout.episodic_returns,
    label="RL Expert",
    linewidth=1.0,
)
plt.plot(
    1 + np.arange(len(bc_rollout.episodic_returns)),
    bc_rollout.episodic_returns,
    label="BC",
    linewidth=1.0,
)

if mtbcs_trained_path:
    for num_tasks, episodic_returns in mtbc_episodic_returns_per_task.items():
        for task_i in range(len(episodic_returns)):
            plt.plot(
                1 + np.arange(len(mtbc_episodic_returns[task_i])),
                mtbc_episodic_returns[task_i],
                label=f"MTBC - {num_tasks}",
                alpha=0.7,
                linestyle="--",
                linewidth=2.0,
            )
plt.title("Comparison in Returns")
plt.xlabel("Episode")
plt.ylabel("Return")
plt.legend()
plt.show()

In [None]:
fig, ax = plt.subplots(1, figsize=(7, 5))

rl_expert_baseline = {
    CONST_MEAN: np.mean(rl_rollout.episodic_returns),
    CONST_STD: np.std(rl_rollout.episodic_returns),
    "max": np.max(rl_rollout.episodic_returns),
    "min": np.min(rl_rollout.episodic_returns),
}

num_tasks_variants = [0.5]
means = [
    np.mean(bc_rollout.episodic_returns),
]
stds = [
    np.std(bc_rollout.episodic_returns),
]
maxs = [
    np.max(bc_rollout.episodic_returns),
]
mins = [
    np.min(bc_rollout.episodic_returns),
]

if mtbcs_trained_path:
    for curr_num_tasks, episodic_returns in mtbc_episodic_returns_per_task.items():
        num_tasks_variants.append(int(curr_num_tasks.split("_")[-1]))
        means.append(np.mean(episodic_returns))
        stds.append(np.std(episodic_returns))
        maxs.append(np.max(episodic_returns))
        mins.append(np.min(episodic_returns))

num_tasks_variants = np.array(num_tasks_variants)
means = np.array(means)
stds = np.array(stds)
maxs = np.array(maxs)
mins = np.array(mins)

if not len(num_tasks_variants):
    num_tasks_variants = [1]

sort_idxes = np.argsort(num_tasks_variants)
num_tasks_variants = num_tasks_variants[sort_idxes]
means = means[sort_idxes]
stds = stds[sort_idxes]
maxs = maxs[sort_idxes]
mins = mins[sort_idxes]

ax.axhline(rl_expert_baseline[CONST_MEAN], 0, 1, c="black", label="RL Expert")
ax.axhspan(
    np.clip(
        rl_expert_baseline[CONST_MEAN] + rl_expert_baseline[CONST_STD],
        a_min=rl_expert_baseline["min"],
        a_max=rl_expert_baseline["max"],
    ),
    np.clip(
        rl_expert_baseline[CONST_MEAN] - rl_expert_baseline[CONST_STD],
        a_min=rl_expert_baseline["min"],
        a_max=rl_expert_baseline["max"],
    ),
    alpha=0.3,
    color="black",
)

if mtbcs_trained_path:
    ax.plot(np.log2(num_tasks_variants), means, marker="x", label="MTBC")
    ax.fill_between(
        np.log2(num_tasks_variants),
        np.clip(means + stds, a_min=mins, a_max=maxs),
        np.clip(means - stds, a_min=mins, a_max=maxs),
        alpha=0.3,
    )


# ax.set_ylim(9250, 9500)
ax.set_title(f"Returns Across {num_episodes} Episodes")
ax.set_xlabel("Number of Tasks in $log_2$ (-1 means no pretraining)")
ax.set_ylabel("Return")
ax.legend()
fig.show()

In [None]:
import _pickle as pickle

with open("num_tasks.pkl", "wb") as f:
    pickle.dump(
        {
            "rl_expert": rl_expert_baseline,
            "mtbc": {
                "num_tasks_variants": num_tasks_variants[1:],
                CONST_MEAN: means[1:],
                CONST_STD: stds[1:],
                "max": maxs[1:],
                "min": mins[1:],
            },
        },
        f,
    )