In [1]:
import os
import yaml
import numpy as np
from scipy.interpolate import CubicSpline
import jax
import jax.numpy as jnp
import mujoco
from mujoco import mjx

In [2]:
from functools import partial
import chex

In [3]:
cpu_device = jax.devices('cpu')[0]
cpu_device

CpuDevice(id=0)

In [4]:
gpu_device = jax.devices('gpu')[0]
gpu_device

cuda(id=0)

In [5]:
#mj_model = mujoco.MjModel.from_xml_path("../models/go1/go1_scene_jax_no_collision.xml")

In [6]:
def load_rollout_jax(step_fn):
    def rollout_aux(obs, actions):
        carry = (obs)
        _, output = jax.lax.scan(f=step_fn, init=carry, xs=actions)
        return output
    func = jax.jit(jax.vmap(rollout_aux, in_axes=(None, 0)))
    return func

In [7]:
class MPPI_JAX:
    def __init__(self, 
                 model_path = "../models/go1/go1_scene_jax_no_collision.xml",
                 config_path="configs/mppi.yml") -> None:
        # load the configuration file
        with open(config_path, 'r') as file:
            params = yaml.safe_load(file)

        # load model
        self.model = mujoco.MjModel.from_xml_path(model_path)
        #self.model.opt.solver = mujoco.mjtSolver.mjSOL_CG
        #self.model.opt.iterations = 6
        #self.model.opt.ls_iterations = 6
    
        self.mjx_model = mjx.device_put(self.model)
        self.mjx_data = mjx.make_data(self.mjx_model)
        self.model.opt.timestep = params['dt']

        # mppi controller configuration
        self.temperature = params['lambda']
        self.horizon = params['horizon']
        self.n_samples = params['n_samples']
        self.noise_sigma = jnp.array(params['noise_sigma'])
        self.num_workers = params['n_workers']
        self.sampling_init = jnp.array([0.073,  1.34 , -2.83 ,  
                                        0.073,  1.34 , -2.83 ,  
                                        0.073,  1.34 , -2.83 ,  
                                        0.073,  1.34 , -2.83 ])
        
        # Cost
        self.Q = jnp.diag(jnp.array(params['Q_diag']))
        self.R = jnp.diag(jnp.array(params['R_diag']))
        self.x_ref = jnp.concatenate([jnp.array(params['q_ref']), jnp.array(params['v_ref'])])

        # Get env parameters
        self.act_dim = 12
        self.act_max = jnp.array([0.863, 4.501, -0.888]*4)
        self.act_min = jnp.array([-0.863, -0.686, -2.818]*4)
        
        # Rollouts
        self.h = params['dt']
        self.sample_type = params['sample_type']
        self.n_knots = params['n_knots']
        self.rollout_func = jax.jit(self.rollout_jax(), device=gpu_device)
        self.random_generator = np.random.default_rng(params["seed"])
        
        self.trajectory = None
        self.reset_planner() 
        self.update(self.x_ref)
        self.reset_planner()     
    
    def rollout_jax(self):
        def step_wrapper_mujoco(carry, action):
            obs = carry
            data = mjx.make_data(self.mjx_model)
            data = data.replace(qpos=obs.qpos, qvel=obs.qvel, ctrl=action)
            data = mjx.step(self.mjx_model, data)
            
            next_obs = jnp.concatenate([data.qpos, data.qvel])
            cost = self.quadruped_cost(next_obs, action)
            carry = data
            output = (next_obs, cost)
            return carry, output
        return load_rollout_jax(step_wrapper_mujoco)
    
    def reset_planner(self):
        self.trajectory = np.zeros((self.horizon, self.act_dim))
        self.trajectory += self.sampling_init
            
    def generate_noise(self, size):
        return self.random_generator.normal(size=size) * self.noise_sigma
    
    def sample_delta_u(self):
        if self.sample_type == 'normal':
            size = (self.n_samples, self.horizon, self.act_dim)
            return self.generate_noise(size)
        elif self.sample_type == 'cubic':
            indices = np.arange(self.n_knots)*self.horizon//self.n_knots
            size = (self.n_samples, self.n_knots, self.act_dim)
            knot_points = self.generate_noise(size)
            cubic_spline = CubicSpline(indices, knot_points, axis=1)
            return cubic_spline(np.arange(self.horizon))
        
    def perturb_action(self):
        if self.sample_type == 'normal':
            size = (self.n_samples, self.horizon, self.act_dim)
            actions = self.trajectory + self.generate_noise(size)
            actions = np.clip(actions, self.act_min, self.act_max)
            return actions
        
        elif self.sample_type == 'cubic':
            indices_float = jnp.linspace(0, self.horizon - 1, num=self.n_knots)
            indices = jnp.round(indices_float).astype(int)
            size = (self.n_samples, self.n_knots, self.act_dim)
            knot_points = self.trajectory[indices] + self.generate_noise(size)
            cubic_spline = CubicSpline(indices, knot_points, axis=1)
            actions = cubic_spline(np.arange(self.horizon))
            actions = np.clip(actions, self.act_min, self.act_max)
            return actions
        
    def update(self, obs): 
        self.mjx_data = self.mjx_data.replace(qpos=obs[:19], qvel=obs[19:])
        actions = jnp.array(self.perturb_action())
        #self.rollout_func(self.state_rollouts, actions, np.repeat([np.concatenate([[0],obs])], self.n_samples, axis=0), num_workers=self.num_workers, nstep=self.horizon)
        _, costs = self.rollout_func(self.mjx_data, actions)
        costs_sum = costs.sum(axis=1)
        #print(costs_sum)
        costs_sum = jnp.where(jnp.isnan(costs_sum), 10000000, costs_sum)
        print(costs_sum)
        # MPPI weights calculation
        ## Scale parameters
        min_cost = np.min(costs_sum)
        max_cost = np.max(costs_sum)
        
        exp_weights = jnp.exp(-1/self.temperature * ((costs_sum - min_cost)/(max_cost - min_cost)))
        #print(exp_weights)
        weighted_delta_u = exp_weights.reshape(self.n_samples, 1, 1) * actions
        weighted_delta_u = jnp.sum(weighted_delta_u, axis=0) / (jnp.sum(exp_weights) + 1e-10)
        updated_actions = jnp.clip(weighted_delta_u, self.act_min, self.act_max)
    
        # Pop out first action from the trajectory and repeat last action
        self.trajectory = jnp.roll(updated_actions, shift=-1, axis=0)
        #self.trajectory[-1] = updated_actions[-1]
        self.trajectory = self.trajectory.at[-1].set(updated_actions[-1])

        # Output first action (MPC)
        action = updated_actions[0] 
        return action
    
    def quaternion_distance(self, q1, q2):
        return 1 - jnp.abs(jnp.dot(q1,q2))
    
    def quadruped_cost(self, x, u):
        kp = 40
        kd = 3
        # Compute the error terms
        x_error = x - self.x_ref
        # Assuming quaternion_distance is a function you've defined elsewhere
        x_error = x_error.at[3:7].set(self.quaternion_distance(x[3:7], self.x_ref[3:7]))
        u_error = kp*(u - x[7:19]) - kd*x[25:]

        # Compute the cost
        cost = jnp.dot(x_error, jnp.dot(self.Q, x_error)) + jnp.dot(u_error, jnp.dot(self.R, u_error))
        return cost

