In [1]:
%load_ext autoreload
%autoreload 2
%aimport -jax
%aimport -jaxlib

In [2]:
from stanza.runtime import activity
from stanza.util.random import PRNGSequence
from stanza.dataclasses import dataclass
from stanza.util.rich import StatisticsTable, ConsoleDisplay, LoopProgress
from stanza.util.logging import logger
import stanza.envs as envs
from stanza.rl.nets import transform_ac_to_a
from stanza.data.trajectory import Timestep
from stanza.data import PyTreeData
from jax.random import PRNGKey
from stanza.rl.bc import BCState, BCTrainer
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
from stanza import Partial
from stanza.rl import ACPolicy
import stanza.policies as policies
from stable_imitation.render import save_rollout_video
from stanza.data import Data
import math
from stanza.rl.ppo import PPO
from stanza.rl.nets import MLPActorCritic
from stable_imitation.data_collect import generate_data
from stanza.goal_conditioned import GCHighLevelEnvironment
from stanza.goal_conditioned.roll_in_sampler import  RollInSampler
from stanza.envs.goal_conditioned_envs.gc_pendulum import make_gc_pendulum_env
from stanza.envs.pendulum import State as PendulumState
from stanza.goal_conditioned.bilevel_policy import BiPolicy, fixed_time_update

from stanza.util.logging import logger
from stanza.policies.mpc import MPC
from stanza.data.trajectory import Timestep
from stanza.data import Data, field



In [21]:
list = [1,2,3]

def even_prop(x: int, y: int):
    return 1.

props = [ even_prop(1,y) for y in list]
print(props)
print(sum(props))

def give_ratios(num, prop_fun = even_prop):
    props = jnp.array([ even_prop(1,y) for y in list])
    props = props/sum(props)
    print(props)

give_ratios(3)
key = PRNGKey(3)
x = jax.random.choice(key, 3, shape=(10,),  p=2*jnp.array([.3,.3,.4]), axis=0)
y = jax.random.choice(key, 3, shape=(10,),  p=3*jnp.array([.1,.1,4./30]), axis=0)

print(x)
print(y)

a = None
b = a.out() if a is not None else None
print(b)



[1.0, 1.0, 1.0]
3.0
[0.33333334 0.33333334 0.33333334]


TypeError: choice requires ndarray or scalar arguments, got <class 'list'> at position 0.

In [5]:
"""
Utilitiy Methods
"""

from stanza.util.ipython import display_video
# BC Pretraining    
from stanza.rl.bc import BCState, BCTrainer
from stanza.rl.nets import transform_ac_to_a


def ac_mlp_init(env, rng):
    """
    initialize actor critic model
    returns model and initial params
    """
    model = MLPActorCritic(env.sample_action(next(rng)))
    init_params = model.init(
            next(rng),
            env.observe(env.sample_state(next(rng)))
        )
    return  model, init_params

def ac_bc_train(rng,data,model,init_params,num_iters):
    """
    takes an actor-critic model 
    and trains it with behavior cloning
    """
    logger.info("Training model")
    actor_apply = transform_ac_to_a(model.apply)

    display = ConsoleDisplay()
    display.add("train", StatisticsTable(), interval=100)
    display.add("train", LoopProgress(), interval=100)
    with display as w:
        trainer = BCTrainer()
        result =  trainer.train(ac_apply=actor_apply, 
                                    ac_params = init_params, dataset=data,
                                    rng_key = next(rng),
                                    max_iterations=num_iters,
                                    hooks=[w.train])
    return result


def render_video(rollout, env, 
                        file_name : str, fps=30):
    imgs = jax.vmap(env.render)(rollout.states)
    display_video(file_name, imgs, fps=fps)

save_rollout_video = render_video
    

In [6]:
"""
Data Collection
"""

import stanza.envs as envs
from stanza.policies.mpc import MPC
from stanza.data import Data
from stanza.data.trajectory import Timestep
from jax.random import PRNGKey
from stanza.solver.ilqr import iLQRSolver
import stanza.policies as policies
import jax

def generate_data(config, rng_key, num_traj):
    env = envs.create(config.env_name)
    if config.env_name == 'pendulum':
        return generate_mpc_data(rng_key,env, num_traj,
                                 config.traj_length)

#returns a Data object
def generate_mpc_data(rng_key, env, num_traj, traj_length):
    solver_t = iLQRSolver()
    expert_policy=MPC(
                # Sample action
                action_sample=env.sample_action(PRNGKey(42)),
                cost_fn=env.cost, 
                model_fn=env.step,
                horizon_length=50,
                solver=solver_t,
                receed=False
            )

    def rollout_mpc(key: PRNGKey):
        # An MPC policy
        rollout = policies.rollout(
            model=env.step,
            state0=env.reset(key),
            length=50,
            policy=expert_policy,
            last_state=False
        )
   
        return Data.from_pytree(Timestep(rollout.states,rollout.actions))
   
    
    rng_keys = jax.random.split(rng_key,num_traj)
    roll_func = jax.vmap(rollout_mpc)
    return Data.from_pytree(roll_func(rng_keys))




In [7]:

"""
BC Config Variables
"""


@dataclass
class BCGCConfig:
    env_name : str = "pendulum"
    seed : int = 42
    traj_length : int = 50
    num_traj : int = 500
    num_iters_ll : int = 5000
    num_iters_hl : int = 5000
    bc_hl_video_filename : str = "bc_rollout_hl.mp4"
    delta_t : int = 2
    bc_bl_video_filename_no_noise : str = "bc_rollout_bl_no_noise.mp4"
    bc_bl_video_filename_with_noise : str = "bc_rollout_bl_with_noise.mp4"





In [44]:
"""Set up configuration goal conditioend env"""

config = BCGCConfig()
logger.info("Initializing")
env = envs.create(config.env_name)
rng = PRNGSequence(PRNGKey(config.seed))

logger.info("Generating data")
data = generate_data(config=config, rng_key=next(rng), 
                        num_traj=config.num_traj)

# sets up goal_conditioned pendulum env
sampler = RollInSampler(env=env, traj_data=data)
gs_sampler = (lambda key: sampler.sample_gc_state(key))
# makes the goal_conditioned pendulum env
# should be a better way to do this

# sets up low level BC policy

gc_pendulum_env = make_gc_pendulum_env(env, gs_sampler)



##TODO test sampling + noising logic



