In [1]:
import collections
import tensorflow as tf

from tf_agents.agents import tf_agent
from tf_agents.agents.dqn import dqn_agent
from tf_agents.drivers import dynamic_step_driver
from tf_agents.environments import suite_gym
from tf_agents.environments import tf_py_environment
from tf_agents.eval import metric_utils
from tf_agents.metrics import tf_metrics
from tf_agents.networks import q_network
from tf_agents.policies import random_tf_policy
from tf_agents.policies import q_policy
from tf_agents.policies import epsilon_greedy_policy
from tf_agents.policies import greedy_policy
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.trajectories import trajectory
from tf_agents.utils import common
from tf_agents.utils import composite
from tf_agents.utils import training as training_lib


tf.compat.v1.enable_v2_behavior()

In [2]:
num_iterations = 20000 # @param {type:"integer"}

initial_collect_steps = 1000  # @param {type:"integer"} 
collect_steps_per_iteration = 1  # @param {type:"integer"}
replay_buffer_max_length = 100000  # @param {type:"integer"}

batch_size = 64  # @param {type:"integer"}
learning_rate = 1e-3  # @param {type:"number"}
log_interval = 200  # @param {type:"integer"}

num_eval_episodes = 10  # @param {type:"integer"}
eval_interval = 1000  # @param {type:"integer"}

In [3]:
env_name = 'CartPole-v0'

train_py_env = suite_gym.load(env_name)
eval_py_env = suite_gym.load(env_name)

train_env = tf_py_environment.TFPyEnvironment(train_py_env)
eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)

In [4]:
fc_layer_params = (100,)

q_net = q_network.QNetwork(
    train_env.observation_spec(),
    train_env.action_spec(),
    fc_layer_params=fc_layer_params)

In [5]:
class MyLossInfo(collections.namedtuple('MyLossInfo',
                                         ('td_loss', 'td_error'))):
    pass

def compute_td_targets(next_q_values, rewards, discounts):
    return tf.stop_gradient(rewards + discounts * next_q_values)

