In [None]:
import jax
import jax.numpy as jnp
import jax.random as jrandom
import numpy as np
import os

from functools import partial
from penzai import pz
from tqdm.notebook import tqdm

from jaxl.constants import *
from jaxl.models.utils import get_model, load_config, load_params, get_wsrl_model, iterate_params, get_policy, policy_output_dim
from jaxl.buffers import get_buffer
from jaxl.utils import get_device, parse_dict

import IPython

pz.ts.register_as_default()

# Optional automatic array visualization extras:
pz.ts.register_autovisualize_magic()
pz.enable_interactive_context()
pz.ts.active_autovisualizer.set_interactive(pz.ts.ArrayAutovisualizer())

get_device("gpu:0")

In [None]:
result_dir = "/home/bryan/research/jaxl/logs/manipulator_learning"

ablation_name = "stack"
# learner_name = "bc-10k_steps-06-04-24_10_06_35-333b32a8-c019-4fed-9b8f-1ce59166bb2b"
# learner_name = "warm_start_policy_evaluation-06-05-24_10_50_37-f9950c16-2b62-4078-85fd-2e43f5e6d1ed"
# learner_name = "rlpd-sac-06-05-24_16_15_56-c5ad96da-4ac4-466b-a221-74cfea71bd19"
learner_name = "rlpd-sac-high_utd-06-06-24_08_51_26-65753ce9-cb20-4d8b-ad9f-8b9c24b98b14"

learner_path = os.path.join(result_dir, ablation_name, learner_name)
check_policy = False

In [None]:
_, config = load_config(learner_path)
config

## Setup Buffer and Model

In [None]:
buffer_config = parse_dict(dict(
    load_buffer='/home/bryan/research/lfgp/lfgp_data/custom_expert_data/stack/1000000_steps_no_extra_final/int_0.gz',
    buffer_type='default',
    set_size=False,
))

In [None]:
buffer = get_buffer(buffer_config)

In [None]:
model_out_dim = policy_output_dim(buffer.act_dim, config.learner_config)

param_key = CONST_POLICY
if config.learner_config.learner == CONST_BC:
    model = get_model(
        int(np.prod(buffer.input_dim)),
        buffer.act_dim,
        config.model_config
    )
    include_absorbing_state = True
elif config.learner_config.task == CONST_WSRL:
    if config.learner_config.learner == CONST_POLICY_EVALUATION:
        model = get_model(
            int(np.prod(buffer.input_dim)),
            (1,),
            config.model_config.vf
        )
        include_absorbing_state = True
        param_key = CONST_VF
    else:
        model = get_wsrl_model(
            int(np.prod(buffer.input_dim)),
            model_out_dim,
            config.model_config.policy
        )
        include_absorbing_state = True
else:
    if check_policy:
        model = get_model(
            int(np.prod(buffer.input_dim)) - 1,
            model_out_dim,
            config.model_config.policy
        )
    else:
        param_key = CONST_VF if hasattr(config.model_config, CONST_VF) else CONST_QF
        if param_key == CONST_QF:
            model = get_model(
                int(np.prod(buffer.input_dim)) - 1 + int(np.prod(buffer.act_dim)),
                (1,),
                config.model_config.qf
            )
        else:
            model = get_model(
                int(np.prod(buffer.input_dim)) - 1,
                (1,),
                config.model_config.vf
            )
    include_absorbing_state = False

obss, _, acts = buffer.sample(256)[:3]

if not include_absorbing_state:
    obss = obss[..., :-1]

print(obss.shape, acts.shape)

## Visualize Dormant

In [None]:
def get_model_outputs(params, obss, acts, config, model, param_key):
    multi_output = False
    if param_key == CONST_QF:
        if config.model_config.qf.architecture == CONST_ENSEMBLE:
            out, state = jax.vmap(
                partial(
                    model.model.model.apply,
                    capture_intermediates=True,
                    mutable=["intermediates"],
                    eval=True
                ),
                in_axes=[0, None]
            )(
                params[CONST_MODEL_DICT][CONST_MODEL][param_key],
                np.concatenate((obss, acts[:, None]), axis=-1),
            )
            multi_output = True
        else:
            out, state = model.model.model.apply(
                params[CONST_MODEL_DICT][CONST_MODEL][param_key],
                np.concatenate((obss, acts[:, None]), axis=-1),
                capture_intermediates=True,
                mutable=["intermediates"],
                eval=True
            )
    else:
        out, state = model.model.apply(
            params[CONST_MODEL_DICT][CONST_MODEL][param_key],
            obss,
            capture_intermediates=True,
            mutable=["intermediates"],
            eval=True
        )
    return out, state, multi_output

inference = partial(
    get_model_outputs,
    config=config,
    model=model,
    param_key=param_key
)

In [None]:
params = load_params(f"{learner_path}:latest")

In [None]:
def compute_dormant(params, obss, acts, dormant_threshold=0.025):
    out, state, multi_output = inference(params, obss, acts)
    dormant_score = dict()
    is_dormant = dict()
    for (kp, val) in jax.tree_util.tree_flatten_with_path(state["intermediates"])[0]:
        if getattr(kp[0], "key", False) == "__call__":
            continue
        per_neuron_score = jnp.mean(jnp.abs(val), axis=1 if multi_output else 0)
        curr_key = "/".join([curr_kp.key if hasattr(curr_kp, "key") else str(curr_kp.idx) for curr_kp in kp][:-2])
        # XXX: https://github.com/google/dopamine/issues/209
        dormant_score[curr_key] = (per_neuron_score / jnp.mean(per_neuron_score, axis=-1, keepdims=True))
        is_dormant[curr_key] = dormant_score[curr_key] <= dormant_threshold

        if np.prod(dormant_score[curr_key].shape) % 4 == 0:
            if multi_output:
                dormant_score[curr_key] = dormant_score[curr_key].reshape((len(dormant_score[curr_key]), 4, -1))
            else:
                dormant_score[curr_key] = dormant_score[curr_key].reshape((4, -1))
    return dormant_score, is_dormant, multi_output

def compute_dormant_percentage(is_dormant, multi_output):
    return jax.tree_util.tree_reduce(
        lambda x, y: x + jnp.sum(y, axis=-1),
        is_dormant,
        0
    ) / jax.tree_util.tree_reduce(
        lambda x, y: x + np.prod(y.shape[int(multi_output):]),
        is_dormant,
        0
    )

In [None]:
dormant_score, is_dormant, multi_output = compute_dormant(params, obss, acts)

In [None]:
np.concatenate([(dormant_score[key]).flatten() for key in list(dormant_score.keys())[:-1]])

In [None]:
compute_dormant_percentage(is_dormant, multi_output)

## Visualize Parameters

In [None]:
{
    "/".join([curr_kp.key for curr_kp in kp]): pz.nx.wrap(val) for (kp, val) in jax.tree_util.tree_flatten_with_path(params["model_dict"]["model"][param_key])[0]
}

## Check Dormant Percentage

In [None]:
params_iter = iterate_params(f"{learner_path}")

In [None]:
dormant_threshold = 0.25
for params, checkpoint_i in params_iter:
    dormant_score, is_dormant, multi_output = compute_dormant(params, obss, acts, dormant_threshold)
    print(checkpoint_i, compute_dormant_percentage(is_dormant, multi_output))