In [7]:
from viewer import BallViewer
import jax 
import jax.numpy as np
from jax import jacfwd
from jax.numpy.linalg import pinv
import numpy as onp
import time 

In [8]:
viewer = BallViewer()
viewer.open()

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7003/static/


<Visualizer using: <meshcat.visualizer.ViewerWindow object at 0x7f2956d3da50> at path: <meshcat.path.Path object at 0x7f2822d01d80>>

In [12]:
def rk4_step(f, x, h):
    # one step of runge-kutta integration
    k1 = h * f(x)
    k2 = h * f(x + k1/2)
    k3 = h * f(x + k2/2)
    k4 = h * f(x + k3)
    return x + 1/6 * (k1 + 2 * k2 + 2 * k3 + k4)

In [23]:

_m = 1. 
_r = 1.
_M = onp.diag([_m, _m, (2./3.) * _m *  _r**2])
_M_inv = onp.linalg.inv(_M)
_g = 9.81
def f_analytical_ball(x):
    q, qdot = np.split(x, 2)
    b = np.array([0, -_g, 0])
    return np.concatenate([qdot, _M_inv @ b]) 

def phi(q):
    return q[1] - _r

def contact_loc(q):
    x,z,th = q
    return np.array([
        x + _r * np.cos(th),
        z + _r * np.sin(th)
    ])

contact_jac = jacfwd(contact_loc)
# def contact_jac(q):
#     return np.array([
#         [1.0, 0.0, _r],
#         [0.0, 1.0, 0.]
#     ])

def reset_map(q, qdot, q_c):
    J = contact_jac(q_c)
    return (np.eye(qdot.shape[0]) - _M_inv@J.T@onp.diag([1.9,1.9])@pinv(J@_M_inv@J.T)@J)@qdot

q0    = np.array([0.0, 4.0, 0.0])
qdot0 = np.array([-1.0, 0.0, 5.0])
x     = np.concatenate([q0, qdot0])

contact_jac(q0)

dt = 0.01
for t in range(500):
    x = rk4_step(f_analytical_ball, x, dt)
    # dxdt = f_analytical_ball(x)
    # x  = x + dt*dxdt
    q, qdot = np.split(x, 2)
    if phi(q) <= 0:
        q_c = np.array([q[0], q[1], -np.pi/2.0])
        qdot = reset_map(q, qdot, q_c)
        x = np.concatenate([q,qdot])
    viewer.render(np.split(x,2)[0])
    time.sleep(0.001)