# Imports

In [1]:
import collections
import functools
import math
import gym
import redis
import time
import multiprocessing
import numpy as np
import random
import tensorflow as tf
import tensorflow_probability as tfp
import pyoneer as pynr
import pyoneer.rl as pyrl

# Seed the environment.
random.seed(42)
np.random.seed(42)
tf.random.set_seed(42)

# Problem

In [2]:
def create_env_model(batch_size=None):
    # Create the gym env, wrapped in a vectorized manner.
    env_spec = 'CartPole-v1'

    if batch_size is None:
        gym_env = gym.make(env_spec)
    else:
        gym_env = pyrl.wrappers.Batch(lambda: gym.make(env_spec), batch_size)

    # Wrap it in a Model.
    env_model = pyrl.rollouts.Env(gym_env)
    return env_model

In [3]:
env_model = create_env_model(1)
env_outputs = env_model.reset()
print(env_outputs.next_state)
env_model.close()

tf.Tensor([[-0.00562286 -0.01017127 -0.04766121  0.04716188]], shape=(1, 4), dtype=float32)


# Solutions
- On-policy PPO
- Model-based, on-policy PPO
- On/Off-policy IMPALA

In [4]:
AgentPolicyOutput = collections.namedtuple(
    'AgentPolicyOutput', ['action', 'log_prob'])
AgentValueOutput = collections.namedtuple(
    'AgentValueOutput', ['value'])
AgentPolicyValueOutput = collections.namedtuple(
    'AgentPolicyValueOutput', ['log_prob', 'entropy', 'value'])


class CartPoleAgent(tf.Module):

    def __init__(self, action_spec):
        super(CartPoleAgent, self).__init__(name='CartPoleAgent')
        self._hidden = tf.keras.layers.Dense(32, activation=tf.nn.relu)
        self._logits = tf.keras.Sequential([
            tf.keras.layers.Dense(8, activation=tf.nn.relu),
            tf.keras.layers.Dense(2)])
        self._value = tf.keras.Sequential([
            tf.keras.layers.Dense(8, activation=tf.nn.relu),
            tf.keras.layers.Dense(1)])
        self._policy = tfp.distributions.Categorical

        self.action_spec = action_spec
        self.log_prob_spec = tf.nest.map_structure(
            lambda spec: tf.TensorSpec([], tf.dtypes.float32),
            self.action_spec)
        self.output_specs = AgentPolicyOutput(
            action=self.action_spec,
            log_prob=self.log_prob_spec)
        self.output_shapes = tf.nest.map_structure(
            lambda spec: spec.shape, self.output_specs)
        self.output_dtypes = tf.nest.map_structure(
            lambda spec: spec.dtype, self.output_specs)

    @tf.function
    def _scale_state(self, state):
        state = (state / [[2.4, 10., 1., 10.]])
        state = tf.concat(
            [state, tf.stack([tf.math.cos(state[..., 2] / math.pi),
                              tf.math.sin(state[..., 2] / math.pi)],
                             axis=-1)],
            axis=-1)
        return tf.clip_by_value(state, -1., 1.)

    @tf.function
    def initialize(self, env_outputs, agent_outputs):
        state = self._scale_state(env_outputs.state)
        hidden = self._hidden(state)
        _ = self._value(hidden)
        _ = self._logits(hidden)
    
    @tf.function
    def value(self, env_outputs, agent_outputs):
        state = self._scale_state(env_outputs.state)
        hidden = self._hidden(state)
        value = tf.squeeze(self._value(hidden), axis=-1)
        return AgentValueOutput(value=value)

    @tf.function
    def policy_value(self, env_outputs, agent_outputs):
        state = self._scale_state(env_outputs.state)
        hidden = self._hidden(state)
        logits = self._logits(hidden)
        policy = self._policy(logits=logits)
        entropy = policy.entropy()
        log_prob = policy.log_prob(agent_outputs.action)
        value = tf.squeeze(self._value(hidden), axis=-1)
        return AgentPolicyValueOutput(log_prob=log_prob,
                                      entropy=entropy,
                                      value=value)

    @tf.function
    def policy_value_with_nexts(self, env_outputs, agent_outputs):
        # Add bootstrap state
        def bootstrap_state(s_t, s_tp1):
            return tf.concat([s_t, s_tp1[:, -1:]], axis=1)

        bootstrapped_state = tf.nest.map_structure(
            bootstrap_state, env_outputs.state, env_outputs.next_state)

        state = self._scale_state(bootstrapped_state)
        hidden = self._hidden(state)
        value = tf.squeeze(self._value(hidden), axis=-1)

        logits = self._logits(hidden[:, :-1])
        policy = self._policy(logits=logits)
        log_prob = policy.log_prob(agent_outputs.action)
        entropy = policy.entropy()

        outputs = AgentPolicyValueOutput(log_prob=log_prob,
                                         entropy=entropy,
                                         value=value[:, :-1])
        bootstrap = AgentValueOutput(value=value[:, -1])
        return outputs, bootstrap

    @tf.function
    def reset(self, env_outputs, explore=True):
        initial_action = pynr.debugging.mock_spec(
            tf.TensorShape([env_outputs.state.shape[0]]), 
            self.action_spec, 
            tf.zeros)
        initial_log_prob = pynr.debugging.mock_spec(
            tf.TensorShape([env_outputs.state.shape[0]]), 
            self.log_prob_spec, 
            tf.zeros)
        return AgentPolicyOutput(
            action=initial_action,
            log_prob=initial_log_prob)

    @tf.function
    def step(self, env_outputs, agent_outputs, time_step, explore=True):
        state = env_outputs.next_state
        state = self._scale_state(state)
        hidden = self._hidden(state)
        logits = self._logits(hidden)
        policy = self._policy(logits=logits)

        if explore:
            action = policy.sample()
        else:
            action = policy.mode()

        action = tf.nest.map_structure(
            lambda t, s: tf.cast(t, s.dtype), 
            action, self.action_spec)
        log_prob = policy.log_prob(action)
        return AgentPolicyOutput(action=action,
                                 log_prob=log_prob)


