In [None]:
import numpy as np
import os
import time

import meshcat
import meshcat.geometry as g
import meshcat.transformations as tf
from meshcat.animation import Animation

import matplotlib.pyplot as plt
import h5py

In [None]:
from utils.jax_utils import get_coords_from_angle, wrap_angle
import jax.numpy as jnp

In [None]:
# Create a new visualizer
vis = meshcat.Visualizer()

In [None]:
vis.delete()

In [None]:
def visualize_cartpole(vis, model, q, dt):

    cart_pole = vis["cart_pole"]
    cart_pole.delete()
    cart = cart_pole["cart"]
    pivot = cart["pivot"]
    pole = pivot["pole"]

    mat1 = g.MeshPhongMaterial(color=0x3f2a32, reflectivity=0.8)
    mat2 = g.MeshPhongMaterial(color=0xb82e3d, reflectivity=0.8)
    
    cart.set_object(g.Box([0.5, 0.3, 0.2]), mat2)
    pole.set_object(g.Box([1, 0.05, 0.05]), mat1)
    pole.set_transform(tf.translation_matrix([0.5, 0, 0]))
    pivot.set_transform(tf.rotation_matrix(np.pi/2, [0, 1, 0]))
    
    
    fps = np.ceil(1/dt)
    anim = Animation(default_framerate=fps)
    
    for t in range(len(q)):
        with anim.at_frame(vis, t) as frame:
            frame["cart_pole/cart"].set_transform(tf.translation_matrix([q[t][0], 0, 0]))
            frame["cart_pole/cart/pivot"].set_transform(tf.rotation_matrix(np.pi/2+ q[t][1], [0, 1, 0]))
    vis.set_animation(anim)

In [None]:
#q = [[np.pi*np.sin(2*np.pi*t), np.pi*np.sin(2*np.pi*t)] for t in np.linspace(0,2, 41)]

In [None]:
#visualize_cartpole(vis, None, q, 0.05)

## Cart MPPI

In [None]:
from controller.mppi import MPPI
from dynamics.models_meshcat import cartpole_dynamics_rk4, cartpole_dynamics_sim_rk4

In [None]:
params = {'seed':42,
          'h':0.02,
          'env_name':'Cartpole-meshcat',
          'sample_type':'cubic',
          #'sample_type':'normal',
          'n_knots':30,
          'horizon':60,
          'temperature':0.001,
          'n_samples':1000,
          'noise_sigma':0.8}

In [None]:
controller_jax = MPPI(None, params)
controller_jax.reset_planner()

In [None]:
x0 = np.array([0, 0.0, 0.0, 0.0])
tfinal = 10
h = params["h"]
Nt = int(np.ceil(tfinal/h))+1
tvec = np.linspace(0, tfinal, Nt)
x = x0*1
q_sim = []
for ti in tvec:
    u, _ = controller_jax.get_u(x)
    x = cartpole_dynamics_sim_rk4(x, u, h)
    q_sim.append(x[0:2])

In [None]:
visualize_cartpole(vis, None, q_sim, h)

## Performance

In [None]:
from jax_wrappers.rollout_functions import load_rollout_jax, step_wrapper_cartpole

In [None]:
rollout_jax = load_rollout_jax(step_wrapper_cartpole)

In [None]:
rng = np.random.default_rng(params["seed"])

In [None]:
n_samples = 1000
horizon = 20
act_dim = 1

In [None]:
size = (n_samples, horizon, act_dim)
acts = rng.normal(size=size)

In [None]:
%timeit rollout_jax(x0, acts, params["h"])