In [5]:
import numpy as np
from keras.optimizers import Adam
import tensorflow as tf
from keras.models import clone_model
import gymnasium as gym

In [6]:
class TD3Agent:

    def __init__(self,
                 env,
                 critic_network1,
                 critic_network2,
                 actor_network,
                 critic_learning_rate=1e-3,
                 actor_learning_rate=1e-3,
                 discount_factor=0.99,
                 minibatch_size=100,
                 tau=0.005,
                 exploratory_noise_std=0.1,
                 policy_noise=0.1,
                 noise_clip=0.5,
                 policy_delay=2,
                 warm_up=1000,
                 max_buffer_length=1_000_000):
        self.env = env
        self.state_size = env.observation_space.shape[0]
        self.action_size = env.action_space.shape[0]
        self.action_min, self.action_max = env.action_space.low, env.action_space.high
        self.max_reward = env.spec.reward_threshold

        # Initialize critic network 1 and target critic network 1.
        self.critic_network1 = clone_model(critic_network1)
        self.critic_network1.set_weights(critic_network1.get_weights())

        self.target_critic_network1 = clone_model(critic_network1)
        self.target_critic_network1.set_weights(critic_network1.get_weights())

        # Initialize critic network 2 and target critic network 2.
        self.critic_network2 = clone_model(critic_network2)
        self.critic_network2.set_weights(critic_network2.get_weights())

        self.target_critic_network2 = clone_model(critic_network2)
        self.target_critic_network2.set_weights(critic_network2.get_weights())

        # Initialize actor network and target actor network.
        self.actor_network = clone_model(actor_network)
        self.actor_network.set_weights(actor_network.get_weights())

        self.target_actor_network = clone_model(actor_network)
        self.target_actor_network.set_weights(actor_network.get_weights())

        # Initialize network optimizers.
        self.critic1_optimizer = Adam(learning_rate=critic_learning_rate)
        self.critic2_optimizer = Adam(learning_rate=critic_learning_rate)
        self.actor_optimizer = Adam(learning_rate=actor_learning_rate)

        # Initialize hyperparameters.
        self.minibatch_size = minibatch_size
        self.discount_factor = discount_factor
        self.tau = tau
        self.exploratory_noise_std = exploratory_noise_std
        self.policy_noise = policy_noise
        self.noise_clip = noise_clip
        self.policy_delay = policy_delay
        self.warm_up = warm_up

        # Initialize buffer.
        self.buffer_width = 2 * self.state_size + self.action_size + 2
        self.max_buffer_length = max_buffer_length
        self.replay_buffer = np.zeros((self.max_buffer_length, self.buffer_width), dtype=np.float32)
        self.buffer_write_idx = 0
        self.buffer_fullness = 0

        # Initialize minibatch slicers.
        self.state_slice = slice(0, self.state_size)
        self.state_action_slice = slice(0, self.state_size + self.action_size)
        self.reward_slice = slice(self.state_size + self.action_size, self.state_size + self.action_size + 1)
        self.next_state_slice = slice(self.state_size + self.action_size + 1, 2 * self.state_size + self.action_size + 1)
        self.terminal_slice = slice(2 * self.state_size + self.action_size + 1, 2 * self.state_size + self.action_size + 2)

    def select_action(self, state: np.ndarray) -> np.ndarray:
        """Select an action at the given state."""
        # Input check.
        assert state.ndim == 1 and state.shape[0] == self.state_size
        # Do forward pass.
        action = self.actor_network(np.expand_dims(state, axis=0), training=False).numpy()[0]
        # Add exploratory noise.
        noise = np.random.normal(0, self.exploratory_noise_std, self.action_size)
        action += noise
        action = np.clip(action, self.action_min, self.action_max)
        return action

    def save_transition(self, state: np.ndarray, action: np.ndarray, reward: float, new_state: np.ndarray, terminal: bool):
        """Save a transition."""
        # Input check.
        assert state.ndim == 1 and state.shape[0] == self.state_size
        assert action.ndim == 1 and action.shape[0] == self.action_size
        assert new_state.ndim == 1 and new_state.shape[0] == self.state_size

        # Save transition.
        transition = np.concatenate((state, action, [reward], new_state, [1.0 if terminal else 0.0]), dtype=np.float32)
        self.replay_buffer[self.buffer_write_idx] = transition

        # Update write index and fullness.
        self.buffer_write_idx = (self.buffer_write_idx + 1) % self.max_buffer_length
        self.buffer_fullness = min(self.buffer_fullness + 1, self.max_buffer_length)

    def sample_minibatch(self) -> tf.Tensor:
        """Sample a minibatch from the replay buffer."""
        indices = np.random.choice(self.buffer_fullness, size=self.minibatch_size, replace=False)
        minibatch = self.replay_buffer[indices]
        return tf.convert_to_tensor(minibatch, dtype=tf.float32)

    @tf.function
    def update_critic_networks(self, minibatch: tf.Tensor):
        """Update critic networks"""
        mb_state_actions = minibatch[:, self.state_action_slice]
        mb_rewards = minibatch[:, self.reward_slice]
        mb_next_states = minibatch[:, self.next_state_slice]
        mb_terminals = minibatch[:, self.terminal_slice]

        next_actions = self.target_actor_network(mb_next_states, training=False)
        noise = tf.random.normal(shape=next_actions.shape, mean=0, stddev=self.policy_noise)
        noise = tf.clip_by_value(noise, -self.noise_clip, self.noise_clip)
        next_actions = tf.clip_by_value(next_actions + noise, self.action_min, self.action_max)
        next_state_actions = tf.concat((mb_next_states, next_actions), axis=1)

        q1_next = self.target_critic_network1(next_state_actions, training=False)
        q2_next = self.target_critic_network2(next_state_actions, training=False)
        q_min_next = tf.minimum(q1_next, q2_next)
        q_target = tf.stop_gradient(mb_rewards + self.discount_factor * (1 - mb_terminals) * q_min_next)

        # Train critic 1 network.
        with tf.GradientTape() as tape1:
            q1_expected = self.critic_network1(mb_state_actions, training=True)
            critic1_loss = tf.reduce_mean(tf.square(q1_expected - q_target))

        critic1_grads = tape1.gradient(critic1_loss, self.critic_network1.trainable_variables)
        self.critic1_optimizer.apply_gradients(zip(critic1_grads, self.critic_network1.trainable_variables))

        # Train critic 2 network.
        with tf.GradientTape() as tape2:
            q2_expected = self.critic_network2(mb_state_actions, training=True)
            critic2_loss = tf.reduce_mean(tf.square(q2_expected - q_target))

        critic2_grads = tape2.gradient(critic2_loss, self.critic_network2.trainable_variables)
        self.critic2_optimizer.apply_gradients(zip(critic2_grads, self.critic_network2.trainable_variables))

    @tf.function
    def update_actor_network(self, minibatch: tf.Tensor):
        """Update the actor network."""
        mb_states = minibatch[:, self.state_slice]

        with tf.GradientTape() as tape:
            pred_actions = self.actor_network(mb_states, training=True)
            pred_state_actions = tf.concat((mb_states, pred_actions), axis=1)
            q_values = self.critic_network1(pred_state_actions, training=False)
            actor_loss = -tf.reduce_mean(q_values)

        actor_grads = tape.gradient(actor_loss, self.actor_network.trainable_variables)
        self.actor_optimizer.apply_gradients(zip(actor_grads, self.actor_network.trainable_variables))

    def soft_update_target_weights(self):
        """Soft update the target networks weights."""
        new_target_critic_network1_weights = [
            self.tau * w_local + (1 - self.tau) * w_target
            for w_local, w_target in zip(self.critic_network1.get_weights(), self.target_critic_network1.get_weights())
        ]
        self.target_critic_network1.set_weights(new_target_critic_network1_weights)

        new_target_critic_network2_weights = [
            self.tau * w_local + (1 - self.tau) * w_target
            for w_local, w_target in zip(self.critic_network2.get_weights(), self.target_critic_network2.get_weights())
        ]
        self.target_critic_network2.set_weights(new_target_critic_network2_weights)

        new_target_actor_weights = [
            self.tau * w_local + (1 - self.tau) * w_target
            for w_local, w_target in zip(self.actor_network.get_weights(), self.target_actor_network.get_weights())
        ]
        self.target_actor_network.set_weights(new_target_actor_weights)

    def save_network_weights(self):
        """Save each network's weights."""
        self.critic_network1.save_weights("TD3 Models/critic_network1.weights.h5")
        self.critic_network2.save_weights("TD3 Models/critic_network2.weights.h5")
        self.actor_network.save_weights("TD3 Models/actor_network.weights.h5")

    def learn(self, n_episodes=1000, stop_after=10):

        episode_rewards = []
        step = 0 # Monitor time steps.

        for n in range(n_episodes):
            # Print episode number.
            print("Episode:", n + 1)

            # Reset environment.
            state, _ = self.env.reset()

            episode_reward = 0 # Monitor reward.

            while True:
                step += 1

                # Select action.
                action = self.select_action(state)

                # Take step.
                new_state, reward, terminal, truncated, _ = self.env.step(action)

                # Store transition.
                self.save_transition(state, action, reward, new_state, terminal)

                # Update episode reward
                episode_reward += reward

                if self.buffer_fullness >= self.minibatch_size and self.buffer_fullness >= self.warm_up:
                    # Sample minibatch.
                    minibatch = self.sample_minibatch()

                    # Update critic networks.
                    self.update_critic_networks(minibatch)

                    if step % self.policy_delay == 0:
                        # Update actor network.
                        self.update_actor_network(minibatch)

                        # Update target weights.
                        self.soft_update_target_weights()

                if terminal or truncated:
                    break

                state = new_state

            # Print and save episode reward.
            print("Episode reward:", episode_reward)
            episode_rewards.append(episode_reward)

            # Early stopping.
            if all(ep_rwd >= self.max_reward for ep_rwd in episode_rewards[-stop_after:]):
                break

        self.save_network_weights()  # Save weights.
        return episode_rewards

