In [5]:
%load_ext autoreload
%autoreload 2
%aimport -jax
%aimport -jaxlib

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [6]:
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.getcwd(),"..","projects")))
print(sys.path)


['/Users/msimchowitz1/Documents/code/stanza/notebooks', '/usr/local/Cellar/python@3.11/3.11.3/Frameworks/Python.framework/Versions/3.11/lib/python311.zip', '/usr/local/Cellar/python@3.11/3.11.3/Frameworks/Python.framework/Versions/3.11/lib/python3.11', '/usr/local/Cellar/python@3.11/3.11.3/Frameworks/Python.framework/Versions/3.11/lib/python3.11/lib-dynload', '', '/Users/msimchowitz1/Documents/code/stanza/.venv/lib/python3.11/site-packages', '/Users/msimchowitz1/Documents/code/stanza', '/Users/msimchowitz1/Documents/code/stanza/projects', '/Users/msimchowitz1/Documents/code/stanza/projects']


In [7]:
import jax.numpy as jnp
import jax
from jax.random import PRNGKey
from stanza.util.random import PRNGSequence

In [8]:
# first step: generate expert trajectories 
import stanza.envs as envs
import stanza
import stanza.policies as policies
from stanza.policies.mpc import MPC
from stanza.solver.ilqr import iLQRSolver
from stanza.util.logging import logger
my_horizon = 50
logger.info("Creating environment")
env = envs.create("pendulum")
my_key = PRNGSequence(PRNGKey(42))
#ilQR solver
solver_t = iLQRSolver()
expert_policy=MPC(
            # Sample action
            action_sample=env.sample_action(PRNGKey(0)),
            cost_fn=env.cost, 
            model_fn=env.step,
            horizon_length=my_horizon,
            solver=solver_t,
            receed=False
        )

def rollout_policy(rng_key, my_pol):
    # random init angle and angular velocity
    x_0 = env.reset(rng_key) 
    roll = policies.rollout(model = env.step,
                     state0 = x_0,
                     policy = my_pol,
                     length = my_horizon,
                     last_state = False)
    
    return roll.states, roll.actions


def batch_roll(rng_key, num_t, my_pol):
    roll_fun = jax.vmap(rollout_policy,in_axes=(0,None))
    rng_keys = jax.random.split(rng_key,num_t)
    return roll_fun(rng_keys,my_pol)



In [9]:
#{states,actions}
#rollout_expert(my_key)

from stanza.data import Data

num_trajs = 100
exp_states, exp_actions = batch_roll(rng_key=next(my_key), 
                    num_t= num_trajs, my_pol = expert_policy )

#reminder "x" here is "angle" or "velocity"
print(jax.tree_map(lambda x: x.shape, exp_states))

#store the inital states
init_states = jax.tree_map(lambda x: x[:,0],exp_states)
print(jax.tree_map(lambda x: x.shape, init_states))


#flattened to comprise data set 
flat_states = jax.tree_map(lambda x: x.reshape((-1,) + x.shape[2:]),exp_states)
flat_actions = jax.tree_map(lambda x: x.reshape((-1,) + x.shape[2:]),exp_actions)

#make a data_set
dataset = Data.from_pytree((flat_states,flat_actions))
my_dataset = dataset.shuffle(next(my_key))






State(angle=(100, 49), vel=(100, 49))
State(angle=(100,), vel=(100,))


In [10]:
# making a net
import haiku as hk 
env_dim  = 1


#TODO: add to stanza.util "dumb max utils" 
#stanza.util.fluffy_dog
sample_action = env.sample_action(PRNGKey(0))
sample_state = env.sample_state(PRNGKey(0))

action_flat, action_unflatten = \
    jax.flatten_util.ravel_pytree(sample_action)
state_flat, state_unflatten = \
    jax.flatten_util.ravel_pytree(sample_state)

# 10 x 10 x 10 inner layer
def net(x):
    x_flat,_ =  jax.flatten_util.ravel_pytree(x)
    net = hk.nets.MLP((10,10,10,action_flat.shape[0]))
    y = net(x_flat)
    return action_unflatten(y)

hk_net = hk.transform(net)
mlp_params = hk_net.init(next(my_key), sample_state)



In [11]:
import optax 

optimizer = optax.adamw(optax.cosine_decay_schedule(1e-3, 5000*10), 
                        weight_decay=1e-6)

def loss_fn(params, rng_key, sample):
    x, y = sample
    out = hk_net.apply(params, rng_key, x)
    dif = jax.tree_map(lambda a,b:a-b, out, y)
    flat_dif, _ = jax.flatten_util.ravel_pytree(dif)

    #note the sum is trivial for 1d actions
    loss = jnp.sum(jnp.square(flat_dif))
    stats = {
        "loss": loss
    }
    return loss, stats

from stanza import Partial
from stanza.train import Trainer
from stanza.train.rich import RichReporter

# uses with the reporter only in this block
with RichReporter(iter_interval=50) as cb:
        trainer = Trainer(epochs=300, batch_size=30, optimizer=optimizer)
        res = trainer.train(
            Partial(loss_fn), my_dataset,
            PRNGKey(42), mlp_params,
            hooks=[cb], jit=True
        )





Output()

In [12]:
train_params = res.fn_params
from stanza.policies import PolicyOutput
#maps state to action
def trained_policy(x):
    action = hk_net.apply(train_params, None, x.observation)
    return PolicyOutput(action)



trained_states, trained_actions = batch_roll(rng_key=next(my_key), 
                    num_t= num_trajs, my_pol = trained_policy )


#final_states = jax.tree_map(lambda x: x[:,my_horizon-1])

def average_loss(states,actions):
    cost_v= jax.vmap(env.cost)
    return jnp.mean(cost_v(states,actions))

print("loss on trained:")
print(average_loss(trained_states,trained_actions))
print("loss on expert:")
print(average_loss(exp_states,exp_actions))


def render_video(states,traj_number = 0):
    render_traj = jax.vmap(env.render)
    video = render_traj(jax.tree_map(lambda x: x[traj_number] , states))
    video = (255 * video).astype(jnp.uint8)
    return video

import ffmpegio
from IPython.display import Video
fps = 10
trained_vid = render_video(trained_states)
trained_file_name = "tained_policy_video.mp4"
ffmpegio.video.write(trained_file_name,
                     fps,trained_vid,
                     overwrite = True, loglevel = "quiet")
Video.from_file(trained_file_name,embed = True)

loss on trained:
184.24336
loss on expert:
178.18272


AttributeError: type object 'Video' has no attribute 'from_file'