In [1]:
import collections
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp

from tf_agents.specs import tensor_spec
from tf_agents.specs import distribution_spec
from tf_agents.environments import suite_gym
from tf_agents.environments import tf_py_environment
from tf_agents.trajectories import trajectory
from tf_agents.trajectories import policy_step
from tf_agents.drivers import dynamic_step_driver
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.networks import network
from tf_agents.policies import tf_policy
from tf_agents.policies import random_tf_policy
from tf_agents.agents import tf_agent
from tf_agents.metrics import tf_metrics
from tf_agents.utils import nest_utils
from tf_agents.utils import common

import matplotlib.pyplot as plt

tf.compat.v1.enable_v2_behavior()

In [2]:
load_weights = False
#env_name = "Pendulum-v0"
env_name = "LunarLanderContinuous-v2" 

num_iterations = 100000 

initial_collect_steps = 10000  
collect_steps_per_iteration = 1 
replay_buffer_max_length = 100000 

batch_size = 256 

critic_learning_rate = 3e-4 
actor_learning_rate = 3e-4 
alpha_learning_rate = 3e-4 
target_update_tau = 0.005 
target_update_period = 1 
gamma = 0.99 

actor_fc_layer_params = (256, 256)
critic_fc_layer_params = (256, 256)

log_interval = 5000 

num_eval_episodes = 10 
eval_interval = 10000 
max_episode_steps = 1000

In [3]:
train_py_env = suite_gym.load(env_name, max_episode_steps=max_episode_steps)
eval_py_env = suite_gym.load(env_name)

train_env = tf_py_environment.TFPyEnvironment(train_py_env)
eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)

observation_spec = train_env.observation_spec()
action_spec = train_env.action_spec()

In [4]:
def spec_means_and_magnitudes(action_spec):
    action_means = tf.nest.map_structure(
        lambda spec: (spec.maximum + spec.minimum) / 2.0, action_spec)
    action_magnitudes = tf.nest.map_structure(
        lambda spec: (spec.maximum - spec.minimum) / 2.0, action_spec)
    return tf.cast(action_means, dtype=tf.float32), tf.cast(action_magnitudes, dtype=tf.float32)

class SimpleActorDistributionNetwork(network.DistributionNetwork):
    def __init__(self,
                 input_tensor_spec,
                 output_tensor_spec,
                 fc_layer_params,
                 name="ActorNormalDistributionNetwork"):
        
        output_spec = self._output_distribution_spec(output_tensor_spec, name) 
        
        super(SimpleActorDistributionNetwork, self).__init__(
            input_tensor_spec=input_tensor_spec,
            state_spec=(),
            output_spec=output_spec,
            name=name)

        num_actions = output_tensor_spec.shape.num_elements()
          
        self._encoding_layers = []
        for num_units in fc_layer_params:
            self._encoding_layers.append(tf.keras.layers.Dense(
                num_units,
                activation=tf.keras.activations.relu,
                kernel_initializer=tf.compat.v1.keras.initializers.glorot_uniform(),
                name='%s/dense' % name))
        
       # means layer for distribution
        init_means_output_factor = 0.1
        std_bias_initializer_value = 0.0
        
        self._means_projection_layer = tf.keras.layers.Dense(
            num_actions,
            activation=None,
            kernel_initializer=tf.compat.v1.keras.initializers.VarianceScaling(
                scale=init_means_output_factor),
            bias_initializer=tf.keras.initializers.Zeros(),
            name='means_projection_layer')

        # standard dev layer for distribution
        self._stddev_projection_layer = tf.keras.layers.Dense(
            num_actions,
            activation=None,
            kernel_initializer=tf.compat.v1.keras.initializers.VarianceScaling(
                scale=init_means_output_factor),
            bias_initializer=tf.keras.initializers.Constant(
                value=std_bias_initializer_value),
            name='stddev_projection_layer')
        
        # Scale
        action_means, action_magnitudes = spec_means_and_magnitudes(output_tensor_spec)
        bijectors = [tfp.bijectors.Shift(action_means),
                     tfp.bijectors.Scale(action_magnitudes),
                     tfp.bijectors.Tanh()]
        self._bijector_chain = tfp.bijectors.Chain(bijectors)
        
        
    def _output_distribution_spec(self, sample_spec, network_name):
        input_param_shapes = tfp.distributions.Normal.param_static_shapes(
            sample_spec.shape)

        input_param_spec = {
            name: tensor_spec.TensorSpec(  
                shape=shape,
                dtype=sample_spec.dtype,
                name=network_name + '_' + name)
            for name, shape in input_param_shapes.items()
        }

        def distribution_builder(*args, **kwargs):            
            distribution = tfp.distributions.Normal(*args, **kwargs)
            return tfp.distributions.TransformedDistribution(distribution=distribution, bijector=self._bijector_chain)

        return distribution_spec.DistributionSpec(distribution_builder, input_param_spec, sample_spec=sample_spec)

    
    def call(self, observations, step_type, network_state, training=False):        
        encoding = observations
        
        for layer in self._encoding_layers:
            encoding = layer(encoding, training=training)
        
        means = self._means_projection_layer(encoding, training=training)

        stds = self._stddev_projection_layer(encoding, training=training)
        stds = tf.clip_by_value(stds, -20, 2)
        stds = tf.exp(stds)
        
        return self.output_spec.builder(loc=means, scale=stds), network_state