# Strategies

In [5]:
class Strategy(object):

    def __init__(self, agent, explore):
        self.agent = agent
        self.explore = explore

    @tf.function
    def reset(self, *args, **kwargs):
        return self.agent.reset(*args, explore=self.explore, **kwargs)

    @tf.function
    def step(self, *args, **kwargs):
        return self.agent.step(*args, explore=self.explore, **kwargs)


# Experiments

In [6]:
HyperParameters = collections.namedtuple(
    'HyperParameters', 
    ['iterations',
     'epochs',
     'discounts',
     'lambdas',
     'epsilon',
     'value_scale',
     'entropy_scale',
     'eval_every',
     'learning_rate'])

# On-Policy PPO

In [7]:
explore_size = 32
exploit_size = 16
max_steps = 500

explore_env_model = create_env_model(explore_size)
exploit_env_model = create_env_model(explore_size)

agent_model = CartPoleAgent(explore_env_model.action_spec)

explore_strategy = Strategy(agent_model, True)
exploit_strategy = Strategy(agent_model, False)

explore_rollout = pyrl.rollouts.Rollout(explore_env_model, explore_strategy, max_steps)
exploit_rollout = pyrl.rollouts.Rollout(exploit_env_model, exploit_strategy, max_steps)

hparams = HyperParameters(
    iterations=50,
    epochs=5,
    discounts=.99,
    lambdas=.975,
    epsilon=.2,
    value_scale=.5,
    entropy_scale=.05,
    eval_every=1,
    learning_rate=1e-3,
)
optimizer = tf.keras.optimizers.Adam(hparams.learning_rate)
discounted_returns = tf.function(pyrl.targets.discounted_returns)
generalized_advantage_estimate = tf.function(pyrl.targets.generalized_advantage_estimate)

mock_env_outputs = pynr.debugging.mock_spec(
    tf.TensorShape([1, max_steps]), explore_env_model.output_specs)
mock_agent_outputs = pynr.debugging.mock_spec(
    tf.TensorShape([1, max_steps]), agent_model.output_specs)
agent_model.initialize(mock_env_outputs, mock_agent_outputs)