[2;36m[20:19:15][0m[2;36m [0m[37mINFO  [0m - Initializing                                  ]8;id=376370;file:///var/folders/c4/13bl08593w34g7qzjvs_pbyc0000gp/T/ipykernel_86870/192119518.py\[2m192119518.py[0m]8;;\[2m:[0m]8;id=747751;file:///var/folders/c4/13bl08593w34g7qzjvs_pbyc0000gp/T/ipykernel_86870/192119518.py#4\[2m4[0m]8;;\
[2;36m          [0m[2;36m [0m[37mINFO  [0m - Generating data                               ]8;id=302678;file:///var/folders/c4/13bl08593w34g7qzjvs_pbyc0000gp/T/ipykernel_86870/192119518.py\[2m192119518.py[0m]8;;\[2m:[0m]8;id=564817;file:///var/folders/c4/13bl08593w34g7qzjvs_pbyc0000gp/T/ipykernel_86870/192119518.py#8\[2m8[0m]8;;\


In [47]:
""" 
Compare Sampler to Baseline
"""
from stanza.goal_conditioned.roll_in_sampler import make_no_noise_roll_in,make_no_noise_roll_in_v2

sampler_1 = make_no_noise_roll_in(env=env, traj_data=data)
sampler_2 = make_no_noise_roll_in_v2(env=env, traj_data=data)
gs_sampler_1 = (lambda key: sampler_1.sample_goal_state_action(key))
gs_sampler_2 = (lambda key: sampler_2.sample_goal_state_action(key))
a_key = PRNGKey(42)

print(gs_sampler_1(a_key))
print(gs_sampler_2(a_key))



#do things

start_t 39
some roll_len 1 3
some roll_len 2 3
(State(angle=Array(3.1420307, dtype=float32), vel=Array(3.2278345e-05, dtype=float32)), EndGoal(end_state=State(angle=Array(3.1420372, dtype=float32), vel=Array(-7.248179e-05, dtype=float32)), other_info=None), Array(-0.00056599, dtype=float32), 0)
start_t 39
some roll_len 1 3
some roll_len 2 3
start_index 36
3
roll len 3
hi
timestep Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
end_loop_state (State(angle=Array(3.1420307, dtype=float32), vel=Array(3.2278345e-05, dtype=float32)), Array(39, dtype=int32), Array([1813963473,  208843817], dtype=uint32), Array([3555094720,  894980426], dtype=uint32), Array(2, dtype=int32, weak_type=True))
(State(angle=Array(3.1420307, dtype=float32), vel=Array(3.2278345e-05, dtype=float32)), EndGoal(end_state=State(angle=Array(3.1420372, dtype=float32), vel=Array(-7.248179e-05, dtype=float32)), other_info=None), Array(-0.00056599, dtype=float32), Array(2, dtype=int32, weak_type=

In [6]:
"""
Run BC
"""

logger.info("Setting up low level BC")
num_bc_data_ll = 500
sample_bc_ll = (lambda key: sampler.sample_gc_timestep(key))
gc_bc_data_ll = jax.vmap(sample_bc_ll)(jax.random.split(
    next(rng),num_bc_data_ll))
gc_bc_data_ll = Data.from_pytree(gc_bc_data_ll)
net_ll, init_params_ll = ac_mlp_init(gc_pendulum_env, rng)

logger.info("Starting Training low level BC")

result_ll = ac_bc_train(rng=rng,data=gc_bc_data_ll,model=net_ll,
                        init_params=init_params_ll,num_iters=config.num_iters_ll)

ac_apply_ll = Partial(net_ll.apply, result_ll.fn_params)
policy_ll = ACPolicy(ac_apply_ll,use_mean=True)

logger.info("Done Training low level BC")


# stes up high level BC policy
logger.info("Setting up high level BC")

sampler_hl = RollInSampler(env=env, traj_data=data, delta_t_min=config.delta_t,
                            delta_t_max=config.delta_t)
gs_sampler_hl = (lambda key: sampler_hl.sample_gc_state(key))

gc_pendulum_env_hl = GCHighLevelEnvironment(gs_sampler=gs_sampler_hl,
                                            base_env=env)
num_bc_data_hl = 1000
sample_bc_hl = (lambda key: sampler_hl.sample_gc_timestep_high_level(key))
gc_bc_data_hl = jax.vmap(sample_bc_hl)(jax.random.split(
    next(rng),num_bc_data_hl))
gc_bc_data_hl = Data.from_pytree(gc_bc_data_hl)


net_hl, init_params_hl = ac_mlp_init(gc_pendulum_env_hl, rng)

logger.info("Starting Training high level BC")

result_hl = ac_bc_train(rng=rng,data=gc_bc_data_hl,model=net_hl,
                        init_params=init_params_hl,num_iters=config.num_iters_hl)

ac_apply_hl = Partial(net_hl.apply, result_hl.fn_params)
policy_hl = ACPolicy(ac_apply_hl, use_mean=True)

logger.info("Done Training high level BC")

logger.info("Rollout out high level BC")

r = policies.rollout(gc_pendulum_env_hl.step, 
    env.reset(next(rng)), policy_hl, 
    model_rng_key=next(rng),
    policy_rng_key=next(rng),
    observe=env.observe,
    length=200)

save_rollout_video(rollout=r,env=gc_pendulum_env_hl, 
                    file_name=config.bc_hl_video_filename, fps=30)

is_update_time = Partial(fixed_time_update,t_max=1)


rolls = []
for parity in [True,False]:
    policy_ll = ACPolicy(ac_apply_ll,use_mean=parity)
    bi_policy = BiPolicy(policy_low=policy_ll,policy_high=policy_hl,is_update_time=is_update_time)
    r = policies.rollout(env.step, 
        env.reset(next(rng)), bi_policy, 
        model_rng_key=next(rng),
        policy_rng_key=next(rng),
        observe=env.observe,
        length=200)
    rolls.append(r)

save_rollout_video(rollout=rolls[0],env=env, 
                    file_name=config.bc_bl_video_filename_no_noise, fps=30)

save_rollout_video(rollout=rolls[1],env=env, 
                    file_name=config.bc_bl_video_filename_with_noise, fps=30)

logger.info("done")

Output()

Output()