In [5]:
class SimpleActorPolicy(tf_policy.Base):
    def __init__(self,
        time_step_spec,
        action_spec,
        actor_network,
        training=False):

        actor_network.create_variables()
        self._actor_network = actor_network
        self._training = training

        super(SimpleActorPolicy, self).__init__(
            time_step_spec=time_step_spec,
            action_spec=action_spec,
            policy_state_spec=actor_network.state_spec)

    def _variables(self):
        return self._actor_network.variables

    def _distribution(self, time_step, policy_state):
        distributions, policy_state = self._actor_network(time_step.observation, 
                                                          time_step.step_type, 
                                                          policy_state, 
                                                          training=self._training)

        return policy_step.PolicyStep(distributions, policy_state)

In [6]:
# Critic
class SimpleCriticNetwork(network.Network):
    def __init__(self,
                 input_tensor_spec,
                 fc_layer_params,
                 name='SimpleCriticNetwork'):
        
        super(SimpleCriticNetwork, self).__init__(
            input_tensor_spec=input_tensor_spec,
            state_spec=(),
            name=name)
          
        self._encoding_layers = []
        for num_units in fc_layer_params:
            self._encoding_layers.append(tf.keras.layers.Dense(
                num_units,
                activation=tf.keras.activations.relu,
                kernel_initializer=tf.compat.v1.keras.initializers.glorot_uniform(),
                name='%s/dense' % name))
        
        self._value = tf.keras.layers.Dense(
            1,
            activation=None,
            kernel_initializer=tf.compat.v1.keras.initializers.glorot_uniform(),
            name='value')


    def call(self, inputs, step_type=(), network_state=(), training=False):
        observations, actions = inputs
        encoding = tf.concat([observations, actions], 1)
        
        for layer in self._encoding_layers:
            encoding = layer(encoding, training=training)

        value = self._value(encoding, training=training)
        return tf.reshape(value, [-1]), network_state

In [7]:
SacLossInfo = collections.namedtuple(
    'SacLossInfo', ('critic_loss', 'actor_loss', 'alpha_loss'))

class SacAgent(tf_agent.TFAgent):
    def __init__(self,
        time_step_spec,
        action_spec,
        critic_network,
        actor_network,
        actor_optimizer,
        critic_optimizer,
        alpha_optimizer,
        target_update_tau=1.0,
        target_update_period=1,
        gamma=1.0,
        train_step_counter=None,
        name=None):

        flat_action_spec = tf.nest.flatten(action_spec)

        self._critic_network_1 = critic_network
        self._critic_network_1.create_variables()
        
        self._target_critic_network_1 = critic_network.copy(name='TargetCriticNetwork1')
        self._target_critic_network_1.create_variables()

        self._critic_network_2 = critic_network.copy(name='CriticNetwork2')
        self._critic_network_2.create_variables()
        
        self._target_critic_network_2 = critic_network.copy(name='TargetCriticNetwork2')
        self._target_critic_network_2.create_variables()
        
        actor_network.create_variables()
        self._actor_network = actor_network

        policy = SimpleActorPolicy(
            time_step_spec=time_step_spec,
            action_spec=action_spec,
            actor_network=self._actor_network,
            training=False)

        self._train_policy = SimpleActorPolicy(
            time_step_spec=time_step_spec,
            action_spec=action_spec,
            actor_network=self._actor_network,
            training=True)

        self._log_alpha = tf.compat.v2.Variable(0.0, trainable=True, dtype=tf.float32, name='initial_log_alpha')
        
        flat_action_spec = tf.nest.flatten(action_spec)
        target_entropy = -np.sum([
          np.product(single_spec.shape.as_list())
          for single_spec in flat_action_spec
        ])

        self._target_update_tau = target_update_tau
        self._target_update_period = target_update_period
        self._actor_optimizer = actor_optimizer
        self._critic_optimizer = critic_optimizer
        self._alpha_optimizer = alpha_optimizer
        self._td_errors_loss_fn = tf.compat.v1.losses.mean_squared_error
        self._gamma = gamma
        self._target_entropy = target_entropy

        self._update_target = self._get_target_updater(
            tau=self._target_update_tau, 
            period=self._target_update_period)

        train_sequence_length = 2

        super(SacAgent, self).__init__(
            time_step_spec,
            action_spec,
            policy=policy,
            collect_policy=policy,
            train_sequence_length=train_sequence_length,
            train_step_counter=train_step_counter)

    def _initialize(self):
        common.soft_variables_update(
            self._critic_network_1.variables,
            self._target_critic_network_1.variables,
            tau=1.0)
        common.soft_variables_update(
            self._critic_network_2.variables,
            self._target_critic_network_2.variables,
            tau=1.0)

    def _experience_to_transitions(self, experience):
        transitions = trajectory.to_transition(experience)
        time_steps, policy_steps, next_time_steps = transitions
        actions = policy_steps.action
        if (self.train_sequence_length is not None and
            self.train_sequence_length == 2):
            # Sequence empty time dimension if critic network is stateless.
            time_steps, actions, next_time_steps = tf.nest.map_structure(
                lambda t: tf.squeeze(t, axis=1),
                (time_steps, actions, next_time_steps))
        return time_steps, actions, next_time_steps

    def _train(self, experience, weights):
        # Get Transitions
        time_steps, actions, next_time_steps = self._experience_to_transitions(experience)

        # Train Critic
        trainable_critic_variables = (
            self._critic_network_1.trainable_variables +
            self._critic_network_2.trainable_variables)
        
        with tf.GradientTape(watch_accessed_variables=False) as tape:
            tape.watch(trainable_critic_variables)
            critic_loss = self.critic_loss(
                time_steps,
                actions,
                next_time_steps,
                td_errors_loss_fn=self._td_errors_loss_fn,
                gamma=self._gamma,
                weights=weights,
                training=True)

        tf.debugging.check_numerics(critic_loss, 'Critic loss is inf or nan.')
        critic_grads = tape.gradient(critic_loss, trainable_critic_variables)
        self._critic_optimizer.apply_gradients(list(zip(critic_grads, trainable_critic_variables)))

        # Train Actor
        trainable_actor_variables = self._actor_network.trainable_variables
        
        with tf.GradientTape(watch_accessed_variables=False) as tape:
            tape.watch(trainable_actor_variables)
            actor_loss = self.actor_loss(time_steps, weights=weights)

        tf.debugging.check_numerics(actor_loss, 'Actor loss is inf or nan.')
        actor_grads = tape.gradient(actor_loss, trainable_actor_variables)        
        self._actor_optimizer.apply_gradients(list(zip(actor_grads, trainable_actor_variables)))

        # Update Alpha
        alpha_variable = [self._log_alpha]
        with tf.GradientTape(watch_accessed_variables=False) as tape:
            assert alpha_variable, 'No alpha variable to optimize.'
            tape.watch(alpha_variable)
            alpha_loss = self.alpha_loss(time_steps, weights=weights)
            
        tf.debugging.check_numerics(alpha_loss, 'Alpha loss is inf or nan.')
        alpha_grads = tape.gradient(alpha_loss, alpha_variable)
        self._alpha_optimizer.apply_gradients(list(zip(alpha_grads, alpha_variable)))

        self.train_step_counter.assign_add(1)
        self._update_target()
        
        total_loss = critic_loss + actor_loss + alpha_loss

        extra = SacLossInfo(critic_loss=critic_loss,
                            actor_loss=actor_loss,
                            alpha_loss=alpha_loss)

        return tf_agent.LossInfo(loss=total_loss, extra=extra)


    def _get_target_updater(self, tau=1.0, period=1):
        with tf.name_scope('update_target'):

            def update():
                critic_update_1 = common.soft_variables_update(
                    self._critic_network_1.variables,
                    self._target_critic_network_1.variables,
                    tau,
                    tau_non_trainable=1.0)
                critic_update_2 = common.soft_variables_update(
                    self._critic_network_2.variables,
                    self._target_critic_network_2.variables,
                    tau,
                    tau_non_trainable=1.0)
                return tf.group(critic_update_1, critic_update_2)

            return common.Periodically(update, period, 'update_targets')

    def _actions_and_log_probs(self, time_steps):
        # Get raw action distribution from policy, and initialize bijectors list.
        batch_size = nest_utils.get_outer_shape(time_steps, self._time_step_spec)[0]
        policy_state = self._train_policy.get_initial_state(batch_size)
        action_distribution = self._train_policy.distribution(time_steps, policy_state=policy_state).action

        # Sample actions and log_pis from transformed distribution.
        actions = tf.nest.map_structure(lambda d: d.sample(), action_distribution)
        log_pi = common.log_probability(action_distribution, actions, self.action_spec)

        return actions, log_pi

    def critic_loss(self,
        time_steps,
        actions,
        next_time_steps,
        td_errors_loss_fn,
        gamma=1.0,
        reward_scale_factor=1.0,
        weights=None,
        training=False):

        with tf.name_scope('critic_loss'):
            tf.nest.assert_same_structure(actions, self.action_spec)
            tf.nest.assert_same_structure(time_steps, self.time_step_spec)
            tf.nest.assert_same_structure(next_time_steps, self.time_step_spec)

            next_actions, next_log_pis = self._actions_and_log_probs(next_time_steps)
            target_input = (next_time_steps.observation, next_actions)
            target_q_values1, unused_network_state1 = self._target_critic_network_1(
                target_input, next_time_steps.step_type, training=False)
            target_q_values2, unused_network_state2 = self._target_critic_network_2(
                target_input, next_time_steps.step_type, training=False)
            target_q_values = (
                tf.minimum(target_q_values1, target_q_values2) -
                tf.exp(self._log_alpha) * next_log_pis)

            td_targets = tf.stop_gradient(
                reward_scale_factor * next_time_steps.reward +
                gamma * next_time_steps.discount * target_q_values)

            pred_input = (time_steps.observation, actions)
            pred_td_targets1, _ = self._critic_network_1(pred_input, time_steps.step_type, training=training)
            pred_td_targets2, _ = self._critic_network_2(pred_input, time_steps.step_type, training=training)
            critic_loss1 = td_errors_loss_fn(td_targets, pred_td_targets1)
            critic_loss2 = td_errors_loss_fn(td_targets, pred_td_targets2)
            critic_loss = critic_loss1 + critic_loss2

            if weights is not None:
                critic_loss *= weights

            if nest_utils.is_batched_nested_tensors(
                time_steps, self.time_step_spec, num_outer_dims=2):
                # Sum over the time dimension.
                critic_loss = tf.reduce_sum(input_tensor=critic_loss, axis=1)

            # Take the mean across the batch.
            critic_loss = tf.reduce_mean(input_tensor=critic_loss)

            return critic_loss

    def actor_loss(self, time_steps, weights=None):
        with tf.name_scope('actor_loss'):
            tf.nest.assert_same_structure(time_steps, self.time_step_spec)

            actions, log_pi = self._actions_and_log_probs(time_steps)
            target_input = (time_steps.observation, actions)
            target_q_values1, _ = self._critic_network_1(target_input,
                                                       time_steps.step_type,
                                                       training=False)
            target_q_values2, _ = self._critic_network_2(target_input,
                                                       time_steps.step_type,
                                                       training=False)
            target_q_values = tf.minimum(target_q_values1, target_q_values2)
            actor_loss = tf.exp(self._log_alpha) * log_pi - target_q_values
            if nest_utils.is_batched_nested_tensors(
                  time_steps, self.time_step_spec, num_outer_dims=2):
                # Sum over the time dimension.
                actor_loss = tf.reduce_sum(input_tensor=actor_loss, axis=1)
            if weights is not None:
                actor_loss *= weights
            actor_loss = tf.reduce_mean(input_tensor=actor_loss)

            return actor_loss
    
    def alpha_loss(self, time_steps, weights=None):
        with tf.name_scope('alpha_loss'):
            tf.nest.assert_same_structure(time_steps, self.time_step_spec)

            unused_actions, log_pi = self._actions_and_log_probs(time_steps)
            entropy_diff = tf.stop_gradient(-log_pi - self._target_entropy)
            alpha_loss = (self._log_alpha * entropy_diff)

            if nest_utils.is_batched_nested_tensors(time_steps, self.time_step_spec, num_outer_dims=2):
                # Sum over the time dimension.
                alpha_loss = tf.reduce_sum(input_tensor=alpha_loss, axis=1)

            if weights is not None:
                alpha_loss *= weights

            alpha_loss = tf.reduce_mean(input_tensor=alpha_loss)

            return alpha_loss

In [8]:
actor_net = SimpleActorDistributionNetwork(observation_spec, action_spec, actor_fc_layer_params)
critic_net = SimpleCriticNetwork((observation_spec, action_spec), critic_fc_layer_params)

In [9]:
# Agent
global_step = tf.compat.v1.train.get_or_create_global_step()
agent = SacAgent(
    train_env.time_step_spec(),
    action_spec,
    actor_network=actor_net,
    critic_network=critic_net,
    actor_optimizer=tf.compat.v1.train.AdamOptimizer(
        learning_rate=actor_learning_rate),
    critic_optimizer=tf.compat.v1.train.AdamOptimizer(
        learning_rate=critic_learning_rate),
    alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
        learning_rate=alpha_learning_rate),
    target_update_tau=target_update_tau,
    target_update_period=target_update_period,
    gamma=gamma,
    train_step_counter=global_step)
agent.initialize()

In [10]:
# Load Weights
def load_model_weights():
    agent._actor_network.load_weights("./{}/actor/saved_actor".format(env_name))
    agent._critic_network_1.load_weights("./{}/critic1/saved_critic".format(env_name))
    agent._critic_network_2.load_weights("./{}/critic2/saved_critic".format(env_name))

if (load_weights == True):   
    print("Loaded Weights")
    load_model_weights()

