# SAC (Soft Actor-Critic) Implementation

In [36]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import copy
import gymnasium as gym
import pickle
import time

In [37]:
class SACAgent:

    def __init__(self,
                 env_id,
                 env_hardcore,
                 critic_network1,
                 critic_network2,
                 actor_network,
                 critic_learning_rate=3e-4,
                 actor_learning_rate=3e-4,
                 alpha_learning_rate=3e-4,
                 log_alpha_start=0.0,
                 discount_factor=0.99,
                 minibatch_size=256,
                 tau=0.005,
                 random_exploration_steps=10_000,
                 actor_exploration_steps=1000,
                 critic_gradient_clip=1.0,
                 actor_gradient_clip=1.0,
                 alpha_gradient_clip=0.0,
                 updates_per_step=1,
                 max_buffer_length=1_000_000):
        # CPU or GPU?
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # Environment.
        self.env_id = env_id
        self.env_hardcore = env_hardcore
        self.env = gym.make(id=env_id, hardcore=env_hardcore, render_mode=None)
        self.state_size = self.env.observation_space.shape[0]
        self.action_size = self.env.action_space.shape[0]
        self.action_min = torch.tensor(self.env.action_space.low, device=self.device)
        self.action_max = torch.tensor(self.env.action_space.high, device=self.device)
        self.max_reward = self.env.spec.reward_threshold

        # Initialize critic network 1 and target critic network 1.
        self.critic_network1 = copy.deepcopy(critic_network1)
        self.target_critic_network1 = copy.deepcopy(critic_network1)

        # Initialize critic network 2 and target critic network 2.
        self.critic_network2 = copy.deepcopy(critic_network2)
        self.target_critic_network2 = copy.deepcopy(critic_network2)

        # Initialize actor network and target actor network.
        self.actor_network = copy.deepcopy(actor_network)

        # Move networks to correct device.
        self.actor_network.to(self.device)
        self.critic_network1.to(self.device)
        self.critic_network2.to(self.device)
        self.target_critic_network1.to(self.device)
        self.target_critic_network2.to(self.device)

        # Initialize entropy.
        self.target_entropy = -self.action_size
        self.log_alpha_start = log_alpha_start
        self.log_alpha = nn.Parameter(torch.tensor(log_alpha_start, requires_grad=True, device=self.device))

        # Initialize optimizers.
        self.critic_learning_rate = critic_learning_rate
        self.actor_learning_rate = actor_learning_rate
        self.alpha_learning_rate = alpha_learning_rate
        self.critic1_optimizer = optim.Adam(self.critic_network1.parameters(), lr=critic_learning_rate)
        self.critic2_optimizer = optim.Adam(self.critic_network2.parameters(), lr=critic_learning_rate)
        self.actor_optimizer = optim.Adam(self.actor_network.parameters(), lr=actor_learning_rate)
        self.alpha_optimizer = optim.Adam([self.log_alpha], lr=alpha_learning_rate)

        # Initialize hyperparameters.
        self.tau = tau
        self.minibatch_size = minibatch_size
        self.discount_factor = discount_factor
        self.random_exploration_steps = random_exploration_steps
        self.actor_exploration_steps = actor_exploration_steps
        self.critic_gradient_clip = critic_gradient_clip
        self.actor_gradient_clip = actor_gradient_clip
        self.alpha_gradient_clip = alpha_gradient_clip
        self.updates_per_step = updates_per_step

        # Initialize buffer.
        self.buffer_width = 2 * self.state_size + self.action_size + 2
        self.max_buffer_length = max_buffer_length
        self.buffer_write_idx = 0
        self.buffer_fullness = 0
        self.buffer_states = torch.zeros((self.max_buffer_length, self.state_size),
                                         dtype=torch.float32,
                                         device=self.device)
        self.buffer_actions = torch.zeros((self.max_buffer_length, self.action_size),
                                          dtype=torch.float32,
                                          device=self.device)
        self.buffer_rewards = torch.zeros((self.max_buffer_length, 1),
                                          dtype=torch.float32,
                                          device=self.device)
        self.buffer_next_states = torch.zeros((self.max_buffer_length, self.state_size),
                                              dtype=torch.float32,
                                              device=self.device)
        self.buffer_terminals = torch.zeros((self.max_buffer_length, 1),
                                            dtype=torch.float32,
                                            device=self.device)

    @property
    def alpha(self):
        return self.log_alpha.exp()

    def select_action(self, state: np.ndarray) -> np.ndarray:
        """Select an action at the given state."""
        self.actor_network.eval()
        with torch.no_grad():
            state_tensor = torch.tensor(state, dtype=torch.float32, device=self.device).unsqueeze(0)
            action_tensor, _ = self.actor_network.sample(state_tensor)
            return action_tensor.squeeze(0).cpu().numpy()

    def save_transition(self, state: np.ndarray, action: np.ndarray, reward: float, new_state: np.ndarray, terminal: bool):
        """Save a transition."""
        self.buffer_states[self.buffer_write_idx] = torch.tensor(state, dtype=torch.float32, device=self.device)
        self.buffer_actions[self.buffer_write_idx] = torch.tensor(action, dtype=torch.float32, device=self.device)
        self.buffer_rewards[self.buffer_write_idx] = torch.tensor([reward], dtype=torch.float32, device=self.device)
        self.buffer_next_states[self.buffer_write_idx] = torch.tensor(new_state, dtype=torch.float32, device=self.device)
        self.buffer_terminals[self.buffer_write_idx] = torch.tensor([1.0 if terminal else 0.0],
                                                                    dtype=torch.float32, device=self.device)

        self.buffer_write_idx = (self.buffer_write_idx + 1) % self.max_buffer_length
        self.buffer_fullness = min(self.buffer_fullness + 1, self.max_buffer_length)

    def sample_minibatch(self):
        """Sample a minibatch from the replay buffer."""
        indices = torch.randint(0, self.buffer_fullness, (self.minibatch_size,), device=self.device)

        mb_states = self.buffer_states[indices]
        mb_actions = self.buffer_actions[indices]
        mb_rewards = self.buffer_rewards[indices]
        mb_next_states = self.buffer_next_states[indices]
        mb_terminals = self.buffer_terminals[indices]

        return mb_states, mb_actions, mb_rewards, mb_next_states, mb_terminals

    def update_critic_networks(self, minibatch):
        """Update the critic networks."""
        mb_states, mb_actions, mb_rewards, mb_next_states, mb_terminals = minibatch
        mb_state_actions = torch.cat([mb_states, mb_actions], dim=1)

        self.actor_network.eval()
        self.target_critic_network1.eval()
        self.target_critic_network2.eval()
        self.critic_network1.train()
        self.critic_network2.train()

        with torch.no_grad():
            next_actions, logp = self.actor_network.sample(mb_next_states)
            next_state_actions = torch.cat((mb_next_states, next_actions), dim=1)
            q1_next = self.target_critic_network1(next_state_actions)
            q2_next = self.target_critic_network2(next_state_actions)
            q_min_next = torch.min(q1_next, q2_next)
            q_target = mb_rewards + self.discount_factor * (1 - mb_terminals) * (q_min_next - self.alpha * logp)

        # Critic network 1 update.
        with torch.set_grad_enabled(True):
            q_expected = self.critic_network1(mb_state_actions)
            critic1_loss = torch.mean((q_target - q_expected) ** 2)

        self.critic1_optimizer.zero_grad()
        critic1_loss.backward()
        if self.critic_gradient_clip > 0.0:
            nn.utils.clip_grad_norm_(self.critic_network1.parameters(), self.critic_gradient_clip)
        self.critic1_optimizer.step()

        # Critic network 2 update.
        with torch.set_grad_enabled(True):
            q_expected = self.critic_network2(mb_state_actions)
            critic2_loss = torch.mean((q_target - q_expected) ** 2)

        self.critic2_optimizer.zero_grad()
        critic2_loss.backward()
        if self.critic_gradient_clip > 0.0:
            nn.utils.clip_grad_norm_(self.critic_network2.parameters(), self.critic_gradient_clip)
        self.critic2_optimizer.step()

    def update_actor_network_and_alpha(self, minibatch):
        """Update the actor network."""
        mb_states, *_ = minibatch

        self.actor_network.train()
        self.critic_network1.eval()
        self.critic_network2.eval()

        with torch.set_grad_enabled(True):
            pred_actions, logp = self.actor_network.sample(mb_states)
            pred_state_actions = torch.cat((mb_states, pred_actions), dim=1)
            q1_pred = self.critic_network1(pred_state_actions)
            q2_pred = self.critic_network2(pred_state_actions)
            q_min_pred = torch.min(q1_pred, q2_pred)
            actor_loss = -(q_min_pred - self.alpha * logp).mean()
            alpha_loss = -(self.log_alpha * (logp + self.target_entropy).detach()).mean()

        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        if self.actor_gradient_clip > 0.0:
            nn.utils.clip_grad_norm_(self.actor_network.parameters(), self.actor_gradient_clip)
        self.actor_optimizer.step()

        self.alpha_optimizer.zero_grad()
        alpha_loss.backward()
        if self.alpha_gradient_clip > 0.0:
            nn.utils.clip_grad_norm_([self.log_alpha], self.alpha_gradient_clip)
        self.alpha_optimizer.step()

    def soft_update_target_critics(self):
        """Soft update the target networks weights."""
        with torch.no_grad():
            for w_target, w_local in zip(self.target_critic_network1.parameters(), self.critic_network1.parameters()):
                w_target.data.copy_(self.tau * w_local.data + (1 - self.tau) * w_target.data)

            for w_target, w_local in zip(self.target_critic_network2.parameters(), self.critic_network2.parameters()):
                w_target.data.copy_(self.tau * w_local.data + (1 - self.tau) * w_target.data)

    def get_settings(self, n_episodes, stop_after):
        return {
            "env_id": self.env_id,
            "env_hardcore": self.env_hardcore,
            "critic_learning_rate": self.critic_learning_rate,
            "actor_learning_rate": self.actor_learning_rate,
            "alpha_learning_rate": self.alpha_learning_rate,
            "log_alpha_start": self.log_alpha_start,
            "discount_factor": self.discount_factor,
            "minibatch_size": self.minibatch_size,
            "tau": self.tau,
            "random_exploration_steps": self.random_exploration_steps,
            "actor_exploration_steps": self.actor_exploration_steps,
            "critic_gradient_clip": self.critic_gradient_clip,
            "actor_gradient_clip": self.actor_gradient_clip,
            "alpha_gradient_clip": self.alpha_gradient_clip,
            "updates_per_step": self.updates_per_step,
            "max_buffer_length": self.max_buffer_length,
            "n_episodes": n_episodes,
            "stop_after": stop_after
        }

    def save_outputs(self, episode_rewards, episode_step_counts, episode_run_times, episode_alphas, n_episodes, stop_after):
        """Save each network's weights and episode rewards."""
        torch.save(self.critic_network1.state_dict(), "critic_network1.pth")
        torch.save(self.critic_network2.state_dict(), "critic_network2.pth")
        torch.save(self.actor_network.state_dict(), "actor_network.pth")
        with open("episode_rewards.pkl", "wb") as fp:
            pickle.dump(episode_rewards, fp)
        with open("episode_step_counts.pkl", "wb") as fp:
            pickle.dump(episode_step_counts, fp)
        with open("episode_run_times.pkl", "wb") as fp:
            pickle.dump(episode_run_times, fp)
        with open("episode_alphas.pkl", "wb") as fp:
            pickle.dump(episode_alphas, fp)
        with open("settings.pkl", "wb") as fp:
            pickle.dump(self.get_settings(n_episodes, stop_after), fp)

    def show_test_episode(self):
        """Do a visual test run."""
        print("\n========TEST RUN========")
        test_env = gym.make(id=self.env_id, hardcore=self.env_hardcore, render_mode="human")
        s, _ = test_env.reset()
        test_episode_reward = 0
        test_episode_step_count = 0
        test_start_time = time.time()

        while True:
            a = self.select_action(s)
            s_, r, terminal, truncated, _ = test_env.step(a)
            test_episode_reward += r
            test_episode_step_count += 1

            if terminal or truncated:
                break

            s = s_

        test_episode_end_time = time.time()
        test_episode_run_time = test_episode_end_time - test_start_time
        test_env.close()

        print(f"Reward: {test_episode_reward:.2f} - Step Count: {test_episode_step_count} - Run Time: {test_episode_run_time:.2f}s\n")

    def random_exploration(self):
        print("Performing Random Exploration...")
        step_count = 0
        temp_env = gym.make(id=self.env_id, hardcore=self.env_hardcore, render_mode=None)
        state, _ = temp_env.reset()

        while step_count < self.random_exploration_steps:

            action_tensor = self.action_min + (self.action_max - self.action_min) * torch.rand(self.action_size, device=self.device)
            action = action_tensor.cpu().numpy()
            new_state, reward, terminal, truncated, _ = temp_env.step(action)
            self.save_transition(state, action, reward, new_state, terminal)

            step_count += 1
            state = new_state

            if terminal or truncated:
                state, _ = temp_env.reset()

        temp_env.close()
        print("Random Exploration Complete.")

    def actor_exploration(self):
        print("Performing Actor Exploration...")
        step_count = 0
        temp_env = gym.make(id=self.env_id, hardcore=self.env_hardcore, render_mode=None)
        state, _ = temp_env.reset()

        while step_count < self.actor_exploration_steps:

            action = self.select_action(state)
            new_state, reward, terminal, truncated, _ = temp_env.step(action)
            self.save_transition(state, action, reward, new_state, terminal)

            step_count += 1
            state = new_state

            if terminal or truncated:
                state, _ = temp_env.reset()

        temp_env.close()
        print("Actor Exploration Complete.")

    def learn(self, n_episodes=2000, display_every=50, stop_after=None):

        self.random_exploration()
        self.actor_exploration()

        episode_rewards = []
        episode_step_counts = []
        episode_run_times = []
        episode_alphas = []

        total_step_count = 0

        for n in range(n_episodes):
            # Print episode number.
            print(f"Running Episode {n + 1}...")

            start_time = time.time()
            state, _ = self.env.reset()

            episode_reward = 0
            episode_step_count = 0

            while True:
                # Select action and take step.
                action = self.select_action(state)
                new_state, reward, terminal, truncated, _ = self.env.step(action)

                # Store transition.
                self.save_transition(state, action, reward, new_state, terminal)

                # Update step count.
                episode_step_count += 1
                total_step_count += 1

                # Update episode reward.
                episode_reward += reward

                if self.buffer_fullness >= self.minibatch_size:

                    for _ in range(self.updates_per_step):

                        minibatch = self.sample_minibatch()
                        self.update_critic_networks(minibatch)
                        self.update_actor_network_and_alpha(minibatch)
                        self.soft_update_target_critics()

                if terminal or truncated:
                    break

                state = new_state

            end_time = time.time()
            episode_run_time = end_time - start_time

            # Print and save episode reward.
            print(f"Reward: {episode_reward:.2f} - Step Count: {episode_step_count} - Run Time: {episode_run_time:.2f}s - Alpha: {self.alpha:.5f}")
            episode_rewards.append(episode_reward)
            episode_step_counts.append(episode_step_count)
            episode_run_times.append(episode_run_time)
            episode_alphas.append(self.alpha.item())

            # Early stopping.
            if stop_after is not None and all(ep_rew >= self.max_reward for ep_rew in episode_rewards[-stop_after:]):
                break

            if n % display_every == 0:
                self.show_test_episode()

        self.save_outputs(episode_rewards, episode_step_counts, episode_run_times, episode_alphas, n_episodes, stop_after)

