In [737]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [227]:
from deluca.agents import DRC, LQR
from deluca.envs import LDS
import jax
import jax.numpy as jnp

In [231]:
def loop(context, i):
    lds, controller = context
    action = controller(lds.obs)
    lds.step(action)
    lds.state += 0.5 * jnp.sin(i) # add sine noise
    error = jnp.linalg.norm(lds.state)+jnp.linalg.norm(action)
    return (lds, controller), error

def get_err(T, lds, controller):
    mean_error = 0
    for i in range(T):
        (lds, controller) , error = loop((lds, controller), i)
        mean_error += error/T
    return mean_error

def get_err_scan(T, lds, controller):
    xs = jnp.array(jnp.arange(T)) 
    _, error = jax.lax.scan(loop, (lds, controller), xs)
    return jnp.mean(error)

In [229]:
T = 100

In [230]:
A,B = jnp.array([[.8,.5], [0,.8]]), jnp.array([[0],[0.8]])
lds = LDS(state_size= B.shape[0], action_size=B.shape[1], A=A, B=B)
drc = DRC(lds.A, lds.B, C=lds.C)
print("Pure DRC incurs ", get_err(T, lds, drc), " loss")
lds = LDS(state_size= B.shape[0], action_size=B.shape[1], A=A, B=B)
drc = DRC(lds.A, lds.B, C=lds.C)
print("Purer DRC with scan incurs", get_err_scan(T, lds, drc), " loss")


Pure DRC incurs  0.6832412287296924  loss
Purer DRC with scan incurs 0.6832412287296923  loss


In [232]:
# Note: need to rerun def loop cell before running this cell
A,B = jnp.array([[.8,.5], [0,.8]]), jnp.array([[0],[0.8]])
lds = LDS(state_size= B.shape[0], action_size=B.shape[1], A=A, B=B)
drc_lqr = DRC(lds.A, lds.B, C=lds.C, K= LQR(A, B).K)
print("DRC initialized with LQR incurs ", get_err(T, lds, drc_lqr), " loss")
lds = LDS(state_size= B.shape[0], action_size=B.shape[1], A=A, B=B)
drc_lqr = DRC(lds.A, lds.B, C=lds.C, K= LQR(A, B).K)
print("DRC with scan initialized with LQR incurs ", get_err_scan(T, lds, drc_lqr), " loss")


DRC initialized with LQR incurs  1.324636032988597  loss
DRC with scan initialized with LQR incurs  1.3246360329885964  loss


In [233]:
lds = LDS(state_size= B.shape[0], action_size=B.shape[1], A=A, B=B)
lqr = LQR(A, B)
print("LQR incurs ", get_err(T, lds, lqr), " loss")

LQR incurs  1.2980238445428083  loss
