In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import jax
import jax.numpy as jnp
import tqdm
import flax
import time
import deluca.core
import matplotlib.pyplot as plt
from ilqr import iLQR
from deluca.envs import PlanarQuadrotor
from wilhelm.sim.quadrotor import PlanarQuadrotor as wpq
import numpy as np



In [3]:
def cost(x, u, sim):
    return 0.1 * (u - sim.goal_action) @ (u - sim.goal_action) + (x.arr - sim.goal_state) @ (x.arr - sim.goal_state)

env = PlanarQuadrotor.create()
#Initial actions
U0 = jnp.tile(env.goal_action, (env.H, 1))
# Warm up with either file or iLQR
# warmup_steps = 8
# X, U, k, K, c = iLQR(env_sim, cost, U0, warmup_steps, verbose=False)

In [4]:
## First comparison - simple one step env evaluation
state = env.init()
for i in range(5):
    s = time.time()
    state = env(state, U0[0])
    print("Time", time.time()-s)
    
## Comparing to wilhelm
wenv = wpq()
wenv.reset()
U0 = np.tile(np.zeros_like(wenv.goal_action), (wenv.H, 1))
for i in range(5):
    s = time.time()
    wenv.step(U0[0])
    print("Time", time.time()-s)
    
    
##### Notice a 20x slowdown in the forward step
##### Fixing this is the first step

Time 0.14448881149291992
Time 0.009276866912841797
Time 0.005244255065917969
Time 0.0048258304595947266
Time 0.004687786102294922
Time 0.04680180549621582
Time 0.0003101825714111328
Time 0.0004787445068359375
Time 0.00029277801513671875
Time 0.0004856586456298828


In [8]:
## First comparison - simple one step env evaluation
state = env.init()
for i in range(5):
    s = time.time()
    state = env(state, U0[0])
    _ = jax.jacfwd(env, argnums=(0,1))(state, U0[0])
    print("Time", time.time()-s)
    
## Comparing to wilhelm
wenv = wpq()
wenv.reset()
U0 = np.tile(np.zeros_like(env.goal_action), (env.H, 1))
for i in range(5):
    s = time.time()
    state, _ = wenv.step(U0[0])
    _, _ = wenv.f_x(state, U0[0]), wenv.f_u(state, U0[0])
    print("Time", time.time()-s)

### Notice about a 100X time difference now

Time 0.02914571762084961
Time 0.03044605255126953
Time 0.03442502021789551
Time 0.030633926391601562
Time 0.029471158981323242
Time 0.17121410369873047
Time 0.00043892860412597656
Time 0.0004811286926269531
Time 0.0004947185516357422
Time 0.00035309791564941406
