In [1]:
%load_ext autoreload
%autoreload 2
%aimport -jax
%aimport -jaxlib

In [2]:
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.getcwd(),"..","projects")))
print(sys.path)


['/Users/msimchowitz1/Documents/code/stanza/notebooks', '/usr/local/Cellar/python@3.11/3.11.3/Frameworks/Python.framework/Versions/3.11/lib/python311.zip', '/usr/local/Cellar/python@3.11/3.11.3/Frameworks/Python.framework/Versions/3.11/lib/python3.11', '/usr/local/Cellar/python@3.11/3.11.3/Frameworks/Python.framework/Versions/3.11/lib/python3.11/lib-dynload', '', '/Users/msimchowitz1/Documents/code/stanza/.venv/lib/python3.11/site-packages', '/Users/msimchowitz1/Documents/code/stanza', '/Users/msimchowitz1/Documents/code/stanza/projects']


In [3]:
import jax.numpy as jnp
import jax
from jax.random import PRNGKey
from stanza.util.random import PRNGSequence

In [4]:
# first step: generate expert trajectories 
import stanza.envs as envs
import stanza
import stanza.policies as policies
from stanza.policies.mpc import MPC
from stanza.solver.ilqr import iLQRSolver
from stanza.util.logging import logger
my_horizon = 50
logger.info("Creating environment")
env = envs.create("pendulum")
my_key = PRNGSequence(PRNGKey(42))
#ilQR solver
solver_t = iLQRSolver()
expert_policy=MPC(
            # Sample action
            action_sample=env.sample_action(PRNGKey(0)),
            cost_fn=env.cost, 
            model_fn=env.step,
            horizon_length=my_horizon,
            solver=solver_t,
            receed=False
        )

def rollout_policy(rng_key, my_pol):
    # random init angle and angular velocity
    x_0 = env.reset(rng_key) 
    roll = policies.rollout(model = env.step,
                     state0 = x_0,
                     policy = my_pol,
                     length = my_horizon,
                     last_state = False)
    
    return roll.states, roll.actions


def batch_roll(rng_key, num_t, my_pol):
    roll_fun = jax.vmap(rollout_policy,in_axes=(0,None))
    rng_keys = jax.random.split(rng_key,num_t)
    return roll_fun(rng_keys,my_pol)



In [5]:
#{states,actions}
#rollout_expert(my_key)

from stanza.data import Data
from stanza.rl_tools.flax_models import Batch as Flax_Batch
from stanza.util import vmap_ravel_pytree
num_trajs = 100
exp_states, exp_actions = batch_roll(rng_key=next(my_key), 
                    num_t= num_trajs, my_pol = expert_policy )

#reminder "x" here is "angle" or "velocity"
print(jax.tree_map(lambda x: x.shape, exp_states))

#store the inital states
init_states = jax.tree_map(lambda x: x[:,0],exp_states)
print(jax.tree_map(lambda x: x.shape, init_states))

def tree_reshaper(x):
    return x.reshape((-1,) + x.shape[2:]) 

#flattened to comprise data set 
flat_states = jax.tree_map(lambda x: tree_reshaper(x),exp_states)
flat_actions = jax.tree_map(lambda x: tree_reshaper(x),exp_actions)

#make a data_set
dataset = Data.from_pytree((flat_states,flat_actions))
my_dataset = dataset.shuffle(next(my_key))

#note: nonstochastic environment here buddy
exp_next_states = jax.vmap(env.step)(exp_states,exp_actions,None)
exp_rewards = -jax.vmap(env.cost)(exp_states,exp_actions)
flat_next_states = jax.tree_map(lambda x: tree_reshaper(x),exp_next_states)
flat_rewards = jax.tree_map(lambda x: tree_reshaper(x),exp_rewards)
"""
Batch = collections.namedtuple(
    'Batch',
    ['observations', 'actions', 'rewards', 'masks', 'next_observations'])
"""

rl_info_keys = {'obs','act','next','rew'}
rl_dict = {'obs:':exp_states,'act':exp_actions,'next':exp_next_states,'rew':exp_rewards}
rl_dict_flat = jax.tree_map(lambda x: tree_reshaper(x),rl_dict)
rl_dict_ravel = dict()

for key in rl_info_keys:
    rl_dict_ravel[key] = vmap_ravel_pytree(rl_dict_flat[key])[0]

rl_dataset = Data.from_pytree(Flax_Batch(rl_dict_ravel['obs'],
                                         rl_dict_ravel['act'],
                                         rl_dict_ravel['next'],
                                           None, 
                                        rl_dict_ravel['rew']))
