In [138]:
import sys
sys.path.append("..")
%load_ext autoreload
%autoreload 2
%aimport -jax
%aimport -jaxlib

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [139]:
import jax
import jax.flatten_util
import jax.numpy as jnp

import stanza.envs as envs
import stanza.policies as policies
from stanza.envs.pendulum import State as PendulumState
from stanza.envs.quadrotor import State as QuadrotorState
from stanza.util.logging import logger

from learning_mpc.expert import make_expert

In [140]:
# use "pendulum" or "quadrotor"
env_name = "quadrotor"

In [141]:
env = envs.create(env_name)
if env_name == "pendulum":
    angles = jnp.linspace(-jnp.pi + jnp.pi, 2*jnp.pi, 50)
    vels = jnp.linspace(-1, 1, 50)
    xs, ys = angles, vels
    angles_g, vels_g = jnp.meshgrid(angles, vels)
    angles_g = jnp.reshape(angles_g, (-1,))
    vels_g = jnp.reshape(vels_g, (-1,))
    eval_states = PendulumState(angle=angles_g, vel=vels_g)
elif env_name == "quadrotor":
    xs = jnp.linspace(-2, 2, 50)
    ys = jnp.linspace(-2, 2, 50)
    xs_g, ys_g = jnp.meshgrid(xs, ys)
    xs_g = jnp.reshape(xs_g, (-1,))
    ys_g = jnp.reshape(ys_g, (-1,))
    eval_states = QuadrotorState(
        x=jnp.zeros_like(xs_g), z=xs_g,
        phi=jnp.zeros_like(xs_g),
        x_dot=jnp.zeros_like(xs_g),
        z_dot=ys_g,
        phi_dot=jnp.zeros_like(xs_g)
    )

In [142]:
def eval_expert(eta):
    expert = make_expert(env_name, env, eta=eta)
    batch_expert = jax.pmap(lambda x: expert(x), backend="cpu")

    N = jax.tree_util.tree_leaves(eval_states)[0].shape[0]
    batch_size = min(64, len(jax.devices("cpu")))
    n_batches = (N + batch_size - 1) // batch_size
    batches = []
    for i in range(n_batches):
        batch = jax.tree_map(lambda x: x[i*batch_size:(i+1)*batch_size], eval_states)
        batches.append(batch_expert(policies.PolicyInput(batch)).action)
    data = jax.tree_map(lambda *x: jnp.concatenate(x, axis=0), *batches)
    return data

In [143]:
#etas = [1e-8, 1e-4, 1e-2, 1e-1, 1]
etas = [1e-8]
data = {}
for eta in etas:
    logger.info(f"evaluating eta {eta}")
    data[eta] = eval_expert(eta)

In [144]:
import plotly.graph_objects as go

def visualize(data, fig):
    if env_name == "quadrotor":
        data = data[:,0]
    fig.add_trace(go.Surface(
        x=xs, y=ys, z=jnp.reshape(data, (xs.shape[0], ys.shape[0]))
    ))
fig = go.Figure()
visualize(data[1e-8], fig)
fig