In [8]:
mppi = MPPI_JAX()

[5755.0044 5612.051  5729.599  5635.6353 5629.41   5794.2524 5666.199
 5463.289  5457.498  5324.941  5482.0264 5707.7344 5788.5176 5702.5234
 5535.9243 5469.3076 5528.4365 5650.5728 5709.2734 5655.968  5716.8506
 5806.6113 5499.584  5775.3623 5394.427  5604.5566 5386.8857 5561.8193
 5538.831  5849.9697 5639.3154 5299.9834 5730.534  5379.0425 5723.2266
 5509.3506 5678.961  5836.605  5597.151  5601.8955 5681.7856 5793.5635
 5642.7627 5563.2314 5547.4863 5436.945  5409.2188 5503.5474 5138.695
 5529.458  5442.4814 5604.9688 5716.3877 5521.6943 5823.8203 5850.446
 5569.4307 5456.9355 5682.017  5414.083  5727.106  5730.038  5514.164
 5741.2573 5685.5283 5784.166  5639.781  5512.339  5664.751  5722.5444
 5786.751  5399.8926 5799.9023 5690.748  5530.708  5754.835  5713.2354
 5486.162  5648.7754 5643.619  5559.2144 5276.6416 5776.8535 5550.7837
 5709.927  5490.3574 5704.3823 5726.0864 5826.467  5483.064  5814.756
 5614.274  5659.934  5684.5234 5365.405  5531.5654 5728.169  5713.175
 5683.9297 5

OverflowError: An overflow was encountered while parsing an argument to a jitted computation, whose argument path is x.

In [None]:
actions = mppi.perturb_action()

In [None]:
actions.shape

In [None]:
%%timeit
mppi.update(mppi.x_ref)

In [None]:
obs = mppi.x_ref

In [None]:
actions[0,0].shape

In [None]:
mppi.mjx_data = mppi.mjx_data.replace(qpos=obs[:19], qvel=obs[19:], ctrl=actions[0,0])

In [None]:
mppi.mjx_data.qvel

In [None]:
mppi.mjx_data.qvel

In [None]:
costs = mppi.update(mppi.x_ref)

In [None]:
costs

In [None]:
mppi.mjx_data.qpos.shape

In [None]:
mppi.mjx_data.qvel.shape

In [None]:
mppi.x_ref

In [None]:
mppi.mjx_data.qpos

In [None]:
from matplotlib.animation import FuncAnimation
from IPython.display import HTML
import mujoco_viewer

In [None]:
import matplotlib.pyplot as plt

In [None]:
import copy as cp

In [None]:
mujoco.mjtSolver.mjSOL_CG

In [None]:
#model_sim = mujoco.MjModel.from_xml_path("../models/go1/go1_scene_jax_no_collision.xml")
model_sim = mujoco.MjModel.from_xml_path("../models/go1/scene_opt_pd.xml")

In [None]:
model_sim.opt.solver

In [None]:
model_sim.opt.iterations

In [None]:
model_sim.opt.ls_iterations

In [None]:
dt_sim = 0.01
model_sim.opt.timestep = dt_sim

data_sim = mujoco.MjData(model_sim)

In [None]:
viewer = mujoco_viewer.MujocoViewer(model_sim, data_sim, 'offscreen')

In [None]:
# reset robot (keyframes are defined in the xml)
mujoco.mj_resetDataKeyframe(model_sim, data_sim, 1) # stand position
mujoco.mj_forward(model_sim, data_sim)
q_init = cp.deepcopy(data_sim.qpos) # save reference pose
v_init = cp.deepcopy(data_sim.qvel) # save reference pose

In [None]:
print("Configuration: {}".format(q_init)) # save reference pose

In [None]:
img = viewer.read_pixels()
plt.imshow(img)

In [None]:
# reset robot (keyframes are defined in the xml)
mujoco.mj_resetDataKeyframe(model_sim, data_sim, 0) # stand position
mujoco.mj_forward(model_sim, data_sim)
q_ref_mj = cp.deepcopy(data_sim.qpos) # save reference pose
v_ref_mj = cp.deepcopy(data_sim.qvel) # save reference pose

In [None]:
print("Configuration: {}".format(q_ref_mj)) # save reference pose

In [None]:
img = viewer.read_pixels()
plt.imshow(img)

In [None]:
q_curr = cp.deepcopy(data_sim.qpos) # save reference pose
v_curr = cp.deepcopy(data_sim.qvel) # save reference pose
x = jnp.concatenate([q_curr, v_curr])

In [None]:
tfinal = 5
tvec = jnp.linspace(0,tfinal,int(jnp.ceil(tfinal/dt_sim))+1)

In [None]:
mujoco.mj_resetDataKeyframe(model_sim, data_sim, 1)
mujoco.mj_forward(model_sim, data_sim)

In [None]:
img = viewer.read_pixels()
plt.imshow(img)

In [None]:
mppi.reset_planner()

In [None]:
mppi.trajectory

In [None]:
%%time
anim_imgs = []
sim_inputs = []
for ticks, ti in enumerate(tvec):
    #if ticks % 1 == 0:
    q_curr = cp.deepcopy(data_sim.qpos) # save reference pose
    v_curr = cp.deepcopy(data_sim.qvel) # save reference pose
    x = jnp.concatenate([q_curr, v_curr])
    u_joints = mppi.update(x)    
    data_sim.ctrl[:] = u_joints
    mujoco.mj_step(model_sim, data_sim)
    mujoco.mj_forward(model_sim, data_sim)
    img = viewer.read_pixels()
    anim_imgs.append(img)
    sim_inputs.append(u_joints)

In [None]:
fig, ax = plt.subplots()
skip_frames = 10
interval = dt_sim*1000*skip_frames

def animate(i):
    ax.clear()
    ax.imshow(anim_imgs[i * skip_frames])  # Display the image, skipping frames
    ax.axis('off')

# Create animation, considering the reduced frame rate due to skipped frames
ani = FuncAnimation(fig, animate, frames=len(anim_imgs) // skip_frames, interval=interval)  # 50 ms for 20 Hz

# Display the animation
HTML(ani.to_jshtml())

In [None]:
#mppi.reset_planner()

In [None]:
actions = jnp.array(mppi.perturb_action())

In [None]:
_, costs = mppi.rollout_func(mppi.mjx_data, actions) 

In [None]:
costs