In [1]:
import os
import PIL.Image
from tensorflow.python.keras.utils.generic_utils import Progbar
from tf_agents.environments import suite_gym, parallel_py_environment
from tf_agents.environments import tf_py_environment
from tf_agents.trajectories.trajectory import Trajectory
from tf_agents.networks import actor_distribution_network
from tf_agents.drivers import dynamic_episode_driver, dynamic_step_driver
from tf_agents.policies import random_tf_policy
import tensorflow as tf
import numpy as np
from reinforcement_learning.sac_training import NumberOfSafetyViolations
from tf_agents.trajectories import time_step as ts
from reinforcement_learning import sac_training
from reinforcement_learning import labeling_functions

In [2]:
py_env = suite_gym.load('BipedalWalker-v2')
py_env.render(mode='human')
py_env.reset()
tf_env = tf_py_environment.TFPyEnvironment(py_env)


In [None]:
bad_state_detection = lambda trajectory: print("bad state!: {}".format(trajectory.observation[..., 0]))\
    if tf.math.abs(trajectory.observation[...,0]) > np.pi / 3 else None

tf_env = tf_py_environment.TFPyEnvironment(py_env)
policy = random_tf_policy.RandomTFPolicy(time_step_spec=tf_env.time_step_spec(), action_spec=tf_env.action_spec())
dynamic_episode_driver.DynamicEpisodeDriver(tf_env, policy, num_episodes=15,
                                            observers=[bad_state_detection, lambda _: py_env.render(mode='human')]).run()

In [None]:
def bad_state_detection(trajectory):
    # if trajectory.reward[..., 0] <= -100:
    if trajectory.observation[..., 0] < -1. or trajectory.observation[..., 0] > 1:
        py_env.render(mode='human')
        print(trajectory.observation)
        return True
    else:
        return False

walk = True
while walk:
    action = policy.action(time_step=tf_env.current_time_step())
    time_step = tf_env.step(action)
    walk = not bad_state_detection(time_step)


In [None]:
from reinforcement_learning import labeling_functions

labeling_function = labeling_functions['BipedalWalker-v2']
safety_violations = NumberOfSafetyViolations(labeling_function)
progressbar = Progbar(target=None, interval=0.5, stateful_metrics=['violation'])

tf_env = tf_py_environment.TFPyEnvironment(py_env)
policy = random_tf_policy.RandomTFPolicy(
    time_step_spec=tf_env.time_step_spec(),
    action_spec=tf_env.action_spec())
dynamic_episode_driver.DynamicEpisodeDriver(
    tf_env,
    policy,
    num_episodes=15,
    observers=[safety_violations,
               lambda _: progressbar.add(
                   1, [('violation', safety_violations.average())]),
               lambda _: py_env.render(mode='human')]
).run()

safety_violations._num_episodes

In [None]:
# Parallel environments.
num_parallel_environments = 16
tf_env = tf_py_environment.TFPyEnvironment(
    parallel_py_environment.ParallelPyEnvironment(
    [lambda : suite_gym.load('BipedalWalker-v2')] * num_parallel_environments))
tf_env.reset()

In [None]:
labeling_function = labeling_functions['BipedalWalker-v2']

policy = random_tf_policy.RandomTFPolicy(
    time_step_spec=tf_env.time_step_spec(),
    action_spec=tf_env.action_spec())

safety_violations = NumberOfSafetyViolations(labeling_function)

dynamic_step_driver.DynamicStepDriver(
    tf_env,
    policy,
    num_steps=5000,
    observers=[safety_violations]
).run()

print('Safety violations')
print('episodes', safety_violations._num_episodes)
print('result=', safety_violations.result())
print('average=', safety_violations.average())

In [None]:
import importlib
from tf_agents.environments import suite_gym
from reinforcement_learning import labeling_functions
from reinforcement_learning import sac_training

importlib.reload(sac_training)

learner = sac_training.SACLearner(
    env_name='BipedalWalker-v2',
    env_suite=suite_gym,
    num_iterations=int(1e6),
    num_parallel_environments=8,
    labeling_function=labeling_functions['BipedalWalker-v2']
)

In [None]:
learner.train_and_eval()

In [None]:
# Before running this cell, load the single py environment
tf_env = tf_py_environment.TFPyEnvironment(py_env)
stochastic_policy_dir = "../saves/BipedalWalker-v2/policy"
policy = tf.compat.v2.saved_model.load(stochastic_policy_dir)
dynamic_episode_driver.DynamicEpisodeDriver(
    tf_env,
    policy,
    num_episodes=15,
    observers=[lambda _: py_env.render(mode='human')]
).run()

In [None]:
from reinforcement_learning import sac_training

learner = sac_training.SACLearner(
    env_name='BipedalWalker-v2',
    env_suite=suite_gym,
    labeling_function=labeling_functions['BipedalWalker-v2'],
    save_directory_location='..'
)

In [None]:
variance_multiplier = 3.

In [None]:
learner.save_permissive_variance_policy(variance_multiplier=variance_multiplier)

In [None]:
# Before running this cell, load the single py environment
stochastic_policy_dir = os.path.join(
    learner.save_directory_location,
    'policy',
    "permissive_variance_policy-multiplier={}".format(
        variance_multiplier)
)
policy = tf.compat.v2.saved_model.load(stochastic_policy_dir)
safety_violations = NumberOfSafetyViolations(
    labeling_function=labeling_functions['BipedalWalker-v2'])

dynamic_episode_driver.DynamicEpisodeDriver(
    tf_env,
    policy,
    num_episodes=1000,
    observers=[
        #  lambda _: py_env.render(mode='human'),
        safety_violations
    ]
).run()

print("avg number of safety violations per episode", safety_violations.average())

In [6]:
import variational_action_discretizer
from variational_action_discretizer import VariationalActionDiscretizer

vae_mdp = variational_action_discretizer.load(
    '../saves/BipedalWalker-v2/models/vae_LS14_MC5_CER10.0_KLA0.0_TD1.00-0.90_1e-06-2e-06_step410000_eval_elbo52.650/permissive_variance_policy-multiplier=3.0/action_discretizer/LA6_MC16_CER1.0-decay=0.001_KLA0.0-growth=5e-06_TD0.20-0.13_1e-06-2e-06_decoder_divergence0.1_params=one_output_per_action/step200000/eval_elbo-1.073'
)
print("VAE MDP loaded")
tf_env.reset()
policy = vae_mdp.generate_random_policy(tf_env)

safety_violations = NumberOfSafetyViolations(
    labeling_function=labeling_functions['BipedalWalker-v2'])

initial_state = tf.zeros(
    shape=tf_env.time_step_spec().observation.shape, dtype=tf.float32)
initial_action = tf.zeros(shape=tf_env.action_spec().shape, dtype=tf.float32)
initial_reward = tf.zeros(
    shape=(1, ), dtype=tf.float32
)

state, action, reward = [
    tf.stack([initial_state for _ in range(tf_env.batch_size)]),
    tf.stack([initial_action for _ in range(tf_env.batch_size)]),
    tf.stack([initial_reward for _ in range(tf_env.batch_size)]),
]

num_steps = 5000
for _ in range(num_steps):
    py_env.render(mode='human')
    next_state = tf_env.current_time_step().observation
    next_label = tf.expand_dims(
        tf.cast(labeling_functions['BipedalWalker-v2'](next_state), tf.float32),
        axis=-1
    )
    z = vae_mdp.binary_encode(state, action, reward, next_state, next_label).sample()
    state = next_state
    action = policy.action(
        time_step=ts.TimeStep(
            step_type=tf_env.current_time_step().step_type,
            reward=tf_env.current_time_step().reward,
            discount=tf_env.current_time_step().discount,
            observation=z,
        )
    ).action
    time_step = tf_env.step(action)
    reward = tf.expand_dims(time_step.reward, axis=-1)  # scalar reward
    safety_violations(time_step)

print("avg number of safety violations per episode", safety_violations.average())

VAE MDP loaded


KeyboardInterrupt: 

In [7]:
state, action, reward, next_state, next_label


(<tf.Tensor: shape=(1, 24), dtype=float32, numpy=
 array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)>,
 <tf.Tensor: shape=(1, 4), dtype=float32, numpy=array([[0., 0., 0., 0.]], dtype=float32)>,
 <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.]], dtype=float32)>,
 <tf.Tensor: shape=(1, 24), dtype=float32, numpy=
 array([[ 2.7470193e-03,  3.3419803e-06, -4.3579831e-04, -1.6000021e-02,
          9.2256993e-02,  1.0116901e-03,  8.6004359e-01,  5.9883302e-04,
          1.0000000e+00,  3.2633666e-02,  1.0116381e-03,  8.5365462e-01,
         -6.8835629e-04,  1.0000000e+00,  4.4081375e-01,  4.4581985e-01,
          4.6142250e-01,  4.8954991e-01,  5.3410250e-01,  6.0246068e-01,
          7.0914847e-01,  8.8593131e-01,  1.0000000e+00,  1.0000000e+00]],
       dtype=float32)>,
 <tf.Tensor: shape=(1,), dtype=bool, numpy=array([False])>)