In [None]:
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

In [None]:
import numpy as np
import os
import time

import meshcat
import meshcat.geometry as g
import meshcat.transformations as tf
from meshcat.animation import Animation

import matplotlib.pyplot as plt
import h5py

In [None]:
# Create a new visualizer
vis = meshcat.Visualizer()

In [None]:
f2 = h5py.File("./costs/hopper_ref.jld2", "r")

In [None]:
q_ref = [np.array(list(qi[0])) for qi in f2["q"]]

In [None]:
def visualize_hopper(vis, q, dt):
    def kinematics(q):
        return [q[2], q[3]]
    
    # body radius
    r_body = 0.2
    # foot radius
    r_foot = 0.1

    # leg width
    r_leg = 0.5 * r_foot
    
    fps = np.ceil(1/dt)
    anim = Animation(default_framerate=fps)
    
    # create body
    vis["body"].set_object(g.Sphere(r_body), 
                           g.MeshLambertMaterial(color=0xb82e3d,reflectivity=0.8))

    # create foot
    vis["foot"].set_object(g.Sphere(r_foot), 
                           g.MeshLambertMaterial(color=0xb82e3d,reflectivity=0.8))

    # create leg
    n_leg = 20
    for i in range(n_leg):
        vis["leg{}".format(i)].set_object(g.Sphere(r_leg), 
                                g.MeshPhongMaterial(color=0x3f2a32,reflectivity=0.8))

    p_leg = [np.zeros(3) for i in range(n_leg)]
    
    for t in range(len(q)):
        p_body = np.array([q[t][0], 0.0, q[t][1]])
        p_foot = np.array([kinematics(q[t])[0], 0.0, kinematics(q[t])[1]])
        div = np.array([q[t][2] - q[t][0], q[t][3] - q[t][1]])
        div = div / np.linalg.norm(div)
        r_range = np.linspace(0, np.sqrt((q[t][0] - q[t][2])**2 + (q[t][1] - q[t][3])**2), n_leg)
        for i in range(n_leg):
            p_leg[i] = [q[t][0] + r_range[i] * div[0], 0.0, q[t][1] + r_range[i] * div[1]]
        z_shift = np.array([0.0, 0.0, r_foot])
        with anim.at_frame(vis, t) as frame:
            frame["body"].set_transform(tf.translation_matrix(p_body + z_shift))
            frame["foot"].set_transform(tf.translation_matrix(p_foot + z_shift))
            for i in range(n_leg):
                frame["leg{}".format(i)].set_transform(tf.translation_matrix(p_leg[i] + z_shift))
                
    vis.set_animation(anim)

In [None]:
#visualize_hopper(vis, q_ref, 0.1)

## Hopper MPPI

In [None]:
from controller.mppi import MPPI
from dynamics.models_meshcat import hopper_dynamics

In [None]:
params = {'seed':42,
          'h':0.02,
          'env_name':'Hopper-meshcat',
          'sample_type':'cubic',
          #'sample_type':'normal',
          'n_knots':10,
          'horizon':50,
          'temperature':0.001,
          'n_samples':1000,
          'noise_sigma':[10.0, 1.0]}

In [None]:
controller_jax = MPPI(None, params)

In [None]:
controller_jax.reset_planner()

In [None]:
stance_flag = True
stance_count = 0

tfinal = 2.2
h = params["h"]
tvec = np.linspace(0,tfinal,int(np.ceil(tfinal/h))+1)
x = q_ref[0]*1
q_sim = []  
costs_sim = []
u_sim = []
for ti in tvec:
    u, costs = controller_jax.get_u(x)
    x, stance_flag, stance_count = hopper_dynamics(x, u, h, stance_flag, stance_count)
    q_sim.append(x)
    u_sim.append(u)
    costs_sim.append(costs)

In [None]:
visualize_hopper(vis, q_sim, h)

In [None]:
u1 = [ui[0] for ui in u_sim]
u2 = [ui[1] for ui in u_sim]

In [None]:
plt.plot(u1)

In [None]:
plt.plot(u2)

In [None]:
tfinal = 2.2
h = 0.02
Nt = int(np.ceil(tfinal/h))+1
Nx = 8
x_ref = np.zeros((Nt,Nx))
x_ref[:,0] = np.linspace(-1.0,1.0,Nt)
x_ref[:,1] = 1.0 + 0.35*np.sin(2*np.pi/50.0*(np.arange(Nt))-3*np.pi/8)
x_ref[:10,1] = 1.0
x_ref[-51:,1] = 1.0
x_ref[:,3] = -0.35*np.sin(2*np.pi/50.0*(np.arange(Nt)))
x_ref[x_ref[:,3] < 0, 3] = 0
x_ref[-50:,3] = 0

In [None]:
plt.plot(x_ref[:,1])
plt.plot(x_ref[:,3])