In [None]:
import jax
import jax.numpy as jnp
import jax.random as jrandom
import numpy as np
import os

from penzai import pz
from tqdm.notebook import tqdm

from jaxl.constants import *
from jaxl.models.utils import get_model, load_config, load_params, get_wsrl_model
from jaxl.models.policies import get_policy, policy_output_dim
from jaxl.buffers import get_buffer
from jaxl.utils import get_device, parse_dict

import IPython

pz.ts.register_as_default()

# Optional automatic array visualization extras:
pz.ts.register_autovisualize_magic()
pz.enable_interactive_context()
pz.ts.active_autovisualizer.set_interactive(pz.ts.ArrayAutovisualizer())

get_device("gpu:0")

In [None]:
result_dir = "/home/bryan/research/jaxl/logs/manipulator_learning"

ablation_name = "stack"
learner_name = "bc-10k_steps-06-04-24_10_06_35-333b32a8-c019-4fed-9b8f-1ce59166bb2b"

learner_path = os.path.join(result_dir, ablation_name, learner_name)

In [None]:
_, config = load_config(learner_path)
config

In [None]:
buffer_config = parse_dict(dict(
    load_buffer='/home/bryan/research/lfgp/lfgp_data/custom_expert_data/stack/1000000_steps_no_extra_final/int_0.gz',
    buffer_type='default',
    set_size=False,
))

In [None]:
buffer = get_buffer(buffer_config)

In [None]:
model_out_dim = policy_output_dim(buffer.act_dim, config.learner_config)

if config.learner_config.learner == CONST_BC:
    model = get_model(
        int(np.prod(buffer.input_dim)),
        buffer.act_dim,
        config.model_config
    )
    include_absorbing_state = True
elif config.learner_config.task == CONST_WSRL:
    model = get_wsrl_model(
        int(np.prod(buffer.input_dim)),
        model_out_dim,
        config.model_config.policy
    )
    include_absorbing_state = True
else:
    model = get_model(
        int(np.prod(buffer.input_dim)) - 1,
        model_out_dim,
        config.model_config.policy
    )
    include_absorbing_state = False
params = load_params(f"{learner_path}:latest")

## Visualize Parameters

In [None]:
import IPython

pz.ts.register_as_default()

# Optional automatic array visualization extras:
pz.ts.register_autovisualize_magic()
pz.enable_interactive_context()
pz.ts.active_autovisualizer.set_interactive(pz.ts.ArrayAutovisualizer())

def my_continuous_autovisualizer(
    value,
    path,
):
  if isinstance(value, (np.ndarray, pz.nx.NamedArray)):
    return pz.ts.IPythonVisualization(
        pz.ts.render_array(value, continuous=True, around_zero=False))
  
with pz.ts.active_autovisualizer.set_scoped(
    my_continuous_autovisualizer
):
  IPython.display.display({
    "/".join([curr_kp.key for curr_kp in kp]): pz.nx.wrap(val) for (kp, val) in jax.tree_util.tree_flatten_with_path(params["model_dict"]["model"]["policy"])[0]
})

## Check Dormant

In [None]:
obss = buffer.sample(256)[0]

In [None]:
out, state = model.model.apply(params[CONST_MODEL_DICT][CONST_MODEL][CONST_POLICY], obss, capture_intermediates=True, mutable=["intermediates"], eval=True)

In [None]:
res = dict()
for (kp, val) in jax.tree_util.tree_flatten_with_path(state["intermediates"])[0]:
    per_neuron_score = jnp.mean(jnp.abs(val), axis=0)
    res["/".join([curr_kp.key if hasattr(curr_kp, "key") else str(curr_kp.idx) for curr_kp in kp])] = (per_neuron_score / jnp.mean(per_neuron_score, axis=-1)).reshape((4, -1))
    print(jnp.min(res["/".join([curr_kp.key if hasattr(curr_kp, "key") else str(curr_kp.idx) for curr_kp in kp])]))

In [None]:
res