In [1]:
import os
import yaml
import numpy as np
from scipy.interpolate import CubicSpline
import jax
import jax.numpy as jnp
import mujoco
from mujoco import mjx

In [2]:
from functools import partial
import chex

In [3]:
cpu_device = jax.devices('cpu')[0]
cpu_device

CpuDevice(id=0)

In [4]:
gpu_device = jax.devices('gpu')[0]
gpu_device

cuda(id=0)

In [5]:
#mj_model = mujoco.MjModel.from_xml_path("../models/go1/go1_scene_jax_no_collision.xml")

In [6]:
def load_rollout_jax(step_fn):
    def rollout_aux(obs, actions):
        carry = (obs)
        _, output = jax.lax.scan(f=step_fn, init=carry, xs=actions)
        return output
    func = jax.jit(jax.vmap(rollout_aux, in_axes=(None, 0)))
    return func

In [9]:
class MPPI_JAX:
    def __init__(self, 
                 model_path = "../models/go1/go1_scene_jax_no_collision.xml",
                 config_path="configs/mppi.yml") -> None:
        # load the configuration file
        with open(config_path, 'r') as file:
            params = yaml.safe_load(file)

        # load model
        self.model = mujoco.MjModel.from_xml_path(model_path)
        #self.model.opt.solver = mujoco.mjtSolver.mjSOL_CG
        #self.model.opt.iterations = 6
        #self.model.opt.ls_iterations = 6
    
        self.mjx_model = mjx.device_put(self.model)
        self.mjx_data = mjx.make_data(self.mjx_model)
        self.model.opt.timestep = params['dt']

        # mppi controller configuration
        self.temperature = params['lambda']
        self.horizon = params['horizon']
        self.n_samples = params['n_samples']
        self.noise_sigma = jnp.array(params['noise_sigma'])
        self.num_workers = params['n_workers']
        self.sampling_init = jnp.array([0.073,  1.34 , -2.83 ,  
                                        0.073,  1.34 , -2.83 ,  
                                        0.073,  1.34 , -2.83 ,  
                                        0.073,  1.34 , -2.83 ])
        
        # Cost
        self.Q = jnp.diag(jnp.array(params['Q_diag']))
        self.R = jnp.diag(jnp.array(params['R_diag']))
        self.x_ref = jnp.concatenate([jnp.array(params['q_ref']), jnp.array(params['v_ref'])])

        # Get env parameters
        self.act_dim = 12
        self.act_max = jnp.array([0.863, 4.501, -0.888]*4)
        self.act_min = jnp.array([-0.863, -0.686, -2.818]*4)
        
        # Rollouts
        self.h = params['dt']
        self.sample_type = params['sample_type']
        self.n_knots = params['n_knots']
        self.rollout_func = jax.jit(self.rollout_jax(), device=gpu_device)
        self.random_generator = np.random.default_rng(params["seed"])
        
        self.trajectory = None
        self.reset_planner() 
        self.update(self.x_ref)
        self.reset_planner()     
    
    def rollout_jax(self):
        def step_wrapper_mujoco(carry, action):
            obs = carry
            data = mjx.make_data(self.mjx_model)
            data = data.replace(qpos=obs.qpos, qvel=obs.qvel, ctrl=action)
            data = mjx.step(self.mjx_model, data)
            
            next_obs = jnp.concatenate([data.qpos, data.qvel])
            cost = self.quadruped_cost(next_obs, action)
            carry = data
            output = (next_obs, cost)
            return carry, output
        return load_rollout_jax(step_wrapper_mujoco)
    
    def reset_planner(self):
        self.trajectory = np.zeros((self.horizon, self.act_dim))
        self.trajectory += self.sampling_init
            
    def generate_noise(self, size):
        return self.random_generator.normal(size=size) * self.noise_sigma
    
    def sample_delta_u(self):
        if self.sample_type == 'normal':
            size = (self.n_samples, self.horizon, self.act_dim)
            return self.generate_noise(size)
        elif self.sample_type == 'cubic':
            indices = np.arange(self.n_knots)*self.horizon//self.n_knots
            size = (self.n_samples, self.n_knots, self.act_dim)
            knot_points = self.generate_noise(size)
            cubic_spline = CubicSpline(indices, knot_points, axis=1)
            return cubic_spline(np.arange(self.horizon))
        
    def perturb_action(self):
        if self.sample_type == 'normal':
            size = (self.n_samples, self.horizon, self.act_dim)
            actions = self.trajectory + self.generate_noise(size)
            actions = np.clip(actions, self.act_min, self.act_max)
            return actions
        
        elif self.sample_type == 'cubic':
            indices_float = jnp.linspace(0, self.horizon - 1, num=self.n_knots)
            indices = jnp.round(indices_float).astype(int)
            size = (self.n_samples, self.n_knots, self.act_dim)
            knot_points = self.trajectory[indices] + self.generate_noise(size)
            cubic_spline = CubicSpline(indices, knot_points, axis=1)
            actions = cubic_spline(np.arange(self.horizon))
            actions = np.clip(actions, self.act_min, self.act_max)
            return actions
        
    def update(self, obs): 
        self.mjx_data = self.mjx_data.replace(qpos=obs[:19], qvel=obs[19:])
        actions = jnp.array(self.perturb_action())
        #self.rollout_func(self.state_rollouts, actions, np.repeat([np.concatenate([[0],obs])], self.n_samples, axis=0), num_workers=self.num_workers, nstep=self.horizon)
        _, costs = self.rollout_func(self.mjx_data, actions)
        costs_sum = costs.sum(axis=1)
        #print(costs_sum)
        costs_sum = jnp.where(jnp.isnan(costs_sum), 10000000, costs_sum)
        print(costs_sum)
        # MPPI weights calculation
        ## Scale parameters
        min_cost = np.min(costs_sum)
        max_cost = np.max(costs_sum)
        
        exp_weights = jnp.exp(-1/self.temperature * ((costs_sum - min_cost)/(max_cost - min_cost)))
        #print(exp_weights)
        weighted_delta_u = exp_weights.reshape(self.n_samples, 1, 1) * actions
        weighted_delta_u = jnp.sum(weighted_delta_u, axis=0) / (jnp.sum(exp_weights) + 1e-10)
        updated_actions = jnp.clip(weighted_delta_u, self.act_min, self.act_max)
    
        # Pop out first action from the trajectory and repeat last action
        self.trajectory = jnp.roll(updated_actions, shift=-1, axis=0)
        #self.trajectory[-1] = updated_actions[-1]
        self.trajectory = self.trajectory.at[-1].set(updated_actions[-1])

        # Output first action (MPC)
        action = updated_actions[0] 
        return action
    
    def quaternion_distance(self, q1, q2):
        return 1 - jnp.abs(jnp.dot(q1,q2))
    
    def quadruped_cost(self, x, u):
        kp = 40
        kd = 3
        # Compute the error terms
        x_error = x - self.x_ref
        # Assuming quaternion_distance is a function you've defined elsewhere
        x_error = x_error.at[3:7].set(self.quaternion_distance(x[3:7], self.x_ref[3:7]))
        u_error = kp*(u - x[7:19]) - kd*x[25:]

        # Compute the cost
        cost = jnp.dot(x_error, jnp.dot(self.Q, x_error)) + jnp.dot(u_error, jnp.dot(self.R, u_error))
        return cost