In [38]:
class Critic(nn.Module):

    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(28, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )

    def forward(self, x):
        return self.net(x)

In [39]:
class Actor(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(24, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
        )
        self.mean_linear = nn.Linear(256, 4)
        self.log_std_linear = nn.Linear(256, 4)

        self.log_std_min = -20
        self.log_std_max = 2

    def forward(self, x):
        x = self.net(x)
        mean = self.mean_linear(x)
        log_std = self.log_std_linear(x)
        log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
        return mean, log_std

    def sample(self, x):
        mean, log_std = self.forward(x)
        std = log_std.exp()
        normal = torch.distributions.Normal(mean, std)
        a_raw = normal.rsample()
        action = torch.tanh(a_raw)

        log_prob = normal.log_prob(a_raw) - torch.log(1 - action ** 2 + 1e-6)
        log_prob = log_prob.sum(dim=-1, keepdim=True)
        return action, log_prob

# Training - Stage 1 - Complete Normal Mode

In [None]:
# Train robot to complete normal.

actor = Actor()
critic1 = Critic()
critic2 = Critic()

agent = SACAgent(env_id="BipedalWalker-v3",
                 env_hardcore=False,
                 critic_network1=critic1,
                 critic_network2=critic2,
                 actor_network=actor)

agent.learn(n_episodes=2000, stop_after=10)

In [5]:
# Print hyperparameters used.

with open("stage 1/settings.pkl", "rb") as f:
    stage1_settings = pickle.load(f)

print(stage1_settings)

{'env_id': 'BipedalWalker-v3', 'env_hardcore': False, 'critic_learning_rate': 0.0003, 'actor_learning_rate': 0.0003, 'alpha_learning_rate': 0.0003, 'log_alpha_start': 0.0, 'discount_factor': 0.99, 'minibatch_size': 256, 'tau': 0.005, 'random_exploration_steps': 10000, 'actor_exploration_steps': 1000, 'critic_gradient_clip': 1.0, 'actor_gradient_clip': 1.0, 'alpha_gradient_clip': 0.0, 'updates_per_step': 1, 'max_buffer_length': 250000, 'n_episodes': 2000, 'stop_after': 10}


In [7]:
# Watch an episode of trained robot.

trained_actor = Actor()
trained_actor.load_state_dict(torch.load("stage 1/actor_network.pth", map_location=torch.device("cpu")))

critic1 = Critic()
critic2 = Critic()

agent = SACAgent(env_id="BipedalWalker-v3",
                 env_hardcore=False,
                 critic_network1=critic1,
                 critic_network2=critic2,
                 actor_network=trained_actor)

agent.show_test_episode()


Reward: 304.69 - Step Count: 958 - Run Time: 21.22s



# Training - Stage 2 - Faster in Normal Mode

In [None]:
# Train robot to go faster in normal.

stage1_actor = Actor()
stage1_actor.load_state_dict(torch.load("stage 1/actor_network.pth"))

stage1_critic1 = Critic()
stage1_critic1.load_state_dict(torch.load("stage 1/critic_network1.pth"))

stage1_critic2 = Critic()
stage1_critic2.load_state_dict(torch.load("stage 1/critic_network2.pth"))

agent = SACAgent(env_id="BipedalWalker-v3",
                 env_hardcore=False,
                 critic_network1=stage1_critic1,
                 critic_network2=stage1_critic2,
                 actor_network=stage1_actor,
                 log_alpha_start=-4.605)

agent.learn(n_episodes=300, stop_after=None)

In [10]:
# Print hyperparameters used.

with open("stage 2/settings.pkl", "rb") as f:
    stage2_settings = pickle.load(f)

print(stage2_settings)

{'env_id': 'BipedalWalker-v3', 'env_hardcore': False, 'critic_learning_rate': 0.0003, 'actor_learning_rate': 0.0003, 'alpha_learning_rate': 0.0003, 'log_alpha_start': -4.605, 'discount_factor': 0.99, 'minibatch_size': 256, 'tau': 0.005, 'random_exploration_steps': 1000, 'actor_exploration_steps': 10000, 'critic_gradient_clip': 1.0, 'actor_gradient_clip': 1.0, 'alpha_gradient_clip': 0.0, 'updates_per_step': 1, 'max_buffer_length': 250000, 'n_episodes': 300, 'stop_after': None}


In [11]:
# Watch an episode of trained robot.

trained_actor = Actor()
trained_actor.load_state_dict(torch.load("stage 2/actor_network.pth", map_location=torch.device("cpu")))

critic1 = Critic()
critic2 = Critic()

agent = SACAgent(env_id="BipedalWalker-v3",
                 env_hardcore=False,
                 critic_network1=critic1,
                 critic_network2=critic2,
                 actor_network=trained_actor)

agent.show_test_episode()


Reward: 322.44 - Step Count: 736 - Run Time: 15.10s



# Training - Stage 3 - Complete Hardcore Mode

In [None]:
# Train robot to complete hardcore.

stage2_actor = Actor()
stage2_actor.load_state_dict(torch.load("stage 2/actor_network.pth"))

stage2_critic1 = Critic()
stage2_critic1.load_state_dict(torch.load("stage 2/critic_network1.pth"))

stage2_critic2 = Critic()
stage2_critic2.load_state_dict(torch.load("stage 2/critic_network2.pth"))

agent = SACAgent(env_id="BipedalWalker-v3",
                 env_hardcore=True,
                 critic_network1=stage2_critic1,
                 critic_network2=stage2_critic2,
                 actor_network=stage2_actor,
                 log_alpha_start=-4.605)

agent.learn(n_episodes=2000, stop_after=3)

In [13]:
# Print hyperparameters used.

with open("stage 3/settings.pkl", "rb") as f:
    stage3_settings = pickle.load(f)

print(stage3_settings)

{'env_id': 'BipedalWalker-v3', 'env_hardcore': True, 'critic_learning_rate': 0.0003, 'actor_learning_rate': 0.0003, 'alpha_learning_rate': 0.0003, 'log_alpha_start': -4.605, 'discount_factor': 0.99, 'minibatch_size': 256, 'tau': 0.005, 'random_exploration_steps': 1000, 'actor_exploration_steps': 10000, 'critic_gradient_clip': 1.0, 'actor_gradient_clip': 1.0, 'alpha_gradient_clip': 0.0, 'updates_per_step': 1, 'max_buffer_length': 250000, 'n_episodes': 2000, 'stop_after': 3}


In [14]:
trained_actor = Actor()
trained_actor.load_state_dict(torch.load("stage 3/actor_network.pth", map_location=torch.device("cpu")))

critic1 = Critic()
critic2 = Critic()

agent = SACAgent(env_id="BipedalWalker-v3",
                 env_hardcore=True,
                 critic_network1=critic1,
                 critic_network2=critic2,
                 actor_network=trained_actor)

agent.show_test_episode()


Reward: 180.33 - Step Count: 1600 - Run Time: 33.05s