In [11]:
# Create the replay buffer for training
replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
    data_spec=agent.collect_data_spec,
    batch_size=train_env.batch_size,
    max_length=replay_buffer_max_length)

# Collect some random samples to start.
random_policy = random_tf_policy.RandomTFPolicy(train_env.time_step_spec(),
                                                train_env.action_spec())

dynamic_step_driver.DynamicStepDriver(
    train_env, 
    random_policy,
    observers=[replay_buffer.add_batch],
    num_steps=initial_collect_steps).run()

# Create collection driver
num_episodes = tf_metrics.NumberOfEpisodes()
env_steps = tf_metrics.EnvironmentSteps()
observers = [num_episodes, env_steps, replay_buffer.add_batch]

collect_op = dynamic_step_driver.DynamicStepDriver(
    train_env, 
    agent.collect_policy,
    observers=observers,
    num_steps=collect_steps_per_iteration)

# Create a data set for the training loop
dataset = replay_buffer.as_dataset(
    num_parallel_calls=3, 
    sample_batch_size=batch_size, 
    num_steps=2).prefetch(3)

iterator = iter(dataset)

In [12]:
def compute_avg_return(environment, policy, num_episodes=10):
    total_return = 0.0
    for _ in range(num_episodes):

        time_step = environment.reset()
        episode_return = 0.0

        while not time_step.is_last():
            action_step = policy.action(time_step)
            time_step = environment.step(action_step.action)
            episode_return += time_step.reward
        total_return += episode_return

    avg_return = total_return / num_episodes
    return avg_return.numpy()[0]

