Skip to content

This issue was moved to a discussion.

You can continue the conversation there. Go to discussion →

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nan encounted in pipeline_step() #474

Closed
MasterXiong opened this issue Apr 7, 2024 · 1 comment
Closed

Nan encounted in pipeline_step() #474

MasterXiong opened this issue Apr 7, 2024 · 1 comment

Comments

@MasterXiong
Copy link

Hi,

I encountered a similar issue as #467 when running simulation on a robot with random actions. I followed the suggestions mentioned in this thread but still observed nan in simulation. Below is the minimal code and xml file to reproduce the issue. Could you help have a check on what may be the issue here? Thanks a lot!

from brax import base
from brax import math
from brax.envs.base import PipelineEnv, State
from brax.io import mjcf
import jax
from jax import numpy as jp
import mujoco

class Unimal(PipelineEnv):
  def __init__(
      self,
      xml_path, 
      backend='generalized',
      **kwargs,
  ):
    sys = mjcf.load(xml_path)

    sys = sys.replace(dt=0.005)
    n_frames = 5
    kwargs['n_frames'] = kwargs.get('n_frames', n_frames)

    super().__init__(sys=sys, backend=backend, **kwargs)

    self.get_action_index()
    self._reset_noise_scale = 0.1

  def get_action_index(self):
    # mask the joints for each limb
    self.limb_num = self.sys.num_links()
    dof_link_idx = self.sys.dof_link()[6:].copy()
    repeat_mask = (dof_link_idx[1:] == dof_link_idx[:-1])
    repeat_mask = jp.insert(repeat_mask, 0, 0)
    self.action_index = dof_link_idx * 2 + repeat_mask

  def reset(self, rng: jp.ndarray) -> State:
    """Resets the environment to an initial state."""
    rng, rng1, rng2 = jax.random.split(rng, 3)

    low, hi = -self._reset_noise_scale, self._reset_noise_scale
    qpos = self.sys.qpos0 + jax.random.uniform(
        rng1, (self.sys.nq,), minval=low, maxval=hi
    )
    qvel = jax.random.uniform(rng2, (self.sys.nv,), minval=low, maxval=hi)

    data = self.pipeline_init(qpos, qvel)
    obs = None

    reward, done, zero = jp.zeros(3)
    metrics = {}
    return State(data, obs, reward, done, metrics)

  def step(self, state: State, action: jp.ndarray) -> State:
    """Run one timestep of the environment's dynamics."""

    # remove useless action dimensions
    action = action[self.action_index]

    # step
    pipeline_state0 = state.pipeline_state
    pipeline_state = self.pipeline_step(pipeline_state0, action)

    obs = None
    reward, done = 0., 0.

    return state.replace(
        pipeline_state=pipeline_state, obs=obs, reward=reward, done=done
    )


xml_path = 'robot.xml'
agent = Unimal(xml_path)
action_dim = agent.sys.num_links() * 2

jit_env_reset = jax.jit(agent.reset)
jit_env_step = jax.jit(agent.step)


episode_length = 2560
random_action = jax.random.normal(jax.random.PRNGKey(seed=1), shape=(episode_length, action_dim))

state = jit_env_reset(rng=jax.random.PRNGKey(seed=0))
for t in range(episode_length):
    state = jit_env_step(state, random_action[t])
    print (state.pipeline_state.q)
    if jp.any(jp.isnan(state.pipeline_state.q)):
      print (t)
      break

And the xml configuration is

