In [54]:
%load_ext autoreload
%autoreload 2
%aimport -jax
%aimport -jaxlib

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [55]:
import stanza.envs as envs
import stanza.policies as policies

import jax.flatten_util
import jax
import jax.numpy as jnp

from jax.random import PRNGKey
from stanza.util.random import PRNGSequence
from stanza.util.logging import logger

rng = PRNGSequence(42)
env = envs.create("quadrotor")

In [56]:
print(env.reset(next(rng)))

State(x=Array(-0.9451697, dtype=float32), z=Array(0.86506915, dtype=float32), phi=Array(0., dtype=float32), x_dot=Array(0., dtype=float32), z_dot=Array(0., dtype=float32), phi_dot=Array(0., dtype=float32))


In [57]:
# first step: generate expert trajectories 
from stanza.policies.mpc import MPC
from stanza.solver.ilqr import iLQRSolver
my_horizon = 100
solver_t = iLQRSolver()
expert_policy=MPC(
            # Sample action
            action_sample=env.sample_action(PRNGKey(0)),
            cost_fn=env.cost, 
            model_fn=env.step,
            horizon_length=100,
            solver=solver_t,
            receed=False
        )

def rollout_policy(rng_key, my_pol):
    # random init angle and angular velocity
    x_0 = env.reset(rng_key) 
    roll = policies.rollout(model = env.step,
                     state0 = x_0,
                     policy = my_pol,
                     length = my_horizon,
                     last_state = False)
    
    return roll.states, roll.actions

def batch_roll(rng_key, num_t, my_pol):
    roll_fun = jax.vmap(rollout_policy,in_axes=(0,None))
    rng_keys = jax.random.split(rng_key,num_t)
    return roll_fun(rng_keys,my_pol)


In [58]:
from stanza.data import Data, PyTreeData
expert_data = batch_roll(PRNGKey(42), 200,expert_policy)
expert_data  = Data.from_pytree(expert_data)
expert_data = expert_data.map(lambda x: Data.from_pytree(x))
expert_data = PyTreeData.from_data(expert_data.flatten(), chunk_size=4096)

In [59]:
from stanza.reporting.wandb import WandbDatabase

db = WandbDatabase("dpfrommer-projects/quadrotor_mlp").create()
logger.info(f"Logging to [blue]{db.name}[/blue]")


VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016668433283363506, max=1.0…

In [60]:
from stanza.nets.mlp import MLP
import chex

action_flat, action_uf = jax.flatten_util.ravel_pytree(env.sample_action(PRNGKey(0)))
state_flat, state_uf = jax.flatten_util.ravel_pytree(env.sample_state(PRNGKey(0)))
model = MLP([100, 100, 100, action_flat.shape[0]])

def loss_fn(state, params, _rng_key, sample):
    x, y = sample
    y_flat, _ = jax.flatten_util.ravel_pytree(y)
    x_flat, _ = jax.flatten_util.ravel_pytree(x)
    chex.assert_equal_shape([y_flat, action_flat])
    chex.assert_equal_shape([x_flat, state_flat])
    a_flat = model.apply(params, x_flat)
    loss = jnp.sum(jnp.square(y_flat-a_flat))
    stats = {'loss': loss}
    return state, loss, stats

def model_policy(params, input):
    x_flat, _ = jax.flatten_util.ravel_pytree(input.observation)
    a_flat = model.apply(params, x_flat)
    action = action_uf(a_flat)
    return policies.PolicyOutput(action)


In [61]:
from stanza.train import Trainer, batch_loss
from stanza.train.validate import Validator
from stanza.util.loop import every_kth_iteration, every_iteration, LoggerHook
from stanza.util.rich import ConsoleDisplay, StatisticsTable, LoopProgress
from stanza.reporting.jax import JaxDBScope

import stanza
import optax

iterations = 20000
optimizer = optax.adamw(optax.cosine_decay_schedule(1e-3, iterations), weight_decay=1e-4)

display = ConsoleDisplay()
display.add("train", StatisticsTable(), interval=100)
display.add("train", LoopProgress(), interval=100)

def val_fn(_, params, _rng_key, rng_key):
    x0 = env.reset(rng_key)
    policy = stanza.Partial(model_policy, params)
    rollout = policies.rollout(env.step, x0,
                               policy, length=100, 
                               last_state=False)
    cost = env.cost(rollout.states, rollout.actions)
    stats = {'cost': cost}
    return stats

def batch_val_fn(_, params, _rng_key, rng_keys):
    bval_fn = jax.vmap(val_fn, in_axes=(None, None, None, 0))
    stats = bval_fn(None, params, _rng_key, rng_keys)
    return jax.tree_map(lambda x: jnp.mean(x), stats)

validator = Validator(next(rng), 
        Data.from_pytree(jax.random.split(next(rng), 64)),
        condition=every_kth_iteration(1),
        stat_fn=batch_val_fn)

dbs = JaxDBScope(db)

with display as display_handle, dbs as dbs_handle:
    logger_hook = LoggerHook(every_kth_iteration(1000))
    db_logger_hook = dbs_handle.statistic_logging_hook(log_cond=every_kth_iteration(1), buffer=100)
    trainer = Trainer(
        loss_fn=batch_loss(loss_fn), batch_size=128,
        optimizer=optimizer,
        max_iterations=iterations,
        train_hooks=[validator, db_logger_hook, logger_hook,
                     display_handle.train]
    )
    
    logger.info("Initializing model...")
    init_params = model.init(next(rng), state_flat)
    logger.info("Training...")
    params = trainer.train(expert_data, rng_key=next(rng), init_params=init_params, jit=True)


Output()