In [7]:
from keras.models import Sequential
from keras.layers import Input, Dense
from keras.initializers import RandomUniform


critic_network1 = Sequential([
    Input(shape=(28,)),
    Dense(256, activation="relu", kernel_initializer="he_uniform"),
    Dense(256, activation="relu", kernel_initializer="he_uniform"),
    Dense(1, activation="linear", kernel_initializer=RandomUniform(-0.003, 0.003))  # No activation.
])

critic_network2 = Sequential([
    Input(shape=(28,)),
    Dense(256, activation="relu", kernel_initializer="he_uniform"),
    Dense(256, activation="relu", kernel_initializer="he_uniform"),
    Dense(1, activation="linear", kernel_initializer=RandomUniform(-0.003, 0.003))  # No activation.
])

actor_network = Sequential([
    Input(shape=(24,)),
    Dense(256, activation="relu", kernel_initializer="he_uniform"),
    Dense(256, activation="relu", kernel_initializer="he_uniform"),
    Dense(4, activation="tanh", kernel_initializer=RandomUniform(-0.003, 0.003))  # Tanh to map outputs to [-1, 1].
])

In [9]:
env = gym.make("BipedalWalker-v3", hardcore=False, render_mode=None)

  from pkg_resources import resource_stream, resource_exists


In [None]:
# TODO: Dimension check on the tf.reduce_mean functions in update actor and critics (maybe should use keepdims=True)

In [10]:
test = TD3Agent(env, critic_network1, critic_network2, actor_network)
ers = test.learn()

Episode: 1
Episode reward: -93.891754
Episode: 2
Episode reward: -93.71252
Episode: 3
Episode reward: -118.1867
Episode: 4
Episode reward: -119.80537
Episode: 5


KeyboardInterrupt: 