In [182]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [183]:
import jax
import jax.numpy as jnp
from deluca.igpc.ilqr import iLQR
from deluca.envs import PlanarQuadrotor

from deluca.core import Agent
from deluca.core import field
from deluca.core import Obj
from functools import partial

In [202]:
class OpenLoopState(Obj):
    arr: jnp.ndarray = field(jaxed=True)
    h: float = field(0.0, jaxed=True)

class OpenLoop(Agent):

    def init(self, actions):
        return OpenLoopState(arr=actions)

    def setup(self):
        self.decay = self.dt / (self.dt + self.RC)

    def __call__(self, state, obs):
        action = jax.lax.dynamic_slice(state.arr, (state.h.astype(int), 0), (1, state.arr.shape[1]))[0]

        return OpenLoopState(arr=state.arr, h=state.h+1), action
    
class LQGAgentState(Obj):
    k: jnp.ndarray = field(jaxed=True)
    K: jnp.ndarray = field(jaxed=True)
    X_ref: jnp.ndarray = field(jaxed=True)
    U_ref: jnp.ndarray = field(jaxed=True)
    alpha: float = field(0.0, jaxed=True)
    h: float = field(0.0, jaxed=True)
    
class LQGAgent(Agent):

    def init(self, actions):
        return OpenLoopState(arr=actions)

    def setup(self):
        self.decay = self.dt / (self.dt + self.RC)

    def __call__(self, state, obs):
        k = jax.lax.dynamic_slice(state.k, (state.h.astype(int), 0), (1, state.arr.shape[1]))[0]
        K = jax.lax.dynamic_slice(state.K, (state.h.astype(int), 0), (1, state.arr.shape[1]))[0]
        X_ref = jax.lax.dynamic_slice(state.X_ref, (state.h.astype(int), 0), (1, state.arr.shape[1]))[0]
        U_ref = jax.lax.dynamic_slice(state.U_ref, (state.h.astype(int), 0), (1, state.arr.shape[1]))[0]
        action = U_ref + state.alpha*k + K@(obs-X_ref)

        return OpenLoopState(k=state.arr, K=state.K, X_ref=state.X_ref, U_ref.state.U_ref, alpha=state.alpha, h=state.h+1), action

In [203]:
@partial(jax.jit, static_argnums=(5,6,))
def rollout(env, env_start_state, env_init_obs, agent, agent_start_state, H, cost_func):
    
    def rollout_step(inp, counter):
        env_state, env_obs, agent_state = inp
        agent_next_state, action = agent(agent_state, env_obs)
        #print(action)
        env_next_state, env_next_obs = env(env_state, action)
        cost = cost_func(env_state, action, env)
        return (env_next_state, env_next_obs, agent_next_state), (env_next_state, env_next_obs, agent_next_state, action, cost)
    
    _, Rvals  = jax.lax.scan(rollout_step, (env_start_state, env_init_obs, agent_start_state), jnp.arange(H))
    return Rvals

In [204]:
def cost(x, u, sim):
    return 0.1 * (u - sim.goal_action) @ (u - sim.goal_action) + (x.arr - sim.goal_state) @ (x.arr - sim.goal_state)

env = PlanarQuadrotor.create()
env_state, env_obs = env.init()
#Initial actions
U0 = jnp.tile(jnp.zeros_like(env.goal_action), (env.H, 1))

In [205]:
U = OpenLoop()
U_state = U.init(U0)

In [207]:
import time 
for _ in range(10):
    s = time.time()
    L=rollout(env, env_state, env_obs, U, U_state, env.H, cost)
    states = L[0]
    print(time.time()-s)

0.0014836788177490234
0.0022170543670654297
0.001252889633178711
0.0008630752563476562
0.0009517669677734375
0.0009610652923583984
0.0010900497436523438
0.0075299739837646484
0.0007767677307128906
0.0007150173187255859


In [189]:
jnp.sum(L[4])

DeviceArray(357221.4, dtype=float32)

In [124]:
import jax
jax.tree_util.tree_flatten(states)

