In [5]:
from orbax.checkpoint import PyTreeCheckpointer

import _pickle as pickle
import jax
import json
import logging
import os
import timeit

from jaxl.buffers import get_buffer
from jaxl.constants import *
from jaxl.models import (
    get_model,
    get_policy,
    policy_output_dim,
)
from jaxl.envs import get_environment
from jaxl.envs.rollouts import EvaluationRollout
from jaxl.utils import set_seed, parse_dict

In [6]:
run_seed = None
set_seed(run_seed)

In [7]:
rl_trained_path = "/Users/chanb/research/personal/jaxl/jaxl/logs/inverted_pendulum/reinforce/05-30-23_11_06_49-d9afde67-e3fd-4d7c-8df9-4be90167d3a3"
bc_trained_path = "/Users/chanb/research/personal/jaxl/jaxl/logs/inverted_pendulum/behavioural_cloning/05-30-23_13_08_07-449d4f3c-f440-4308-8ffd-d859614c602e"

num_episodes = 100
buffer_size = 100000

In [8]:
rl_config_path = os.path.join(rl_trained_path, "config.json")
with open(rl_config_path, "r") as f:
    rl_config_dict = json.load(f)
    rl_config_dict["learner_config"]["buffer_config"]["buffer_size"] = buffer_size
    rl_config_dict["learner_config"]["buffer_config"]["buffer_type"] = CONST_DEFAULT
    rl_config = parse_dict(rl_config_dict)
env_seed = rl_config.learner_config.seeds.env_seed

bc_config_path = os.path.join(bc_trained_path, "config.json")
with open(bc_config_path, "r") as f:
    bc_config_dict = json.load(f)
    bc_config_dict["learner_config"]["policy_distribution"] = CONST_DETERMINISTIC
    bc_config = parse_dict(bc_config_dict)

In [9]:
h_state_dim = (1,)
if hasattr(rl_config.model_config, "h_state_dim"):
    h_state_dim = rl_config.model_config.h_state_dim

env = get_environment(rl_config.learner_config.env_config)
rl_buffer = get_buffer(
    rl_config.learner_config.buffer_config,
    rl_config.learner_config.seeds.buffer_seed,
    env,
    h_state_dim,
)

input_dim = rl_buffer.input_dim
output_dim = policy_output_dim(rl_buffer.output_dim, rl_config.learner_config)
model = get_model(input_dim, output_dim, rl_config.model_config)
policy = get_policy(model, rl_config.learner_config)

rl_model_path = os.path.join(rl_trained_path, "termination_model")
checkpointer = PyTreeCheckpointer()
model_dict = checkpointer.restore(rl_model_path)
rl_policy_params = model_dict[CONST_MODEL][CONST_POLICY]
with open(os.path.join(rl_model_path, "learner_dict.pkl"), "rb") as f:
    learner_dict = pickle.load(f)
    rl_obs_rms = learner_dict[CONST_OBS_RMS]

rl_rollout = EvaluationRollout(env, seed=env_seed)
rl_rollout.rollout(rl_policy_params, policy, rl_obs_rms, num_episodes, rl_buffer)

TypeError: PRNG key seed must be an integer; got Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>

In [None]:
h_state_dim = (1,)
if hasattr(rl_config.model_config, "h_state_dim"):
    h_state_dim = bc_config.model_config.h_state_dim

env = get_environment(rl_config.learner_config.env_config)
bc_buffer = get_buffer(
    rl_config.learner_config.buffer_config,
    rl_config.learner_config.seeds.buffer_seed,
    env,
    h_state_dim,
)

input_dim = bc_buffer.input_dim
output_dim = policy_output_dim(bc_buffer.output_dim, rl_config.learner_config)
model = get_model(input_dim, output_dim, rl_config.model_config)
policy = get_policy(model, rl_config.learner_config)

bc_model_path = os.path.join(bc_trained_path, "termination_model")
checkpointer = PyTreeCheckpointer()
model_dict = checkpointer.restore(bc_model_path)
bc_policy_params = model_dict[CONST_MODEL][CONST_POLICY]
with open(os.path.join(bc_model_path, "learner_dict.pkl"), "rb") as f:
    learner_dict = pickle.load(f)
    bc_obs_rms = learner_dict[CONST_OBS_RMS]

bc_rollout = EvaluationRollout(env, seed=env_seed)
bc_rollout.rollout(bc_policy_params, policy, bc_obs_rms, num_episodes, bc_buffer)