# (Optional) Optimize by wrapping some of the code in a graph using TF function.
agent.train = common.function(agent.train)

# Reset the train step
agent.train_step_counter.assign(0)

# Evaluate the agent's policy once before training.
avg_return = compute_avg_return(eval_env, agent.policy, num_eval_episodes)
returns = [avg_return]
print("avg_return={}; max_return={}".format(avg_return, np.amax(returns)))

for _ in range(num_iterations):

    # Collect a few steps using collect_policy and save to the replay buffer.
    collect_op.run()
    
    # Sample a batch of data from the buffer and update the agent's network.
    experience, unused_info = next(iterator)
    train_loss = agent.train(experience).loss

    step = env_steps.result().numpy()
    episodes = num_episodes.result().numpy()

    if step % log_interval == 0:
        print('step = {0}: loss = {1}'.format(step, train_loss))

    if step % eval_interval == 0:
        avg_return = compute_avg_return(eval_env, agent.policy, num_eval_episodes)
        print('step = {0}: episodes={1}: Average Return = {2}'.format(step, episodes, avg_return))
        if (avg_return > max(returns)):
            # Save Weights
            print("Save Weights: avg_return={}; max_return={}".format(avg_return, np.amax(returns)))
            agent._actor_network.save_weights("./{}/actor/saved_actor".format(env_name))
            agent._critic_network_1.save_weights("./{}/critic1/saved_critic".format(env_name))
            agent._critic_network_2.save_weights("./{}/critic2/saved_critic".format(env_name))
        returns.append(avg_return)


avg_return=-206.75851440429688; max_return=-206.75851440429688
step = 5000: loss = 170.78155517578125
step = 10000: loss = 124.37114715576172
step = 10000: episodes=20: Average Return = -53.99135208129883
Save Weights: avg_return=-53.99135208129883; max_return=-206.75851440429688
step = 15000: loss = 23.42526626586914
step = 20000: loss = -1.7273131608963013
step = 20000: episodes=37: Average Return = -1.349188208580017
Save Weights: avg_return=-1.349188208580017; max_return=-53.99135208129883
step = 25000: loss = 4.47935152053833
step = 30000: loss = 44.95140838623047
step = 30000: episodes=48: Average Return = 99.681640625
Save Weights: avg_return=99.681640625; max_return=-1.349188208580017
step = 35000: loss = 14.585715293884277
step = 40000: loss = -18.631175994873047
step = 40000: episodes=87: Average Return = 193.220703125
Save Weights: avg_return=193.220703125; max_return=99.681640625
step = 45000: loss = 41.61185836791992
step = 50000: loss = 26.60292625427246
step = 50000: epi

KeyboardInterrupt: 

In [None]:
for i in range(10):
    rewards = 0.0
    time_step = eval_env.reset()
    while not time_step.is_last():
        action_step = agent.policy.action(time_step)
        time_step = eval_env.step(action_step.action)
        rewards += time_step.reward
        eval_py_env.render()
    print(rewards)