In [1]:
%load_ext autoreload
%autoreload 2
%aimport -jax
%aimport -jaxlib

In [2]:
import stanza.envs as envs
import stanza.policies as policies

import jax.flatten_util
import jax
import jax.numpy as jnp

from jax.random import PRNGKey
from stanza.util.random import PRNGSequence
from stanza.util.logging import logger

rng = PRNGSequence(42)
env = envs.create("quadrotor")

In [3]:
print(env.reset(next(rng)))

State(x=Array(-0.9451697, dtype=float32), z=Array(0.86506915, dtype=float32), phi=Array(0., dtype=float32), x_dot=Array(0., dtype=float32), z_dot=Array(0., dtype=float32), phi_dot=Array(0., dtype=float32))


In [4]:
# first step: generate expert trajectories 
from stanza.policies.mpc import MPC
from stanza.solver.ilqr import iLQRSolver
my_horizon = 50
solver_t = iLQRSolver()
expert_policy=MPC(
            # Sample action
            action_sample=env.sample_action(PRNGKey(0)),
            cost_fn=env.cost, 
            model_fn=env.step,
            horizon_length=my_horizon,
            solver=solver_t,
            receed=False
        )

def rollout_policy(rng_key, my_pol):
    # random init angle and angular velocity
    x_0 = env.reset(rng_key) 
    roll = policies.rollout(model = env.step,
                     state0 = x_0,
                     policy = my_pol,
                     length = my_horizon,
                     last_state = False)
    
    return roll.states, roll.actions

def batch_roll(rng_key, num_t, my_pol):
    roll_fun = jax.vmap(rollout_policy,in_axes=(0,None))
    rng_keys = jax.random.split(rng_key,num_t)
    return roll_fun(rng_keys,my_pol)


In [5]:
from stanza.data import Data
expert_data = batch_roll(PRNGKey(42), 200,expert_policy)
expert_data  = Data.from_pytree(expert_data)

In [6]:
from stanza.nets.mlp import MLP

action_flat, action_uf = jax.flatten_util.ravel_pytree(env.sample_action(PRNGKey(0)))
state_flat, state_uf = jax.flatten_util.ravel_pytree(env.sample_state(PRNGKey(0)))
model = MLP([16, 16, 8,action_flat.shape[0]])

def loss_fn(_,params,rng_key:PRNGKey, sample):
    x,y = sample
    y_flat, _ = jax.flatten_util.ravel_pytree(y)
    x_flat, _ = jax.flatten_util.ravel_pytree(x)
    a_flat = model.apply(params, x_flat)
    loss = jnp.sum(jnp.square(y_flat-a_flat))
    stats = {'loss': loss}
    return None, loss, stats

def model_policy(params, input):
    x_flat, _ = jax.flatten_util.ravel_pytree(input.observation)
    a_flat = model.apply(params, x_flat)
    action = action_uf(a_flat)
    return policies.PolicyOutput(action)


In [7]:
from stanza.reporting.wandb import WandbDatabase

db = WandbDatabase("dpfrommer-projects/quadrotor_mlp").create()
logger.info(f"Logging to [blue]{db.name}[/blue]")


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mmaxsimchowitz92[0m ([33mdpfrommer-projects[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
from stanza.train import Trainer, batch_loss
from stanza.train.validate import Validator
from stanza.util.loop import every_kth_iteration, every_iteration, LoggerHook
from stanza.util.rich import ConsoleDisplay, StatisticsTable, LoopProgress
from stanza.reporting.jax import JaxDBScope

import optax

iterations = 10000
optimizer = optax.adamw(optax.cosine_decay_schedule(1e-3, iterations), weight_decay=1e-4)


print("creating console")

display = ConsoleDisplay()
display.add("train", StatisticsTable(), interval=100)
display.add("train", LoopProgress(), interval=100)


print("creating scope")

#validator = Validator(next(rng), Data.from_pytree())

dbs = JaxDBScope()

print("training")

with display as display_handle, dbs as dbs_handle:
    logger_hook = LoggerHook(every_kth_iteration(100))
    db_logger_hook = dbs_handle.statistic_logging_hook(log_cond=every_kth_iteration(1), buffer=100)
    trainer = Trainer(
        loss_fn=batch_loss(loss_fn), batch_size=128,
        optimizer=optimizer,
        train_hooks=[db_logger_hook, logger_hook,
                     display_handle.train]
    )
    
    logger.info("Initializing model...")
    init_params = model.init(next(rng), state_flat)
    logger.info("Training...")
    params = trainer.train(expert_data, jit=True)