# could shuffle if you want.




State(angle=(100, 49), vel=(100, 49))
State(angle=(100,), vel=(100,))


In [6]:
import numpy as np

a = np.array([1,2,3])
print(a[None])

[[1 2 3]]


In [16]:
# train the RL stuff
from stanza.rl_tools.iql_learner import Learner
from typing import Dict
import tqdm
from absl import app, flags
from tensorboardX import SummaryWriter
import flax.linen as nn


flag_dict = {'log_interval': 1000, 'eval_interval': 5000, 
             'batch_size': 20,  'max_steps': int(1e6),
             'eval_episodes': 10,
             'tqdm': True,
             'seed': 42}




def iql_init():
    
    sample_action = env.sample_action(PRNGKey(0))
    sample_obs = env.sample_state(PRNGKey(0))
    ravel_action, _ = jax.flatten_util.ravel_pytree(sample_action)
    ravel_obs, _ = jax.flatten_util.ravel_pytree(sample_obs)

    agent = Learner(seed,
                    ravel_action[np.newaxis],
                    ravel_obs[np.newaxis],
                    max_steps=max_steps)

    summary_writer = None#
    #SummaryWriter(os.path.join(FLAGS.save_dir, 'tb',
                                                   # str(FLAGS.seed)),
                                   # write_to_disk=True) 
    #os.makedirs(FLAGS.save_dir, exist_ok=True)
    return agent, summary_writer




#TODO modify evaluate for your environments
def evaluate(agent: nn.Module, env: envs.Environment,
             num_episodes: int, traj_length: int) -> Dict[str, float]:
    """stats = {'return': [], 'length': []}

    for _ in range(num_episodes):
        
        observation, done = env.reset(), False

        while not done:
            action = agent.sample_actions(observation, temperature=0.0)
            observation, _, done, info = env.step(action)

        for k in stats.keys():
            stats[k].append(info['episode'][k])

    for k, v in stats.items():
        stats[k] = np.mean(v)

    return stats"""
    return 0

_,_ = iql_init()
#turn actions into obs/action

#TODO this should be implemted somewhere in stanza    
def sample_a_batch(a_dataset: Data, batchsize : int, key : PRNGKey):
    inds = jax.random.randint(key,minval = 0,maxval = a_dataset.length, shape = (batchsize,))
    start = a_dataset.start # iterator at the start
    batch_advance = jax.vmap(a_dataset.advance, in_axes=(None,0))
    iterators = batch_advance(start,inds)
    batch_get = jax.vmap(a_dataset.get)
    return batch_get(iterators)

def to_Flax_Batch(a_batch: Data):
    return Flax_Batch()


In [19]:
def train_iql(a_dataset : Data, batchsize: int, key: PRNGKey):
    
    eval_returns = []
    agent, summary_writer = iql_init() 
    for i in tqdm.tqdm(range(1, flag_dict['max_steps'] + 1)):
        #sample from environment
        key, subkey = jax.random.split(key)
        a_batch = sample_a_batch(a_dataset=a_dataset,batchsize=batchsize,key=subkey)
        print('hi')
        print(a_batch)
        update_info = agent.update(a_batch)

        if i % flag_dict['log_interval'] == 0:
            for k, v in update_info.items():
                if v.ndim == 0:
                    print(f'training/{k}', v, i)
                    #summary_writer.add_scalar(f'training/{k}', v, i)
                #else:
                    #print()
                    #summary_writer.add_histogram(f'training/{k}', v, i)
            #summary_writer.flush()

        if False: #i % FLAGS.eval_interval == 0:
            eval_stats = evaluate(agent, env, FLAGS.eval_episodes)

            for k, v in eval_stats.items():
                summary_writer.add_scalar(f'evaluation/average_{k}s', v, i)
            summary_writer.flush()

            eval_returns.append((i, eval_stats['return']))
            np.savetxt(os.path.join(FLAGS.save_dir, f'{FLAGS.seed}.txt'),
                       eval_returns,
                       fmt=['%d', '%.1f'])
    

train_iql(a_dataset=rl_dataset , batchsize = 20, key = PRNGKey(42))
                   


  0%|          | 0/1000000 [00:00<?, ?it/s]

