# Evaluate PPO Parameters (Single File)

Pick one `.npy` file from `parameters/` and evaluate it using the same network configuration as training.

In [None]:
from pathlib import Path
import os, sys
import numpy as np
import jax

# Headless rendering for MuJoCo
os.environ.setdefault('MUJOCO_GL', 'egl')


PosixPath('/home/loke/Desktop/EAI2025_RL_FINAL')

## Load in paths

In [None]:
# Get paths
from config import ProjectPaths
# Models
model_paths = {
    'baseline': ProjectPaths.PARAMS_BASELINE,
    'with_knee': ProjectPaths.PARAMS_WITH_HEIGHT,
    'with_height': ProjectPaths.PARAMS_WITH_HEIGHT,
    'with_height_and_knee': ProjectPaths.PARAMS_WITH_HEIGHT_AND_KNEE
}

env_paths = {
    'custom': ProjectPaths.CUSTOM_ENV_XML, # Maybe should be named default
    'custom_debug': ProjectPaths.CUSTOM_ENV_DEBUG_XML,
    'stairs': ProjectPaths.STAIRS_ENV_XML
}

print("Available models:")
print("=" * 35)
for key in model_paths:
    print(f"'{key}'")

print("\n")

print("Available Environments:")
print("=" * 35)
for key in env_paths:
    print(f"'{key}'")



Available models:
'baseline'
'with_knee'
'with_height'
'with_height_and_knee'


Available Environments:
'custom'
'custom_debug'
'stairs'


## Choose model and environment

In [None]:
model = model_paths['with_height_and_knee'] # Enter model here
env = env_paths['custom'] # Enter environment here

In [43]:
# Paste a path from the list above
from pathlib import Path
PARAM_FILE = model
PARAM_FILE = Path(PARAM_FILE)
print('Selected:', PARAM_FILE)
assert PARAM_FILE.exists(), f'File not found: {PARAM_FILE}'
PARAM_FILE


Selected: /home/loke/Desktop/EAI2025_RL_FINAL/parameters/params_baseline.npy


PosixPath('/home/loke/Desktop/EAI2025_RL_FINAL/parameters/params_baseline.npy')

## Run Evaluation

In [44]:
# Build env and policy like in training, then evaluate
from environments.custom_env import Joystick, default_config
from brax.training.acme import running_statistics
from brax.training.agents.ppo import networks as ppo_networks
from utils import evaluate_policy

# Environment
xml_path = env
env_cfg = default_config()
env = Joystick(xml_path=str(xml_path), config=env_cfg)
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)

# Observation shape pytree as plain tuples (matches training)
state0 = jit_reset(jax.random.PRNGKey(0))
obs_shape = jax.tree_util.tree_map(lambda x: x.shape[-1:], state0.obs)

# Network factory kwargs from training config
net_kwargs = {}
try:
    from mujoco_playground.config import locomotion_params
    ppo_params = locomotion_params.brax_ppo_config('Go1JoystickRoughTerrain')
    if hasattr(ppo_params, 'network_factory'):
        net_kwargs = dict(ppo_params.network_factory)
except Exception as e:
    print('Warning: using default network settings:', e)

# Build networks and inference fn
ppo_net = ppo_networks.make_ppo_networks(
    obs_shape, env.action_size, preprocess_observations_fn=running_statistics.normalize, **net_kwargs
)
make_policy = ppo_networks.make_inference_fn(ppo_net)

# Load params and evaluate
params_arr = np.load(PARAM_FILE, allow_pickle=True)
if getattr(params_arr, 'dtype', None) != object or len(params_arr) != 3:
    raise ValueError(f'Unexpected params format: dtype={getattr(params_arr, "dtype", None)} shape={getattr(params_arr, "shape", None)}')
normalizer_params, policy_params, value_params = params_arr
params_tuple = (normalizer_params, policy_params, value_params)
inference_fn = make_policy(params_tuple, deterministic=True)

evaluate_policy(env, jax.jit(inference_fn), jit_step, jit_reset, env_cfg, env)


TypeError: sub got incompatible shapes for broadcasting: (52,), (48,).