explore_env_model.seed(42)
for iteration in range(hparams.iterations):
    if (iteration % hparams.eval_every) == 0:
        (_, eval_env_outputs) = exploit_rollout().outputs
        eval_returns = tf.reduce_sum(eval_env_outputs.reward * eval_env_outputs.weight, axis=1)
        tf.print(tf.reduce_mean(eval_returns))

    (agent_outputs, env_outputs) = explore_rollout().outputs
    agent_value_outputs = agent_model.value(env_outputs, agent_outputs)
    returns = discounted_returns(
        env_outputs.reward * env_outputs.weight, 
        discounts=hparams.discounts)
    advantages = generalized_advantage_estimate(
        env_outputs.reward * env_outputs.weight, 
        agent_value_outputs.value * env_outputs.weight,
        discounts=hparams.discounts, 
        lambdas=hparams.lambdas, 
        weights=env_outputs.weight)

    for _ in range(hparams.epochs):
        with tf.GradientTape(watch_accessed_variables=False) as tape:
            tape.watch(agent_model.trainable_variables)

            agent_estimates_output = agent_model.policy_value(
                env_outputs, agent_outputs)

            ratio = tf.exp(
                agent_estimates_output.log_prob - agent_outputs.log_prob)
            surrogate1 = ratio * advantages
            surrogate2 = tf.clip_by_value(
                ratio,
                1 - hparams.epsilon,
                1 + hparams.epsilon) * advantages
            surrogate_loss = tf.minimum(surrogate1, surrogate2)
            policy_loss = -tf.reduce_sum(
                surrogate_loss * env_outputs.weight)

            value_loss = hparams.value_scale * tf.reduce_sum(
                (tf.square(agent_estimates_output.value - tf.stop_gradient(returns)) *
                 env_outputs.weight))

            entropy_loss = -hparams.entropy_scale * tf.reduce_sum(
                 agent_estimates_output.entropy * env_outputs.weight)

            loss = (policy_loss + value_loss + entropy_loss) / (explore_size * max_steps)

        variables = agent_model.trainable_variables
        grads = tape.gradient(loss, variables)
        optimizer.apply_gradients(zip(grads, variables))

explore_env_model.close()
exploit_env_model.close()

W0903 14:01:32.349082 123145399345152 backprop.py:842] The dtype of the watched tensor must be floating (e.g. tf.float32), got tf.int32
W0903 14:01:37.568933 123145398272000 backprop.py:842] The dtype of the watched tensor must be floating (e.g. tf.float32), got tf.int64
W0903 14:01:37.571855 123145398272000 backprop.py:842] The dtype of the watched tensor must be floating (e.g. tf.float32), got tf.int64
W0903 14:01:37.575312 123145398808576 backprop.py:842] The dtype of the watched tensor must be floating (e.g. tf.float32), got tf.int64
W0903 14:01:37.579843 123145398272000 backprop.py:842] The dtype of the watched tensor must be floating (e.g. tf.float32), got tf.int64


9.71875


W0903 14:01:43.949117 140736428594112 deprecation.py:323] From /Users/samwenke/.local/share/virtualenvs/pyoneer-K_ZVJbe4/lib/python3.6/site-packages/tensorflow/python/ops/math_grad.py:1250: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


9.28125
9.5625
9.09375
9.25
9.125
9.4375
9.21875
9.40625
10.40625
13.71875
88.84375
9.875
9.3125
9.1875
9.78125
10.6875
10.21875
13.3125
20.40625
28.78125
42.5625
65.4375
181.90625
450.96875
433.84375
422.625
447.03125
56
18.03125
25.59375
87.15625
184.5625
150.15625
359.78125
87.03125
99.4375
491.1875
148.3125
85.5625
377.21875
104.3125
121.4375
449.21875
80.78125
61.03125
372.5
62.3125
52.0625
376.5625


# Model-Based On-Policy PPO

In [8]:
ForwardOutput = collections.namedtuple(
    'ForwardOutput', ['deltas_norm'])