<?xml version='1.0' encoding='UTF-8'?>
<!-- Universal Animal Template: unimal -->
<mujoco model="unimal">
  <compiler angle="degree"/>
  <size njmax="2000" nconmax="500"/>
  <option timestep=".005">
    <flag filterparent="disable"/>
  </option>
  <!-- Common defaults to make search space tractable -->
  <default>
    <!-- Define motor defaults -->
    <motor ctrlrange="-1 1" ctrllimited="true"/>
    <!-- Define joint defaults -->
    <default class="normal_joint">
      <joint type="hinge" damping="1" stiffness="1" armature="1" limited="true" range="-120 120" solimplimit="0 0.99 0.01"/>
    </default>
    <default class="walker_joint">
      <joint type="hinge" damping="0.2" stiffness="1" armature=".01" limited="true" range="-120 120" solimplimit="0 0.99 0.01"/>
    </default>
    <default class="stiff_joint">
      <joint type="hinge" damping="5" stiffness="10" armature=".01" limited="true" solimplimit="0 0.99 0.01"/>
    </default>
    <default class="free">
      <joint limited="false" damping="0" armature="0" stiffness="0"/>
    </default>
    <default class="growth_site">
      <site size="1e-6 1e-6 1e-6"/>
    </default>
    <default class="torso_growth_site">
      <site size="1e-6 1e-6 1e-6"/>
    </default>
    <default class="mirror_growth_site">
      <site size="1e-6 1e-6 1e-6"/>
    </default>
    <default class="btm_pos_site">
      <site size="1e-6 1e-6 1e-6"/>
    </default>
    <default class="box_face_site">
      <site size="1e-6 1e-6 1e-6"/>
    </default>
    <default class="imu_vel">
      <site type="box" size="0.05" rgba="1 0 0 1"/>
    </default>
    <default class="touch_site">
      <site group="3" rgba="0 0 1 .3"/>
    </default>
    <default class="food_site">
      <site material="food" size="0.15"/>
    </default>
    <!-- Define geom defaults -->
    <geom type="capsule" condim="3" friction="0.7 0.1 0.1" material="self"/>
  </default>
  <worldbody>
    <light diffuse="1 1 1" directional="true" exponent="1" pos="0 0 1" specular=".1 .1 .1"/>
    <!-- <geom name="floor" type="plane" pos="0 0 0" size="50 50 1" material="grid"/> -->
    <!-- Programatically generated xml goes here -->
    <body name="torso/0" pos="0 0 0.75">
      <joint name="root" type="free" class="free"/>
      <site name="root" class="imu_vel"/>
      <geom name="torso/0" type="sphere" size="0.1" condim="3" density="1000"/>
      <camera name="side" pos="0 -7 2" xyaxes="1 0 0 0 1 2" mode="trackcom"/>
      <site name="torso/0" class="growth_site" pos="0 0 0"/>
      <site name="torso/btm_pos/0" class="btm_pos_site" pos="0 0 -0.1"/>
      <site name="torso/touch/0" class="touch_site" size="0.11"/>
      <site name="torso/horizontal_y/0" class="torso_growth_site" pos="-0.1 0 0"/>
      <body name="limb/0" pos="0.0 0.0 -0.1">
        <joint name="limbx/0" type="hinge" class="normal_joint" range="0 60" pos="0.0 0.0 0.05" axis="1.0 0.0 0.0"/>
        <joint name="limby/0" type="hinge" class="normal_joint" range="-60 30" pos="0.0 0.0 0.05" axis="0.0 1.0 0.0"/>
        <geom name="limb/0" type="capsule" fromto="0.0 0.0 0.0 0.0 0.0 -0.45" size="0.05" density="600"/>
        <site name="limb/mid/0" class="growth_site" pos="0.0 0.0 -0.25"/>
        <site name="limb/btm/0" class="growth_site" pos="0.0 0.0 -0.45"/>
        <site name="limb/btm_pos/0" class="btm_pos_site" pos="0.0 0.0 -0.45"/>
        <site name="limb/touch/0" class="touch_site" size="0.060000000000000005" fromto="0.0 0.0 0.0 0.0 0.0 -0.45" type="capsule"/>
      </body>
      <body name="limb/4" pos="-0.1 0.0 0.0">
        <joint name="limby/4" type="hinge" class="normal_joint" range="-30 60" pos="0.05 0.0 0.0" axis="0.0 1.0 0.0"/>
        <geom name="limb/4" type="capsule" fromto="0.0 0.0 0.0 -0.45 0.0 0.0" size="0.05" density="600"/>
        <site name="limb/mid/4" class="growth_site" pos="-0.25 0.0 0.0"/>
        <site name="limb/btm/4" class="growth_site" pos="-0.45 0.0 0.0"/>
        <site name="limb/btm_pos/4" class="btm_pos_site" pos="-0.5 0.0 0.0"/>
        <site name="limb/touch/4" class="touch_site" size="0.060000000000000005" fromto="0.0 0.0 0.0 -0.45 0.0 0.0" type="capsule"/>
        <body name="limb/5" pos="-0.45 0.05 0.0">
          <joint name="limbx/5" type="hinge" class="normal_joint" range="-60 30" pos="0.0 -0.05 0.0" axis="1.0 0.0 0.0"/>
          <geom name="limb/5" type="capsule" fromto="0.0 0.0 0.0 0.0 0.45 0.0" size="0.05" density="600"/>
          <site name="limb/mid/5" class="mirror_growth_site" pos="0.0 0.25 0.0"/>
          <site name="limb/btm/5" class="mirror_growth_site" pos="0.0 0.45 0.0"/>
          <site name="limb/btm_pos/5" class="btm_pos_site" pos="0.0 0.5 0.0"/>
          <site name="limb/touch/5" class="touch_site" size="0.060000000000000005" fromto="0.0 0.0 0.0 0.0 0.45 0.0" type="capsule"/>
          <body name="limb/7" pos="-0.05 0.45 0.0">
            <joint name="limbx/7" type="hinge" class="normal_joint" range="-60 30" pos="0.05 0.0 0.0" axis="0.0 0.0 -1.0"/>
            <joint name="limby/7" type="hinge" class="normal_joint" range="-30 60" pos="0.05 0.0 0.0" axis="0.0 1.0 0.0"/>
            <geom name="limb/7" type="capsule" fromto="0.0 0.0 0.0 -0.25 0.0 0.0" size="0.05" density="600"/>
            <site name="limb/mid/7" class="mirror_growth_site" pos="-0.15 0.0 0.0"/>
            <site name="limb/btm/7" class="mirror_growth_site" pos="-0.25 0.0 0.0"/>
            <site name="limb/btm_pos/7" class="btm_pos_site" pos="-0.3 0.0 0.0"/>
            <site name="limb/touch/7" class="touch_site" size="0.060000000000000005" fromto="0.0 0.0 0.0 -0.25 0.0 0.0" type="capsule"/>
            <body name="limb/9" pos="-0.3 0.0 0.0">
              <joint name="limby/9" type="hinge" class="normal_joint" range="-30 30" pos="0.05 0.0 0.0" axis="0.0 1.0 0.0"/>
              <geom name="limb/9" type="capsule" fromto="0.0 0.0 0.0 -0.45 0.0 0.0" size="0.05" density="600"/>
              <site name="limb/mid/9" class="mirror_growth_site" pos="-0.25 0.0 0.0"/>
              <site name="limb/btm/9" class="mirror_growth_site" pos="-0.45 0.0 0.0"/>
              <site name="limb/btm_pos/9" class="btm_pos_site" pos="-0.5 0.0 0.0"/>
              <site name="limb/touch/9" class="touch_site" size="0.060000000000000005" fromto="0.0 0.0 0.0 -0.45 0.0 0.0" type="capsule"/>
              <body name="limb/12" pos="-0.49 0.0 -0.04">
                <joint name="limby/12" type="hinge" class="normal_joint" range="-90 0" pos="0.04 0.0 0.04" axis="0.0 1.0 0.0"/>
                <geom name="limb/12" type="capsule" fromto="0.0 0.0 0.0 -0.25 0.0 -0.25" size="0.05" density="600"/>
                <site name="limb/mid/12" class="mirror_growth_site" pos="-0.14 0.0 -0.14"/>
                <site name="limb/btm/12" class="mirror_growth_site" pos="-0.25 0.0 -0.25"/>
                <site name="limb/btm_pos/12" class="btm_pos_site" pos="-0.28 0.0 -0.28"/>
                <site name="limb/touch/12" class="touch_site" size="0.060000000000000005" fromto="0.0 0.0 0.0 -0.25 0.0 -0.25" type="capsule"/>
              </body>
            </body>
          </body>
        </body>
        <body name="limb/6" pos="-0.45 -0.05 0.0">
          <joint name="limbx/6" type="hinge" class="normal_joint" range="-60 30" pos="0.0 0.05 0.0" axis="1.0 0.0 -0.0"/>
          <geom name="limb/6" type="capsule" fromto="0.0 0.0 0.0 0.0 -0.45 0.0" size="0.05" density="600"/>
          <site name="limb/mid/6" class="mirror_growth_site" pos="0.0 -0.25 0.0"/>
          <site name="limb/btm/6" class="mirror_growth_site" pos="0.0 -0.45 0.0"/>
          <site name="limb/btm_pos/6" class="btm_pos_site" pos="0.0 -0.5 0.0"/>
          <site name="limb/touch/6" class="touch_site" size="0.060000000000000005" fromto="0.0 0.0 0.0 0.0 -0.45 0.0" type="capsule"/>
          <body name="limb/8" pos="-0.05 -0.45 0.0">
            <joint name="limbx/8" type="hinge" class="normal_joint" range="-60 30" pos="0.05 0.0 0.0" axis="0.0 0.0 -1.0"/>
            <joint name="limby/8" type="hinge" class="normal_joint" range="-30 60" pos="0.05 0.0 0.0" axis="0.0 1.0 0.0"/>
            <geom name="limb/8" type="capsule" fromto="0.0 0.0 0.0 -0.25 0.0 0.0" size="0.05" density="600"/>
            <site name="limb/mid/8" class="mirror_growth_site" pos="-0.15 0.0 0.0"/>
            <site name="limb/btm/8" class="mirror_growth_site" pos="-0.25 0.0 0.0"/>
            <site name="limb/btm_pos/8" class="btm_pos_site" pos="-0.3 0.0 0.0"/>
            <site name="limb/touch/8" class="touch_site" size="0.060000000000000005" fromto="0.0 0.0 0.0 -0.25 0.0 0.0" type="capsule"/>
            <body name="limb/10" pos="-0.3 0.0 0.0">
              <joint name="limby/10" type="hinge" class="normal_joint" range="-30 30" pos="0.05 0.0 0.0" axis="0.0 1.0 0.0"/>
              <geom name="limb/10" type="capsule" fromto="0.0 0.0 0.0 -0.45 0.0 0.0" size="0.05" density="600"/>
              <site name="limb/mid/10" class="mirror_growth_site" pos="-0.25 0.0 0.0"/>
              <site name="limb/btm/10" class="mirror_growth_site" pos="-0.45 0.0 0.0"/>
              <site name="limb/btm_pos/10" class="btm_pos_site" pos="-0.5 0.0 0.0"/>
              <site name="limb/touch/10" class="touch_site" size="0.060000000000000005" fromto="0.0 0.0 0.0 -0.45 0.0 0.0" type="capsule"/>
              <body name="limb/11" pos="-0.49 0.0 -0.04">
                <joint name="limby/11" type="hinge" class="normal_joint" range="-90 0" pos="0.04 0.0 0.04" axis="0.0 1.0 0.0"/>
                <geom name="limb/11" type="capsule" fromto="0.0 0.0 0.0 -0.25 0.0 -0.25" size="0.05" density="600"/>
                <site name="limb/mid/11" class="mirror_growth_site" pos="-0.14 0.0 -0.14"/>
                <site name="limb/btm/11" class="mirror_growth_site" pos="-0.25 0.0 -0.25"/>
                <site name="limb/btm_pos/11" class="btm_pos_site" pos="-0.28 0.0 -0.28"/>
                <site name="limb/touch/11" class="touch_site" size="0.060000000000000005" fromto="0.0 0.0 0.0 -0.25 0.0 -0.25" type="capsule"/>
              </body>
            </body>
          </body>
        </body>
      </body>
    </body>
  </worldbody>
  <actuator>
    <motor joint="limbx/0" gear="200" name="limbx/0"/>
    <motor joint="limby/0" gear="300" name="limby/0"/>
    <motor joint="limby/4" gear="150" name="limby/4"/>
    <motor joint="limbx/5" gear="300" name="limbx/5"/>
    <motor joint="limbx/7" gear="250" name="limbx/7"/>
    <motor joint="limby/7" gear="250" name="limby/7"/>
    <motor joint="limby/9" gear="150" name="limby/9"/>
    <motor joint="limby/12" gear="150" name="limby/12"/>
    <motor joint="limbx/6" gear="300" name="limbx/6"/>
    <motor joint="limbx/8" gear="250" name="limbx/8"/>
    <motor joint="limby/8" gear="250" name="limby/8"/>
    <motor joint="limby/10" gear="150" name="limby/10"/>
    <motor joint="limby/11" gear="150" name="limby/11"/>
  </actuator>
  <sensor>
    <accelerometer name="torso_accel" site="root"/>
    <gyro name="torso_gyro" site="root"/>
    <velocimeter name="torso_vel" site="root"/>
    <subtreeangmom name="unimal_am" body="torso/0"/>
    <touch name="torso/0" site="torso/touch/0"/>
    <touch name="limb/0" site="limb/touch/0"/>
    <touch name="limb/4" site="limb/touch/4"/>
    <touch name="limb/5" site="limb/touch/5"/>
    <touch name="limb/7" site="limb/touch/7"/>
    <touch name="limb/9" site="limb/touch/9"/>
    <touch name="limb/12" site="limb/touch/12"/>
    <touch name="limb/6" site="limb/touch/6"/>
    <touch name="limb/8" site="limb/touch/8"/>
    <touch name="limb/10" site="limb/touch/10"/>
    <touch name="limb/11" site="limb/touch/11"/>
  </sensor>
  <!-- Add hfield assets -->
  <asset/>
  <!-- List of contacts to exclude -->
  <contact/>
  <!-- Define material, texture etc -->
  <asset>
    <material name="grid" texture="grid" texrepeat="1 1" texuniform="true" reflectance="0"/>
    <material name="hfield" texture="hfield" texrepeat="1 1" texuniform="true" reflectance="0"/>
    <material name="wall" texture="wall" texrepeat="1 1" texuniform="true" reflectance=".5"/>
    <material name="platform" texture="platform" texrepeat="1 1" texuniform="true" reflectance=".5"/>
    <material name="boundary" texture="boundary" texrepeat="1 1" texuniform="true" reflectance=".5"/>
    <material name="jump" texture="jump" texrepeat="1 1" texuniform="true" reflectance=".5"/>
    <material name="goal" rgba="1 0 0 1"/>
    <material name="food" rgba="0 0 1 1" emission="1"/>
    <material name="self" rgba=".7 .5 .3 1"/>
    <material name="self_default" rgba=".7 .5 .3 1"/>
    <material name="self_highlight" rgba="0 .5 .3 1"/>
    <material name="effector" rgba=".7 .4 .2 1"/>
    <material name="effector_default" rgba=".7 .4 .2 1"/>
    <material name="effector_highlight" rgba="0 .5 .3 1"/>
    <material name="decoration" rgba=".3 .5 .7 1"/>
    <material name="eye" rgba="0 .2 1 1"/>
    <material name="target" rgba=".6 .3 .3 1"/>
    <material name="target_default" rgba=".6 .3 .3 1"/>
    <material name="target_highlight" rgba=".6 .3 .3 .4"/>
    <material name="site" rgba=".5 .5 .5 .3"/>
    <material name="ball" texture="ball"/>
  </asset>
  <asset>
    <texture name="grid" type="2d" builtin="checker" rgb1="0.1 0.1 0.1" rgb2="0.1 0.1 0.1" width="300" height="300" mark="edge" markrgb="0.2 0.2 0.2"/>
    <texture name="hfield" type="2d" builtin="checker" rgb1="0.1 0.1 0.1" rgb2="0.1 0.1 0.1" width="300" height="300"/>
    <texture name="wall" type="2d" builtin="flat" rgb1="0.9 0.7 0" rgb2="0.9 0.7 0" width="300" height="300"/>
    <texture name="platform" type="2d" builtin="flat" rgb1="0.3 0 0.8" rgb2="0.3 0 0.8" width="300" height="300"/>
    <texture name="boundary" type="2d" builtin="flat" rgb1="0.3 0.3 0.3" rgb2="0.3 0.3 0.3" width="300" height="300"/>
    <texture name="jump" type="2d" builtin="flat" rgb1="0.3 0.3 0.3" rgb2="0.3 0.3 0.3" width="300" height="300"/>
    <texture name="skybox" type="skybox" builtin="flat" rgb1="0.8 1 1" rgb2="0.8 1 1" width="800" height="800"/>
    <texture name="ball" builtin="checker" mark="cross" width="151" height="151" rgb1="0.1 0.1 0.1" rgb2="0.9 0.9 0.9" markrgb="1 1 1"/>
  </asset>
  <visual>
    <headlight ambient=".4 .4 .4" diffuse=".8 .8 .8" specular="0.1 0.1 0.1"/>
    <map znear=".01"/>
    <quality shadowsize="2048"/>
  </visual>
</mujoco>
@btaba
Copy link
Collaborator

btaba commented Apr 23, 2024

Hi @MasterXiong , I loaded the model and visualized with

mjpython -m mujoco.mjx.viewer --mjcf=tmp.xml

and it looks like the fixture falls forever. I removed the free joint at the root.

Then I ran your code, but it seems like there are some bugs in your code (your env is producing incompatible shapes in the actions).

I ran this instead:

m = mujoco.MjModel.from_xml_path(path)
mx = mjx.put_model(m)
dx = mjx.make_data(mx)

for _ in range(3000):
  dx = dx.replace(ctrl=np.random.uniform(low=-1, high=1, size=(mx.nu,)))
  dx = jax.jit(mjx.step)(mx, dx)

and the results seem OK

@google google locked and limited conversation to collaborators Apr 23, 2024
@btaba btaba converted this issue into discussion #482 Apr 23, 2024

This issue was moved to a discussion.

You can continue the conversation there. Go to discussion →

Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants