In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import jax
import jax.numpy as jnp
import tqdm
import flax
import time
import deluca.core
import matplotlib.pyplot as plt
from deluca.igpc.ilqr import iLQR
from deluca.envs import PlanarQuadrotor
from wilhelm.sim.quadrotor import PlanarQuadrotor as wpq
import numpy as np



In [3]:
def cost(x, u, sim):
    return 0.1 * (u - sim.goal_action) @ (u - sim.goal_action) + (x.arr - sim.goal_state) @ (x.arr - sim.goal_state)

env = PlanarQuadrotor.create()
#Initial actions
U0 = jnp.tile(env.goal_action, (env.H, 1))
# Warm up with either file or iLQR
# warmup_steps = 8
# X, U, k, K, c = iLQR(env_sim, cost, U0, warmup_steps, verbose=False)

# First comparison - simple one step env evaluation

## deluca

In [19]:
%%timeit
state = env.init()
for i in range(1000):
    state, _ = env(state, U0[0])

4.43 s ± 114 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


## deluca with jax.lax.scan

In [28]:
@jax.jit
def loop(carry, args):
    env, state = carry
    state, _ = env(state, U0[0])
    return (env, state), args

In [29]:
%%timeit
jax.lax.scan(loop, (env, env.init()), jnp.array(list(range(1000))))

1.76 ms ± 42.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


## wilhelm

In [22]:
%%timeit
wenv = wpq()
wenv.reset()
U0 = np.tile(np.zeros_like(wenv.goal_action), (wenv.H, 1))
for i in range(1000):
    wenv.step(U0[0])

88.2 ms ± 643 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [6]:
## First comparison - simple one step env evaluation
print("deluca")
state = env.init()
for i in range(5):
    s = time.time()
    state, _ = env(state, U0[0])
    print("Time", time.time()-s)
    
## Comparing to wilhelm
print("wilhelm")
wenv = wpq()
wenv.reset()
U0 = np.tile(np.zeros_like(wenv.goal_action), (wenv.H, 1))
for i in range(5):
    s = time.time()
    wenv.step(U0[0])
    print("Time", time.time()-s)
    
    
##### Notice a 20x slowdown in the forward step
##### Fixing this is the first step

deluca
Time 0.006835222244262695
Time 0.0058498382568359375
Time 0.0055637359619140625
Time 0.005511045455932617
Time 0.00557708740234375
wilhelm
Time 0.04752230644226074
Time 0.0003154277801513672
Time 0.00018477439880371094
Time 8.749961853027344e-05
Time 0.00018525123596191406


# Second comparison - env + jacobian

## deluca

In [26]:
%%timeit
state = env.init()
for i in range(1000):
    state, _ = env(state, U0[0])
    _ = jax.jacfwd(env, argnums=(0,1))(state, U0[0])

32.9 s ± 641 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


## deluca with jax.lax.scan

In [30]:
@jax.jit
def loop(carry, args):
    env, state = carry
    state, _ = env(state, U0[0])
    _ = jax.jacfwd(env, argnums=(0,1))(state, U0[0])
    return (env, state), args

In [31]:
%%timeit
jax.lax.scan(loop, (env, env.init()), jnp.array(list(range(1000))))

1.84 ms ± 5.56 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


## wilhelm

In [27]:
%%timeit
wenv = wpq()
wenv.reset()
U0 = np.tile(np.zeros_like(env.goal_action), (env.H, 1))
for i in range(1000):
    state, _ = wenv.step(U0[0])
    _, _ = wenv.f_x(state, U0[0]), wenv.f_u(state, U0[0])

267 ms ± 2.26 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [8]:
## First comparison - simple one step env evaluation
state = env.init()
for i in range(5):
    s = time.time()
    state = env(state, U0[0])
    _ = jax.jacfwd(env, argnums=(0,1))(state, U0[0])
    print("Time", time.time()-s)
    
## Comparing to wilhelm
wenv = wpq()
wenv.reset()
U0 = np.tile(np.zeros_like(env.goal_action), (env.H, 1))
for i in range(5):
    s = time.time()
    state, _ = wenv.step(U0[0])
    _, _ = wenv.f_x(state, U0[0]), wenv.f_u(state, U0[0])
    print("Time", time.time()-s)

### Notice about a 100X time difference now

Time 0.02914571762084961
Time 0.03044605255126953
Time 0.03442502021789551
Time 0.030633926391601562
Time 0.029471158981323242
Time 0.17121410369873047
Time 0.00043892860412597656
Time 0.0004811286926269531
Time 0.0004947185516357422
Time 0.00035309791564941406