class MyAgent(tf_agent.TFAgent):
    def __init__(self,
                 time_step_spec,
                 action_spec,
                 q_network,
                 optimizer,
                 epsilon_greedy=0.1,
                 gamma=1.0,
                 n_step_update=1,
                 train_step_counter=None,
                 name=None):
        tf.Module.__init__(self, name=name)

        self._q_network = q_network
        q_network.create_variables()
        self._target_q_network = common.maybe_copy_target_network_with_checks(
            self._q_network, None, 'TargetQNetwork')

        self._epsilon_greedy = epsilon_greedy
        self._n_step_update = n_step_update
        self._optimizer = optimizer
        self._td_errors_loss_fn = common.element_wise_squared_loss
        self._gamma = gamma
        self._update_target = self._get_target_updater()

        policy = q_policy.QPolicy(time_step_spec, action_spec, q_network=self._q_network)
        collect_policy = epsilon_greedy_policy.EpsilonGreedyPolicy(policy, epsilon=self._epsilon_greedy)
        policy = greedy_policy.GreedyPolicy(policy)

        # Create self._target_greedy_policy in order to compute target Q-values.
        target_policy = q_policy.QPolicy(time_step_spec, action_spec, q_network=self._target_q_network)
        self._target_greedy_policy = greedy_policy.GreedyPolicy(target_policy)

        train_sequence_length = n_step_update + 1

        super(MyAgent, self).__init__(
            time_step_spec,
            action_spec,
            policy,
            collect_policy,
            train_sequence_length=train_sequence_length,
            train_step_counter=train_step_counter)
        
    def _get_target_updater(self, tau=1.0, period=1):
        with tf.name_scope('update_targets'):

            def update():
                return common.soft_variables_update(
                    self._q_network.variables,
                    self._target_q_network.variables,
                    tau,
                    tau_non_trainable=1.0)

        return common.Periodically(update, period, 'periodic_update_targets')
        
    def _initialize(self):
        common.soft_variables_update(self._q_network.variables, self._target_q_network.variables, tau=1.0)
        
    def _experience_to_transitions(self, experience):
        transitions = trajectory.to_transition(experience)

        # Remove time dim if we are not using a recurrent network.
        if not self._q_network.state_spec:
            transitions = tf.nest.map_structure(lambda x: composite.squeeze(x, 1), transitions)

        time_steps, policy_steps, next_time_steps = transitions
        actions = policy_steps.action
        return time_steps, actions, next_time_steps

    # Use @common.function in graph mode or for speeding up.
    def _train(self, experience, weights):
        with tf.GradientTape() as tape:
            loss_info = self._loss(experience, training=True)
        
        tf.debugging.check_numerics(loss_info[0], 'Loss is inf or nan')
        variables_to_train = self._q_network.trainable_weights
        non_trainable_weights = self._q_network.non_trainable_weights
        assert list(variables_to_train), "No variables in the agent's q_network."
        grads = tape.gradient(loss_info.loss, variables_to_train)
        # Tuple is used for py3, where zip is a generator producing values once.
        grads_and_vars = list(zip(grads, variables_to_train))

        training_lib.apply_gradients(self._optimizer, grads_and_vars, global_step=self.train_step_counter)

        self._update_target()

        return loss_info

    def _loss(self, experience, training=False):
        time_steps, actions, next_time_steps = self._experience_to_transitions(experience)

        with tf.name_scope('loss'):
            q_values = self._compute_q_values(time_steps, actions, training=training)

            next_q_values = self._compute_next_q_values(next_time_steps)

            # Special case for n = 1 to avoid a loss of performance.
            td_targets = compute_td_targets(
                next_q_values,
                rewards = next_time_steps.reward,
                discounts = self._gamma * next_time_steps.discount)

            valid_mask = tf.cast(~time_steps.is_last(), tf.float32)
            td_error = valid_mask * (td_targets - q_values)

            td_loss = valid_mask * self._td_errors_loss_fn(td_targets, q_values)

            loss = tf.reduce_mean(input_tensor=td_loss)

            # Add network loss (such as regularization loss)
            if self._q_network.losses:
                loss = loss + tf.reduce_mean(self._q_network.losses)

            return tf_agent.LossInfo(loss, MyLossInfo(td_loss=td_loss,
                                                     td_error=td_error))

    def _compute_q_values(self, time_steps, actions, training=False):
        network_observation = time_steps.observation

        q_values, _ = self._q_network(network_observation, time_steps.step_type,
                                      training=training)
        multi_dim_actions = self._action_spec.shape.rank > 0
        return common.index_with_actions(
            q_values,
            tf.cast(actions, dtype=tf.int32),
            multi_dim_actions=multi_dim_actions)

    def _compute_next_q_values(self, next_time_steps):
        network_observation = next_time_steps.observation

        next_target_q_values, _ = self._target_q_network(network_observation, next_time_steps.step_type)
        batch_size = (next_target_q_values.shape[0] or tf.shape(next_target_q_values)[0])
        dummy_state = self._target_greedy_policy.get_initial_state(batch_size)
        # Find the greedy actions using our target greedy policy. This ensures that
        # action constraints are respected and helps centralize the greedy logic.
        greedy_actions = self._target_greedy_policy.action(next_time_steps, dummy_state).action

        # Handle action_spec.shape=(), and shape=(1,) by using the multi_dim_actions
        # param. Note: assumes len(tf.nest.flatten(action_spec)) == 1.
        multi_dim_actions = tf.nest.flatten(self._action_spec)[0].shape.rank > 0
        return common.index_with_actions(
            next_target_q_values,
            greedy_actions,
            multi_dim_actions=multi_dim_actions)



In [6]:
optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)

train_step_counter = tf.Variable(0)

#agent = dqn_agent.DqnAgent(
#    train_env.time_step_spec(),
#    train_env.action_spec(),
#    q_network=q_net,
#    optimizer=optimizer,
#    td_errors_loss_fn=common.element_wise_squared_loss,
#    train_step_counter=train_step_counter)

