In [None]:
!pip install 'tf-agents[reverb]'
!pip install pyglet
!pip install pyvirtualdisplay -i https://pypi.tuna.tsinghua.edu.cn/simple

In [None]:
from tf_agents.environments import suite_gym, tf_py_environment
from tf_agents.networks import sequential
from tf_agents.specs import tensor_spec
from tf_agents.utils import common
from tf_agents.agents.dqn import dqn_agent
from tf_agents.policies import random_tf_policy

# import reverb

import tensorflow as tf

import PIL
# import pyvirtualdisplay
import numpy as np

In [None]:
# display = pyvirtualdisplay.Display(visible=0, size=(1400, 900)).start()


In [None]:
learning_rate=1e-3
num_eval_episodes=100

## Environment

In [None]:
env_name = 'CartPole-v0'
env = suite_gym.load(env_name)

In [None]:
env.reset()

In [None]:
frame = env.render(mode='rgb_array')

In [None]:
PIL.Image.fromarray(frame)

In [None]:
print('Time Step Spec:')
print(env.time_step_spec())

In [None]:
print('Action Spec:')
print(env.action_spec())

In [None]:
time_step = env.reset()
time_step


In [None]:
action = np.array(1, dtype=np.int32)

next_time_step = env.step(action)
next_time_step

In [None]:
train_py_env = suite_gym.load(env_name)
eval_py_env = suite_gym.load(env_name)

train_env = tf_py_environment.TFPyEnvironment(train_py_env)
eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)

## Agent

In [None]:
fc_layer_params = (100, 50)
action_tensor_spec = tensor_spec.from_spec(env.action_spec())
num_actions = action_tensor_spec.maximum - action_tensor_spec.minimum + 1

In [None]:
def dense_layer(num_units):
    return tf.keras.layers.Dense(
        num_units,
        activation=tf.keras.activations.relu,
        kernel_initializer=tf.keras.initializers.VarianceScaling(
            scale=2.0, mode='fan_in', distribution='truncated_normal'))

In [None]:
dense_layers = [dense_layer(num_units) for num_units in fc_layer_params]

In [None]:
q_values_layer = tf.keras.layers.Dense(
    num_actions,
    activation=None,
    kernel_initializer=tf.keras.initializers.RandomUniform(
        minval=-0.03, maxval=0.03),
    bias_initializer=tf.keras.initializers.Constant(-0.2))

In [None]:
q_net = sequential.Sequential(dense_layers + [q_values_layer])

In [None]:
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

In [None]:
train_step_counter = tf.Variable(0)

agent = dqn_agent.DqnAgent(
    train_env.time_step_spec(),
    train_env.action_spec(),
    q_network=q_net,
    optimizer=optimizer,
    td_errors_loss_fn=common.element_wise_squared_loss,
    train_step_counter=train_step_counter)

agent.initialize()

## Policy

In [None]:
eval_policy = agent.policy
collect_policy = agent.collect_policy

### Random Policy Example

In [None]:
random_policy = random_tf_policy.RandomTFPolicy(train_env.time_step_spec(),
                                                train_env.action_spec())

In [None]:
example_environment = tf_py_environment.TFPyEnvironment(
    suite_gym.load('CartPole-v0'))

In [None]:
time_step = example_environment.reset()

In [None]:
for _ in range(100):
    action = random_policy.action(time_step)
    
    time_step = example_environment.step(action)
    
    print(time_step.is_last(), time_step.reward, action)

## Metrics and Evaluation

In [None]:
def compute_avg_return(environment, policy, num_episodes=10):
    
    total_return = 0.0
    for _ in range(num_episodes):
        
        time_step = environment.reset()
        episode_return =0.0
        
        while not time_step.is_last():
            action_step = policy.action(time_step)
            time_step = environment.step(action_step.action)
            episode_return += time_step.reward
            
        total_return += episode_return
        
    avg_return = total_return / num_episodes
    
    return avg_return.numpy()[0]

In [None]:
compute_avg_return(example_environment, random_policy, num_eval_episodes)

## Replay Buffer

In [None]:
table_name = 'uniform_table'

replay_buffer_signature = tensor_spec.from_spec(
    agent.collect_data_spec)
replay_buffer_signature = tensor_spec.add_outer_dim(
    replay_buffer_signature)

table = reverb.Table(
    table_name,
    max_size=replay_buffer_max_length,
    sampler=reverb.selectors.Uniform(),
    remover=reverb.selectors.Fifo(),
    rate_limiter=reverb.rate_limiters.MinSize(1),
    signature=replay_buffer_signature)

reverb_server = reverb.Server([table])

replay_buffer = reverb_replay_buffer.ReverbReplayBuffer(
    agent.collect_data_spec,
    table_name=table_name,
    sequence_length=2,
    local_server=reverb_server)

rb_observer = reverb_utils.ReverbAddTrajectoryObserver(
    replay_buffer.py_client,
    table_name,
    sequence_length=2)