In [36]:
%load_ext autoreload
%autoreload 2
%aimport -jax
%aimport -jaxlib

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [37]:
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.getcwd(),"..","projects")))
print(sys.path)


['/Users/daniel/Documents/code/stable_imitation/stanza/notebooks', '/opt/homebrew/Cellar/python@3.11/3.11.3/Frameworks/Python.framework/Versions/3.11/lib/python311.zip', '/opt/homebrew/Cellar/python@3.11/3.11.3/Frameworks/Python.framework/Versions/3.11/lib/python3.11', '/opt/homebrew/Cellar/python@3.11/3.11.3/Frameworks/Python.framework/Versions/3.11/lib/python3.11/lib-dynload', '', '/Users/daniel/Documents/code/stable_imitation/stanza/.venv/lib/python3.11/site-packages', '/Users/daniel/Documents/code/stable_imitation/stanza', '/Users/daniel/Documents/code/stable_imitation/stanza/projects', '/Users/daniel/Documents/code/stable_imitation/stanza/projects']


In [38]:
import jax.numpy as jnp
import jax
from jax.random import PRNGKey
from stanza.util.random import PRNGSequence

In [39]:
# first step: generate expert trajectories 
import stanza.envs as envs
import stanza
import stanza.policies as policies
from stanza.policies.mpc import MPC
from stanza.solver.ilqr import iLQRSolver
from stanza.util.logging import logger
my_horizon = 50
logger.info("Creating environment")
env = envs.create("pendulum")
my_key = PRNGSequence(PRNGKey(42))
#ilQR solver
solver_t = iLQRSolver()
expert_policy=MPC(
            # Sample action
            action_sample=env.sample_action(PRNGKey(0)),
            cost_fn=env.cost, 
            model_fn=env.step,
            horizon_length=my_horizon,
            solver=solver_t,
            receed=False
        )

def rollout_policy(rng_key, my_pol):
    # random init angle and angular velocity
    x_0 = env.reset(rng_key) 
    roll = policies.rollout(model = env.step,
                     state0 = x_0,
                     policy = my_pol,
                     length = my_horizon,
                     last_state = False)
    
    return roll.states, roll.actions

def batch_roll(rng_key, num_t, my_pol):
    roll_fun = jax.vmap(rollout_policy,in_axes=(0,None))
    rng_keys = jax.random.split(rng_key,num_t)
    return roll_fun(rng_keys,my_pol)

In [None]:
from stanza.data import Data, PyTreeData
from stanza.rl_tools.flax_models import Batch
num_trajs = 100
exp_states, exp_actions = batch_roll(rng_key=next(my_key), 
                    num_t= num_trajs, my_pol = expert_policy )

def tree_reshaper(x):
    return x.reshape((-1,) + x.shape[2:]) 
flat_states = jax.tree_map(lambda x: tree_reshaper(x),exp_states)
flat_actions = jax.tree_map(lambda x: tree_reshaper(x),exp_actions)
flat_next_states = jax.vmap(env.step)(flat_states,flat_actions,None)
flat_rewards = -jax.vmap(env.cost)(flat_states,flat_actions)

dataset = Data.from_pytree((flat_states,flat_actions,
                            flat_next_states,flat_rewards))
my_dataset = dataset.shuffle(next(my_key))
def data_map(x):
    s, a, ns, r = x
    f_s, _ = jax.flatten_util.ravel_pytree(s)
    f_a, _ = jax.flatten_util.ravel_pytree(a)
    f_ns, _ = jax.flatten_util.ravel_pytree(ns)
    mask = jnp.ones_like(r, dtype=bool)
    return Batch(f_s, f_a, r, mask, f_ns)
rl_dataset = PyTreeData.from_data(my_dataset.map(data_map))

In [43]:
# train the RL stuff
from stanza.rl_tools.iql_learner import Learner
from typing import Dict
from stanza.util.attrdict import AttrDict
import stanza.util as util

FLAGS = AttrDict({
    'log_interval': 1000, 'eval_interval': 5000, 
    'batch_size': 20,  'max_steps': int(1e6),
    'eval_episodes': 10, 'tqdm': True,
    'seed': 42
})

sample = rl_dataset.get(rl_dataset.start)

agent = Learner(FLAGS.seed,
                sample.observations[jnp.newaxis],
                sample.actions[jnp.newaxis],
                max_steps=FLAGS.max_steps)

from rich.progress import track 

rng = PRNGSequence(42)
batch = rl_dataset.sample_batch(FLAGS.batch_size, next(rng))
for i in track(range(FLAGS.max_steps)):
    batch = rl_dataset.sample_batch(FLAGS.batch_size, next(rng))
    agent.update(batch)

Output()