-
Notifications
You must be signed in to change notification settings - Fork 233
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Differentiation wrt system parameters #19
Comments
This is currently somewhat irritating to do. Brax ingests all of this data from the config protobuf, and fills fields of internal datastructures during the initialization of the system, starting hereabouts: https://github.com/google/brax/blob/main/brax/physics/system.py#L38 For things to behave properly, you'd have to essentially overwrite that mass value in all of the internal datastructures at the end of system initialization, so, something like this:
This is, of course, extremely silly and defeats the purpose of having a simple initialization scheme if it means you have to go through and replumb all of the data. I'll noodle a bit on how to simplify this. It shouldn't be too horrible to make this "just work" because all the proto is really doing is assigning fields, which should be easily traceable by Jax, but it's a little bit indirect. |
Ok, I can try this. I see how this is maybe not ideal, but that's ok for now. I am just wondering how I can find out how I could get something like a list of all the |
So, you can definitely get a list of this data via:
But, like I said above, this data is unpackaged and repackaged into brax-internal datastructures at system initialization in a way that isn't traceable by Jax (currently). So, while this will tell you what the masses were at init, modifying this data won't actually change the masses Brax knows about (unless you jump through the hoops I mentioned in the initialization). |
fyi making this less irritating is on our roadmap of improvements. |
Hello, I appreciate the detailed explanation provided earlier. I am keen to know if the functionality for differentiation over physical parameters has been integrated into v2? I aim to identify system parameters using gradient descent, which would require gradient calculations with respect to elements like mass or friction. Are there resources available that could guide me in utilizing such a function, or is this feature still in the development phase? |
Yes, this is indeed quite a bit simpler! Please take a look at our new introductory colabs. The |
Hey,
I can see how brax can be used to differentiate with respect to the system state. I wonder if there is a nice way to also diff wrt, e.g., the mass of a body. Taking the basic tutorial as an example, I would like to have something like
step(sys, ..., ball_mass=10.)
which is a pure jax function.I have found
brax.physics.bodies.Body
, which could be adapted via a.replace()
call. However, I don't see how I can a) find it starting out with aSystem
instance and b) how I could update that instance with a replaced version without potentially breaking things.The text was updated successfully, but these errors were encountered: