Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions source/isaaclab/isaaclab/envs/mdp/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,11 @@ def __call__(
else:
joint_ids = torch.tensor(self.asset_cfg.joint_ids, dtype=torch.int, device=self.asset.device)

if env_ids != slice(None) and joint_ids != slice(None):
env_ids_for_slice = env_ids[:, None]
else:
env_ids_for_slice = env_ids

# sample joint properties from the given ranges and set into the physics simulation
# joint friction coefficient
if friction_distribution_params is not None:
Expand All @@ -719,7 +724,7 @@ def __call__(
friction_coeff = torch.clamp(friction_coeff, min=0.0)

# Always set static friction (indexed once)
static_friction_coeff = friction_coeff[env_ids[:, None], joint_ids]
static_friction_coeff = friction_coeff[env_ids_for_slice, joint_ids]

# if isaacsim version is lower than 5.0.0 we can set only the static friction coefficient
major_version = int(env.sim.get_version()[0])
Expand Down Expand Up @@ -750,8 +755,8 @@ def __call__(
dynamic_friction_coeff = torch.minimum(dynamic_friction_coeff, friction_coeff)

# Index once at the end
dynamic_friction_coeff = dynamic_friction_coeff[env_ids[:, None], joint_ids]
viscous_friction_coeff = viscous_friction_coeff[env_ids[:, None], joint_ids]
dynamic_friction_coeff = dynamic_friction_coeff[env_ids_for_slice, joint_ids]
viscous_friction_coeff = viscous_friction_coeff[env_ids_for_slice, joint_ids]
else:
# For versions < 5.0.0, we do not set these values
dynamic_friction_coeff = None
Expand All @@ -777,7 +782,7 @@ def __call__(
distribution=distribution,
)
self.asset.write_joint_armature_to_sim(
armature[env_ids[:, None], joint_ids], joint_ids=joint_ids, env_ids=env_ids
armature[env_ids_for_slice, joint_ids], joint_ids=joint_ids, env_ids=env_ids
)

# joint position limits
Expand Down Expand Up @@ -805,7 +810,7 @@ def __call__(
)

# extract the position limits for the concerned joints
joint_pos_limits = joint_pos_limits[env_ids[:, None], joint_ids]
joint_pos_limits = joint_pos_limits[env_ids_for_slice, joint_ids]
if (joint_pos_limits[..., 0] > joint_pos_limits[..., 1]).any():
raise ValueError(
"Randomization term 'randomize_joint_parameters' is setting lower joint limits that are greater"
Expand Down