In [None]:
from gymnasium.experimental.wrappers import RecordVideoV0
from orbax.checkpoint import PyTreeCheckpointer

import _pickle as pickle
import jax
import json
import logging
import matplotlib.pyplot as plt
import numpy as np
import os
import timeit

from jaxl.buffers import get_buffer
from jaxl.constants import *
from jaxl.models import (
    get_model,
    get_policy,
    policy_output_dim,
)
from jaxl.models.policies import MultitaskPolicy
from jaxl.envs import get_environment
from jaxl.envs.rollouts import EvaluationRollout
from jaxl.utils import set_seed, parse_dict, set_dict_value, get_dict_value

In [None]:
run_seed = None
set_seed(run_seed)

In [None]:
def get_env(agent_path):
    agent_config_path = os.path.join(agent_path, "config.json")
    with open(agent_config_path, "r") as f:
        agent_config_dict = json.load(f)
    return agent_config_dict["learner_config"]["env_config"]

In [None]:
runs_dir = "/Users/chanb/research/personal/jaxl/jaxl/logs/half_cheetah/ppo"
buffer_config = {
    "buffer_type": "default",
    "load_buffer": "/Users/chanb/research/personal/jaxl/jaxl/logs/half_cheetah/ppo/06-27-23_10_59_35-b651bdeb-e751-4e3e-962b-9d599009e77e/termination_buffer.gzip",
}
buffer_config = parse_dict(buffer_config)
buffer = get_buffer(buffer_config, 42)

In [None]:
def get_config(agent_path):
    agent_config_path = os.path.join(agent_path, "config.json")
    with open(agent_config_path, "r") as f:
        agent_config_dict = json.load(f)
        (multitask, num_models) = get_dict_value(agent_config_dict, "num_models")
        set_dict_value(agent_config_dict, "vmap_all", False)
        agent_config = parse_dict(agent_config_dict)
    return agent_config, {
        "multitask": multitask,
        "num_models": num_models,
    }

In [None]:
models = {}
env_configs = {}

checkpointer = PyTreeCheckpointer()
for root, dirnames, _ in os.walk(runs_dir):
    for dirname in dirnames:
        if dirname != "termination_model":
            continue
        agent_model_path = os.path.join(root, dirname)
        agent_config, _ = get_config(root)
        learner_config = agent_config.learner_config
        
        model_dict = checkpointer.restore(agent_model_path)
        models[os.path.basename(root)] = {
            "model": get_model(buffer.input_dim, buffer.output_dim, agent_config.model_config),
            "model_dict": model_dict,
        }
        env_configs[os.path.basename(root)] = pickle.load(open(os.path.join(root, "env_config.pkl"), "rb"))