class CartPoleEnv(tf.Module):

    def __init__(self, env):
        super(CartPoleEnv, self).__init__(name='CartPoleEnv')
        self._hidden = tf.keras.layers.Dense(32, activation=tf.nn.relu)
        self._logits = tf.keras.Sequential([
            tf.keras.layers.Dense(8, activation=tf.nn.relu),
            tf.keras.layers.Dense(4)])
        self.deltas_moments = pynr.moments.StreamingMoments([4])
        self.output_specs = env.output_specs
        self.output_shapes = tf.nest.map_structure(
            lambda spec: spec.shape, self.output_specs)
        self.output_dtypes = tf.nest.map_structure(
            lambda spec: spec.dtype, self.output_specs)

    @tf.function
    def _scale_state(self, state):
        state = (state / [[2.4, 10., 1., 10.]])
        state = tf.concat(
            [state, tf.stack([tf.math.cos(state[..., 2] / math.pi),
                              tf.math.sin(state[..., 2] / math.pi)],
                             axis=-1)],
            axis=-1)
        return tf.clip_by_value(state, -1., 1.)

    @tf.function
    def terminals(self, next_state, time_step):
        is_terminal = pynr.debugging.mock_spec(
            next_state.shape[:1], 
            self.output_specs.terminal, 
            tf.ones)

        if time_step > 200:
            return is_terminal

        state_abs = tf.abs(next_state)
        return tf.where(
            tf.logical_or(tf.greater(state_abs[:, 0], 2.4), 
                          tf.greater(state_abs[:, 2], 12.)), 
            is_terminal, ~is_terminal)

    @tf.function
    def rewards(self, next_state):
        return pynr.debugging.mock_spec(
            next_state.shape[:1], 
            self.output_specs.reward, 
            tf.ones)

    @tf.function
    def forward(self, env_outputs, agent_outputs):
        state = self._scale_state(env_outputs.state)
        hidden = self._hidden(
            tf.concat([
                state, 
                tf.cast(agent_outputs.action[..., None], 
                        tf.dtypes.float32)
            ], axis=-1))
        deltas_norm = self._logits(hidden)
        return ForwardOutput(deltas_norm=deltas_norm)

    @tf.function
    def reset(self, size, seed):
        seed = tfp.distributions.SeedStream(seed, salt='forward_reset')
        states_loc = [0., 0., 0., 0.]
        states_scale_diag = [.25, .25, .25, .25]
        initial_state_distribution = tfp.distributions.MultivariateNormalDiag(
            loc=states_loc, scale_diag=states_scale_diag)
        next_initial_state = initial_state_distribution.sample([size], seed=seed())
        initial_state = pynr.debugging.mock_spec(
            tf.TensorShape([size]), self.output_specs.state)
        initial_reward = pynr.debugging.mock_spec(
            tf.TensorShape([size]), self.output_specs.reward)
        initial_terminal = pynr.debugging.mock_spec(
            tf.TensorShape([size]), self.output_specs.terminal)
        initial_weight = pynr.debugging.mock_spec(
            tf.TensorShape([size]), self.output_specs.weight, 
            tf.ones)
        return pyrl.rollouts.Transition(
            state=initial_state, 
            next_state=next_initial_state,
            reward=initial_reward,
            terminal=initial_terminal,
            weight=initial_weight)

    @tf.function
    def step(self, agent_outputs, env_outputs, time_step):
        state = env_outputs.next_state
        state = self._scale_state(state)
        action = agent_outputs.action
        hidden = self._hidden(
            tf.concat([
                state, 
                tf.cast(agent_outputs.action[..., None], 
                        tf.dtypes.float32)
            ], axis=-1))
        deltas_norm = self._logits(hidden)
        deltas = self.deltas_moments.denormalize(
            deltas_norm, env_outputs.weight[..., None])

        delta_high = self.deltas_moments.mean + 3. * self.deltas_moments.std
        deltas = tf.clip_by_value(deltas, -delta_high, delta_high)

        next_state = env_outputs.next_state + deltas
        terminal = self.terminals(next_state, time_step)
        terminal = tf.logical_or(terminal, env_outputs.terminal)
        reward = self.rewards(next_state)
        weight = tf.cast(~env_outputs.terminal, tf.dtypes.float32)

        return pyrl.rollouts.Transition(
            state=env_outputs.next_state, 
            next_state=next_state,
            reward=reward,
            terminal=terminal,
            weight=weight)


## Forward Models

