In [1]:
from viewer import BallViewer
import jax 
import jax.numpy as np
from jax import jacfwd
from jax.numpy.linalg import pinv

import numpy as onp
import time 

In [2]:
viewer = BallViewer()
viewer.open()

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7016/static/


<Visualizer using: <meshcat.visualizer.ViewerWindow object at 0x13e89e8c0> at path: <meshcat.path.Path object at 0x10c0329b0>>

$$
\dot{x} = f(x), x = \begin{bmatrix} q \\ \dot{q} \end{bmatrix}
$$

In [3]:
# xdot = f(x), x = [q, qdot]
def rk4_step(f, x, h):
    # one step of runge-kutta integration
    k1 = h * f(x)
    k2 = h * f(x + k1/2)
    k3 = h * f(x + k2/2)
    k4 = h * f(x + k3)
    return x + 1/6 * (k1 + 2 * k2 + 2 * k3 + k4)

def euler_step(f, x, h):
    return x + f(x) * h

In [21]:
_m = 1. 
_r = 1.
_M = np.diag([_m, _m, (2./3.) * _m *  _r**2])
_M_inv = onp.linalg.inv(_M)
_g = 9.81

def f_analytical_ball(x):
    # \dot{x} = f(x)
    q, qdot = np.split(x, 2)
    b = np.array([0, -_g, 0])
    _M_inv = np.linalg.inv(_M)
    return np.concatenate([qdot, _M_inv @ b]) 

TypeError: diag requires ndarray or scalar arguments, got <class 'list'> at position 0.

In [22]:
def phi(q):
    return q[1] - _r

def contact_loc(q):
    x,z,th = q
    return np.array([
        x + _r * np.cos(th),
        z + _r * np.sin(th)
    ])

contact_jac = jacfwd(contact_loc)
# def contact_jac(q):
#     return np.array([
#         [1.0, 0.0, _r],
#         [0.0, 1.0, 0.]
#     ])
q_c = np.array([0, _r, -np.pi/2.0])


In [33]:
q0    = np.array([0.0, 4.0, 0.0])
qdot0 = np.array([0.0, 0.0, -5.0])
x     = np.concatenate([q0, qdot0])
dfdx = jax.jit(jacfwd(f_analytical_ball))

In [34]:
dfdx(x)

DeviceArray([[0., 0., 0., 1., 0., 0.],
             [0., 0., 0., 0., 1., 0.],
             [0., 0., 0., 0., 0., 1.],
             [0., 0., 0., 0., 0., 0.],
             [0., 0., 0., 0., 0., 0.],
             [0., 0., 0., 0., 0., 0.]], dtype=float32)

In [37]:
%timeit dfdx(x).block_until_ready()

2.38 µs ± 5.68 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [9]:
%timeit contact_jac(q_c)

2.7 ms ± 29.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [13]:
def reset_map(q, qdot, q_c):
    # qdot+ 
    J = contact_jac(q_c)
    return (np.eye(qdot.shape[0]) - _M_inv@J.T@onp.diag([1.,1.])@pinv(J@_M_inv@J.T)@J)@qdot


In [14]:
q0    = np.array([0.0, 4.0, 0.0])
qdot0 = np.array([0.0, 0.0, -5.0])
x     = np.concatenate([q0, qdot0])

dt = 0.01
for t in range(500):
    x = rk4_step(f_analytical_ball, x, dt)
    # dxdt = f_analytical_ball(x)
    # x  = x + dt*dxdt
    # euler step 
    # x = euler_step(f_analytical_ball, x, dt)
    q, qdot = np.split(x, 2)
    if phi(q) <= 0:
        q_c = np.array([q[0], q[1], -np.pi/2.0])
        qdot = reset_map(q, qdot, q_c)
        x = np.concatenate([q,qdot])
    viewer.render(np.split(x,2)[0])
    time.sleep(dt)