In [1]:
import os
import yaml
import numpy as np
from scipy.interpolate import CubicSpline
import jax
import jax.numpy as jnp
import mujoco
from mujoco import rollout
import concurrent.futures
import threading

In [104]:
class GaitScheduler:
    #def __init__(self, gait_path = '../gaits/walking_gait_S30ms_O40ms_H10cm.tsv', phase_time = 0):
    def __init__(self, gait_path = '../gaits/walking_gait_S10ms_O15ms_H10cm.tsv', phase_time = 0):
    #def __init__(self, gait_path = '../gaits/walking_gait_S10ms_O15ms_H20cm.tsv', phase_time = 0):
        # Load the configuration file
        with open(gait_path, 'r') as file:
            gait_array = np.loadtxt(file, delimiter='\t')
        
        # Load model
        self.gait = gait_array
        self.phase_length = gait_array.shape[1]
        self.phase_time = phase_time
        self.indices = jnp.arange(self.phase_length)
        
    def roll(self):
        self.phase_time += 1
        self.indices = jnp.roll(self.indices, -1)
    
    def get_current_ref(self):
        return self.gait[:, self.phase_time]
    
class MPPI:
    def __init__(self, model_path = "../models/go1/go1_scene_jax_no_collision.xml",
                 config_path="configs/mppi.yml") -> None:
        # load the configuration file
        with open(config_path, 'r') as file:
            params = yaml.safe_load(file)

        # Load model
        self.model = mujoco.MjModel.from_xml_path(model_path)
        self.model.opt.timestep = params['dt']

        # MPPI controller configuration
        self.temperature = params['lambda']
        self.horizon = params['horizon']
        self.n_samples = params['n_samples']
        self.noise_sigma = jnp.array(params['noise_sigma'])
        self.num_workers = params['n_workers']
        self.sampling_init = jnp.array([0.073,  1.34, -2.83,  
                                        0.073,  1.34, -2.83,  
                                        0.073,  1.34, -2.83,  
                                        0.073,  1.34, -2.83])
        
        # Cost
        self.Q = jnp.diag(jnp.array(params['Q_diag']))
        self.R = jnp.diag(jnp.array(params['R_diag']))
        self.x_ref = jnp.concatenate([jnp.array(params['q_ref']), jnp.array(params['v_ref'])])
        self.q_ref = jnp.array(params['q_ref'])
        self.v_ref = jnp.array(params['v_ref'])
        self.body_ref = jnp.concatenate([self.q_ref[:7], self.v_ref[:6]])
        
        #self.body_ref = self.body_ref.at[1].set(1)
        #self.body_ref = self.body_ref.at[0].set(2)
        
        # Threding
        self.thread_local = threading.local()

        # Get env parameters
        self.act_dim = 12
        self.act_max = [0.863, 4.501, -0.888]*4
        self.act_min = [-0.863, -0.686, -2.818]*4
        
        # Gait scheduler
        self.gait_scheduler = GaitScheduler()
        #self.x_ref = jnp.concatenate([jnp.array(params['q_ref']), jnp.array(params['v_ref'])])
        
        # Rollouts
        self.h = params['dt']
        self.sample_type = params['sample_type']
        self.n_knots = params['n_knots']
        self.random_generator = np.random.default_rng(params["seed"])
        
        self.rollout_func = self.threaded_rollout
        self.cost_func = jax.jit(jax.vmap(jax.vmap(self.quadruped_cost, in_axes=(0, 0, None, None, None, None, None)), in_axes=(1, 1, 0, None, None, None, None))) #, device=gpu_device)
        self.state_rollouts = np.zeros((self.n_samples, self.horizon, mujoco.mj_stateSize(self.model, mujoco.mjtState.mjSTATE_FULLPHYSICS.value)))
            
        self.trajectory = None
        self.reset_planner() 
        self.update(self.x_ref)
        self.reset_planner()     
                
    def reset_planner(self):
        self.trajectory = np.zeros((self.horizon, self.act_dim))
        self.trajectory += self.sampling_init
            
    def generate_noise(self, size):
        return self.random_generator.normal(size=size) * self.noise_sigma
    
    def sample_delta_u(self):
        if self.sample_type == 'normal':
            size = (self.n_samples, self.horizon, self.act_dim)
            return self.generate_noise(size)
        elif self.sample_type == 'cubic':
            indices = np.arange(self.n_knots)*self.horizon//self.n_knots
            size = (self.n_samples, self.n_knots, self.act_dim)
            knot_points = self.generate_noise(size)
            cubic_spline = CubicSpline(indices, knot_points, axis=1)
            return cubic_spline(np.arange(self.horizon))
        
    def perturb_action(self):
        if self.sample_type == 'normal':
            size = (self.n_samples, self.horizon, self.act_dim)
            actions = self.trajectory + self.generate_noise(size)
            actions = np.clip(actions, self.act_min, self.act_max)
            return actions
        
        elif self.sample_type == 'cubic':
            indices_float = jnp.linspace(0, self.horizon - 1, num=self.n_knots)
            indices = jnp.round(indices_float).astype(int)
            size = (self.n_samples, self.n_knots, self.act_dim)
            knot_points = self.trajectory[indices] + self.generate_noise(size)
            cubic_spline = CubicSpline(indices, knot_points, axis=1)
            actions = cubic_spline(np.arange(self.horizon))
            actions = np.clip(actions, self.act_min, self.act_max)
            return actions
        
    def update(self, obs): 
        actions = self.perturb_action()
        self.rollout_func(self.state_rollouts, actions, np.repeat([np.concatenate([[0],obs])], self.n_samples, axis=0), num_workers=self.num_workers, nstep=self.horizon)
        costs = self.cost_func(self.state_rollouts[:,:,1:], actions, 
                               self.gait_scheduler.indices[:self.horizon],
                               self.Q, self.R, self.gait_scheduler.gait, self.body_ref)
        #costs = self.cost_func(self.state_rollouts[:,:,1:], actions)
        
        self.gait_scheduler.roll()
        costs_sum = costs.sum(axis=0)
        
        # MPPI weights calculation
        ## Scale parameters
        min_cost = np.min(costs_sum)
        max_cost = np.max(costs_sum)
        
        exp_weights = np.exp(-1/self.temperature * ((costs_sum - min_cost)/(max_cost - min_cost)))
        weighted_delta_u = exp_weights.reshape(self.n_samples, 1, 1) * actions
        weighted_delta_u = np.sum(weighted_delta_u, axis=0) / (np.sum(exp_weights) + 1e-10)
        updated_actions = np.clip(weighted_delta_u, self.act_min, self.act_max)
    
        # Pop out first action from the trajectory and repeat last action
        self.trajectory = np.roll(updated_actions, shift=-1, axis=0)
        self.trajectory[-1] = updated_actions[-1]

        # Output first action (MPC)
        action = updated_actions[0] 
        return action
    
    def thread_initializer(self):
        self.thread_local.data = mujoco.MjData(self.model)

    def call_rollout(self, initial_state, ctrl, state):
        rollout.rollout(self.model, self.thread_local.data, skip_checks=True,
                        nroll=state.shape[0], nstep=state.shape[1],
                        initial_state=initial_state, control=ctrl, state=state)

    def threaded_rollout(self, state, ctrl, initial_state, num_workers=32, nstep=5):
        n = initial_state.shape[0] // num_workers  # integer division

        chunks = []  # a list of tuples, one per worker
        for i in range(num_workers-1):
            chunks.append(
                (initial_state[i*n:(i+1)*n], ctrl[i*n:(i+1)*n], state[i*n:(i+1)*n]))

        # Last chunk, absorbing the remainder:
        chunks.append(
            (initial_state[(num_workers-1)*n:], ctrl[(num_workers-1)*n:],
                state[(num_workers-1)*n:]))

        with concurrent.futures.ThreadPoolExecutor(
            max_workers=num_workers, initializer=self.thread_initializer) as executor:
            futures = []
            for chunk in chunks:
                futures.append(executor.submit(self.call_rollout, *chunk))
            for future in concurrent.futures.as_completed(futures):
                future.result()
    
    def quaternion_distance(self, q1, q2):
        return 1 - jnp.abs(jnp.dot(q1,q2))
    
    def quadruped_cost(self, x, u, phase_time, Q, R, joints_gait, body_ref):
        kp = 30
        kd = 3

        joints_ref = joints_gait[:, phase_time]
        x_ref = jnp.concatenate([body_ref[:7], joints_ref[:12], body_ref[7:], joints_ref[12:]])

        # Compute the error terms
        x_error = x - x_ref

        x_error = x_error.at[3:7].set(self.quaternion_distance(x[3:7], x_ref[3:7]))
        u_error = kp*(u - x[7:19]) #+ kd*(x_ref[25:] - x[25:])
        # Compute the cost
        cost = jnp.dot(x_error, jnp.dot(Q, x_error)) + jnp.dot(u_error, jnp.dot(R, u_error))
        return cost