agent = MyAgent(
    train_env.time_step_spec(),
    train_env.action_spec(),
    q_network=q_net,
    optimizer=optimizer,
    train_step_counter=train_step_counter)

agent.initialize()

In [7]:
def compute_avg_return(environment, policy, num_episodes=10):
    total_return = 0.0
    for _ in range(num_episodes):

        time_step = environment.reset()
        episode_return = 0.0

        while not time_step.is_last():
            action_step = policy.action(time_step)
            time_step = environment.step(action_step.action)
            episode_return += time_step.reward
        total_return += episode_return

    avg_return = total_return / num_episodes
    return avg_return.numpy()[0]

def collect_step(environment, policy, buffer):
    time_step = environment.current_time_step()
    action_step = policy.action(time_step)
    next_time_step = environment.step(action_step.action)
    traj = trajectory.from_transition(time_step, action_step, next_time_step)

    # Add trajectory to the replay buffer
    buffer.add_batch(traj)

def collect_data(env, policy, buffer, steps):
    for _ in range(steps):
        collect_step(env, policy, buffer)

random_policy = random_tf_policy.RandomTFPolicy(train_env.time_step_spec(),
                                                train_env.action_spec())

replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
    data_spec=agent.collect_data_spec,
    batch_size=train_env.batch_size,
    max_length=replay_buffer_max_length)

collect_data(train_env, random_policy, replay_buffer, steps=100)

dataset = replay_buffer.as_dataset(
    num_parallel_calls=3, 
    sample_batch_size=batch_size, 
    num_steps=2).prefetch(3)

iterator = iter(dataset)

In [8]:
# (Optional) Optimize by wrapping some of the code in a graph using TF function.
agent.train = common.function(agent.train)

# Reset the train step
agent.train_step_counter.assign(0)

# Evaluate the agent's policy once before training.
avg_return = compute_avg_return(eval_env, agent.policy, num_eval_episodes)
returns = [avg_return]

for _ in range(num_iterations):

    # Collect a few steps using collect_policy and save to the replay buffer.
    for _ in range(collect_steps_per_iteration):
        collect_step(train_env, agent.collect_policy, replay_buffer)

    # Sample a batch of data from the buffer and update the agent's network.
    experience, unused_info = next(iterator)
    train_loss = agent.train(experience).loss

    step = agent.train_step_counter.numpy()

    if step % log_interval == 0:
        print('step = {0}: loss = {1}'.format(step, train_loss))

    if step % eval_interval == 0:
        avg_return = compute_avg_return(eval_env, agent.policy, num_eval_episodes)
        print('step = {0}: Average Return = {1}'.format(step, avg_return))
        returns.append(avg_return)

step = 200: loss = 11.006783485412598
step = 400: loss = 3.7969253063201904
step = 600: loss = 4.921608924865723
step = 800: loss = 8.25903034210205
step = 1000: loss = 6.7414398193359375
step = 1000: Average Return = 12.199999809265137
step = 1200: loss = 29.57880210876465
step = 1400: loss = 20.229900360107422
step = 1600: loss = 10.648536682128906
step = 1800: loss = 13.900444984436035
step = 2000: loss = 30.028514862060547
step = 2000: Average Return = 21.200000762939453
step = 2200: loss = 11.769752502441406
step = 2400: loss = 9.325913429260254
step = 2600: loss = 20.44812774658203
step = 2800: loss = 9.726699829101562
step = 3000: loss = 47.587310791015625
step = 3000: Average Return = 50.900001525878906
step = 3200: loss = 19.15411376953125
step = 3400: loss = 59.7524299621582
step = 3600: loss = 33.184104919433594
step = 3800: loss = 24.32809829711914
step = 4000: loss = 47.7923469543457
step = 4000: Average Return = 68.80000305175781
step = 4200: loss = 50.71818923950195
step