In [1]:
import numpy as np
from gym import logger; logger.set_level(logger.DISABLED)
import tensorflow as tf; tf.enable_eager_execution()
import tensorflow.contrib.eager as tfe
import tensorflow_probability as tfp
import trfl
import pyoneer as pynr
import pyoneer.rl as pyrl
import pyoneer.contrib.interactor as pyactor

from pybullet_envs.bullet.kukaGymEnv import KukaGymEnv

current_dir=/home/eddie/anaconda3/lib/python3.6/site-packages/pybullet_envs/bullet


  return _inspect.getargspec(target)


In [2]:
from gym.envs.registration import register

register(
    id='KukaBulletDiscrete-v0',
    entry_point='pybullet_envs.bullet.kukaGymEnv:KukaGymEnv',
    kwargs={'isDiscrete': True},
)

In [3]:
class Policy(tf.keras.Model):
    def __init__(self, state_normalizer, action_size):
        super(Policy, self).__init__()
        self.state_normalizer = state_normalizer
        self.hidden = tf.layers.Dense(
            64, 
            activation=pynr.nn.swish, 
            kernel_initializer=tf.initializers.variance_scaling(scale=2.0))
        self.outputs = tf.layers.Dense(
            action_size,
            kernel_initializer=tf.initializers.variance_scaling(scale=2.0))

    def call(self, inputs, **kwargs):
        norm = self.state_normalizer(inputs)
        hidden = self.hidden(norm)
        outputs = self.outputs(hidden)
        return tfp.distributions.Categorical(logits=outputs)


In [4]:
class Value(tf.keras.Model):

    def __init__(self, state_normalizer):
        super(Value, self).__init__()
        self.state_normalizer = state_normalizer
        kernel_initializer = tf.initializers.variance_scaling(scale=2.0)
        self.hidden = tf.layers.Dense(
            64, 
            activation=pynr.nn.swish,
            kernel_initializer=kernel_initializer)
        self.value = tf.layers.Dense(
            1, 
            kernel_initializer=kernel_initializer)

    def call(self, states, training=False, reset_state=True):
        states_norm = self.state_normalizer(states)
        hidden = self.hidden(states_norm)
        return self.value(hidden)



In [5]:
def actor(strategy, deterministic=False):
    def actor_fn(i, state, action, reward, done, is_initial_state):
        if deterministic:
            actions = strategy.policy(state).mode()
        else:
            actions = strategy(state)
        return tf.cast(actions, tf.int32)
    return actor_fn



In [6]:
#These have not been modified to fit kukabullet environment, just a sample run
num_iterations = 100
num_epochs = 5
num_explore_episodes = 256
num_explore_max_steps = 200
num_exploit_episodes = 10
num_exploit_max_steps = 200
returns_threshold = 199.

explore_env = pyactor.batch_gym_make("KukaBulletDiscrete-v0")
exploit_env = pyactor.batch_gym_make("KukaBulletDiscrete-v0")

state_normalizer = pynr.features.HighLowNormalizer(
    np.clip(explore_env.observation_space.high, -20., 20.),
    np.clip(explore_env.observation_space.low, -20., 20.), 
    dtype=tf.float32)
policy = Policy(state_normalizer, explore_env.action_space.n)
behavioral_policy = Policy(state_normalizer, explore_env.action_space.n)
value = Value(state_normalizer)
global_step = tfe.Variable(0, dtype=tf.int64)
epsilon = tf.train.exponential_decay(
    .5, global_step, num_iterations, .99)
strategy = pyrl.strategies.EpsilonGreedyStrategy(policy, epsilon)
agent = pyrl.agents.ProximalPolicyOptimizationAgent(
    policy=policy, 
    behavioral_policy=behavioral_policy,
    value=value,
    optimizer=tf.train.AdamOptimizer(1e-3))

state = explore_env.reset(1)
policy(state)
behavioral_policy(state)
trfl.update_target_variables(
    behavioral_policy.trainable_variables, 
    policy.trainable_variables)

for _ in range(num_iterations):
    explore_rollouts = pyactor.batch_rollout(
        explore_env,
        actor(strategy),
        episodes=num_explore_episodes,
        max_steps=num_explore_max_steps)

    for _ in range(num_epochs):
        _ = agent.fit(
            states=explore_rollouts.states,
            actions=explore_rollouts.actions,
            rewards=explore_rollouts.rewards,
            weights=explore_rollouts.weights,
            global_step=global_step)

    trfl.update_target_variables(
        behavioral_policy.trainable_variables, 
        policy.trainable_variables)

    exploit_rollouts = pyactor.batch_rollout(
        exploit_env,
        actor(strategy, deterministic=True),
        episodes=num_exploit_episodes,
        max_steps=num_exploit_max_steps)
    mean_episodic_exploit_returns = tf.reduce_mean(
        tf.reduce_sum(exploit_rollouts.rewards, axis=-1))
    print(mean_episodic_exploit_returns)
    if mean_episodic_exploit_returns.numpy() > returns_threshold:
        break

  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_poi

  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_poi

  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_poi

  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_poi

  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_poi

  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_poi

  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)
  result = entry_point.load(False)


tf.Tensor(-861.8031, shape=(), dtype=float32)
tf.Tensor(-843.1719, shape=(), dtype=float32)
tf.Tensor(-839.8451, shape=(), dtype=float32)
tf.Tensor(-864.9084, shape=(), dtype=float32)
tf.Tensor(-865.618, shape=(), dtype=float32)
tf.Tensor(-847.1625, shape=(), dtype=float32)
tf.Tensor(-844.366, shape=(), dtype=float32)
tf.Tensor(-863.8612, shape=(), dtype=float32)
tf.Tensor(-866.5583, shape=(), dtype=float32)
tf.Tensor(-870.3812, shape=(), dtype=float32)
tf.Tensor(-864.47656, shape=(), dtype=float32)
tf.Tensor(-842.6764, shape=(), dtype=float32)
tf.Tensor(-896.1838, shape=(), dtype=float32)
tf.Tensor(-868.2803, shape=(), dtype=float32)
tf.Tensor(-846.6245, shape=(), dtype=float32)
tf.Tensor(-857.21387, shape=(), dtype=float32)
tf.Tensor(-892.5197, shape=(), dtype=float32)
tf.Tensor(-857.31445, shape=(), dtype=float32)
tf.Tensor(-862.9977, shape=(), dtype=float32)
tf.Tensor(-845.7367, shape=(), dtype=float32)
tf.Tensor(-841.71466, shape=(), dtype=float32)
tf.Tensor(-863.234, shape=(), dt

KeyboardInterrupt: 