hi
Batch(observations=State(angle=Array([ 1.7929245 ,  3.1420758 ,  3.1420658 ,  3.1412294 ,  3.1419492 ,
        3.1353874 ,  2.7189033 ,  1.70531   ,  2.273899  ,  3.1417966 ,
        3.1013107 ,  2.8646424 ,  3.1421866 ,  3.141653  ,  2.3994043 ,
        3.132509  ,  2.643619  , -0.13768749,  1.9886829 ,  2.820491  ],      dtype=float32), vel=Array([ 9.5495671e-01, -1.6166980e-04, -2.4479997e-04, -2.7765523e-04,
       -1.9456619e-04,  9.7514940e-03,  4.3176144e-01,  1.1603546e+00,
        7.9022473e-01, -1.1147910e-04,  5.1758386e-02,  2.7848008e-01,
       -1.9025607e-05, -1.8535415e-04,  6.4511955e-01,  1.3504710e-02,
        4.7808081e-01,  1.2592314e+00,  9.7278601e-01,  3.1426808e-01],      dtype=float32)), actions=Array([-2.5236082e-01, -2.2154899e-04, -3.2142874e-05,  1.9770727e-04,
        2.9018474e-06, -1.1724297e-02, -3.3656341e-01, -5.5576348e-01,
       -5.0384367e-01,  7.3714182e-06, -5.3151295e-02, -2.1151087e-01,
       -6.4755173e-04,  4.9462837e-05, -3.6843044e-01




TypeError: concatenate requires ndarray or scalar arguments, got <class 'stanza.envs.pendulum.State'> at position 0.

In [10]:
# making a net
import haiku as hk 
env_dim  = 1


#TODO: add to stanza.util "dumb max utils" 
#stanza.util.fluffy_dog
sample_action = env.sample_action(PRNGKey(0))
sample_state = env.sample_state(PRNGKey(0))

action_flat, action_unflatten = \
    jax.flatten_util.ravel_pytree(sample_action)
state_flat, state_unflatten = \
    jax.flatten_util.ravel_pytree(sample_state)

# 10 x 10 x 10 inner layer
def net(x):
    x_flat,_ =  jax.flatten_util.ravel_pytree(x)
    net = hk.nets.MLP((10,10,10,action_flat.shape[0]))
    y = net(x_flat)
    return action_unflatten(y)

hk_net = hk.transform(net)
params = hk_net.init(next(my_key), sample_state)



In [11]:
import optax 

optimizer = optax.adamw(optax.cosine_decay_schedule(1e-3, 5000*10), 
                        weight_decay=1e-6)

def loss_fn(params, rng_key, sample):
    x, y = sample
    out = hk_net.apply(params, rng_key, x)
    dif = jax.tree_map(lambda a,b:a-b, out, y)
    flat_dif, _ = jax.flatten_util.ravel_pytree(dif)

    #note the sum is trivial for 1d actions
    loss = jnp.sum(jnp.square(flat_dif))
    stats = {
        "loss": loss
    }
    return loss, stats

from stanza import Partial
from stanza.train import Trainer
from stanza.train.rich import RichReporter

# uses with the reporter only in this block
with RichReporter(iter_interval=50) as cb:
        trainer = Trainer(epochs=300, batch_size=30, optimizer=optimizer)
        res = trainer.train(
            Partial(loss_fn), my_dataset,
            PRNGKey(42), mlp_params,
            hooks=[cb], jit=True
        )





Output()

In [13]:
"""
train_params = res.fn_params
from stanza.policies import PolicyOutput
#maps state to action
def trained_policy(x):
    action = hk_net.apply(train_params, None, x.observation)
    return PolicyOutput(action)



trained_states, trained_actions = batch_roll(rng_key=next(my_key), 
                    num_t= num_trajs, my_pol = trained_policy )


#final_states = jax.tree_map(lambda x: x[:,my_horizon-1])

def average_loss(states,actions):
    cost_v= jax.vmap(env.cost)
    return jnp.mean(cost_v(states,actions))

print("loss on trained:")
print(average_loss(trained_states,trained_actions))
print("loss on expert:")
print(average_loss(exp_states,exp_actions))


def render_video(states,traj_number = 0):
    render_traj = jax.vmap(env.render)
    video = render_traj(jax.tree_map(lambda x: x[traj_number] , states))
    video = (255 * video).astype(jnp.uint8)
    return video

import ffmpegio
from IPython.display import Video
fps = 10
trained_vid = render_video(trained_states)
trained_file_name = "tained_policy_video.mp4"
ffmpegio.video.write(trained_file_name,
                     fps,trained_vid,
                     overwrite = True, loglevel = "quiet")
Video(trained_file_name,embed = True)
"""

loss on trained:
187.44075
loss on expert:
178.18272


In [None]:
#learn value from imitation:


