In [11]:
%load_ext autoreload
%autoreload 2
%aimport -jax
%aimport -jaxlib

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [12]:
import stanza.envs as envs
import stanza.policies as policies
import optax
import jax
import jax.numpy as jnp
from jax.random import PRNGKey
from stanza import Partial
from stanza.rl.ppo import PPO
from stanza.train import Trainer
from stanza.rl import EpisodicEnvironment, ACPolicy
from stanza.rl.nets import MLPActorCritic
from stanza.util.rich import StatisticsTable, ConsoleDisplay, LoopProgress
from stanza.solver.ilqr import iLQRSolver
from stanza.util.random import PRNGSequence


In [13]:
from stanza.util.logging import logger
from stanza.policies.mpc import MPC
from stanza.data.trajectory import Timestep
from stanza.data import Data
env = envs.create("pendulum")
# will automatically reset when done
# or when 1000 timesteps have been reached
solver_t = iLQRSolver()
expert_policy=MPC(
            # Sample action
            action_sample=env.sample_action(PRNGKey(42)),
            cost_fn=env.cost, 
            model_fn=env.step,
            horizon_length=50,
            solver=solver_t,
            receed=False
        )

def rollout_mpc(key: PRNGKey):
    # An MPC policy
    rollout = policies.rollout(
        model=env.step,
        state0=env.reset(key),
        length=50,
        policy=expert_policy
    )
    #turns from Python/jax Data class into Stanza Dataset
    return Data.from_pytree(Timestep(rollout.states,rollout.actions))
    
    #logger.info(f'MPC Rollout with {solver} solver results')
    #logger.info('states: {}', rollout.states)
    #logger.info('actions: {}', rollout.actions)
    #cost = env.cost(rollout.states, rollout.actions)
    #logger.info('cost: {}', cost)


num_trajs = 100
def batch_roll(rng_key, num_t):
    roll_fun = jax.vmap(rollout_mpc)
    rng_keys = jax.random.split(rng_key,num_t)
    return roll_fun(rng_keys)

expert_data = Data.from_pytree(batch_roll(PRNGKey(42), num_trajs))

In [21]:
from stanza.goal_conditioned.roll_in_sampler import roll_in_sampler
from stanza.envs import Environment
from stanza.goal_conditioned import GCState, StartEndGoal
import chex


action_noiser = None
process_noiser = None

def gs_sampler(key: PRNGKey, encode_start = False):
    rng = PRNGSequence(key)
    rand_traj = expert_data.sample(next(rng))
    traj_len = rand_traj.length
    delta_t = jax.random.randint(next(rng), (), minval = 3,maxval = 8)
    roll_len = jax.random.randint(next(rng), (), minval = 3,maxval = 8)
    start_t = jax.random.randint(next(rng), (), minval = roll_len,
                                 maxval = traj_len - delta_t)
    

    start_state =  roll_in_sampler(traj = rand_traj,
                    target_time = start_t,
                    noise_rng_key = next(rng), 
                    roll_len = roll_len, 
                    env = env, 
                    env_rng_key = next(rng),
                    action_noiser = action_noiser, 
                    process_noiser = process_noiser )
    
    end_state = rand_traj.get(start_t + delta_t).observation

    if encode_start:
        goal = StartEndGoal(start_state = start_state, end_state = end_state)
    else:
        goal = StartEndGoal(start_state = None, end_state = end_state)
    return GCState(goal = goal, env_state = start_state)
    
my_gc_state = gs_sampler(PRNGKey(42))
print(my_gc_state)


GCState(goal=StartEndGoal(start_state=State(angle=Array(3.139656, dtype=float32), vel=Array(0.00422384, dtype=float32)), end_state=State(angle=Array(3.1417477, dtype=float32), vel=Array(0.00079618, dtype=float32))), env_state=State(angle=Array(3.139656, dtype=float32), vel=Array(0.00422384, dtype=float32)))


In [49]:
import math


def goal_reward(state, next_state, end_state):
        angle_diff = next_state.angle - state.angle
        vel_diff = next_state.vel - state.vel
        angle_rew = 32 * angle_diff * jnp.sign(end_state.angle - next_state.angle)
        vel_rew = vel_diff * jnp.sign(end_state.vel-next_state.vel)
        return angle_rew + vel_rew

def cost_to_goal( x, u, x_goal):
        x = jnp.stack((x.angle, x.vel), -1)
        x_goal = jnp.stack((x_goal.angle, x_goal.vel), -1)
        diff = (x - x_goal)
        x_cost = jnp.sum(diff[:-1]**2)
        xf_cost = jnp.sum(diff[-1]**2)
        if u == None:
            u_cost = 0
        else:
            u_cost = jnp.sum(u**2)
        return 5*xf_cost + 2*x_cost + u_cost

def gc_reward(gc_state, action, next_state ):
    env_state, goal = gc_state.env_state, gc_state.goal
    end_state = goal.end_state
    
    return goal_reward(env_state,next_state,end_state)
    #return 3 - (1 * cost_to_goal(env_state, action, end_state))

def g_done(gc_state):
        x = gc_state.env_state
        x_goal = gc_state.goal.end_state
        return (cost_to_goal(x =x,u=None,x_goal = x_goal) < .03*.03)


my_gc_state = gs_sampler(PRNGKey(42))



In [50]:
from stanza.goal_conditioned import GCEnvironment
gc_pendulum_env = GCEnvironment(env = env, gs_sampler = gs_sampler,
                            gc_reward = gc_reward, g_done = g_done)

In [51]:
ep_env = EpisodicEnvironment(gc_pendulum_env, 1000)


net = MLPActorCritic(
    ep_env.sample_action(PRNGKey(0))
)
params = net.init(PRNGKey(42),
    ep_env.observe(ep_env.sample_state(PRNGKey(0))))

display = ConsoleDisplay()
display.add("ppo", StatisticsTable(), interval=1)
display.add("ppo", LoopProgress("RL"), interval=1)

ppo = PPO(
    trainer = Trainer(
        optimizer=optax.chain(
            optax.clip_by_global_norm(0.5),
            optax.adam(3e-4, eps=1e-5)
        )
    )
)

with display as dh:
    trained_params = ppo.train(
        PRNGKey(42),
        ep_env, net.apply,
        params,
        rl_hooks=[dh.ppo]
    )

ac_apply = Partial(net.apply, trained_params.fn_params)
policy = ACPolicy(ac_apply)

r = policies.rollout(ep_env.step, 
    ep_env.reset(PRNGKey(42)), policy, 
    model_rng_key=PRNGKey(31231),
    policy_rng_key=PRNGKey(43232),
    observe=ep_env.observe,
    length=200)

print(jax.vmap(ep_env.observe)(r.states))

Output()

GCState(goal=StartEndGoal(start_state=State(angle=Array([3.139656 , 3.1171174, 3.1171174, 2.5913975, 2.5913975, 2.5913975,
       2.5913975, 2.5913975, 2.5913975, 2.5913975, 2.5913975, 2.5913975,
       2.5913975, 2.5913975, 2.5913975, 2.5913975, 2.5913975, 2.5913975,
       2.5913975, 2.5913975, 2.5913975, 2.5913975, 2.5913975, 2.5913975,
       2.5913975, 2.5913975, 2.5913975, 2.5913975, 2.5913975, 2.5913975,
       2.5913975, 2.5913975, 2.5913975, 2.5913975, 2.5913975, 2.5913975,
       2.5913975, 2.5913975, 2.5913975, 2.5913975, 2.5913975, 2.5913975,
       2.5913975, 2.5913975, 2.5913975, 2.5913975, 2.5913975, 2.5913975,
       2.5913975, 2.5913975, 2.5913975, 2.5913975, 2.5913975, 2.5913975,
       2.5913975, 2.5913975, 2.5913975, 2.5913975, 2.5913975, 2.5913975,
       2.5913975, 2.5913975, 2.5913975, 2.5913975, 2.5913975, 2.5913975,
       2.5913975, 2.5913975, 2.5913975, 2.5913975, 2.5913975, 2.5913975,
       2.5913975, 2.5913975, 2.5913975, 2.5913975, 2.5913975, 2.5913975,
 