In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import jax
import jax.numpy as jnp
import tqdm
import flax
import deluca.core
import matplotlib.pyplot as plt
from igpc import iGPC_closed
from deluca.envs import PlanarQuadrotor



In [3]:
## Testing IGPC_closed - With LR = 0.0 should match ILC Closed

def cost(x, u, sim):
    return 0.1 * (u - sim.goal_action) @ (u - sim.goal_action) + (x.arr - sim.goal_state) @ (x.arr - sim.goal_state)


wind, T = 0.1, 10
env_true, env_sim = PlanarQuadrotor.create(wind=wind), PlanarQuadrotor.create()
U_initial = jnp.tile(env_sim.goal_action, (env_sim.H, 1))
X, U, _, _, c = iGPC_closed(env_true, env_sim, cost, U_initial, 10, lr=0.0)

iGPC: t = -1, r = 1, c = 168708.515625
iGPC: t = 0, r = 2, c = 27426.640625, alpha = 1.100000023841858
iGPC: t = 1, r = 3, c = 17136.1953125, alpha = 1.2100000381469727
iGPC: t = 2, r = 4, c = 3157.14453125, alpha = 1.3310000896453857
iGPC: t = 3, r = 5, c = 2199.740234375, alpha = 1.4641001224517822
iGPC: t = 4, r = 6, c = 1974.510009765625, alpha = 1.6105101108551025
iGPC: t = 5, r = 9, c = 980.0306396484375, alpha = 1.209999918937683
iGPC: t = 6, r = 10, c = 211.4761505126953, alpha = 1.3309998512268066
iGPC: t = 7, r = 11, c = 116.54098510742188, alpha = 1.4640998840332031
iGPC: t = 8, r = 12, c = 91.48587799072266, alpha = 1.6105098724365234
iGPC: t = 9, r = 14, c = 87.46642303466797, alpha = 1.6105098724365234


In [4]:
# Does slightly better trajectory at a higher learning rate

def cost(x, u, sim):
    return 0.1 * (u - sim.goal_action) @ (u - sim.goal_action) + (x.arr - sim.goal_state) @ (x.arr - sim.goal_state)


wind, T = 0.1, 10
env_true, env_sim = PlanarQuadrotor.create(wind=wind), PlanarQuadrotor.create()
U_initial = jnp.tile(env_sim.goal_action, (env_sim.H, 1))
X, U, _, _, c = iGPC_closed(env_true, env_sim, cost, U_initial, 10, lr=0.01)

iGPC: t = -1, r = 1, c = 168708.515625
iGPC: t = 0, r = 2, c = 26733.255859375, alpha = 1.100000023841858
iGPC: t = 1, r = 3, c = 15575.7412109375, alpha = 1.2100000381469727
iGPC: t = 2, r = 4, c = 2660.698486328125, alpha = 1.3310000896453857
iGPC: t = 3, r = 5, c = 2046.8980712890625, alpha = 1.4641001224517822
iGPC: t = 4, r = 6, c = 1708.9422607421875, alpha = 1.6105101108551025
iGPC: t = 5, r = 9, c = 771.8240356445312, alpha = 1.209999918937683
iGPC: t = 6, r = 10, c = 184.49684143066406, alpha = 1.3309998512268066
iGPC: t = 7, r = 11, c = 129.6488037109375, alpha = 1.4640998840332031
iGPC: t = 8, r = 12, c = 105.90928649902344, alpha = 1.6105098724365234
iGPC: t = 9, r = 15, c = 77.5784912109375, alpha = 1.2099997997283936
