In [1]:
from gymnasium.experimental.wrappers import RecordVideoV0
from orbax.checkpoint import PyTreeCheckpointer

import _pickle as pickle
import jax
import json
import logging
import matplotlib.pyplot as plt
import numpy as np
import os
import timeit

from jaxl.buffers import get_buffer
from jaxl.constants import *
from jaxl.models import (
    get_model,
    get_policy,
    policy_output_dim,
)
from jaxl.models.policies import MultitaskPolicy
from jaxl.envs import get_environment
from jaxl.envs.rollouts import EvaluationRollout
from jaxl.utils import set_seed, parse_dict, set_dict_value, get_dict_value

In [2]:
run_seed = None
set_seed(run_seed)

In [3]:
def get_env(agent_path):
    agent_config_path = os.path.join(agent_path, "config.json")
    with open(agent_config_path, "r") as f:
        agent_config_dict = json.load(f)
    return agent_config_dict["learner_config"]["env_config"]

In [4]:
runs_dir = "/Users/chanb/research/personal/jaxl/jaxl/logs/half_cheetah-default/ppo"
buffer_config = {
    "buffer_type": "default",
    "load_buffer": "/Users/chanb/research/personal/jaxl/jaxl/logs/half_cheetah/ppo/06-27-23_10_59_35-b651bdeb-e751-4e3e-962b-9d599009e77e/termination_buffer.gzip",
}
buffer_config = parse_dict(buffer_config)
buffer = get_buffer(buffer_config, 42)

In [5]:
def get_config(agent_path):
    agent_config_path = os.path.join(agent_path, "config.json")
    with open(agent_config_path, "r") as f:
        agent_config_dict = json.load(f)
        (multitask, num_models) = get_dict_value(agent_config_dict, "num_models")
        set_dict_value(agent_config_dict, "vmap_all", False)
        agent_config = parse_dict(agent_config_dict)

    return agent_config, {
        "multitask": multitask,
        "num_models": num_models,
    }

In [6]:
models = {}
env_configs = {}

checkpointer = PyTreeCheckpointer()
for root, dirnames, _ in os.walk(runs_dir):
    for dirname in dirnames:
        if dirname != "termination_model":
            continue
        agent_model_path = os.path.join(root, dirname)
        agent_config, misc = get_config(root)
        learner_config = agent_config.learner_config

        model_dict = checkpointer.restore(agent_model_path)
        models[os.path.basename(root)] = {
            "model": get_model(
                buffer.input_dim,
                buffer.output_dim,
                getattr(agent_config.model_config, "policy", agent_config.model_config),
            ),
            "model_dict": model_dict,
        }
        env_configs[os.path.basename(root)] = pickle.load(
            open(os.path.join(root, "env_config.pkl"), "rb")
        )

In [10]:
from pprint import pprint

for key, val in env_configs.items():
    print(key)
    print(val["modified_attributes"])

07-06-23_09_18_42-8df3523f-540c-4f9e-87ad-76780b5951c2
<mujoco model="cheetah">
  <compiler angle="radian" coordinate="local" inertiafromgeom="true" settotalmass="14" />
  <default>
    <joint armature=".1" damping=".01" limited="true" solimplimit="0 .8 .03" solreflimit=".02 1" stiffness="8" />
    <geom conaffinity="0" condim="3" contype="1" friction=".4 .1 .1" rgba="0.8 0.6 .4 1" solimp="0.0 0.8 0.01" solref="0.02 1" />
    <motor ctrllimited="true" ctrlrange="-1 1" />
  </default>
  <size nstack="300000" nuser_geom="1" />
  <option gravity="0 0 -9.81" timestep="0.01" />
  <asset>
    <texture builtin="gradient" height="100" rgb1="1 1 1" rgb2="0 0 0" type="skybox" width="100" />
    <texture builtin="flat" height="1278" mark="cross" markrgb="1 1 1" name="texgeom" random="0.01" rgb1="0.8 0.6 0.4" rgb2="0.8 0.6 0.4" type="cube" width="127" />
    <texture builtin="checker" height="100" name="texplane" rgb1="0 0 0" rgb2="0.8 0.8 0.8" type="2d" width="100" />
    <material name="MatPlane