In [None]:
from gymnasium.experimental.wrappers import RecordVideoV0
from orbax.checkpoint import PyTreeCheckpointer

import _pickle as pickle
import jax
import json
import logging
import matplotlib.pyplot as plt
import numpy as np
import os
import timeit

from jaxl.buffers import get_buffer
from jaxl.constants import *
from jaxl.models import (
    get_model,
    get_policy,
    policy_output_dim,
)
from jaxl.models.policies import MultitaskPolicy
from jaxl.envs import get_environment
from jaxl.envs.rollouts import EvaluationRollout
from jaxl.utils import set_seed, parse_dict, set_dict_value, get_dict_value, l2_norm

In [None]:
run_seed = None
set_seed(run_seed)

In [None]:
def get_env(agent_path):
    agent_config_path = os.path.join(agent_path, "config.json")
    with open(agent_config_path, "r") as f:
        agent_config_dict = json.load(f)
    return agent_config_dict["learner_config"]["env_config"]

In [None]:
runs_dir = "/Users/chanb/research/personal/jaxl/data/inverted_double_pendulum/test_mtbc/gravity-representation_sensitivity/runs/0"
buffer_config = {
    "buffer_type": "default",
    "load_buffer": "/Users/chanb/research/personal/jaxl/data/inverted_double_pendulum/expert_data/gravity//gravity_-8.249612491943623-06-09-23_15_21_56-296b3f54-5c33-43f3-97dd-3b7eb184bc99.gzip",
}
buffer_config = parse_dict(buffer_config)
buffer = get_buffer(buffer_config, 42)

In [None]:
def get_config(agent_path):
    agent_config_path = os.path.join(agent_path, "config.json")
    with open(agent_config_path, "r") as f:
        agent_config_dict = json.load(f)
        (multitask, num_models) = get_dict_value(agent_config_dict, "num_models")
        set_dict_value(agent_config_dict, "vmap_all", False)
        agent_config = parse_dict(agent_config_dict)
    return agent_config, {
        "multitask": multitask,
        "num_models": num_models,
    }

In [None]:
models = {}

checkpointer = PyTreeCheckpointer()
for root, dirnames, _ in os.walk(runs_dir):
    for dirname in dirnames:
        if dirname != "termination_model":
            continue
        agent_model_path = os.path.join(root, dirname)
        agent_config, _ = get_config(root)
        learner_config = agent_config.learner_config

        model_dict = checkpointer.restore(agent_model_path)
        models[os.path.basename(os.path.dirname(root))] = {
            "model": get_model(
                buffer.input_dim, buffer.output_dim, agent_config.model_config
            ),
            "model_dict": model_dict,
        }

In [None]:
models.keys()

In [None]:
for key, model in models.items():
    print(key)
    print(
        l2_norm(model["model_dict"][CONST_MODEL][CONST_POLICY][CONST_ENCODER]),
        l2_norm(model["model_dict"][CONST_MODEL][CONST_POLICY][CONST_PREDICTOR]),
    )

In [None]:
obss, h_states, acts, *_ = buffer.sample(256)

In [None]:
latent_preds = {}
act_preds = {}
for key, model in models.items():
    latent_preds[key] = model["model"].encode(
        model["model_dict"][CONST_MODEL][CONST_POLICY][CONST_ENCODER],
        obss,
        h_states,
    )[0]
    act_preds[key] = model["model"].forward(
        model["model_dict"][CONST_MODEL][CONST_POLICY], obss, h_states
    )[0]

In [None]:
act_diffs = []
for key_1, preds_1 in act_preds.items():
    act_diffs.append([])
    for key_2, preds_2 in act_preds.items():
        act_diffs[-1].append(np.sum((preds_1 - preds_2) ** 2))

In [None]:
fig, ax = plt.subplots(figsize=(30, 5))

# Hide axes
collabel = list(act_preds.keys())
rowlabel = list(act_preds.keys())
ax.axis("off")
ax.table(act_diffs, loc="center", colLabels=collabel, rowLabels=rowlabel)
fig.tight_layout()

In [None]:
latent_diffs = []
for key_1, preds_1 in latent_preds.items():
    latent_diffs.append([])
    for key_2, preds_2 in latent_preds.items():
        latent_diffs[-1].append(np.sum((preds_1 - preds_2) ** 2))

In [None]:
fig, ax = plt.subplots(figsize=(30, 5))

# Hide axes
collabel = list(latent_preds.keys())
rowlabel = list(latent_preds.keys())
ax.axis("off")
ax.table(latent_diffs, loc="center", colLabels=collabel, rowLabels=rowlabel)
fig.tight_layout()