# Evalute whether there appears to be Neural Collapse in policy

In [1]:
from gymnasium.experimental.wrappers import RecordVideoV0
from orbax.checkpoint import PyTreeCheckpointer, CheckpointManager
from pprint import pprint

import _pickle as pickle
import jax
import json
import logging
import matplotlib.pyplot as plt
import numpy as np
import os
import timeit

from jaxl.buffers import get_buffer
from jaxl.constants import *
from jaxl.models import (
    get_model,
    get_policy,
    policy_output_dim,
)
from jaxl.models.policies import MultitaskPolicy
from jaxl.envs import get_environment
from jaxl.envs.rollouts import EvaluationRollout
from jaxl.utils import (
    set_seed,
    parse_dict,
    set_dict_value,
    get_dict_value,
    RunningMeanStd,
    get_device,
)

In [2]:
run_seed = 0
device = "gpu:1"
get_device(device)
set_seed(run_seed)

In [3]:
def get_env(agent_path):
    agent_config_path = os.path.join(agent_path, "config.json")
    with open(agent_config_path, "r") as f:
        agent_config_dict = json.load(f)
    return agent_config_dict["learner_config"]["env_config"]

In [14]:
base_path = "/home/bryanpu1/projects/rl_nc_representation/jaxl"
log_path = os.path.join(base_path, "jaxl/logs")
project_name = "cartpole"
run_name = (
    "ppo-03-05-24_17_28_52-0d5db8df-df17-47e1-9802-8ddeec98c0b2"
)

agent_path = (
    agent_to_load_env_path
) = os.path.join(
    log_path,
    project_name,
    run_name,
)
trained_env_parameters = get_env(agent_to_load_env_path)

num_episodes = 10
env_seed = 9999
buffer_size = 0
num_seeds = 50
record_video = False
exp_name = "-".join(run_name.split("-")[:-8])

In [15]:
def get_config(agent_path):
    agent_config_path = os.path.join(agent_path, "config.json")
    with open(agent_config_path, "r") as f:
        agent_config_dict = json.load(f)
        agent_config_dict["learner_config"]["env_config"] = trained_env_parameters
        agent_config_dict["learner_config"]["env_config"]["env_kwargs"][
            "render_mode"
        ] = "rgb_array"
        if "policy_distribution" not in agent_config_dict["learner_config"]:
            agent_config_dict["learner_config"][
                "policy_distribution"
            ] = CONST_DETERMINISTIC
        set_dict_value(agent_config_dict, "vmap_all", False)
        (multitask, num_models) = get_dict_value(agent_config_dict, "num_models")
        agent_config = parse_dict(agent_config_dict)
    return agent_config, {
        "multitask": multitask,
        "num_models": num_models,
    }

In [16]:
from jaxl.buffers.ram_buffers import NextStateNumPyBuffer

In [17]:
checkpoint_manager = CheckpointManager(
    os.path.join(agent_path, "models"),
    PyTreeCheckpointer(),
)
params = checkpoint_manager.restore(checkpoint_manager.latest_step())
model_dict = params[CONST_MODEL_DICT]
agent_policy_params = model_dict[CONST_MODEL][CONST_POLICY]
agent_obs_rms = False
if CONST_OBS_RMS in params:
    agent_obs_rms = RunningMeanStd()
    agent_obs_rms.set_state(params[CONST_OBS_RMS])

agent_config, aux = get_config(agent_path)
env = get_environment(agent_config.learner_config.env_config)

buffer = NextStateNumPyBuffer(
    buffer_size=100000,
    obs_dim=env.observation_space.shape,
    act_dim=env.act_dim,
    rew_dim=env.reward_dim,
    h_state_dim=(1,),
    rng=np.random.RandomState(42)
)

if record_video:
    env = RecordVideoV0(env, f"videos/{exp_name}-videos")

input_dim = env.observation_space.shape
output_dim = policy_output_dim(env.act_dim, agent_config.learner_config)
model = get_model(
    input_dim,
    output_dim,
    getattr(agent_config.model_config, "policy", agent_config.model_config),
)
policy = get_policy(model, agent_config.learner_config)
if aux["multitask"]:
    policy = MultitaskPolicy(policy, model, aux["num_models"])

agent_rollout = EvaluationRollout(env, seed=env_seed)
agent_rollout.rollout(
    agent_policy_params, policy, agent_obs_rms, num_episodes, buffer
)

  logger.warn(
100%|██████████| 10/10 [00:08<00:00,  1.17it/s]


In [22]:
all_acts = buffer.actions[:buffer.pointer]

In [27]:
unique_acts, _ = np.unique(all_acts)

In [31]:
agent_config

namespace(logging_config=namespace(save_path='./logs/cartpole',
                                   experiment_name='ppo',
                                   log_interval=10,
                                   checkpoint_interval=100),
          model_config=namespace(policy=namespace(architecture='mlp',
                                                  layers=[64, 64],
                                                  activation='tanh'),
                                 vf=namespace(architecture='mlp',
                                              layers=[64, 64],
                                              activation='tanh')),
          optimizer_config=namespace(policy=namespace(optimizer='adam',
                                                      lr=namespace(scheduler='constant_schedule',
                                                                   scheduler_kwargs=namespace(value=0.0003)),
                                                      max_grad_norm=False),
      

In [38]:
params[CONST_MODEL_DICT][CONST_MODEL].keys()

dict_keys(['policy', 'vf'])

In [41]:
from jaxl.models.common import get_activation
from jaxl.models.modules import MLPModule

from collections import OrderedDict

def get_latent(params, inputs, carries):
    _, mlp_states = MLPModule(
        agent_config.model_config.policy.layers,
        get_activation(CONST_RELU),
        get_activation(CONST_IDENTITY),
        use_batch_norm=False,
    ).apply(
        {"params": params[CONST_MODEL_DICT][CONST_MODEL][CONST_POLICY][CONST_PARAMS]},
        inputs,
        eval=True,
        capture_intermediates=True,
        mutable=["mlp_latents"],
    )

    latents = OrderedDict()
    for states, key in [
        (mlp_states, "mlp_latents"),
    ]:
        for state, state_val in states[key].items():
            latents[state] = state_val
    return latents

In [43]:
latents = jax.vmap(get_latent, in_axes=[None, 0, 0])(
    params,
    buffer.observations[:buffer.pointer],
    buffer.hidden_states[:buffer.pointer],
)