In [1]:
# If you haven't installed tf-agents or gym yet, run:
try:
    %tensorflow_version 2.x
except:
    pass
!pip install --upgrade tensorflow-probability
!pip install tf-agents
!pip install gym

Looking in indexes: https://pypi.org/simple, https://artifactory.spotify.net/artifactory/api/pypi/pypi/simple/
Requirement already up-to-date: tensorflow-probability in /Users/lingh/.pyenv/versions/3.7.0/envs/my-virtual-env-3.7.0/lib/python3.7/site-packages (0.9.0)
Looking in indexes: https://pypi.org/simple, https://artifactory.spotify.net/artifactory/api/pypi/pypi/simple/
Looking in indexes: https://pypi.org/simple, https://artifactory.spotify.net/artifactory/api/pypi/pypi/simple/


In [2]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
import numpy as np

from tf_agents import specs
from tf_agents.agents.dqn import dqn_agent
from tf_agents.drivers import dynamic_step_driver
from tf_agents.environments import suite_gym
from tf_agents.environments import tf_py_environment
from tf_agents.networks import q_network
from tf_agents.replay_buffers import py_uniform_replay_buffer
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.specs import tensor_spec
from tf_agents.trajectories import time_step

tf.compat.v1.enable_v2_behavior()

### Replay Buffer API

In [4]:
class ReplayBuffer(tf.Module):
  """Abstract base class for TF-Agents replay buffer."""

  def __init__(self, data_spec, capacity):
    """Initializes the replay buffer.

    Args:
      data_spec: A spec or a list/tuple/nest of specs describing
        a single item that can be stored in this buffer
      capacity: number of elements that the replay buffer can hold.
    """

    @property
    def data_spec(self):
        pass

    @property
    def capacity(self):
        pass

    def add_batch(self, items):
        pass

    def get_next(self,
               sample_batch_size=None,
               num_steps=None,
               time_stacked=True):
        pass

    def as_dataset(self,
                 sample_batch_size=None,
                 num_steps=None,
                 num_parallel_calls=None):
        pass

    def gather_all(self):
        return self._gather_all()

    def clear(self):
        pass

### TFUniformReplayBuffer

#### Creating the buffer

To create a TFUniformReplayBuffer we pass in:

- the spec of the data elements that the buffer will store
- the batch size corresponding to the batch size of the buffer
- the max_length number of elements per batch segment

Here is an example of creating a TFUniformReplayBuffer with sample data specs, batch_size 32 and max_length 1000.

In [5]:
data_spec =  (
        tf.TensorSpec([3], tf.float32, 'action'),
        (
            tf.TensorSpec([5], tf.float32, 'lidar'),
            tf.TensorSpec([3, 2], tf.float32, 'camera')
        )
)

batch_size = 32
max_length = 1000

replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
    data_spec,
    batch_size=batch_size,
    max_length=max_length)

#### Writing to the buffer

In [8]:
action = tf.constant(1 * np.ones(data_spec[0].shape.as_list(), dtype=np.float32))
lidar = tf.constant(2 * np.ones(data_spec[1][0].shape.as_list(), dtype=np.float32))
camera = tf.constant(3 * np.ones(data_spec[1][1].shape.as_list(), dtype=np.float32))
  
values = (action, (lidar, camera))
values_batched = tf.nest.map_structure(lambda t: tf.stack([t] * batch_size),values)
  
replay_buffer.add_batch(values_batched)

#### Reading from the buffer

In [9]:
# add more items to the buffer before reading
for _ in range(5):
    replay_buffer.add_batch(values_batched)

# Get one sample from the replay buffer with batch size 10 and 1 timestep:

sample = replay_buffer.get_next(sample_batch_size=10, num_steps=1)

# Convert the replay buffer to a tf.data.Dataset and iterate through it
dataset = replay_buffer.as_dataset(
    sample_batch_size=4,
    num_steps=2)

iterator = iter(dataset)
print("Iterator trajectories:")
trajectories = []
for _ in range(3):
    t, _ = next(iterator)
    trajectories.append(t)

print(tf.nest.map_structure(lambda t: t.shape, trajectories))

# Read all elements in the replay buffer:
trajectories = replay_buffer.gather_all()

print("Trajectories from gather all:")
print(tf.nest.map_structure(lambda t: t.shape, trajectories))

Iterator trajectories:
[(TensorShape([4, 2, 3]), (TensorShape([4, 2, 5]), TensorShape([4, 2, 3, 2]))), (TensorShape([4, 2, 3]), (TensorShape([4, 2, 5]), TensorShape([4, 2, 3, 2]))), (TensorShape([4, 2, 3]), (TensorShape([4, 2, 5]), TensorShape([4, 2, 3, 2])))]
Trajectories from gather all:
(TensorShape([32, 8, 3]), (TensorShape([32, 8, 5]), TensorShape([32, 8, 3, 2])))


### PyUniformReplayBuffer

In [10]:
replay_buffer_capacity = 1000*32 # same capacity as the TFUniformReplayBuffer

py_replay_buffer = py_uniform_replay_buffer.PyUniformReplayBuffer(
    capacity=replay_buffer_capacity,
    data_spec=tensor_spec.to_nest_array_spec(data_spec))

### Using replay buffers during training

In [13]:
env = suite_gym.load('CartPole-v0')
tf_env = tf_py_environment.TFPyEnvironment(env)

q_net = q_network.QNetwork(tf_env.time_step_spec().observation,tf_env.action_spec(),fc_layer_params=(100,))

agent = dqn_agent.DqnAgent(tf_env.time_step_spec(),tf_env.action_spec(),q_network=q_net,optimizer=tf.compat.v1.train.AdamOptimizer(0.001))

replay_buffer_capacity = 1000

replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(agent.collect_data_spec,batch_size=tf_env.batch_size,max_length=replay_buffer_capacity)

# Add an observer that adds to the replay buffer:
replay_observer = [replay_buffer.add_batch]

collect_steps_per_iteration = 10
collect_op = dynamic_step_driver.DynamicStepDriver(tf_env,agent.collect_policy,observers=replay_observer,num_steps=collect_steps_per_iteration).run()

Instructions for updating:
back_prop=False is deprecated. Consider using tf.stop_gradient instead.
Instead of:
results = tf.while_loop(c, b, vars, back_prop=False)
Use:
results = tf.nest.map_structure(tf.stop_gradient, tf.while_loop(c, b, vars))


#### Reading data for a train step

In [14]:
# Read the replay buffer as a Dataset,
# read batches of 4 elements, each with 2 timesteps:
dataset = replay_buffer.as_dataset(
    sample_batch_size=4,
    num_steps=2)

iterator = iter(dataset)

num_train_steps = 10

for _ in range(num_train_steps):
    trajectories, _ = next(iterator)
    loss = agent.train(experience=trajectories)