In [105]:
mppi = MPPI()

In [106]:
actions = mppi.perturb_action()

In [107]:
obs = mppi.x_ref

In [108]:
mppi.rollout_func(mppi.state_rollouts, actions, np.repeat([np.concatenate([[0],obs])], mppi.n_samples, axis=0), num_workers=mppi.num_workers, nstep=mppi.horizon)

In [109]:
mppi.gait_scheduler.roll()

In [110]:
states = mppi.state_rollouts[:,:,1:]

In [111]:
costs = mppi.cost_func(states, actions, 
                       mppi.gait_scheduler.indices[:mppi.horizon],
                       mppi.Q, mppi.R, mppi.gait_scheduler.gait, mppi.body_ref)

In [112]:
costs_sum = costs.sum(axis=0)
costs_sum

Array([36855.96 , 35630.77 , 37493.094, 36326.848, 37771.19 , 36509.465,
       37507.91 , 37759.492, 36800.926, 36881.44 , 37271.46 , 37890.79 ,
       36977.965, 37190.555, 36639.04 , 36238.504, 36328.05 , 37099.324,
       35415.598, 36718.27 , 36676.65 , 37018.465, 37894.37 , 37299.3  ,
       37787.516, 37765.56 , 36597.19 , 36721.27 , 37220.418, 36296.73 ],      dtype=float32)