In [9]:
class ForwardModel(object):

    def __init__(self, env, size):
        self.env = env
        self.size = size
        self._seed = tf.Variable(0, trainable=False)

    def seed(self, random_seed):
        self._seed.assign(random_seed)

    @tf.function
    def reset(self, *args, **kwargs):
        return self.env.reset(
            *args, 
            size=self.size, 
            seed=self._seed,
            **kwargs)

    @tf.function
    def step(self, *args, **kwargs):
        return self.env.step(*args, **kwargs)

ForwardHyperParameters = collections.namedtuple(
    'ForwardHyperParameters', 
    ['iterations',
     'epochs',
     'learning_rate'])

In [10]:
explore_size = 32
exploit_size = 16
forward_explore_size = 512
forward_exploit_size = 16
max_steps = 500

explore_env_model = create_env_model(explore_size)
exploit_env_model = create_env_model(exploit_size)

forward_model = CartPoleEnv(explore_env_model)
agent_model = CartPoleAgent(explore_env_model.action_spec)

explore_strategy = Strategy(agent_model, True)
exploit_strategy = Strategy(agent_model, False)

explore_rollout = pyrl.rollouts.Rollout(explore_env_model, explore_strategy, max_steps)
exploit_rollout = pyrl.rollouts.Rollout(exploit_env_model, exploit_strategy, max_steps)

explore_forward_model = ForwardModel(forward_model, forward_explore_size)
exploit_forward_model = ForwardModel(forward_model, forward_exploit_size)

explore_forward_rollout = pyrl.rollouts.Rollout(explore_forward_model, explore_strategy, max_steps)
exploit_forward_rollout = pyrl.rollouts.Rollout(exploit_forward_model, exploit_strategy, max_steps)

hparams = HyperParameters(
    iterations=10,
    epochs=10,
    discounts=.99,
    lambdas=.975,
    epsilon=.2,
    value_scale=.5,
    entropy_scale=.05,
    eval_every=1,
    learning_rate=1e-3,
)
forward_hparams = ForwardHyperParameters(
    iterations=30,
    epochs=15,
    learning_rate=1e-2,
)

optimizer = tf.keras.optimizers.Adam(hparams.learning_rate)
forward_optimizer = tf.keras.optimizers.Adam(forward_hparams.learning_rate)

discounted_returns = tf.function(pyrl.targets.discounted_returns)
generalized_advantage_estimate = tf.function(pyrl.targets.generalized_advantage_estimate)

