In [1]:
import jax.numpy as jp
import numpy as np
import jax
from jax import config # Analytical gradients work much better with double precision.
config.update("jax_debug_nans", True)
config.update("jax_enable_x64", True)
config.update('jax_default_matmul_precision', 'high')
from mujoco_playground import registry
env = registry.load("AnymalTrot")

In [2]:
print(env.lowers)
print()
print(env.uppers)

[-0.72    -9.42478 -9.42478 -0.49    -9.42478 -9.42478 -0.72    -9.42478
 -9.42478 -0.49    -9.42478 -9.42478]

[0.49    9.42478 9.42478 0.72    9.42478 9.42478 0.49    9.42478 9.42478
 0.72    9.42478 9.42478]


In [3]:
step_fn = jax.jit(env.step)
reset_fn = jax.jit(env.reset)

state = reset_fn(jax.random.PRNGKey(42))
state = step_fn(state, jp.zeros(env.action_size))

In [4]:
def generate_test_actions(action_dim, steps=20, seed=0):
    rng = np.random.default_rng(seed)
    actions = []

    # 1. 零动作
    actions.append(np.zeros(action_dim))

    # 2. 小幅随机动作
    for _ in range(5):
        actions.append(0.1 * rng.standard_normal(action_dim))

    # 3. 周期性动作（正弦波）
    for t in range(5):
        a = 0.3 * np.sin(2 * np.pi * t / 5) * np.ones(action_dim)
        actions.append(a)

    # 4. 极端动作（边界值）
    actions.append(np.ones(action_dim))    # max
    actions.append(-np.ones(action_dim))   # min

    # 5. 混合：随机 + 正弦
    for t in range(steps - len(actions)):
        a = 0.2 * np.sin(2 * np.pi * t / 10) * np.ones(action_dim)
        a += 0.1 * rng.standard_normal(action_dim)
        actions.append(a)

    return np.array(actions)

# 用法
actions = generate_test_actions(env.action_size, steps=20, seed=42)
print(actions.shape)  # (20, action_dim)
print(actions)      # 打印生成的动作序列


(20, 12)
[[ 0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.          0.          0.        ]
 [ 0.03047171 -0.10399841  0.07504512  0.09405647 -0.19510352 -0.13021795
   0.01278404 -0.03162426 -0.00168012 -0.08530439  0.0879398   0.07777919]
 [ 0.00660307  0.11272412  0.04675093 -0.08592925  0.03687508 -0.09588826
   0.08784503 -0.00499259 -0.01848624 -0.06809295  0.12225413 -0.01545295]
 [-0.04283278 -0.03521336  0.05323092  0.03654441  0.04127326  0.0430821
   0.21416476 -0.0406415  -0.05122427 -0.08137727  0.06159794  0.11289723]
 [-0.01139475 -0.08401565 -0.08244812  0.06505928  0.07432542  0.05431543
  -0.06655097  0.02321613  0.01166858  0.02186886  0.08714288  0.02235955]
 [ 0.06789136  0.00675791  0.02891194  0.06312882 -0.14571558 -0.03196712
  -0.04703727 -0.06388778 -0.02751423  0.14949413 -0.08658311  0.09682784]
 [ 0.          0.          0.          0.          0.          0.
   0.          0.          0.          0

In [7]:
import time

def rollout(reset_fn, step_fn, actions, seed=0):
    """
    Rollout the TrotAnymal environment using a sequence of actions.
    Args:
        env: TrotAnymal instance
        actions: jp.array of shape [episode_length, action_dim]
        seed: random seed for reset
    Returns:
        states: list of mjx_env.State objects
    """
    rng = jax.random.PRNGKey(seed)
    state = reset_fn(rng)
    states = [state]
    print("start")
    for action in actions:
        start = time.time()
        state = step_fn(state, action)
        states.append(state)
        print("step cost:", time.time() - start)
    return states

states = rollout(reset_fn, step_fn, jp.array(actions), seed=42)

start
step cost: 0.030487775802612305
step cost: 0.031014680862426758
step cost: 0.03406476974487305
step cost: 0.02564239501953125
step cost: 0.031165122985839844
step cost: 0.02360248565673828
step cost: 0.02169322967529297
step cost: 0.018362045288085938
step cost: 0.033351898193359375
step cost: 0.02606487274169922
step cost: 0.030895471572875977
step cost: 0.026174306869506836
step cost: 0.0293428897857666
step cost: 0.026992082595825195
step cost: 0.03629803657531738
step cost: 0.01883101463317871
step cost: 0.018632888793945312
step cost: 0.022719383239746094
step cost: 0.02656722068786621
step cost: 0.025464296340942383


In [None]:
qpos_list = [state.data.qpos for state in states]
qvel_list = [state.data.qvel for state in states]
reward_list = [state.reward for state in states]
done_list = [state.done for state in states]

# 将列表转换为 NumPy 数组
qpos_array = np.stack(qpos_list)
qvel_array = np.stack(qvel_list)
reward_array = np.stack(reward_list)
done_array = np.stack(done_list)

print("qpos_array shape:", qpos_array.shape)
print("qvel_array shape:", qvel_array.shape)
print("reward_array shape:", reward_array.shape)
print("done_array shape:", done_array.shape)

# np.savez_compressed(
#     'test_trajectory_data.npz',
#     qpos=qpos_array,
#     qvel=qvel_array,
#     reward=reward_array,
#     done=done_array
# )

qpos_array shape: (21, 19)
qvel_array shape: (21, 18)
reward_array shape: (21,)
done_array shape: (21,)