In [10]:
mppi = MPPI_JAX()

[5755.0044 5612.053  5729.599  5635.6353 5629.41   5794.2524 5666.199
 5463.289  5457.498  5324.941  5482.0264 5707.7344 5788.5176 5702.5234
 5535.9243 5469.3076 5528.4365 5650.5728 5709.273  5655.968  5716.8506
 5806.6113 5499.584  5775.3623 5394.427  5604.5566 5386.8857 5561.8193
 5538.831  5849.9697 5639.3154 5299.9834 5730.534  5379.0425 5723.2266
 5509.3506 5678.961  5836.605  5597.151  5601.8955 5681.7856 5793.5635
 5642.7627 5563.2314 5547.4863 5436.945  5409.2188 5503.5474 5138.695
 5529.458  5442.4814 5604.9688 5716.3877 5521.6943 5823.8203 5850.446
 5569.4316 5456.9355 5682.017  5414.083  5727.106  5730.038  5514.164
 5741.2573 5685.5283 5784.166  5639.78   5512.339  5664.751  5722.5444
 5786.751  5399.8926 5799.9023 5690.748  5530.708  5754.835  5713.2354
 5486.162  5648.7754 5643.619  5559.2144 5276.6416 5776.8535 5550.7837
 5709.927  5490.357  5704.3823 5726.0864 5826.467  5483.064  5814.756
 5614.274  5659.934  5684.5234 5365.405  5531.5654 5728.169  5713.175
 5683.9297 5

In [11]:
actions = mppi.perturb_action()

In [12]:
actions.shape

(100, 40, 12)

In [13]:
%%timeit
mppi.update(mppi.x_ref)

[5621.57   5612.879  5810.8096 5543.3574 5748.824  5635.465  5658.918
 5682.5684 5464.7676 5595.5586 5631.352  5877.8965 5804.311  5689.2227
 5568.3467 5716.0195 5580.968  5745.2983 5755.29   5557.9297 5624.801
 5576.0186 5802.3306 5599.1885 5279.2476 5551.431  5676.8037 5628.706
 5481.374  5752.319  5685.6724 5712.2485 5798.0303 5581.219  5678.09
 5586.778  5552.1943 5685.5303 5712.353  5624.611  5609.461  5434.927
 5691.3066 5626.205  5254.843  5681.284  5623.276  5577.6426 5687.3325
 5702.8486 5587.2246 5537.1216 5705.6704 5628.746  5709.2676 5786.259
 5779.007  5711.1562 5724.8154 5813.979  5694.241  5879.846  5470.8057
 5749.015  5753.082  5671.6133 5455.1904 5682.1245 5445.628  5619.025
 5751.874  5802.785  5654.4814 5593.7246 5753.1006 5714.251  5751.5137
 5670.784  5551.373  5521.271  5675.748  5565.4053 5756.948  5529.6904
 5569.1895 5670.8174 5725.73   5785.1333 5784.159  5686.909  5727.1973
 5792.213  5647.38   5790.763  5853.154  5775.8115 5467.7305 5635.218
 5484.636  5441

[1791.3896 1541.1484 1794.8928 1598.5979 1755.691  1700.3197 1700.3032
 1847.7329 1819.1113 1528.0872 1673.3123 1973.6049 1824.3586 1671.8527
 1857.5479 1798.0701 1760.0032 1869.3186 1937.6985 1528.0408 1742.7758
 1669.7302 1687.267  1863.4363 2009.3425 1546.1162 1804.4188 1728.8066
 1602.5554 1700.2446 1727.8342 1771.3965 1790.8958 1864.2256 1761.5481
 1704.0503 1739.5978 1674.3845 1776.6768 1618.6646 1658.0942 1708.6154
 1862.7118 1563.8138 1821.7926 1585.0331 1834.3997 1811.1038 1694.6194
 1684.4996 1756.9138 1684.3018 1754.5732 1816.552  1726.782  1701.7083
 1648.2498 1714.4052 1762.8116 1930.689  1762.2184 1570.0249 1601.9197
 1700.6775 1688.1515 1832.1089 1853.7026 1898.8015 1726.6632 1597.5355
 1752.2146 1775.8428 1696.8799 1608.5049 1642.0435 1684.2085 1871.6025
 1990.3278 1685.9841 1671.1963 1660.3984 1630.3773 1508.265  1551.0847
 1705.053  1765.3738 1674.6747 1736.5773 1762.268  1764.7146 1762.3828
 1535.764  1838.0552 1638.9016 1488.522  1585.1631 1555.9453 1598.4524
 1414.

[585.6542  537.7245  473.66446 556.34827 469.6872  690.3432  583.4259
 526.0316  566.04877 685.8231  486.42307 505.19858 559.69086 522.2412
 609.3562  624.7559  638.06165 551.1677  502.24927 504.52368 587.3807
 636.1343  533.3473  654.208   533.44855 558.8708  608.2573  706.20544
 688.12366 594.17004 688.7191  530.4713  522.7029  542.3462  539.4076
 569.51465 495.51276 613.5262  523.2072  509.23932 551.0161  496.71494
 595.2993  459.59116 599.6782  636.6619  570.9723  557.57544 613.5465
 584.58575 555.2263  546.3635  585.59033 568.2394  505.0704  522.47327
 564.42786 602.5681  561.8084  543.7338  540.8853  628.70685 508.15784
 594.4619  598.264   560.90985 570.1887  625.9621  526.1769  606.1743
 617.45184 503.70496 703.6937  605.02856 668.143   535.7136  543.5919
 513.83673 589.2575  566.42847 593.823   604.8624  670.4485  633.2678
 620.2528  554.5366  623.08435 552.48975 597.3753  613.81805 676.328
 681.5969  655.4558  557.7605  586.4234  536.3236  555.5908  643.16815
 528.042   528.2

[268.3079  211.10101 234.22821 228.44275 239.61897 257.14767 257.56436
 275.55554 206.57599 241.0956  278.84015 199.80453 262.76852 317.35843
 321.16745 252.4335  229.80698 280.84494 230.29309 248.74664 249.6903
 248.05472 297.4812  208.90923 261.17706 239.0585  272.5951  328.53375
 294.35797 232.48132 294.00415 260.01953 283.03003 244.18149 298.82092
 206.52936 274.4236  268.0188  297.9643  261.26434 209.21155 291.99466
 229.24632 295.0752  237.63867 221.73418 216.8401  207.42648 273.09644
 256.66803 266.0456  276.03894 222.97421 258.67346 229.36818 292.00854
 256.5147  295.96228 221.78705 213.2287  216.86209 265.91788 282.67465
 270.1658  296.94812 243.89798 249.35904 282.08826 269.52655 268.60333
 268.71674 208.8226  269.65472 303.89813 250.22339 248.78091 248.3724
 216.84387 251.60527 287.94962 271.381   248.44467 228.54214 271.99084
 263.48532 248.48944 252.58847 208.33789 235.16664 190.14177 233.6275
 249.52936 269.92322 272.50604 251.69261 265.90668 233.57205 270.55835
 258.5369

[173.1932  182.19753 188.68121 221.21239 175.08574 152.77576 203.1826
 198.5502  203.2547  225.89577 185.50903 215.27786 215.61198 212.02032
 214.37564 171.36685 234.75597 210.73402 229.95831 174.91096 254.80116
 205.09656 172.54825 237.59804 239.31927 210.46423 188.8183  171.80148
 213.58435 176.40839 208.4572  192.22766 160.4448  257.00967 179.42664
 207.55264 199.59938 229.09523 212.9964  206.27545 187.54388 184.66553
 187.93204 192.92978 213.31638 249.66577 210.80945 180.68617 203.5995
 175.37985 243.45593 190.0785  167.41174 196.5618  197.92471 203.71698
 212.05533 192.35992 179.82745 255.56372 271.9878  206.59674 164.59622
 199.00621 196.85309 188.5262  225.09674 202.51257 251.46252 185.40338
 175.37485 185.98593 209.58517 225.68683 217.35825 208.94775 246.63626
 201.08081 195.42607 195.98616 181.42084 177.46454 199.60165 177.7212
 178.30212 226.3558  222.80238 207.00772 197.30574 190.9358  250.47287
 200.61647 203.46693 213.09839 245.31487 204.98691 211.54817 177.38513
 179.7243

[140.34596 156.06436 207.3523  157.13828 162.50085 144.47574 153.31427
 132.0105  157.0281  146.98996 138.76248 221.83795 171.0485  142.76555
 238.8874  160.49207 145.4094  185.98979 118.26272 131.38208 132.56651
 152.99194 131.81522 144.05112 155.01607 191.00291 193.65958 158.24477
 144.53557 152.59784 145.66208 132.59894 166.66582 142.66324 160.31433
 148.97792 137.06741 123.95628 147.0737  147.51616 152.07736 146.51732
 132.96222 132.05267 153.97275 141.64594 146.28432 155.7799  195.35818
 158.90616 176.24896 157.33269 147.23889 132.93427 150.53232 138.85944
 150.32281 146.01384 158.34631 164.1086  148.45483 182.3988  131.42116
 203.0475  159.89996 145.60689 193.19278 144.48868 180.142   139.98233
 191.09998 150.50282 172.81064 158.00754 138.55551 157.05527 152.78989
 143.74217 189.38766 150.11768 152.7397  146.39365 201.08778 164.42694
 180.78116 142.97577 149.61229 129.19063 174.23805 154.11429 151.37729
 167.29193 160.04639 119.86655 190.268   172.10364 162.72313 157.13725
 153.8

[116.952896 130.80289   98.58923   98.45128  216.20828  146.54738
 148.45253  111.00552  141.45984  127.70859  123.68995  126.09367
 145.1695    96.10728  144.9894   128.84135  153.12161   99.99641
 113.50164  106.91638  114.39163  105.66717  139.9016   114.77598
 147.07011  143.9357   109.29763  189.69025  159.87848  130.97992
 159.26727  140.71555  131.29248  133.17725  133.63428  125.796875
 124.58899  130.02054  175.3464   158.1265   136.74287  150.65103
 148.57066  113.80167  137.27301  112.1546   131.2829   134.48703
  96.90603  100.69676  113.55762  119.857315 137.41733  100.03863
 148.7539   153.4611   160.81589  140.66663   98.423325 133.64867
 104.42903  138.45906  175.03331   97.71896  137.80603  114.043
 123.261154 125.36328  140.59015  114.49893  114.428894  95.725716
 154.02927  107.49891  115.451385 125.64086  128.54175  127.90177
 113.67996  146.21442  150.13287  169.22827  144.01743  124.21576
 134.05544  114.99161   94.728745 152.59506  109.58989  114.165985
 134.6968

[122.515884 113.19801  120.586945  99.7771    99.7021   102.46915
  96.64566  131.61119  121.87875  113.999084 148.31871  114.54704
  98.55452  119.226746  93.20596  120.45735  130.61647  129.67377
 101.17589  146.56514  114.22969  109.9777   121.07597   98.990486
 133.29852  129.02887  133.39503  110.90906  107.74786  128.46863
  98.5519   141.41945  111.72896  113.54021  115.20438  121.33074
  97.386665 142.3165    98.07599  137.8879   106.06935  124.48489
 147.69429  134.05005  115.32524   97.80698  103.84916  101.48296
 140.13977  102.11631  103.23014  139.04977   94.28691  150.84122
 139.82416  138.56375  118.46147  128.7652   126.65037  119.69483
 123.913315 186.40187  102.37739  154.68188  153.87448  140.27402
 111.98705  136.6536   129.2216   168.98816  117.51078  102.558105
 159.45386  158.69872  127.430115 120.57718  196.37195   97.56033
  96.73328  138.58588  113.20806  119.43838  112.25016  121.31328
 116.812294 139.84454  138.55222  210.3527   114.488    140.525
 134.23337

[158.35309  142.4805   132.34715  127.93594  164.9069   149.88013
 130.55255  149.43575  141.7776   163.89886  158.61977  144.7987
 150.05225  136.34648  126.000145 166.59366  139.75365   99.24781
 177.17068   94.69651  131.05884  153.08932  158.24866  192.56456
 155.98483  155.69641  128.44781  137.87549  137.24667  112.87265
 168.63164  118.75335  136.15086  107.78511  146.77182  150.01282
 162.52075  150.24445  142.18752  133.67534  122.06726  153.07693
 128.05849  162.57545  173.54803  128.93457  148.79163  176.02538
 138.98529  129.4041   126.4048   197.62483  102.26553  135.67346
 130.68527  193.1041   137.80655  126.54107  112.57379  127.09691
 142.66457  152.09656  146.74359  123.2479   127.27257  159.16869
 139.9605   110.16148  132.21434  138.11575  122.57822  130.17174
 133.95242  131.61707  136.83995  175.89755  128.60559  118.55409
 131.79147  129.90027  114.55659  123.28914  147.71997  159.71988
 135.57117  158.22014  133.09105  186.70477  146.7626   130.85275
 130.9592  

[116.29944  144.9821   153.80904  131.74356  189.85217  186.42865
 129.57028  115.00534  136.30116  168.51724  304.16797  145.71942
 173.84653  122.62685  153.53265  198.15067  139.38889  123.68922
 128.08722  159.50452  135.93675  163.74991  157.332    168.22672
 113.55482  190.12514  148.46541  158.5791   193.12294  177.69637
 184.28326  155.51732  251.15744  142.05342  170.18367  152.04517
 126.16514  146.78178  163.81532  206.54164  158.32623  122.97243
 136.32056  115.23544  141.05588  132.85254  139.95224  153.29971
 124.604095 155.2508   164.88982  183.84497   95.40164  167.61688
 153.33302  166.92807  142.85728  120.683174 134.12619  186.91187
 152.1948   157.64636  124.15433  141.56735  119.40737  165.91064
 131.49176  183.70874  145.73811  133.64722  173.24124  183.26715
 181.13837  126.92894  100.44562  145.36566  128.82074  139.90657
 113.542755 171.71588  174.0469   142.74988  151.00021  155.38205
 154.36105  149.04218  153.73941  132.27339  110.26889  163.55371
 154.78497

In [14]:
obs = mppi.x_ref

In [15]:
actions[0,0].shape

(12,)

In [16]:
mppi.mjx_data = mppi.mjx_data.replace(qpos=obs[:19], qvel=obs[19:], ctrl=actions[0,0])

In [None]:
mppi.mjx_data.qvel

In [None]:
mppi.mjx_data.qvel

In [None]:
costs = mppi.update(mppi.x_ref)

In [None]:
costs

In [None]:
mppi.mjx_data.qpos.shape

In [None]:
mppi.mjx_data.qvel.shape

In [None]:
mppi.x_ref

In [None]:
mppi.mjx_data.qpos

In [None]:
from matplotlib.animation import FuncAnimation
from IPython.display import HTML
import mujoco_viewer

In [None]:
import matplotlib.pyplot as plt

In [None]:
import copy as cp

In [None]:
mujoco.mjtSolver.mjSOL_CG

In [None]:
#model_sim = mujoco.MjModel.from_xml_path("../models/go1/go1_scene_jax_no_collision.xml")
model_sim = mujoco.MjModel.from_xml_path("../models/go1/scene_opt_pd.xml")

In [None]:
model_sim.opt.solver

In [None]:
model_sim.opt.iterations

In [None]:
model_sim.opt.ls_iterations

In [None]:
dt_sim = 0.01
model_sim.opt.timestep = dt_sim

data_sim = mujoco.MjData(model_sim)

In [None]:
viewer = mujoco_viewer.MujocoViewer(model_sim, data_sim, 'offscreen')

In [None]:
# reset robot (keyframes are defined in the xml)
mujoco.mj_resetDataKeyframe(model_sim, data_sim, 1) # stand position
mujoco.mj_forward(model_sim, data_sim)
q_init = cp.deepcopy(data_sim.qpos) # save reference pose
v_init = cp.deepcopy(data_sim.qvel) # save reference pose

In [None]:
print("Configuration: {}".format(q_init)) # save reference pose

In [None]:
img = viewer.read_pixels()
plt.imshow(img)

In [None]:
# reset robot (keyframes are defined in the xml)
mujoco.mj_resetDataKeyframe(model_sim, data_sim, 0) # stand position
mujoco.mj_forward(model_sim, data_sim)
q_ref_mj = cp.deepcopy(data_sim.qpos) # save reference pose
v_ref_mj = cp.deepcopy(data_sim.qvel) # save reference pose

In [None]:
print("Configuration: {}".format(q_ref_mj)) # save reference pose

In [None]:
img = viewer.read_pixels()
plt.imshow(img)

In [None]:
q_curr = cp.deepcopy(data_sim.qpos) # save reference pose
v_curr = cp.deepcopy(data_sim.qvel) # save reference pose
x = jnp.concatenate([q_curr, v_curr])

In [None]:
tfinal = 5
tvec = jnp.linspace(0,tfinal,int(jnp.ceil(tfinal/dt_sim))+1)

In [None]:
mujoco.mj_resetDataKeyframe(model_sim, data_sim, 1)
mujoco.mj_forward(model_sim, data_sim)

In [None]:
img = viewer.read_pixels()
plt.imshow(img)

In [None]:
mppi.reset_planner()

In [None]:
mppi.trajectory

In [None]:
%%time
anim_imgs = []
sim_inputs = []
for ticks, ti in enumerate(tvec):
    #if ticks % 1 == 0:
    q_curr = cp.deepcopy(data_sim.qpos) # save reference pose
    v_curr = cp.deepcopy(data_sim.qvel) # save reference pose
    x = jnp.concatenate([q_curr, v_curr])
    u_joints = mppi.update(x)    
    data_sim.ctrl[:] = u_joints
    mujoco.mj_step(model_sim, data_sim)
    mujoco.mj_forward(model_sim, data_sim)
    img = viewer.read_pixels()
    anim_imgs.append(img)
    sim_inputs.append(u_joints)

In [None]:
fig, ax = plt.subplots()
skip_frames = 10
interval = dt_sim*1000*skip_frames

def animate(i):
    ax.clear()
    ax.imshow(anim_imgs[i * skip_frames])  # Display the image, skipping frames
    ax.axis('off')

# Create animation, considering the reduced frame rate due to skipped frames
ani = FuncAnimation(fig, animate, frames=len(anim_imgs) // skip_frames, interval=interval)  # 50 ms for 20 Hz

# Display the animation
HTML(ani.to_jshtml())

In [None]:
#mppi.reset_planner()

In [None]:
actions = jnp.array(mppi.perturb_action())

In [None]:
_, costs = mppi.rollout_func(mppi.mjx_data, actions) 

In [None]:
costs