# Session # 1 Introduction to Brax

# Brax: a differentiable physics engine

[Brax](https://github.com/google/brax) simulates physical systems made up of rigid bodies, joints, and actutators.  Brax provides the function:

$$
\text{qp}_{t+1} = \text{step}(\text{system}, \text{qp}_t, \text{act})
$$

where:
* $\text{system}$ is the static description of the physical system: each body in the world, its weight and size, and so on
* $\text{qp}_t$ is the dynamic state of the system at time $t$: each body's position, rotation, velocity, and angular velocity
* $\text{act}$ is dynamic input to the system in the form of motor actuation

Brax simulations are differentiable: the gradient $\Delta \text{step}$ can be used for efficient trajectory optimization.  But Brax is also well-suited to derivative-free optimization methods such as evolutionary strategy or reinforcement learning.

Let's review how $\text{system}$, $\text{qp}_t$, and $\text{act}$ are used:

In [2]:
#@title Colab setup and imports

!pip install git+https://github.com/manfreddiaz/brax.git@enhanced-viewer

import jax
import jax.numpy as jnp
from IPython.display import HTML, clear_output

import brax
from brax.io import html # Brax's HTML viewer


Collecting git+https://github.com/manfreddiaz/brax.git@enhanced-viewer
  Cloning https://github.com/manfreddiaz/brax.git (to revision enhanced-viewer) to /tmp/pip-req-build-6f7cjbzs
  Running command git clone -q https://github.com/manfreddiaz/brax.git /tmp/pip-req-build-6f7cjbzs
  Running command git checkout -b enhanced-viewer --track origin/enhanced-viewer
  Switched to a new branch 'enhanced-viewer'
  Branch 'enhanced-viewer' set up to track remote branch 'enhanced-viewer' from 'origin'.
  Resolved https://github.com/manfreddiaz/brax.git to commit 412b7613c1aadeb718abfb8f8aafeb96aae049b5


Building wheels for collected packages: brax
  Building wheel for brax (setup.py) ... [?25ldone
[?25h  Created wheel for brax: filename=brax-0.0.6-py3-none-any.whl size=141267 sha256=ee8ed7a9a9e7477e74ac160756a435d9c6f4b7c96495792ab847fc84943ca044
  Stored in directory: /tmp/pip-ephem-wheel-cache-w_iit_mu/wheels/8d/9b/7d/1e1403a4aeb47bbd6717b63bab0efceda8ccd7be2c28f313ab
Successfully built brax
Installing collected packages: brax
  Attempting uninstall: brax
    Found existing installation: brax 0.0.6
    Uninstalling brax-0.0.6:
      Successfully uninstalled brax-0.0.6
Successfully installed brax-0.0.6
You should consider upgrading via the '/home/linuxpotter/.virtualenvs/jax/bin/python -m pip install --upgrade pip' command.[0m


## Brax Config

Here's a brax config that defines a bouncy ball:

In [6]:
#@markdown  # A Bouncy Ball
ball_radius = 0.11 #@param { type:"slider", min: 0.01, max: 0.95, step:0.05 }
ball_length = 1 #@param { type:"slider", min: 1.0, max: 3.0, step:0.25 }

scene = brax.Config(
  dt=0.05, 
  substeps=100
)
scene.gravity.z = -9.8 # gravity is -9.8 m/s^2 in z dimension

# ground is a frozen (immovable) infinite plane
ground = scene.bodies.add(
    name='ground'
)
ground.frozen.all = True
# till here, the body doesn't have a shape
# so we add a collider
# there are multiple types of colliders
# here we use an specific one for the plane, called "plane" :)
plane = ground.colliders.add().plane
plane.SetInParent()  # for setting an empty oneof

# ball weighs 1kg, has equal rotational inertia along all axes, is 1m long, and
# has an initial rotation of identity (w=1,x=0,y=0,z=0) quaternion
ball = scene.bodies.add(
    name='ball', 
    mass=1.0
)
ball.inertia.x, ball.inertia.y, ball.inertia.z = 1.0, 1.0, 1.0

# The body doesn't have a (visible) shape so we add a collider
# There are multiple types of colliders, here we will use a capsule (a cylinder)
ball_capsule = ball.colliders.add().capsule
ball_capsule.radius = ball_radius 
ball_capsule.length = ball_length


## Brax State
$\text{QP}$, brax's dynamic state, is a structure with the following fields:


In [7]:
qp = brax.QP(
    # position of each body in 3d (z is up, right-hand coordinates)
    pos = jnp.array([[0., 0., 0.],       # ground
                    [0., 0., 3.]]),      # ball is 3m up in the air
    # velocity of each body in 3d
    vel = jnp.array([[0., 0., 0.],       # ground
                    [0., 0., 0.]]),      # ball
    # rotation about center of body, as a quaternion (w, x, y, z)
    rot = jnp.array([[1., 0., 0., 0.],   # ground
                    [1, 0., 0, 0]]),  # ball
    # angular velocity about center of body in 3d
    ang = jnp.array([[0., 0., 0.],       # ground
                    [0., 0., 0.]])       # ball
)

# Brax HTML Viewer

For visualization, we can use Brax's internal HTML renderer based of three.js library. This is really great as it allows super fast feedback on the scene we are creating.

In [8]:
sys = brax.System(scene)
HTML(html.render(sys, [qp])) 

## Brax Step Function

Let's observe $\text{step}(\text{config}, \text{qp}_t)$ with a few different variants of $\text{config}$ and $\text{qp}$:

In [None]:
#@title Simulating the bouncy ball config { run: "auto"}
# let's add some elasticity to the scene
scene.elasticity = 0.85 #@param { type:"slider", min: 0, max: 0.95, step:0.05 }
# and of course, we need to re-create the scene.
sys = brax.System(scene) 


ball_velocity = 1 #@param { type:"slider", min:-5, max:5, step: 0.5 }

# provide an initial velocity to the ball
vel = jax.ops.index_update(qp.vel, jax.ops.index[1, 0], ball_velocity)
qp = qp.replace(vel=vel)

states = [] 
for i in range(100):
  qp, _ = sys.step(qp, [])
  states.append(qp)

HTML(html.render(sys, states))

# Joints

Joints constrain the motion of bodies so that they move in tandem:

In [None]:
#@title A pendulum config for Brax
scene = brax.Config(
  dt=0.01, 
  substeps=100
)
scene.gravity.z = -9.8 # gravity is -9.8 m/s^2 in z dimension

ball_radius = 0.2 #@param { type:"slider", min: 0.01, max: 0.95, step:0.05 }
ball_length = 0.4 #@param { type:"slider", min: 0.01, max: 3.0, step:0.25 }

# image of the system we aim to design
# start with a frozen anchor at the root of the pendulum
anchor = scene.bodies.add(
    name='anchor', 
    mass=1.0
)
anchor.frozen.all = True
anchor.inertia.x, anchor.inertia.y, anchor.inertia.z = 1, 1, 1
# anchor_cap = anchor.colliders.add().capsule
# anchor_cap.radius = 0.1
# anchor_cap.length = 0.2

# now add a middle and bottom ball to the pendulum
middle = scene.bodies.add(
    name='middle',
    mass=1.0
)
middle.inertia.x, middle.inertia.y, middle.inertia.z = 1.0, 1.0, 1.0
middle_cap = middle.colliders.add().capsule
middle_cap.radius = 0.1
middle_cap.length = 1.0

bottom = scene.bodies.add(
    name='bottom',
    mass=1.0
)
bottom.inertia.x, bottom.inertia.y, bottom.inertia.z = 1.0, 1.0, 1.0
bottom_cap = bottom.colliders.add().capsule
bottom_cap.radius = 0.1
bottom_cap.length = 1.0

joint_anchor_middle = scene.joints.add(
    name='joint_anchor_middle', 
    parent='anchor', 
    child='middle', 
    stiffness=10000, 
    angular_damping=20
)
joint_anchor_middle.angle_limit.add(
    min = -180, 
    max = 180
)
joint_anchor_middle.child_offset.z = -0.5
joint_anchor_middle.rotation.z = 90

# connect middle to bottom
joint_middle_bottom = scene.joints.add(
    name='joint_middle_bottom', 
    parent='middle', 
    child='bottom', 
    stiffness=10000, 
    angular_damping=20
)
joint_middle_bottom.angle_limit.add(
    min = -10, 
    max = 10
)
joint_middle_bottom.child_offset.z = -0.8
joint_middle_bottom.rotation.y = 90


# ignore collisions
scene.collide_include.add()

### Add ground plane
# ground is a frozen (immovable) infinite plane
ground = scene.bodies.add(
    name='ground'
)
ground.frozen.all = True
plane = ground.colliders.add().plane
plane.SetInParent()  # for setting an empty oneof

Here is our system at rest:

In [None]:
# rather than building our own qp like last time, we ask brax.System to
# generate a default one for us, which is handy
sys = brax.System(scene)
qp = sys.default_qp()

HTML(html.render(sys, [qp]))


Let's observe $\text{step}(\text{config}, \text{qp}_t)$ by smacking the bottom ball with an initial impulse, simulating a pendulum swing.

In [None]:
#@title Simulating the pendulum config { run: "auto"}
ball_impulse = 15 #@param { type:"slider", min:-15, max:15, step: 0.5 }

# provide an initial velocity to the ball
vel = jax.ops.index_update(qp.vel, jax.ops.index[1, 0], ball_impulse)
qp = qp.replace(vel=vel)

states = [] 
for i in range(50):
  qp, _ = sys.step(qp, [])
  states.append(qp)

HTML(html.render(sys, states))

# Actuators

Actuators provide dynamic input to the system during every physics step.  They provide control parameters for users to manipulate the system interactively via the $\text{act}$ parameter.

In [None]:
#@title A single actuator on the pendulum
actuated_pendulum = brax.Config()
actuated_pendulum.CopyFrom(scene)

# actuating the joint connecting the anchor and middle
angle = actuated_pendulum.actuators.add(
    name='actuator', 
    joint='joint_anchor_middle', 
    strength=100
).angle
angle.SetInParent()  # for setting an empty oneof

Let's observe $\text{step}(\text{config}, \text{qp}_t, \text{act})$ by raising the middle ball to a desired target angle:

In [None]:
#@title Simulating the actuated pendulum config { run: "auto"}
target_angle = -31 #@param { type:"slider", min:-90, max:90, step: 1 }

sys = brax.System(actuated_pendulum)
qp = sys.default_qp()
act = jnp.array([target_angle])

states = [] 
for i in range(100):
  qp, _ = sys.step(qp, act)
  states.append(qp)

HTML(html.render(sys, states))