In [15]:
%load_ext autoreload
%autoreload 2
%aimport -jax
%aimport -jaxlib

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [16]:
import stanza.envs as envs
import stanza.policies as policies
import optax
import jax
import jax.numpy as jnp
from jax.random import PRNGKey
from stanza import Partial
from stanza.rl.ppo import PPO
from stanza.train import Trainer
from stanza.rl import EpisodicEnvironment, ACPolicy
from stanza.rl.nets import MLPActorCritic
from stanza.util.rich import StatisticsTable, ConsoleDisplay, LoopProgress
from stanza.solver.ilqr import iLQRSolver
from stanza.util.random import PRNGSequence
from typing import Callable, Any


In [17]:
a = jnp.array([1,2,3])
a + 1
print(jnp.add(a,1))

class Thing:
    def f(x):
        return x
    
thing = Thing()

def call(f,x):
    return f(x)
    

print(call(thing.__class__.f,1))

[2 3 4]
1


In [18]:
from stanza.util.logging import logger
from stanza.policies.mpc import MPC
from stanza.data.trajectory import Timestep
from stanza.data import Data
env = envs.create("pendulum")
# will automatically reset when done
# or when 1000 timesteps have been reached
solver_t = iLQRSolver()
expert_policy=MPC(
            # Sample action
            action_sample=env.sample_action(PRNGKey(42)),
            cost_fn=env.cost, 
            model_fn=env.step,
            horizon_length=50,
            solver=solver_t,
            receed=False
        )

def rollout_mpc(key: PRNGKey):
    # An MPC policy
    rollout = policies.rollout(
        model=env.step,
        state0=env.reset(key),
        length=50,
        policy=expert_policy
    )
    #turns from Python/jax Data class into Stanza Dataset
    return Data.from_pytree(Timestep(rollout.states,rollout.actions))
    
    #logger.info(f'MPC Rollout with {solver} solver results')
    #logger.info('states: {}', rollout.states)
    #logger.info('actions: {}', rollout.actions)
    #cost = env.cost(rollout.states, rollout.actions)
    #logger.info('cost: {}', cost)


num_trajs = 100
def batch_roll(rng_key, num_t):
    roll_fun = jax.vmap(rollout_mpc)
    rng_keys = jax.random.split(rng_key,num_t)
    return roll_fun(rng_keys)

expert_data = Data.from_pytree(batch_roll(PRNGKey(42), num_trajs))

In [19]:
from stanza.goal_conditioned.roll_in_sampler import roll_in_sample, RollInSampler
from stanza.envs import Environment
from stanza.goal_conditioned import GCState, EndGoal
import chex
from stanza.data.trajectory import Timestep
from stanza.envs.goal_conditioned_envs.gc_pendulum import make_gc_pendulum_env


action_noiser = None
process_noiser = None

sampler = RollInSampler(env=env, traj_data=expert_data)

print(sampler.sample_gc_timestep(PRNGKey(41)))
gs_sampler = (lambda key: sampler.sample_gc_state(key))
gc_pendulum_env = make_gc_pendulum_env(env, gs_sampler)


Timestep(observation=GCObs(goal=EndGoal(end_state=State(angle=Array(2.4917722, dtype=float32), vel=Array(0.6079667, dtype=float32)), other_info=None), env_obs=State(angle=Array(0.6607413, dtype=float32), vel=Array(1.4535836, dtype=float32))), action=Array(0.0572015, dtype=float32))


In [20]:
from stanza.util.ipython import display_video
# BC Pretraining    
from stanza.rl.bc import BCState, BCTrainer
from stanza.rl.nets import transform_ac_to_a


def ac_mlp_init(env, rng):
    """
    initialize actor critic model
    returns model and initial params
    """
    model = MLPActorCritic(env.sample_action(next(rng)))
    init_params = model.init(
            next(rng),
            env.observe(env.sample_state(next(rng)))
        )
    return  model, init_params

def ac_bc_train(rng,data,model,init_params,num_iters):
    """
    takes an actor-critic model 
    and trains it with behavior cloning
    """
    logger.info("Training model")
    actor_apply = transform_ac_to_a(model.apply)

    display = ConsoleDisplay()
    display.add("train", StatisticsTable(), interval=100)
    display.add("train", LoopProgress(), interval=100)
    with display as w:
        trainer = BCTrainer()
        result =  trainer.train(ac_apply=actor_apply, 
                                    ac_params = init_params, dataset=data,
                                    rng_key = next(rng),
                                    max_iterations=num_iters,
                                    hooks=[w.train])
    return result


def render_video(rollout, env, 
                        file_name : str, fps=30):
    imgs = jax.vmap(env.render)(rollout.states)
    display_video(file_name, imgs, fps=fps)
    

In [21]:
from stanza.goal_conditioned import GCHighLevelEnvironment
from stanza.envs.pendulum import State as PendulumState
import math

#my_goal = PendulumState(angle=math.pi, vel=0.)
rng = PRNGSequence(42)
data = expert_data
delta_t = 0
sampler_hl = RollInSampler(env=env, traj_data=data, delta_t_min=delta_t,
                               delta_t_max=delta_t)
gs_sampler_hl = (lambda key: sampler_hl.sample_gc_state(key))

gc_pendulum_env_hl = GCHighLevelEnvironment(gs_sampler=gs_sampler_hl, base_env=env)
num_bc_data_hl = 5


sample_bc_hl = (lambda key: sampler_hl.sample_gc_timestep_high_level(key))
gc_bc_data_hl = jax.vmap(sample_bc_hl)(jax.random.split(
        next(rng),num_bc_data_hl))
gc_bc_data_hl = Data.from_pytree(gc_bc_data_hl)
print(gc_bc_data_hl)


PyTreeData(data=Timestep(observation=State(angle=Array([3.1408978 , 3.135722  , 0.47154176, 2.1327589 , 3.140598  ],      dtype=float32), vel=Array([-2.6174914e-04,  9.1760885e-03,  1.3286523e+00,  9.2453557e-01,
        2.4871987e-03], dtype=float32)), action=EndGoal(end_state=State(angle=Array([3.1408978 , 3.135722  , 0.47154176, 2.1327589 , 3.140598  ],      dtype=float32), vel=Array([-2.6174914e-04,  9.1760885e-03,  1.3286523e+00,  9.2453557e-01,
        2.4871987e-03], dtype=float32)), other_info=None)))


In [90]:
net_hl, init_params_hl = ac_mlp_init(gc_pendulum_env_hl, rng)
num_iters_hl = 20000


result_hl = ac_bc_train(rng=rng,data=gc_bc_data_hl,model=net_hl,
                            init_params=init_params_hl,num_iters=num_iters_hl)


ac_apply_hl = Partial(net_hl.apply, result_hl.fn_params)
policy_hl = ACPolicy(ac_apply_hl,use_mean=True)



Starting Training high level BC


Output()