explore_env_model.seed(42)
explore_forward_model.seed(42)
for forward_iteration in range(forward_hparams.iterations): 
    (agent_outputs, env_outputs) = explore_rollout().outputs

    # Train the forward model
    deltas = (env_outputs.next_state - env_outputs.state)
    forward_model.deltas_moments.update_state(deltas, env_outputs.weight[..., None])
    deltas_norm = forward_model.deltas_moments.normalize(
        deltas, env_outputs.weight[..., None])

    for _ in range(forward_hparams.epochs):
        with tf.GradientTape() as tape:
            forward_estimates_output = forward_model.forward(
                env_outputs, agent_outputs)
            forward_loss = tf.reduce_mean(
                tf.square(deltas_norm - forward_estimates_output.deltas_norm), 
                axis=-1)
            loss = tf.reduce_sum(
                forward_loss * env_outputs.weight)
            loss = loss / (explore_size * max_steps)

        variables = forward_model.trainable_variables
        grads = tape.gradient(loss, variables)
        forward_optimizer.apply_gradients(zip(grads, variables))

    # Train the policy
    for iteration in range(hparams.iterations):
        if (iteration % hparams.eval_every) == 0:
            exploit_forward_model.seed(42 + forward_explore_size + 1)
            (_, eval_forward_outputs) = exploit_forward_rollout().outputs
            returns = tf.reduce_sum(eval_forward_outputs.reward * eval_forward_outputs.weight, axis=1)
            tf.print('Forward', tf.reduce_mean(returns))

            exploit_env_model.seed(42 + explore_size + 1)
            (_, eval_env_outputs) = exploit_rollout().outputs
            returns = tf.reduce_sum(eval_env_outputs.reward * eval_env_outputs.weight, axis=1)
            tf.print('Real World', tf.reduce_mean(returns))

        (agent_outputs, env_outputs) = explore_forward_rollout().outputs
        agent_value_outputs = agent_model.value(env_outputs, agent_outputs)
        returns = discounted_returns(
            env_outputs.reward * env_outputs.weight, discounts=hparams.discounts)
        advantages = generalized_advantage_estimate(
            env_outputs.reward * env_outputs.weight, agent_value_outputs.value * env_outputs.weight,
            discounts=hparams.discounts, lambdas=hparams.lambdas, weights=env_outputs.weight)

        for _ in range(hparams.epochs):
            with tf.GradientTape(watch_accessed_variables=False) as tape:
                tape.watch(agent_model.trainable_variables)

                agent_estimates_output = agent_model.policy_value(
                    env_outputs, agent_outputs)

                ratio = tf.exp(
                    agent_estimates_output.log_prob - agent_outputs.log_prob)
                surrogate1 = ratio * advantages
                surrogate2 = tf.clip_by_value(
                    ratio,
                    1 - hparams.epsilon,
                    1 + hparams.epsilon) * advantages
                surrogate_loss = tf.minimum(surrogate1, surrogate2)
                policy_loss = -tf.reduce_sum(
                    surrogate_loss * env_outputs.weight)
                value_loss = hparams.value_scale * tf.reduce_sum(
                    (tf.square(agent_estimates_output.value - returns) *
                     env_outputs.weight))
                entropy_loss = -hparams.entropy_scale * tf.reduce_sum(
                     agent_estimates_output.entropy * env_outputs.weight)
                loss = (policy_loss + value_loss + entropy_loss) / (explore_size * max_steps)

            variables = agent_model.trainable_variables
            grads = tape.gradient(loss, variables)
            optimizer.apply_gradients(zip(grads, variables))

explore_env_model.close()
exploit_env_model.close()

Forward 202
Real World 9.25
Forward 202
Real World 9.25
Forward 202
Real World 9.25
Forward 202
Real World 9.25
Forward 202
Real World 9.25
Forward 202
Real World 9.25
Forward 202
Real World 9.25
Forward 202
Real World 9.25
Forward 202
Real World 9.25
Forward 202
Real World 9.25
Forward 94.3125
Real World 9.5625
Forward 83.8125
Real World 9.5625
Forward 86.9375
Real World 9.5625
Forward 95.1875
Real World 9.5625
Forward 88.3125
Real World 9.5625
Forward 88.8125
Real World 9.5625
Forward 94.25
Real World 9.5625
Forward 99.5
Real World 9.5625
Forward 96.5
Real World 9.5625
Forward 98.125
Real World 9.5625
Forward 66.6875
Real World 9.5625
Forward 75.125
Real World 9.5625
Forward 69.5625
Real World 9.5625
Forward 68.5
Real World 9.5625
Forward 68.8125
Real World 9.5625
Forward 68.5625
Real World 9.5625
Forward 67.4375
Real World 9.5625
Forward 71.375
Real World 9.5625
Forward 72.0625
Real World 9.5625
Forward 71
Real World 9.5625
Forward 65.9375
Real World 9.5625
Forward 65.1875
Real Worl

Forward 202
Real World 10.125
Forward 202
Real World 10.1875
Forward 202
Real World 10.1875
Forward 202
Real World 10.375
Forward 202
Real World 10.375
Forward 202
Real World 10.375
Forward 202
Real World 10.375
Forward 202
Real World 10.1875
Forward 202
Real World 10.125
Forward 202
Real World 10.375
Forward 202
Real World 10.1875
Forward 202
Real World 10.1875
Forward 202
Real World 10.375
Forward 202
Real World 10.375
Forward 202
Real World 10.375
Forward 202
Real World 10.375
Forward 202
Real World 10.125
Forward 202
Real World 10.375
Forward 202
Real World 10.375
Forward 202
Real World 10.375
Forward 202
Real World 10.375
Forward 202
Real World 10.5625
Forward 202
Real World 10.5625
Forward 202
Real World 10.5625
Forward 202
Real World 10.5
Forward 202
Real World 10.375
Forward 202
Real World 10.375
Forward 202
Real World 10.375
Forward 202
Real World 10.5625
Forward 202
Real World 10.4375
Forward 202
Real World 10.5625
Forward 202
Real World 10.4375


