You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I encountered a similar issue as #467 when running simulation on a robot with random actions. I followed the suggestions mentioned in this thread but still observed nan in simulation. Below is the minimal code and xml file to reproduce the issue. Could you help have a check on what may be the issue here? Thanks a lot!
from brax import base
from brax import math
from brax.envs.base import PipelineEnv, State
from brax.io import mjcf
import jax
from jax import numpy as jp
import mujoco
class Unimal(PipelineEnv):
def __init__(
self,
xml_path,
backend='generalized',
**kwargs,
):
sys = mjcf.load(xml_path)
sys = sys.replace(dt=0.005)
n_frames = 5
kwargs['n_frames'] = kwargs.get('n_frames', n_frames)
super().__init__(sys=sys, backend=backend, **kwargs)
self.get_action_index()
self._reset_noise_scale = 0.1
def get_action_index(self):
# mask the joints for each limb
self.limb_num = self.sys.num_links()
dof_link_idx = self.sys.dof_link()[6:].copy()
repeat_mask = (dof_link_idx[1:] == dof_link_idx[:-1])
repeat_mask = jp.insert(repeat_mask, 0, 0)
self.action_index = dof_link_idx * 2 + repeat_mask
def reset(self, rng: jp.ndarray) -> State:
"""Resets the environment to an initial state."""
rng, rng1, rng2 = jax.random.split(rng, 3)
low, hi = -self._reset_noise_scale, self._reset_noise_scale
qpos = self.sys.qpos0 + jax.random.uniform(
rng1, (self.sys.nq,), minval=low, maxval=hi
)
qvel = jax.random.uniform(rng2, (self.sys.nv,), minval=low, maxval=hi)
data = self.pipeline_init(qpos, qvel)
obs = None
reward, done, zero = jp.zeros(3)
metrics = {}
return State(data, obs, reward, done, metrics)
def step(self, state: State, action: jp.ndarray) -> State:
"""Run one timestep of the environment's dynamics."""
# remove useless action dimensions
action = action[self.action_index]
# step
pipeline_state0 = state.pipeline_state
pipeline_state = self.pipeline_step(pipeline_state0, action)
obs = None
reward, done = 0., 0.
return state.replace(
pipeline_state=pipeline_state, obs=obs, reward=reward, done=done
)
xml_path = 'robot.xml'
agent = Unimal(xml_path)
action_dim = agent.sys.num_links() * 2
jit_env_reset = jax.jit(agent.reset)
jit_env_step = jax.jit(agent.step)
episode_length = 2560
random_action = jax.random.normal(jax.random.PRNGKey(seed=1), shape=(episode_length, action_dim))
state = jit_env_reset(rng=jax.random.PRNGKey(seed=0))
for t in range(episode_length):
state = jit_env_step(state, random_action[t])
print (state.pipeline_state.q)
if jp.any(jp.isnan(state.pipeline_state.q)):
print (t)
break
Hi,
I encountered a similar issue as #467 when running simulation on a robot with random actions. I followed the suggestions mentioned in this thread but still observed nan in simulation. Below is the minimal code and xml file to reproduce the issue. Could you help have a check on what may be the issue here? Thanks a lot!
And the xml configuration is
The text was updated successfully, but these errors were encountered: