In [1]:
import numpy as np
import tensorflow as tf
from keras.optimizers import Adam
from keras.models import clone_model
import gymnasium as gym
import tensorflow_probability as tfp






In [2]:
class SACAgent:

    def __init__(self,
                 env,
                 critic_network1,
                 critic_network2,
                 actor_network,
                 critic_learning_rate=1e-3,
                 actor_learning_rate=1e-3,
                 minibatch_size=100,
                 tau=0.005,
                 warm_up=1000,
                 alpha=0.05,
                 max_buffer_length=1_000_000):
        self.env = env
        self.state_size = env.observation_space.shape[0]
        self.action_size = env.action_space.shape[0]
        self.action_min, self.action_max = env.action_space.low, env.action_space.high

        # Initialize critic network 1 and target critic network 1.
        self.critic_network1 = clone_model(critic_network1)
        self.critic_network1.set_weights(critic_network1.get_weights())

        self.target_critic_network1 = clone_model(critic_network1)
        self.target_critic_network1.set_weights(critic_network1.get_weights())

        # Initialize critic network 2 and target critic network 2.
        self.critic_network2 = clone_model(critic_network2)
        self.critic_network2.set_weights(critic_network2.get_weights())

        self.target_critic_network2 = clone_model(critic_network2)
        self.target_critic_network2.set_weights(critic_network2.get_weights())

        # Initialize actor network.
        self.actor_network = clone_model(actor_network)
        self.actor_network.set_weights(actor_network.get_weights())

        # Initialize network optimizers.
        self.critic1_optimizer = Adam(learning_rate=critic_learning_rate)
        self.critic2_optimizer = Adam(learning_rate=critic_learning_rate)
        self.actor_optimizer = Adam(learning_rate=actor_learning_rate)

        # Initialize hyperparameters.
        self.tau = tau
        self.minibatch_size = minibatch_size
        self.warm_up = warm_up
        self.alpha = alpha

        # Initialize buffer.
        self.buffer_width = 2 * self.state_size + self.action_size + 2
        self.max_buffer_length = max_buffer_length
        self.replay_buffer = np.zeros((self.max_buffer_length, self.buffer_width), dtype=np.float32)
        self.buffer_write_idx = 0
        self.buffer_fullness = 0

        # Initialize minibatch slicers.
        self.state_slice = slice(0, self.state_size)
        self.state_action_slice = slice(0, self.state_size + self.action_size)
        self.reward_slice = slice(self.state_size + self.action_size, self.state_size + self.action_size + 1)
        self.next_state_slice = slice(self.state_size + self.action_size + 1, 2 * self.state_size + self.action_size + 1)
        self.terminal_slice = slice(2 * self.state_size + self.action_size + 1, 2 * self.state_size + self.action_size + 2)

    def select_action(self, state: np.ndarray) -> np.ndarray:
        """Select an action at the given state."""
        # Input check.
        assert state.ndim == 1 and state.shape[0] == 24
        # Do forward pass.
        mean, log_std = (output.numpy()[0] for output in self.actor_network(np.expand_dims(state, axis=0), training=False))
        std = np.exp(log_std)
        epsilon = np.random.normal(size=mean.shape)
        return np.tanh(mean + std * epsilon)

    def save_transition(self, state: np.ndarray, action: np.ndarray, reward: float, new_state: np.ndarray, terminal: bool):
        """Save a transition."""
        # Input check.
        assert state.ndim == 1 and state.shape[0] == self.state_size
        assert action.ndim == 1 and action.shape[0] == self.action_size
        assert new_state.ndim == 1 and new_state.shape[0] == self.state_size

        # Save transition.
        transition = np.concatenate((state, action, [reward], new_state, [1.0 if terminal else 0.0]), dtype=np.float32)
        self.replay_buffer[self.buffer_write_idx] = transition

        # Update write index and fullness.
        self.buffer_write_idx = (self.buffer_write_idx + 1) % self.max_buffer_length
        self.buffer_fullness = min(self.buffer_fullness + 1, self.max_buffer_length)

    def sample_minibatch(self) -> tf.Tensor:
        """Sample a minibatch from the replay buffer."""
        indices = np.random.choice(self.buffer_fullness, size=self.minibatch_size, replace=False)
        minibatch = self.replay_buffer[indices]
        return tf.convert_to_tensor(minibatch, dtype=tf.float32)

    #@tf.function
    def update_critic_networks(self, minibatch: tf.Tensor):
        """Update the critic networks."""
        mb_state_actions = minibatch[:, self.state_action_slice]
        mb_rewards = minibatch[:, self.reward_slice]
        mb_next_states = minibatch[:, self.next_state_slice]
        mb_terminals = minibatch[:, self.terminal_slice]

        mean, log_std = self.actor_network(mb_next_states, training=False)
        std = tf.exp(log_std)
        epsilon = tf.random.normal(shape=mean.shape)
        next_actions_raw = mean + epsilon * std
        next_actions = tf.tanh(next_actions_raw)

        normal = tfp.distributions.Normal(mean, std)
        log_prob_normal = tf.reduce_sum(normal.log_prob(next_actions_raw), axis=-1, keepdims=True)
        tanh_correction = tf.reduce_sum(tf.math.log(1 - tf.tanh(next_actions_raw)**2), axis=-1, keepdims=True)
        log_prob = log_prob_normal - tanh_correction

        next_state_actions = tf.concat((mb_next_states, next_actions), axis=1)
        q1_next = self.target_critic_network1(next_state_actions, training=False)
        q2_next = self.target_critic_network2(next_state_actions, training=False)
        q_min_next = tf.minimum(q1_next, q2_next)
        q_target = mb_rewards + (1 - mb_terminals) * (q_min_next - self.alpha * log_prob)
        print(q_target)

    @tf.function
    def update_actor_network(self, minibatch: tf.Tensor):
        pass

    def learn(self, n_episodes=1000):

        episode_rewards = []

        for n in range(n_episodes):
            # Print episode number.
            print("Episode:", n + 1)

            # Reset environment.
            state, _ = self.env.reset()

            episode_reward = 0 # Monitor reward.

            while True:
                # Select action.
                action = self.select_action(state)

                # Take step.
                new_state, reward, terminal, truncated, _ = self.env.step(action)

                # Store transition.
                self.save_transition(state, action, reward, new_state, terminal)

                # Update episode reward
                episode_reward += reward

                if self.buffer_fullness >= self.minibatch_size and self.buffer_fullness >= self.warm_up:
                    # Sample minibatch.
                    minibatch = self.sample_minibatch()

                    # Update critic networks.
                    self.update_critic_networks(minibatch)