# On/Off-Policy IMPALA

In [11]:
def actor(host, port, actor_id, num_actors):
    print('Starting Actor!')
    max_steps = 500
    explore_env_model = create_env_model()
    agent_model = CartPoleAgent(explore_env_model.action_spec)
    explore_strategy = Strategy(agent_model, True)
    explore_rollout = pyrl.rollouts.Rollout(explore_env_model, explore_strategy, max_steps)

    mock_env_outputs = pynr.debugging.mock_spec(
        tf.TensorShape([1, max_steps]), explore_env_model.output_specs)
    mock_agent_outputs = pynr.debugging.mock_spec(
        tf.TensorShape([1, max_steps]), agent_model.output_specs)
    agent_model.initialize(mock_env_outputs, mock_agent_outputs)

    # Connect to the redis server.
    pipe = redis.Redis(host=host, port=port, db=0)

    # Control flow for rollouts.
    cond = pynr.distributed.Condition(
        pipe, 'WaitCondition')

    # Queue for rollouts.
    queue = pynr.distributed.Queue(
        pipe, 'RolloutQueue',
        dtypes=(tf.dtypes.int32, 
                (agent_model.output_dtypes, explore_env_model.output_dtypes)))

    # Parameter server.
    parameters = pynr.distributed.Register(
        pipe, 'Parameters',
        dtypes=tf.nest.map_structure(lambda var: var.dtype,
                                     agent_model.variables))

    # Queried to determine when to sync parameters.
    sync = pynr.distributed.MultiEvent(
        pipe, actor_id, num_actors, 'SyncParameters')

    explore_env_model.seed(42 + actor_id)
    while True:
        cond.wait(actor_id)
        # Sync parameters only if we need to.
        if sync.get():
            sync.unset()
            tf.nest.map_structure(
                lambda dst, src: dst.assign(src),
                agent_model.variables,
                parameters.get())
        values = explore_rollout().outputs
        queue.enqueue((tf.cast(actor_id, tf.dtypes.int32), 
                       values))