In [113]:
def quaternion_distance_np(q1, q2):
    return 1 - np.abs(np.dot(q1,q2))

def quadruped_cost_np(x, u, x_ref, Q, R, quat_ref):
    kp = 30
    kd = 3

    # Compute the error terms
    x_error = x - x_ref

    q_dist = quaternion_distance_np(x[:, 3:7], quat_ref)
    x_error[:, 3] = q_dist
    x_error[:, 4] = q_dist
    x_error[:, 5] = q_dist
    x_error[:, 6] = q_dist

    u_error = kp * (u - x[:, 7:19]) #- kd * x[:, 25:]
    # Compute cost using einsum for precise matrix operations
    # Apply the matrix Q to x_error and R to u_error, sum over appropriate dimensions
    cost = np.einsum('ij,ik,jk->i', x_error, x_error, Q) + np.einsum('ij,ik,jk->i', u_error, u_error, R)
    return cost

def calculate_total_cost(states, actions, phase_time, joints_gait, body_ref):
    num_samples = states.shape[0]
    num_pairs = states.shape[1]

    # Flatten states and actions to two dimensions, treating all pairs per sample as a batch
    states = states.reshape(-1, states.shape[2])
    actions = actions.reshape(-1, actions.shape[2])
    
    joints_ref = joints_gait[:, phase_time]
    joints_ref = joints_ref.T
    joints_ref = np.tile(joints_ref, (30, 1, 1))
    joints_ref = joints_ref.reshape(-1, joints_ref.shape[2])
    body_ref = np.repeat(body_ref[np.newaxis,:], 1200, axis=0)
    # Transpose the repeated array

    x_ref = np.concatenate([body_ref[:,:7], joints_ref[:,:12], body_ref[:,7:], joints_ref[:,12:]], axis=1)

    # Compute batch costs
    costs = quadruped_cost_np(states, actions, x_ref, mppi.Q, mppi.R, mppi.x_ref[3:7])
    # Sum costs for each sample
    total_costs = costs.reshape(num_samples, num_pairs).sum(axis=1)

    return total_costs.round(2)

In [114]:
costs_np = calculate_total_cost(states, actions, 
                     mppi.gait_scheduler.indices[:mppi.horizon], 
                     mppi.gait_scheduler.gait, mppi.body_ref)
costs_np

array([36856.15, 35630.97, 37493.29, 36327.05, 37771.41, 36509.65,
       37508.12, 37759.69, 36801.14, 36881.65, 37271.66, 37890.99,
       36978.17, 37190.76, 36639.24, 36238.7 , 36328.24, 37099.54,
       35415.78, 36718.47, 36676.86, 37018.66, 37894.58, 37299.51,
       37787.73, 37765.76, 36597.39, 36721.47, 37220.61, 36296.92])

In [115]:
costs_np.shape

(30,)

