In [50]:
import numpy as np

import tensorflow as tf

from tf_agents.environments import suite_gym, tf_py_environment
from tf_agents.drivers import dynamic_episode_driver, dynamic_step_driver
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.policies import random_tf_policy
from tf_agents.trajectories import trajectory
from tf_agents.metrics import tf_metrics

# Try DynamicEpisodeDriver

In [49]:
env = suite_gym.load("CartPole-v0")
tf_env = tf_py_environment.TFPyEnvironment(env)

my_random_policy = random_tf_policy.RandomTFPolicy(tf_env.time_step_spec(), tf_env.action_spec())

avg_episode_length = tf_metrics.AverageEpisodeLengthMetric()
avg_return = tf_metrics.AverageReturnMetric()
num_episodes = tf_metrics.NumberOfEpisodes()
env_steps = tf_metrics.EnvironmentSteps()


collect_data_spec = trajectory.Trajectory(
    step_type=tf_env.time_step_spec().step_type,
    observation=tf_env.observation_spec(),
    action=tf_env.action_spec(),
    next_step_type=tf_env.time_step_spec().step_type,
    reward=tf_env.time_step_spec().reward,
    discount=tf_env.time_step_spec().discount,
    policy_info=(),
)
replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
    data_spec=collect_data_spec,
    batch_size=tf_env.batch_size,
    max_length=10000
)

driver = dynamic_episode_driver.DynamicEpisodeDriver(
    tf_env,
    my_random_policy,
    observers=[avg_episode_length, avg_return, num_episodes, env_steps, replay_buffer.add_batch],
    num_episodes=2,
)


final_time_step, policy_state = driver.run()

print("final_time_step:\n", final_time_step)
print("\nNumber of observations in the replay_buffer:", replay_buffer.num_frames().numpy())
print("Episode number:", num_episodes.result().numpy())
print("Steps:", env_steps.result().numpy(), "(do not count the last observation in an episode)")
print("Avg. episode duration:", avg_episode_length.result().numpy())
print("Avg. return:", avg_return.result().numpy())

final_time_step:
 TimeStep(step_type=<tf.Tensor: shape=(1,), dtype=int32, numpy=array([0], dtype=int32)>, reward=<tf.Tensor: shape=(1,), dtype=float32, numpy=array([0.], dtype=float32)>, discount=<tf.Tensor: shape=(1,), dtype=float32, numpy=array([1.], dtype=float32)>, observation=<tf.Tensor: shape=(1, 4), dtype=float32, numpy=
array([[ 0.02972457, -0.00550822, -0.04527468,  0.03920949]],
      dtype=float32)>)

Number of observations in the replay_buffer: 35
Episode number: 2
Steps: 33 (do not count the last observation in an episode)
Avg. episode duration: 16.5
Avg. return: 16.5