def learner(host, port, num_actors):
    batch_size = 32
    exploit_size = 16
    max_steps = 500
    exploit_env_model = create_env_model()
    agent_model = CartPoleAgent(exploit_env_model.action_spec)
    exploit_strategy = Strategy(agent_model, False)
    exploit_rollout = pyrl.rollouts.Rollout(exploit_env_model, exploit_strategy, max_steps)

    mock_env_outputs = pynr.debugging.mock_spec(
        tf.TensorShape([1, max_steps]), exploit_env_model.output_specs)
    mock_agent_outputs = pynr.debugging.mock_spec(
        tf.TensorShape([1, max_steps]), agent_model.output_specs)
    agent_model.initialize(mock_env_outputs, mock_agent_outputs)

    hparams = HyperParameters(
        iterations=100,
        discounts=.99,
        value_scale=.5,
        epochs=None,
        lambdas=None,
        epsilon=None,
        entropy_scale=.05,
        eval_every=1,
        learning_rate=1e-3,
    )
    optimizer = tf.keras.optimizers.Adam(hparams.learning_rate)
    v_trace_returns = tf.function(pyrl.targets.v_trace_returns)
    temporal_difference = tf.function(pyrl.targets.temporal_difference)

    # Connect to the redis server.
    pipe = redis.Redis(host=host, port=port, db=0)

    # Control flow for rollouts.
    cond = pynr.distributed.Condition(
        pipe, 'WaitCondition')

    # Queue for rollouts.
    queue = pynr.distributed.Queue(
        pipe, 'RolloutQueue',
        dtypes=(tf.dtypes.int32, 
                (agent_model.output_dtypes, exploit_env_model.output_dtypes)))

    # Parameter server.
    parameters = pynr.distributed.Register(
        pipe, 'Parameters',
        dtypes=tf.nest.map_structure(lambda var: var.dtype,
                                     agent_model.variables))

    # Queried to determine when to sync parameters.
    sync = pynr.distributed.MultiEvent(
        pipe, num_actors, num_actors, 'SyncParameters')

    def reader_fn():
        while True:
            actor_id, values = queue.dequeue()
            cond.notify(actor_id)
            yield tf.nest.map_structure(
                lambda t: tf.squeeze(t, axis=0), values)

    def set_n_step_shape_fn(shape):
        return tf.TensorShape([max_steps] + shape.as_list())

    agent_output_shapes = tf.nest.map_structure(
        set_n_step_shape_fn, agent_model.output_shapes)
    env_output_shapes = tf.nest.map_structure(
        set_n_step_shape_fn, exploit_env_model.output_shapes)

    # Stage N batches ahead of time
    prefetch_size = 1
    actor = tf.data.Dataset.from_generator(
        reader_fn,
        output_types=(agent_model.output_dtypes, exploit_env_model.output_dtypes),
        output_shapes=(agent_output_shapes, env_output_shapes))
    actor = actor.batch(batch_size)
    actor = actor.prefetch(prefetch_size)
    actor_reader = iter(actor)
    
    print('Starting learner!')
    for iteration in range(hparams.iterations):
        parameters.set(agent_model.variables)
        sync.set_all()
        cond.notify_all()

        if (iteration % hparams.eval_every) == 0:
            exploit_env_model.seed(42 + num_actors + 1)
            (_, eval_env_outputs) = exploit_rollout().outputs
            returns = tf.reduce_sum(eval_env_outputs.reward * eval_env_outputs.weight, axis=1)
            tf.print(tf.reduce_mean(returns))

        (agent_outputs, env_outputs) = next(actor_reader)

        # Estimate gradients here.
        with tf.GradientTape(watch_accessed_variables=False) as tape:
            tape.watch(agent_model.trainable_variables)

            (agent_estimates_output,
             agent_value_output_last) = agent_model.policy_value_with_nexts(
                 env_outputs, agent_outputs)

            masked_discounts = tf.cast(
                ~env_outputs.terminal, tf.dtypes.float32) * hparams.discounts
            returns = v_trace_returns(
                env_outputs.reward * env_outputs.weight,
                agent_estimates_output.value * env_outputs.weight,
                agent_estimates_output.log_prob * env_outputs.weight,
                agent_outputs.log_prob * env_outputs.weight,
                last_value=agent_value_output_last.value,
                discounts=masked_discounts,
                weights=env_outputs.weight)

            returns_next = tf.concat(
                [returns[:, 1:], tf.expand_dims(
                    agent_value_output_last.value, axis=1)],
                axis=1)

            clipped_is = tf.math.minimum(
                1., tf.exp(
                    agent_estimates_output.log_prob - agent_outputs.log_prob))
            clipped_is = tf.stop_gradient(clipped_is)
            returns = tf.stop_gradient(
                env_outputs.reward + masked_discounts * returns_next)

            advantages = clipped_is * temporal_difference(
                returns * env_outputs.weight,
                agent_estimates_output.value * env_outputs.weight,
                back_prop=True)

            policy_loss = -tf.reduce_sum(
                (agent_estimates_output.log_prob * advantages) *
                env_outputs.weight)
            value_loss = hparams.value_scale * tf.reduce_sum(
                (tf.square(advantages) * env_outputs.weight))
            entropy_loss = -hparams.entropy_scale * tf.reduce_sum(
                agent_estimates_output.entropy * env_outputs.weight)
            loss = ((policy_loss + value_loss + entropy_loss) /
                    (batch_size * max_steps))

        variables = agent_model.trainable_variables
        grads = tape.gradient(loss, variables)
        optimizer.apply_gradients(zip(grads, variables))

    explore_env_model.close()

In [None]:
num_actors = 2
host = '127.0.0.1'
port = '6380'

actor_processes = []
for actor_id in range(num_actors):
    p = multiprocessing.Process(target=actor, args=(host, port, actor_id, num_actors,))
    p.start()
    actor_processes.append(p)

time.sleep(2)
learner_proc = multiprocessing.Process(target=learner, args=(host, port, num_actors,))
learner_proc.start()
learner_proc.join()
for p in actor_processes:
    p.join()

Starting Actor!
Starting Actor!