In [3]:
from keras.models import Sequential, Model
from keras.layers import Input, Dense
from keras.initializers import RandomUniform

critic_network1 = Sequential([
    Input(shape=(28,)),
    Dense(256, activation="relu", kernel_initializer="he_uniform"),
    Dense(256, activation="relu", kernel_initializer="he_uniform"),
    Dense(1, activation="linear", kernel_initializer=RandomUniform(-0.003, 0.003))  # No activation.
])

critic_network2 = Sequential([
    Input(shape=(28,)),
    Dense(256, activation="relu", kernel_initializer="he_uniform"),
    Dense(256, activation="relu", kernel_initializer="he_uniform"),
    Dense(1, activation="linear", kernel_initializer=RandomUniform(-0.003, 0.003))  # No activation.
])

log_std_min = -20
log_std_max = 2

actor_input = Input(shape=(24,))
actor_fcl1 = Dense(256, activation="relu", kernel_initializer="he_uniform")(actor_input)
actor_fcl2 = Dense(256, activation="relu", kernel_initializer="he_uniform")(actor_fcl1)
mean_head = Dense(4, activation="linear", kernel_initializer="he_uniform")(actor_fcl2)
log_std_head = Dense(4, activation="linear", kernel_initializer="he_uniform")(actor_fcl2)
actor_network = Model(inputs=actor_input, outputs=[mean_head, log_std_head])

In [4]:
env = gym.make("BipedalWalker-v3", hardcore=False, render_mode=None)

  from pkg_resources import resource_stream, resource_exists


In [5]:
test = SACAgent(env, critic_network1, critic_network2, actor_network)

test.learn()

Episode: 1
log_prob_normal: (100, 1)
tanh_correction: (100, 1)
log_prob: (100, 1)
q_min_next: (100, 1)
tf.Tensor(
[[-1.0000000e+02]
 [-1.0000000e+02]
 [-1.0000000e+02]
 [ 7.3829532e-02]
 [-1.0000000e+02]
 [-1.0000000e+02]
 [-1.0000000e+02]
 [-1.0000000e+02]
 [-1.0000000e+02]
 [-1.0000000e+02]
 [-1.0000000e+02]
 [-1.0000000e+02]
 [-1.0000000e+02]
 [-1.0000000e+02]
 [-1.0000000e+02]
 [-1.0000000e+02]
 [-1.0000000e+02]
 [-1.0000000e+02]
 [-1.0000000e+02]
 [-1.0000000e+02]
 [-3.2360002e-02]
 [-1.0000000e+02]
 [-1.0000000e+02]
 [-1.0000000e+02]
 [-1.0000000e+02]
 [-1.0000000e+02]
 [-1.0000000e+02]
 [-1.0000000e+02]
 [-1.0000000e+02]
 [-1.0000000e+02]
 [-1.0000000e+02]
 [-1.0000000e+02]
 [-1.0000000e+02]
 [-1.0000000e+02]
 [ 1.5363693e-01]
 [-1.0000000e+02]
 [-1.0000000e+02]
 [-1.0000000e+02]
 [-1.0000000e+02]
 [-1.0000000e+02]
 [-1.0000000e+02]
 [-1.0000000e+02]
 [-1.0000000e+02]
 [-1.0000000e+02]
 [-1.0000000e+02]
 [-1.0000000e+02]
 [-1.0000000e+02]
 [-1.0000000e+02]
 [-1.0000000e+02]
 [-1

KeyboardInterrupt: 