FrozenDict({
    params: {
        Dense_0: {
            kernel: Array([[-2.04953879e-01,  1.92737352e-04,  4.19668630e-02,
                    -5.31767681e-02, -1.99737269e-02, -2.60993596e-02,
                     7.14368597e-02, -2.35323352e-03, -1.01984359e-01,
                     6.12589158e-02,  2.64281183e-02, -4.57116961e-02,
                    -9.77786332e-02, -1.10825107e-01,  1.86495914e-03,
                     7.06536099e-02,  2.32658312e-02, -1.05940789e-01,
                     5.60421571e-02, -1.42390534e-01, -1.62041709e-01,
                    -4.77451757e-02,  9.55187902e-02, -2.98569370e-02,
                     3.03765642e-03,  2.63751447e-02,  1.25678405e-01,
                    -2.79715452e-02,  5.13940975e-02,  8.73752870e-03,
                    -4.85498793e-02, -5.06567955e-02, -5.97479418e-02,
                     3.78095247e-02, -1.32614719e-02, -3.70910577e-02,
                    -1.60377100e-02, -3.94100323e-02, -4.13296223e-02,
                    -2.

TypeError: ACPolicy.__init__() got an unexpected keyword argument 'use_mean'

In [81]:
from stanza.policies import PolicyInput


rng = PRNGSequence(52)
sample_obs = gc_pendulum_env_hl.observe(gc_pendulum_env_hl.sample_state(next(rng)))
print(ac_apply_hl(sample_obs))
print(sample_obs)
my_input = PolicyInput(observation=sample_obs, rng_key=next(rng))
print(my_input)
print(policy_hl(my_input))
print("Done Training high level BC")

print("Rollout out high level BC")

r = policies.rollout(gc_pendulum_env_hl.step, 
        env.reset(next(rng)), policy_hl, 
        model_rng_key=next(rng),
        policy_rng_key=next(rng),
        observe=env.observe,
        length=30)

print(r.states)

if False:
        print("displaying video?")
        render_video(rollout=r,env=gc_pendulum_env_hl, 
                      file_name='file_thing', fps=30)

(MultivariateNormalDiag(mean=EndGoal(end_state=State(angle=Array(3.14115, dtype=float32), vel=Array(4.574904e-06, dtype=float32)), other_info=None), scale_diag=EndGoal(end_state=State(angle=Array(1., dtype=float32), vel=Array(1., dtype=float32)), other_info=None)), Array(0.20623134, dtype=float32))
State(angle=Array(3.1420527, dtype=float32), vel=Array(2.5804233e-05, dtype=float32))
PolicyInput(observation=State(angle=Array(3.1420527, dtype=float32), vel=Array(2.5804233e-05, dtype=float32)), policy_state=None, rng_key=Array([ 428234626, 1076777262], dtype=uint32))
PolicyOutput(action=EndGoal(end_state=State(angle=Array(2.8017738, dtype=float32), vel=Array(-0.7037486, dtype=float32)), other_info=None), policy_state=Array(-2.1430993, dtype=float32), info={'log_prob': Array(-2.1430993, dtype=float32), 'value': Array(0.20623134, dtype=float32)})
Done Training high level BC
Rollout out high level BC
State(angle=Array([ 0.15667033,  3.263031  ,  5.295289  ,  2.5288498 ,  2.1755962 ,
        

In [26]:
#Set up net and env
ep_env = EpisodicEnvironment(gc_pendulum_env, 1000)
from stanza.goal_conditioned.bilevel_policy import make_trivial_bi_policy

net = MLPActorCritic(
    ep_env.sample_action(PRNGKey(0))
)
init_params = net.init(PRNGKey(42),
    ep_env.observe(ep_env.sample_state(PRNGKey(0))))

#net.apply(init_params,gs_sampler(PRNGKey(0)))


ac_apply = Partial(net.apply, init_params)
policy = ACPolicy(ac_apply)
bipo = make_trivial_bi_policy(policy)
print("policy_made")

start_state = ep_env.reset(PRNGKey(42))
m_key = PRNGKey(31231)
p_key = PRNGKey(43232)
print("rolling out policy")
roll_len = 5

def print_roll_info(r):
    print(r.actions)
    print(r.states)
    print(r.final_policy_state)
    print(r.info)
    print(jax.vmap(ep_env.observe)(r.states))


r = policies.rollout(ep_env.step, 
    start_state , policy, 
    model_rng_key=m_key,
    policy_rng_key=p_key,
    observe=ep_env.observe,
    length=roll_len)
print_roll_info(r)

print("rolling out bipo")
r = policies.rollout(ep_env.step, 
    start_state , bipo, 
    model_rng_key=m_key,
    policy_rng_key=p_key,
    observe=ep_env.observe,
    length=roll_len)
print("done rollout")
print_roll_info(r)

#net.apply(init_params, ep_env.sample_state(PRNGKey(0)))
#actor_apply = transform_ac_to_mean(net.apply)
#actor_apply(init_params,ep_env.sample_state(PRNGKey(0)))


policy_made
rolling out policy
[ 1.5575948  -1.1119164   0.08042021  2.0939438 ]
-3.0967412
{'log_prob': Array([-2.1214223 , -1.5447118 , -0.92164135, -3.0967412 ], dtype=float32), 'value': Array([-0.00346179,  0.00905811, -0.00111795, -0.0009095 ], dtype=float32)}
GCState(goal=EndGoal(end_state=State(angle=Array([3.1154735, 3.1154735, 3.1154735, 3.1154735, 3.1154735], dtype=float32), vel=Array([0.03439788, 0.03439788, 0.03439788, 0.03439788, 0.03439788],      dtype=float32)), other_info=None), env_state=State(angle=Array([2.9252484, 2.9712086, 3.0781543, 3.1388757, 3.2019532], dtype=float32), vel=Array([0.22980095, 0.5347287 , 0.3036069 , 0.31538695, 0.7309675 ],      dtype=float32)))
rolling out bipo
done rollout
[ 1.5575948  -1.1119164   0.08042021  2.0939438 ]
BLPolicyState(state_low_level=None, state_high_level=Array(-3.0967412, dtype=float32), chunk_time=Array(1, dtype=int32, weak_type=True), current_goal=Array(2.0939438, dtype=float32), info_high_level={'log_prob': Array(-3.0967

In [11]:



rng_bc = PRNGKey(41)
# note, use > 256 data_points!!
num_bc_data = 500
gsa_sampler = (lambda key: sampler.sample_gc_timestep(key))

gc_data = jax.vmap(gsa_sampler)(jax.random.split(PRNGKey(40),num_bc_data))
gc_data = Data.from_pytree(gc_data)
actor_apply = transform_ac_to_mean(net.apply)


display = ConsoleDisplay()
display.add("train", StatisticsTable(), interval=100)
display.add("train", LoopProgress(), interval=100)

with display as w:
    trainer = BCTrainer()
    result =  trainer.train(ac_apply=actor_apply, 
                                ac_params = init_params, dataset=gc_data,
                                rng_key = rng_bc,
                                max_iterations=10000,
                                 hooks=[w.train])
new_params = result.fn_params


Output()

In [None]:
# RL Training 

display = ConsoleDisplay()
display.add("ppo", StatisticsTable(), interval=1)
display.add("ppo", LoopProgress("RL"), interval=1)

ppo = PPO(
    trainer = Trainer(
        optimizer=optax.chain(
            optax.clip_by_global_norm(0.5),
            optax.adam(3e-4, eps=1e-5)
        )
    )
)

with display as dh:
    trained_params = ppo.train(
        PRNGKey(42),
        ep_env, net.apply,
        new_params,
        rl_hooks=[dh.ppo]
    )

ac_apply = Partial(net.apply, trained_params.fn_params)
policy = ACPolicy(ac_apply)

r = policies.rollout(ep_env.step, 
    ep_env.reset(PRNGKey(42)), policy, 
    model_rng_key=PRNGKey(31231),
    policy_rng_key=PRNGKey(43232),
    observe=ep_env.observe,
    length=200)

print(jax.vmap(ep_env.observe)(r.states))

Output()

NameError: name 'params' is not defined

In [None]:


def goal_reward(state, next_state, end_state):
        angle_diff = next_state.angle - state.angle
        vel_diff = next_state.vel - state.vel
        angle_rew = 32 * angle_diff * jnp.sign(end_state.angle - next_state.angle)
        vel_rew = vel_diff * jnp.sign(end_state.vel-next_state.vel)
        return angle_rew + vel_rew

def cost_to_goal( x, u, x_goal):
        x = jnp.stack((x.angle, x.vel), -1)
        x_goal = jnp.stack((x_goal.angle, x_goal.vel), -1)
        diff = (x - x_goal)
        x_cost = jnp.sum(diff[:-1]**2)
        xf_cost = jnp.sum(diff[-1]**2)
        if u == None:
            u_cost = 0
        else:
            u_cost = jnp.sum(u**2)
        return 5*xf_cost + 2*x_cost + u_cost

def gc_reward(gc_state, action, next_state ):
    env_state, goal = gc_state.env_state, gc_state.goal
    end_state = goal.end_state
    
    return goal_reward(env_state,next_state,end_state)
    #return 3 - (1 * cost_to_goal(env_state, action, end_state))

def g_done(gc_state):
        x = gc_state.env_state
        x_goal = gc_state.goal.end_state
        return (cost_to_goal(x =x,u=None,x_goal = x_goal) < .03*.03)

