In [28]:
%load_ext autoreload
%autoreload 2
%aimport -jax
%aimport -jaxlib

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [30]:
import stanza.envs as envs
import stanza.policies as policies
import optax
import jax
import jax.numpy as jnp
from jax.random import PRNGKey
from stanza import Partial
from stanza.rl.ppo import PPO
from stanza.train import Trainer
from stanza.rl import EpisodicEnvironment, ACPolicy
from stanza.rl.nets import MLPActorCritic
from stanza.util.rich import StatisticsTable, ConsoleDisplay, LoopProgress
from stanza.solver.ilqr import iLQRSolver
from stanza.util.random import PRNGSequence


In [31]:
from stanza.util.logging import logger
from stanza.policies.mpc import MPC
from stanza.data.trajectory import Timestep
from stanza.data import Data
env = envs.create("pendulum")
# will automatically reset when done
# or when 1000 timesteps have been reached
solver_t = iLQRSolver()
expert_policy=MPC(
            # Sample action
            action_sample=env.sample_action(PRNGKey(42)),
            cost_fn=env.cost, 
            model_fn=env.step,
            horizon_length=50,
            solver=solver_t,
            receed=False
        )

def rollout_mpc(key: PRNGKey):
    # An MPC policy
    rollout = policies.rollout(
        model=env.step,
        state0=env.reset(key),
        length=50,
        policy=expert_policy
    )
    #turns from Python/jax Data class into Stanza Dataset
    return Data.from_pytree(Timestep(rollout.states,rollout.actions))
    
    #logger.info(f'MPC Rollout with {solver} solver results')
    #logger.info('states: {}', rollout.states)
    #logger.info('actions: {}', rollout.actions)
    #cost = env.cost(rollout.states, rollout.actions)
    #logger.info('cost: {}', cost)


num_trajs = 100
def batch_roll(rng_key, num_t):
    roll_fun = jax.vmap(rollout_mpc)
    rng_keys = jax.random.split(rng_key,num_t)
    return roll_fun(rng_keys)

expert_data = Data.from_pytree(batch_roll(PRNGKey(42), num_trajs))

In [32]:
from stanza.goal_conditioned.roll_in_sampler import roll_in_sampler
from stanza.envs import Environment
from stanza.goal_conditioned import GCState, StartEndGoal
import chex
from stanza.data.trajectory import Timestep

action_noiser = None
process_noiser = None

def gsa_sampler(key: PRNGKey, encode_start = False):
    rng = PRNGSequence(key)
    rand_traj = expert_data.sample(next(rng))
    traj_len = rand_traj.length
    delta_t = jax.random.randint(next(rng), (), minval = 3,maxval = 8)
    roll_len = jax.random.randint(next(rng), (), minval = 3,maxval = 8)
    start_t = jax.random.randint(next(rng), (), minval = roll_len,
                                 maxval = traj_len - delta_t)
    

    start_state, start_action =  roll_in_sampler(traj = rand_traj,
                    target_time = start_t,
                    noise_rng_key = next(rng), 
                    roll_len = roll_len, 
                    env = env, 
                    env_rng_key = next(rng),
                    action_noiser = action_noiser, 
                    process_noiser = process_noiser )
        
    end_state = rand_traj.get(start_t + delta_t).observation

    if encode_start:
        goal = StartEndGoal(start_state = start_state, 
                            end_state = end_state)
    else:
        goal = StartEndGoal(start_state = None, end_state = end_state)
    return Timestep(observation = GCState(goal = goal, env_state = start_state), 
                    action = start_action)

gs_sampler = (lambda key: gsa_sampler(key, encode_start = False).observation)
my_gc_state = gs_sampler(PRNGKey(42))
print(my_gc_state)


GCState(goal=StartEndGoal(start_state=None, end_state=State(angle=Array(3.1417477, dtype=float32), vel=Array(0.00079618, dtype=float32))), env_state=State(angle=Array(3.139656, dtype=float32), vel=Array(0.00422384, dtype=float32)))


In [33]:
import math


def goal_reward(state, next_state, end_state):
        angle_diff = next_state.angle - state.angle
        vel_diff = next_state.vel - state.vel
        angle_rew = 32 * angle_diff * jnp.sign(end_state.angle - next_state.angle)
        vel_rew = vel_diff * jnp.sign(end_state.vel-next_state.vel)
        return angle_rew + vel_rew