In [116]:
cost_selc = costs_np

In [117]:
min_cost = np.min(cost_selc)
max_cost = np.max(cost_selc)

exp_weights = np.exp(-1/mppi.temperature * ((cost_selc - min_cost)/(max_cost - min_cost)))
weighted_delta_u = exp_weights.reshape(mppi.n_samples, 1, 1) * actions
weighted_delta_u = np.sum(weighted_delta_u, axis=0) / (np.sum(exp_weights) + 1e-10)

In [118]:
updated_actions = np.clip(weighted_delta_u, np.array(mppi.act_min), np.array(mppi.act_max))

In [119]:
# Pop out first action from the trajectory and repeat last action
trajectory = np.roll(updated_actions, shift=-1, axis=0)
trajectory[-1] = updated_actions[-1]

# Output first action (MPC)
action = updated_actions[0]  

In [120]:
action

array([ 0.08408074,  1.31533897, -2.6558156 ,  0.07340127,  1.40455151,
       -2.80647206,  0.09779222,  1.35396206, -2.70680928,  0.07467818,
        1.36480606, -2.79702997])

In [121]:
mppi.body_ref

Array([0.  , 0.  , 0.27, 1.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
       0.  , 0.  ], dtype=float32)

In [122]:
action

array([ 0.08408074,  1.31533897, -2.6558156 ,  0.07340127,  1.40455151,
       -2.80647206,  0.09779222,  1.35396206, -2.70680928,  0.07467818,
        1.36480606, -2.79702997])

In [94]:
joints_gait = mppi.gait_scheduler.gait

In [95]:
joints_gait[1]

array([0.9   , 0.9   , 0.9   , 0.9   , 0.9   , 0.9   , 0.9   , 0.9   ,
       0.9   , 0.9   , 0.9   , 0.9   , 0.9   , 0.9   , 0.9   , 0.9   ,
       1.0149, 1.0958, 1.148 , 1.1736, 1.1736, 1.148 , 1.0958, 1.0149,
       0.9   , 0.9   , 0.9   , 0.9   , 0.9   , 0.9   , 0.9   , 0.9   ,
       0.9   , 0.9   , 0.9   , 0.9   , 0.9   , 0.9   , 0.9   , 0.9   ,
       0.9   , 0.9   , 0.9   , 0.9   , 0.9   , 0.9   , 0.9   , 0.9   ,
       0.9   , 0.9   ])

In [96]:
phase_time = mppi.gait_scheduler.indices[:mppi.horizon]

In [97]:
joints_ref = joints_gait[:, phase_time]

In [98]:
joints_ref[1]

array([0.9   , 0.9   , 0.9   , 0.9   , 0.9   , 0.9   , 0.9   , 0.9   ,
       0.9   , 0.9   , 0.9   , 0.9   , 0.9   , 0.9   , 1.0149, 1.0958,
       1.148 , 1.1736, 1.1736, 1.148 , 1.0958, 1.0149, 0.9   , 0.9   ,
       0.9   , 0.9   , 0.9   , 0.9   , 0.9   , 0.9   , 0.9   , 0.9   ,
       0.9   , 0.9   , 0.9   , 0.9   , 0.9   , 0.9   , 0.9   , 0.9   ])

In [None]:
body_ref

In [None]:
joints_ref = joints_gait[:, phase_time]
joints_ref = joints_ref.T
joints_ref = np.tile(joints_ref, (30, 1, 1))
joints_ref = joints_ref.reshape(-1, joints_ref.shape[2])
body_ref = mppi.body_ref
body_ref = np.repeat(body_ref[np.newaxis,:], 1200, axis=0)

In [None]:
x_ref = np.concatenate([body_ref[:,:7], joints_ref[:,:12], body_ref[:,7:], joints_ref[:,12:]], axis=1)

In [None]:
body_ref.shape

In [None]:
x_ref.shape

In [None]:
states.shape

In [None]:
states_r = states.reshape(-1, states.shape[2])
actions_r = actions.reshape(-1, actions.shape[2])

In [None]:
states_r.shape

In [None]:
joints_ref = joints_gait[:, phase_time]

In [None]:
joints_ref = joints_ref.T

In [None]:
joints_ref.shape

In [None]:
joints_ref = np.tile(joints_ref, (30, 1, 1))

In [None]:
joints_ref = joints_ref.reshape(-1, joints_ref.shape[2])

In [None]:
joints_ref.shape