([DeviceArray([[ 1.0000000e+00,  1.0000000e+00,  0.0000000e+00,
                 0.0000000e+00, -1.7881394e-08,  0.0000000e+00],
               [ 1.0000000e+00,  1.0000000e+00,  0.0000000e+00,
                 0.0000000e+00, -3.5762788e-08,  0.0000000e+00],
               [ 1.0000000e+00,  1.0000000e+00,  0.0000000e+00,
                 0.0000000e+00, -5.3644182e-08,  0.0000000e+00],
               [ 1.0000000e+00,  1.0000000e+00,  0.0000000e+00,
                 0.0000000e+00, -7.1525577e-08,  0.0000000e+00],
               [ 1.0000000e+00,  1.0000000e+00,  0.0000000e+00,
                 0.0000000e+00, -8.9406967e-08,  0.0000000e+00],
               [ 1.0000000e+00,  1.0000000e+00,  0.0000000e+00,
                 0.0000000e+00, -1.0728836e-07,  0.0000000e+00],
               [ 1.0000000e+00,  1.0000000e+00,  0.0000000e+00,
                 0.0000000e+00, -1.2516975e-07,  0.0000000e+00],
               [ 1.0000000e+00,  1.0000000e+00,  0.0000000e+00,
                 0.0000000e+00, -

In [139]:
from deluca.envs.classic._planar_quadrotor import PlanarQuadrotorState

def r(a,b,c):
    return jnp.mean(a)

leaves, treedef = jax.tree_util.tree_flatten(states)
#jax.vmap(PlanarQuadrotorState)(*leaves)
jax.vmap(r)(*leaves)

DeviceArray([0.33333334, 0.33333334, 0.33333334, 0.3333333 , 0.3333333 ,
             0.3333333 , 0.3333333 , 0.3333333 , 0.3333333 , 0.3333333 ,
             0.33333328, 0.33333328, 0.33333328, 0.33333328, 0.33333328,
             0.33333328, 0.33333328, 0.33333328, 0.33333328, 0.33333328,
             0.33333328, 0.33333328, 0.33333328, 0.33333325, 0.33333325,
             0.33333325, 0.33333325, 0.33333325, 0.33333325, 0.33333322,
             0.33333322, 0.33333322, 0.33333322, 0.33333322, 0.33333322,
             0.33333322, 0.33333316, 0.33333316, 0.33333316, 0.33333316,
             0.33333313, 0.33333313, 0.33333313, 0.3333331 , 0.33333308,
             0.33333308, 0.33333308, 0.33333305, 0.33333305, 0.33333302,
             0.33333302, 0.333333  , 0.333333  , 0.333333  , 0.333333  ,
             0.33333296, 0.33333293, 0.33333293, 0.33333293, 0.3333329 ,
             0.33333287, 0.33333287, 0.33333287, 0.33333284, 0.3333328 ,
             0.3333328 , 0.3333328 , 0.3333328 , 0.

In [119]:
states

PlanarQuadrotorState(arr=DeviceArray([[ 1.0000000e+00,  1.0000000e+00,  0.0000000e+00,
               0.0000000e+00, -1.7881394e-08,  0.0000000e+00],
             [ 1.0000000e+00,  1.0000000e+00,  0.0000000e+00,
               0.0000000e+00, -3.5762788e-08,  0.0000000e+00],
             [ 1.0000000e+00,  1.0000000e+00,  0.0000000e+00,
               0.0000000e+00, -5.3644182e-08,  0.0000000e+00],
             [ 1.0000000e+00,  1.0000000e+00,  0.0000000e+00,
               0.0000000e+00, -7.1525577e-08,  0.0000000e+00],
             [ 1.0000000e+00,  1.0000000e+00,  0.0000000e+00,
               0.0000000e+00, -8.9406967e-08,  0.0000000e+00],
             [ 1.0000000e+00,  1.0000000e+00,  0.0000000e+00,
               0.0000000e+00, -1.0728836e-07,  0.0000000e+00],
             [ 1.0000000e+00,  1.0000000e+00,  0.0000000e+00,
               0.0000000e+00, -1.2516975e-07,  0.0000000e+00],
             [ 1.0000000e+00,  1.0000000e+00,  0.0000000e+00,
               0.0000000e+00, -1.43051

In [96]:
print(len(L[1][2].arr))

100


In [83]:
print(len(L[0]))

5