def cost_to_goal( x, u, x_goal):
        x = jnp.stack((x.angle, x.vel), -1)
        x_goal = jnp.stack((x_goal.angle, x_goal.vel), -1)
        diff = (x - x_goal)
        x_cost = jnp.sum(diff[:-1]**2)
        xf_cost = jnp.sum(diff[-1]**2)
        if u == None:
            u_cost = 0
        else:
            u_cost = jnp.sum(u**2)
        return 5*xf_cost + 2*x_cost + u_cost

def gc_reward(gc_state, action, next_state ):
    env_state, goal = gc_state.env_state, gc_state.goal
    end_state = goal.end_state
    
    return goal_reward(env_state,next_state,end_state)
    #return 3 - (1 * cost_to_goal(env_state, action, end_state))

def g_done(gc_state):
        x = gc_state.env_state
        x_goal = gc_state.goal.end_state
        return (cost_to_goal(x =x,u=None,x_goal = x_goal) < .03*.03)



In [34]:
from stanza.goal_conditioned import GCEnvironment
gc_pendulum_env = GCEnvironment(env = env, gs_sampler = gs_sampler,
                            gc_reward = gc_reward, g_done = g_done)

In [51]:
#Set up net and env
ep_env = EpisodicEnvironment(gc_pendulum_env, 1000)
from stanza.rl.nets import transform_ac_to_mean
from stanza.goal_conditioned.bilevel_policy import make_trivial_bi_policy

net = MLPActorCritic(
    ep_env.sample_action(PRNGKey(0))
)
init_params = net.init(PRNGKey(42),
    ep_env.observe(ep_env.sample_state(PRNGKey(0))))

net.apply(init_params,gs_sampler(PRNGKey(0)))


ac_apply = Partial(net.apply, init_params)
policy = ACPolicy(ac_apply)
bipo = make_trivial_bi_policy(policy)
print("policy_made")

start_state = ep_env.reset(PRNGKey(42))
m_key = PRNGKey(31231)
p_key = PRNGKey(43232)
print("rolling out policy")
roll_len = 5

def print_roll_info(r):
    print(r.actions)
    print(r.final_policy_state)
    print(r.info)
    print(jax.vmap(ep_env.observe)(r.states))


r = policies.rollout(ep_env.step, 
    start_state , policy, 
    model_rng_key=m_key,
    policy_rng_key=p_key,
    observe=ep_env.observe,
    length=roll_len)
print_roll_info(r)

print("rolling out bipo")
r = policies.rollout(ep_env.step, 
    start_state , bipo, 
    model_rng_key=m_key,
    policy_rng_key=p_key,
    observe=ep_env.observe,
    length=roll_len)
print("done rollout")
print_roll_info(r)

#net.apply(init_params, ep_env.sample_state(PRNGKey(0)))
#actor_apply = transform_ac_to_mean(net.apply)
#actor_apply(init_params,ep_env.sample_state(PRNGKey(0)))


policy_made
rolling out policy
[ 1.5577074  -1.1118168   0.08042035  2.093885  ]
-3.0967412
{'log_prob': Array([-2.1214223 , -1.5447118 , -0.92164135, -3.0967412 ], dtype=float32), 'value': Array([-0.01819413, -0.01556416, -0.02678034, -0.02567974], dtype=float32)}
GCState(goal=StartEndGoal(start_state=None, end_state=State(angle=Array([3.1417477, 3.132184 , 3.132184 , 3.132184 , 3.132184 ], dtype=float32), vel=Array([0.00079618, 0.01403313, 0.01403313, 0.01403313, 0.01403313],      dtype=float32))), env_state=State(angle=Array([3.139656 , 3.1171174, 3.1235383, 3.0853243, 3.0506372], dtype=float32), vel=Array([ 0.00422384,  0.03210456, -0.19106929, -0.1734356 ,  0.24595097],      dtype=float32)))
rolling out bipo
uno
dos
tres
yay
action: GCObs(goal=Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, env_obs=GCState(goal=StartEndGoal(start_state=None, end_state=State(angle=Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, vel=Traced<ShapedArray(float32[])>w

In [None]:
from stanza.goal_conditioned.bilevel_policy import BilevelPolicy, make

In [None]:
# BC Pretraining    
from stanza.rl.bc import BCState, BCTrainer


rng_bc = PRNGKey(41)
num_bc_data = 500
gc_data = jax.vmap(gsa_sampler)(jax.random.split(PRNGKey(40),num_bc_data))
gc_data = Data.from_pytree(gc_data)
actor_apply = transform_ac_to_mean(net.apply)

if False:
    trainer = BCTrainer()
    result =  trainer.train(ac_apply = actor_apply, 
                            ac_params = init_params, dataset =gc_data,
                            rng_key = rng_bc,
                            max_iterations=100,
                            epochs = 1)
    print(result)


start_state State(angle=Traced<ShapedArray(float32[])>with<BatchTrace(level=1/0)> with
  val = Array([ 1.3752402 ,  3.1374502 ,  3.1254973 ,  3.0145657 ,  3.136479  ,
        2.650888  ,  1.7250981 ,  2.3411443 ,  3.1419103 ,  3.1080596 ,
        0.9580077 ,  3.1380174 ,  0.9652712 ,  3.14218   ,  3.1348963 ,
        3.0978627 ,  3.051627  ,  3.1241927 ,  1.1899076 ,  3.1360989 ,
        3.11285   ,  3.1418827 ,  1.7543236 ,  3.1256258 ,  3.135288  ,
        3.1419873 ,  2.4421604 ,  3.0656493 ,  3.1154256 ,  3.1417675 ,
        3.1419783 ,  2.733592  ,  3.1386871 ,  1.476709  ,  3.1410859 ,
        3.1418498 ,  3.1364276 ,  3.1394615 ,  3.1419108 ,  3.1182451 ,
        2.5323641 ,  3.1414995 ,  3.101314  ,  2.5007079 ,  3.1416533 ,
        3.1396437 ,  2.292484  ,  3.1320982 ,  2.5276804 ,  3.1379838 ,
        3.1417348 ,  3.1202052 ,  3.141729  ,  2.3477166 ,  3.0979772 ,
        3.1420856 ,  2.9801254 ,  3.1414688 ,  3.140793  ,  3.1287007 ,
        3.1304533 ,  3.0200248 ,  3.09148

In [None]:
# RL Training 

display = ConsoleDisplay()
display.add("ppo", StatisticsTable(), interval=1)
display.add("ppo", LoopProgress("RL"), interval=1)

ppo = PPO(
    trainer = Trainer(
        optimizer=optax.chain(
            optax.clip_by_global_norm(0.5),
            optax.adam(3e-4, eps=1e-5)
        )
    )
)

with display as dh:
    trained_params = ppo.train(
        PRNGKey(42),
        ep_env, net.apply,
        params,
        rl_hooks=[dh.ppo]
    )

ac_apply = Partial(net.apply, trained_params.fn_params)
policy = ACPolicy(ac_apply)

r = policies.rollout(ep_env.step, 
    ep_env.reset(PRNGKey(42)), policy, 
    model_rng_key=PRNGKey(31231),
    policy_rng_key=PRNGKey(43232),
    observe=ep_env.observe,
    length=200)

print(jax.vmap(ep_env.observe)(r.states))

Output()

GCState(goal=StartEndGoal(start_state=None, end_state=State(angle=Array([3.1417477, 3.132184 , 3.132184 , 3.132184 , 3.132184 , 3.132184 ,
       3.132184 , 3.132184 , 3.132184 , 3.132184 , 3.132184 , 3.132184 ,
       3.132184 , 3.132184 , 3.132184 , 3.132184 , 3.132184 , 3.132184 ,
       3.132184 , 3.132184 , 3.132184 , 3.132184 , 3.132184 , 3.132184 ,
       3.132184 , 3.132184 , 3.132184 , 3.132184 , 3.132184 , 3.132184 ,
       3.132184 , 3.132184 , 3.132184 , 3.132184 , 3.132184 , 3.132184 ,
       3.132184 , 3.132184 , 3.132184 , 3.132184 , 3.132184 , 3.132184 ,
       3.132184 , 3.132184 , 3.132184 , 3.132184 , 3.132184 , 3.132184 ,
       3.132184 , 3.132184 , 3.132184 , 3.132184 , 3.132184 , 3.132184 ,
       3.132184 , 3.132184 , 3.132184 , 3.132184 , 3.132184 , 3.132184 ,
       3.132184 , 3.132184 , 3.132184 , 3.132184 , 3.132184 , 3.132184 ,
       3.132184 , 3.132184 , 3.132184 , 3.132184 , 3.132184 , 3.132184 ,
       3.132184 , 3.132184 , 3.132184 , 3.132184 , 3.13218