-
Notifications
You must be signed in to change notification settings - Fork 254
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Are Contact Forces supported? #353
Comments
Contact impulses are not populated in the RE: ETA, what are you needing the contact forces for? |
I am a developer of |
Hi @Kallinteris-Andreas , we have examples here loading a whole host of mujoco envs. I believe the contact is the only missing feature, but it can be added. Are you planning to add |
For now, I am evaluating if From what I tell The inclusion of the |
Is there any update on this? I am a PhD student working on training robot controllers, and would like to use Brax. One of the requirements I have is using the contact forces as inputs to the NN during training, so the ability to get the contact forces from the system would allow me to use Brax. |
Ok, at least in PBD, it seems one can use the |
With the release of |
Until we port mjx over, this isn't currently implemented in brax. We'll want to add contact forces for the positional backend as well |
In v0.10.0, MJX is now used in the backend. With MjxEnv being replaced by PipelineEnv, how can contacts and their corresponding forces be accessed? It seems like it can no longer be done similar to the method mentioned in this comment. I also noticed the Contact class being used in pipeline initialization and stepping, but it seems a bit limiting. Are there any plans to include more data in this class? |
Hi @willthibault , notice the Contact class inherits from mjx.Contact so you should have all the same data as before, please let us know otherwise Edit: narrowed down the diff compared to the older version, we'll push out a fix. For now, you can set |
Thanks for the followup @btaba! If we have the same data as before then we can use mj_contactForce, but only after changing the data from MJX to MuJoCo. Is there a better way to do this? I found that trying to do this inside a training step with Brax to only produce errors. I used something like this:
where d came from |
@willthibault by the way, did you have any luck / progress on getting this to work in brax/MJX? I am also generally finding that while things work outside brax, within brax I encounter errors with receiving certain data types, or trying to convert from mujoco to mjx etc within the training step. |
Hi @AlexS28, Thanks for following up on this. The function I posted above doesn't work as you mentioned in training steps because it isn't JAX friendly and results in JAX tracer errors. I think at the moment the MJX functionality for this function doesn't exist. Also, MJX has limited contact functionality (only 3 dimensional pyramidal contacts as mentioned in the feature parity list). It looks like other contact options should be coming in an update, but for now this is all we have. I have found a sort of work around that seems to work well enough for now, I too am working on bipedal walking for humanoids so maybe this is of help to you. In a MJX data structure, contact forces are contained in efc_force. This link explains a little more about how it works and the relation to mj_contactForce. Unfortunately, the forces in this are in a more complex format and different frame because of the pyramidal contact. However, if you set your robot in a stable standing position in the simulator and read these forces, then you'll notice that it is very sparse if your contacts are limited, such as only feet and the floor. Summing the right indices of these forces (the non-zero entries) gave me a force roughly equal to the expected weight of my robot, and summing certain groupings matched the normal forces on the feet that I generated from mj_contactForce. Try running the function I shared above inside the loop of viewer.py while viewing efc_force and summing right and left foot contacts like
I got a reading that looked something like this:
Accessing Hopefully further MJX/Brax updates will make accessing contacts easier. Even having elliptic contacts would make reading efc_force easier, but for pyramidal contacts some form of mj_contactForce would be useful. |
thanks @willthibault for the information. Much appreciated!! Your code worked for me as well. |
If this is useful for anyone else, I implemented a jittable version of contact force calculation. No guarantees that this is correct obviously, but the output matches for all inputs I have tested it with. The # Given an mjx model `s` and mjx state `d`, calculates forces for all contacts.
def get_contact_forces(s, d):
assert(s.opt.cone == mujoco.mjtCone.mjCONE_PYRAMIDAL) # Assert cone is PYRAMIDAL
# mju_decodePyramid
# 1: force: result
# 2: pyramid: d.efc_force + contact.efc_address
# 3: mu: contact.friction
# 4: dim: contact.dim
contact = d.contact
cnt = d.ncon
# Generate 2d array of efc_force indexed by efc_address containing the maximum
# number of potential elements (10).
# This enables us to operate on each contact force pyramid rowwise.
efc_argmap = jp.linspace(
contact.efc_address,
contact.efc_address + 9,
10, dtype=jp.int32
).T
# OOB access clamps in jax, this is safe
pyramid = d.efc_force[efc_argmap.reshape((efc_argmap.size))].reshape(efc_argmap.shape)
# Calculate normal forces
# force[0] = 0
# for (int i=0; i < 2*(dim-1); i++) {
# force[0] += pyramid[i];
# }
index_matrix = jp.repeat(jp.arange(10)[None, :], cnt, axis=0)
force_normal_mask = index_matrix < (2 * (contact.dim - 1)).reshape((cnt, 1))
force_normal = jp.sum(jp.where(force_normal_mask, pyramid, 0), axis=1)
# Calculate tangent forces
# for (int i=0; i < dim-1; i++) {
# force[i+1] = (pyramid[2*i] - pyramid[2*i+1]) * mu[i];
# }
pyramid_indexes = jp.arange(5) * 2
force_tan_all = (pyramid[:, pyramid_indexes] - pyramid[:, pyramid_indexes + 1]) * contact.friction
force_tan = jp.where(pyramid_indexes < contact.dim.reshape((cnt, 1)), force_tan_all, 0)
# Full force array
forces = jp.concatenate((force_normal.reshape((cnt, 1)), force_tan), axis=1)
# Special case frictionless contacts
# if (dim == 1) {
# force[0] = pyramid[0];
# return;
# }
frictionless_mask = contact.dim == 1
frictionless_forces = jp.concatenate((pyramid[:,0:1], jp.zeros((pyramid.shape[0], 5))), axis=1)
return jp.where(
frictionless_mask.reshape((cnt, 1)),
frictionless_forces,
forces
) |
Nice @hansihe , also see google-deepmind/mujoco@c6b1293 |
Ah, I could have used that instead then, probably should have looked around more beforehand :) |
For those interested in picking out forces associated with specific contacts (such as the feet collision contacts mentioned earlier), something like this works well:
where the forces provided are the output of @hansihe's function. For use during training, the ids could instead be set during the initialization of the environment. |
The reference
Ant
and the 2Humanoid
s implementations do not include them, but this issue:#254
Indicates that they might be implemented which is it?
If it is not implemented is there an approximate ETA?
Thanks!
The text was updated successfully, but these errors were encountered: