# Evaluate PPO Parameters (Single File)

Pick one `.npy` file from `parameters/` and evaluate it using the same network configuration as training.

In [1]:
from pathlib import Path
import os, sys
import numpy as np
import jax

# Resolve repo root
repo_root = Path.cwd()
for p in [repo_root] + list(repo_root.parents):
    if (p / 'parameters').exists() and (p / 'evaluation').exists():
        repo_root = p
        break
if str(repo_root) not in sys.path:
    sys.path.append(str(repo_root))

# Headless rendering for MuJoCo
os.environ.setdefault('MUJOCO_GL', 'egl')
repo_root

PosixPath('/home/jovyan/EAI2025_RL_FINAL')

In [2]:
# Paste a path from the list above
from pathlib import Path
PARAM_FILE = '/home/jovyan/EAI2025_RL_FINAL/parameters/params_with_height_and_knee.npy'  # e.g., '/home/jovyan/EAI2025_RL_FINAL/parameters/params.npy' needs to be absolute path
PARAM_FILE = Path(PARAM_FILE)
print('Selected:', PARAM_FILE)
assert PARAM_FILE.exists(), f'File not found: {PARAM_FILE}'
PARAM_FILE


Selected: /home/jovyan/EAI2025_RL_FINAL/parameters/params_with_height_and_knee.npy


PosixPath('/home/jovyan/EAI2025_RL_FINAL/parameters/params_with_height_and_knee.npy')

In [3]:
# Build env and policy like in training, then evaluate
from environments.custom_env import Joystick, default_config
from brax.training.acme import running_statistics
from brax.training.agents.ppo import networks as ppo_networks
from utils import evaluate_policy

# Environment
xml_path = repo_root / 'environments' / 'custom_env.xml'
env_cfg = default_config()
env = Joystick(xml_path=str(xml_path), config=env_cfg)
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)

# Observation shape pytree as plain tuples (matches training)
state0 = jit_reset(jax.random.PRNGKey(0))
obs_shape = jax.tree_util.tree_map(lambda x: x.shape[-1:], state0.obs)

# Network factory kwargs from training config
net_kwargs = {}
try:
    from mujoco_playground.config import locomotion_params
    ppo_params = locomotion_params.brax_ppo_config('Go1JoystickRoughTerrain')
    if hasattr(ppo_params, 'network_factory'):
        net_kwargs = dict(ppo_params.network_factory)
except Exception as e:
    print('Warning: using default network settings:', e)

# Build networks and inference fn
ppo_net = ppo_networks.make_ppo_networks(
    obs_shape, env.action_size, preprocess_observations_fn=running_statistics.normalize, **net_kwargs
)
make_policy = ppo_networks.make_inference_fn(ppo_net)

# Load params and evaluate
params_arr = np.load(PARAM_FILE, allow_pickle=True)
if getattr(params_arr, 'dtype', None) != object or len(params_arr) != 3:
    raise ValueError(f'Unexpected params format: dtype={getattr(params_arr, "dtype", None)} shape={getattr(params_arr, "shape", None)}')
normalizer_params, policy_params, value_params = params_arr
params_tuple = (normalizer_params, policy_params, value_params)
inference_fn = make_policy(params_tuple, deterministic=True)

evaluate_policy(env, jax.jit(inference_fn), jit_step, jit_reset, env_cfg, env)


2025-10-15 12:20:51.238898: W external/xla/xla/service/gpu/autotuning/dot_search_space.cc:200] All configs were filtered out because none of them sufficiently match the hints. Maybe the hints set does not contain a good representative set of valid configs?Working around this by using the full hints set instead.
100%|██████████| 500/500 [01:31<00:00,  5.47it/s]


Evaluate trained model on unseen harder environment containing stairs.

In [None]:
# Build env and policy like in training, then evaluate
from environments.custom_env import Joystick, default_config
from brax.training.acme import running_statistics
from brax.training.agents.ppo import networks as ppo_networks
from utils import evaluate_policy

# Environment
xml_path = repo_root / 'environments' / 'stairs_env.xml'
env_cfg = default_config()
env = Joystick(xml_path=str(xml_path), config=env_cfg)
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)

# Observation shape pytree as plain tuples (matches training)
state0 = jit_reset(jax.random.PRNGKey(0))
obs_shape = jax.tree_util.tree_map(lambda x: x.shape[-1:], state0.obs)

# Network factory kwargs from training config
net_kwargs = {}
try:
    from mujoco_playground.config import locomotion_params
    ppo_params = locomotion_params.brax_ppo_config('Go1JoystickRoughTerrain')
    if hasattr(ppo_params, 'network_factory'):
        net_kwargs = dict(ppo_params.network_factory)
except Exception as e:
    print('Warning: using default network settings:', e)

# Build networks and inference fn
ppo_net = ppo_networks.make_ppo_networks(
    obs_shape, env.action_size, preprocess_observations_fn=running_statistics.normalize, **net_kwargs
)
make_policy = ppo_networks.make_inference_fn(ppo_net)

# Load params and evaluate
params_arr = np.load(PARAM_FILE, allow_pickle=True)
if getattr(params_arr, 'dtype', None) != object or len(params_arr) != 3:
    raise ValueError(f'Unexpected params format: dtype={getattr(params_arr, "dtype", None)} shape={getattr(params_arr, "shape", None)}')
normalizer_params, policy_params, value_params = params_arr
params_tuple = (normalizer_params, policy_params, value_params)
inference_fn = make_policy(params_tuple, deterministic=True)

evaluate_policy(env, jax.jit(inference_fn), jit_step, jit_reset, env_cfg, env)