Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Differentiation wrt system parameters #19

Open
bayerj opened this issue Jul 26, 2021 · 6 comments
Open

Differentiation wrt system parameters #19

bayerj opened this issue Jul 26, 2021 · 6 comments
Labels
enhancement New feature or request

Comments

@bayerj
Copy link

bayerj commented Jul 26, 2021

Hey,

I can see how brax can be used to differentiate with respect to the system state. I wonder if there is a nice way to also diff wrt, e.g., the mass of a body. Taking the basic tutorial as an example, I would like to have something like step(sys, ..., ball_mass=10.) which is a pure jax function.

I have found brax.physics.bodies.Body, which could be adapted via a .replace() call. However, I don't see how I can a) find it starting out with a System instance and b) how I could update that instance with a replaced version without potentially breaking things.

@cdfreeman-google
Copy link
Collaborator

cdfreeman-google commented Jul 26, 2021

This is currently somewhat irritating to do. Brax ingests all of this data from the config protobuf, and fills fields of internal datastructures during the initialization of the system, starting hereabouts: https://github.com/google/brax/blob/main/brax/physics/system.py#L38

For things to behave properly, you'd have to essentially overwrite that mass value in all of the internal datastructures at the end of system initialization, so, something like this:

  def __init__(self, config: config_pb2.Config, differentiable_mass_scale=1.0):
    self.config = validate_config(config)

    self.num_bodies = len(config.bodies)
    self.body_idx = {b.name: i for i, b in enumerate(config.bodies)}

    self.active_pos = 1. * jnp.logical_not(
        jnp.array([vec_to_np(b.frozen.position) for b in config.bodies]))
    self.active_rot = 1. * jnp.logical_not(
        jnp.array([vec_to_np(b.frozen.rotation) for b in config.bodies]))

    self.box_plane = colliders.BoxPlane(config)
    self.capsule_plane = colliders.CapsulePlane(config)
    self.capsule_capsule = colliders.CapsuleCapsule(config)

    self.num_joints = len(config.joints)
    self.joint_revolute = joints.Revolute.from_config(config)
    self.joint_universal = joints.Universal.from_config(config)
    self.joint_spherical = joints.Spherical.from_config(config)

    self.num_actuators = len(config.actuators)
    self.num_joint_dof = sum(len(j.angle_limit) for j in config.joints)

    self.angle_1d = actuators.Angle.from_config(config, self.joint_revolute)
    self.angle_2d = actuators.Angle.from_config(config, self.joint_universal)
    self.angle_3d = actuators.Angle.from_config(config, self.joint_spherical)
    self.torque_1d = actuators.Torque.from_config(config, self.joint_revolute)
    self.torque_2d = actuators.Torque.from_config(config, self.joint_universal)
    self.torque_3d = actuators.Torque.from_config(config, self.joint_spherical)

    # reinit with data that we want to differentiate by
    self.box_plane.box = self.box_plane.box.replace(mass = mass * differentiable_mass_scale)
    self.box_plane.plane = self.box_plane.plane.replace(mass = mass * differentiable_mass_scale)
    # etc. etc. everywhere mass is used

This is, of course, extremely silly and defeats the purpose of having a simple initialization scheme if it means you have to go through and replumb all of the data. I'll noodle a bit on how to simplify this. It shouldn't be too horrible to make this "just work" because all the proto is really doing is assigning fields, which should be easily traceable by Jax, but it's a little bit indirect.

@bayerj
Copy link
Author

bayerj commented Jul 28, 2021

Ok, I can try this. I see how this is maybe not ideal, but that's ok for now.

I am just wondering how I can find out how I could get something like a list of all the Body instances, as that is where the masses are.

@cdfreeman-google
Copy link
Collaborator

cdfreeman-google commented Jul 28, 2021

So, you can definitely get a list of this data via:

all_masses = [b.mass for b in some_brax_env.sys.config.bodies]

But, like I said above, this data is unpackaged and repackaged into brax-internal datastructures at system initialization in a way that isn't traceable by Jax (currently). So, while this will tell you what the masses were at init, modifying this data won't actually change the masses Brax knows about (unless you jump through the hoops I mentioned in the initialization).

@erikfrey erikfrey added this to To do in P2 - Improve Differentiability via automation Nov 11, 2021
@erikfrey
Copy link
Collaborator

fyi making this less irritating is on our roadmap of improvements.

@erikfrey erikfrey added the enhancement New feature or request label Nov 11, 2021
@RolandZhu
Copy link

Hello, I appreciate the detailed explanation provided earlier. I am keen to know if the functionality for differentiation over physical parameters has been integrated into v2? I aim to identify system parameters using gradient descent, which would require gradient calculations with respect to elements like mass or friction. Are there resources available that could guide me in utilizing such a function, or is this feature still in the development phase?

@cdfreeman-google
Copy link
Collaborator

Yes, this is indeed quite a bit simpler! Please take a look at our new introductory colabs. The system data is now all represented within a pytree (not a proto), and can be differentiated-with-respect-to out of the box. Some fields are a little bit harder to track down than others within the pytree, so do let us know if you can't figure out how to do what you want to do!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Development

No branches or pull requests

4 participants