# Lab 14: Soft Actor-Critic (SAC) on HalfCheetah-v4

In this lab, we will move from **TD3 (deterministic policy)** to **Soft Actor-Critic (SAC)**, which is one of the most important **modern off-policy stochastic actor-critic algorithms**.

Unlike TD3, SAC does **not** learn a single deterministic action. Instead, it learns a **probability distribution over actions**, and explicitly encourages **exploration via entropy maximization**.


## Key Idea of SAC

SAC optimizes the following objective:

$$
\max_\pi \; \mathbb{E}\left[ \sum_t r(s_t, a_t) + \alpha \mathcal{H}(\pi(\cdot | s_t)) \right]
$$

where:
- $ r(s, a)$: task reward  
- $ \mathcal{H}(\pi) $: entropy of the policy  
- $ \alpha $: temperature coefficient (controls exploration strength)  

✅ High entropy ⇒ more exploration  
✅ Low entropy ⇒ more exploitation  

---

### From the Maximum-Entropy Objective to the Actor Loss

The entropy term can be written as:

$$
\mathcal{H}(\pi(\cdot|s)) 
= -\mathbb{E}_{a \sim \pi}[\log \pi(a|s)]
$$

Substituting it into the objective leads to the following policy optimization problem:

$$
\max_\pi \;
\mathbb{E}_{s \sim D,\, a \sim \pi}
\left[
Q(s,a) - \alpha \log \pi(a|s)
\right]
$$

For gradient-based optimization, this is implemented in its **minimization form**, which gives the **actor loss used in SAC**:

$$
\boxed{
\mathcal{L}_{\text{actor}}
=
\mathbb{E}
\left[
\alpha \log \pi(a|s) - Q(s,a)
\right]
}
$$

This loss explicitly shows the **trade-off between**:
- maximizing long-term reward $Q(s,a)$, and  
- maintaining sufficient exploration through the entropy term $\alpha \log \pi(a|s)$.


In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
import gymnasium as gym

## 1. Gaussian Policy in Soft Actor-Critic (SAC)

Unlike TD3, which uses a **deterministic policy**  
$$
a = \pi_\theta(s)
$$
SAC uses a **stochastic policy**, where actions are sampled from a **Gaussian distribution**:

$$
a \sim \mathcal{N}(\mu_\theta(s), \sigma_\theta(s))
$$

This means the policy does not output a single fixed action, but a **probability distribution over actions**.

For each state $s$, the policy network outputs two vectors:

- **Mean**: $\mu_\theta(s)$
- **Log standard deviation**: $\log \sigma_\theta(s) $

So the policy represents the distribution:

$$
\pi(a|s) = \mathcal{N}(\mu_\theta(s), \sigma_\theta(s))
$$

The log standard deviation is clipped to avoid:
- extremely large variance
- numerical instability
- exploding gradients

To enable backpropagation through a random sample, SAC uses the **reparameterization trick**:

$$
u = \mu_\theta(s) + \sigma_\theta(s) \cdot \epsilon,
\quad \epsilon \sim \mathcal{N}(0, I)
$$

This transforms a random sampling process into a **deterministic function with random noise**, allowing gradients to flow through the policy parameters.

Most continuous-control environments require bounded actions:

$$
a \in [-1, 1]
$$

So the sampled value $ u $ is passed through:

$$
a = \tanh(u)
$$

and then rescaled:

$$
a_{\text{env}} = a \cdot a_{\max}
$$

This guarantees that all actions sent to the environment are valid.

Because the action is transformed using `tanh`, the log-probability must be corrected using the **change-of-variables rule**:

$$
\log \pi(a|s)
=
\log \mathcal{N}(u; \mu, \sigma)
-
\log(1 - \tanh^2(u))
$$

This corrected log-probability is **essential** for:

- The critic target update
- The actor (policy) update
- The temperature parameter $ \alpha $ update

Without this correction, the entropy estimate would be wrong and training would become unstable.

---


During evaluation, SAC does **not use random sampling**.  
Instead, it uses the **mean action**:

$$
a_{\text{eval}} = \tanh(\mu_\theta(s)) \cdot a_{\max}
$$

This ensures:
- stable test performance
- reproducible results
- no randomness during evaluation

---

####  What the Gaussian Policy Produces

For each input state \( s \), the Gaussian policy provides:

| Quantity | Mathematical Meaning | Purpose |
|----------|------------------------|---------|
| $ a $ | sampled action | environment interaction |
| $ \log \pi(a|s) $ | log-probability | critic target, actor loss, α update |
| $ \tanh(\mu) $ | mean (deterministic) action | evaluation |



In [None]:
class GaussianPolicy(nn.Module):
    def __init__(self, state_dim: int, action_dim: int, max_action: float):
        super().__init__()

        self.max_action = max_action
        self.fc1 = nn.Linear(state_dim, 256)
        self.fc2 = nn.Linear(256, 256)

        self.mean_layer = nn.Linear(256, action_dim)
        self.log_std_layer = nn.Linear(256, action_dim)

        # For clamping log_std
        self.LOG_STD_MIN = -20
        self.LOG_STD_MAX = 2

    def forward(self, state: torch.Tensor):
        x = F.relu(self.fc1(state))
        x = F.relu(self.fc2(x))

        mean = self.mean_layer(x)
        log_std = self.log_std_layer(x)
        log_std = torch.clamp(log_std, self.LOG_STD_MIN, self.LOG_STD_MAX)

        return mean, log_std

    def sample(self, state: torch.Tensor):
        mean, log_std = self.forward(state)
        std = log_std.exp()
        
        normal = Normal(mean, std)
        u = normal.rsample()

        tanh_u = torch.tanh(u)
        action = tanh_u * self.max_action

        log_prob = normal.log_prob(u)
        
        epsilon = 1e-6
        log_prob -= torch.log(1 - tanh_u.pow(2) + epsilon)
        log_prob = log_prob.sum(dim=-1, keepdim=True)

        mean_action = torch.tanh(mean) * self.max_action

        return action, log_prob, mean_action

In [None]:
class Critic(nn.Module):
    def __init__(self, state_dim: int, action_dim: int):
        super().__init__()

        # Q1 architecture
        self.q1_fc1 = nn.Linear(state_dim + action_dim, 256)
        self.q1_fc2 = nn.Linear(256, 256)
        self.q1_out = nn.Linear(256, 1)

        # Q2 architecture
        self.q2_fc1 = nn.Linear(state_dim + action_dim, 256)
        self.q2_fc2 = nn.Linear(256, 256)
        self.q2_out = nn.Linear(256, 1)

    def forward(self, state: torch.Tensor, action: torch.Tensor):
        xu = torch.cat([state, action], dim=-1)  # (batch, state_dim + action_dim)

        # Q1
        x1 = F.relu(self.q1_fc1(xu))
        x1 = F.relu(self.q1_fc2(x1))
        q1 = self.q1_out(x1)

        # Q2
        x2 = F.relu(self.q2_fc1(xu))
        x2 = F.relu(self.q2_fc2(x2))
        q2 = self.q2_out(x2)

        return q1, q2

In [None]:
class ReplayBuffer:
    def __init__(self, state_dim, action_dim, size):
        self.size = size
        self.ptr = 0
        self.full = False

        self.states = np.zeros((size, state_dim), dtype=np.float32)
        self.actions = np.zeros((size, action_dim), dtype=np.float32)
        self.rewards = np.zeros((size,), dtype=np.float32)
        self.next_states = np.zeros((size, state_dim), dtype=np.float32)
        self.dones = np.zeros((size,), dtype=np.float32)

    def add(self, state, action, reward, next_state, done):
        self.states[self.ptr] = state
        self.actions[self.ptr] = action
        self.rewards[self.ptr] = reward
        self.next_states[self.ptr] = next_state
        self.dones[self.ptr] = done

        self.ptr += 1
        if self.ptr >= self.size:
            self.ptr = 0
            self.full = True

    def __len__(self):
        return self.size if self.full else self.ptr

    def sample(self, batch_size, device):
        max_size = self.size if self.full else self.ptr
        idx = np.random.randint(0, max_size, size=batch_size)

        states = torch.as_tensor(self.states[idx], dtype=torch.float32, device=device)
        actions = torch.as_tensor(self.actions[idx], dtype=torch.float32, device=device)
        rewards = torch.as_tensor(self.rewards[idx], dtype=torch.float32, device=device).unsqueeze(-1)
        next_states = torch.as_tensor(self.next_states[idx], dtype=torch.float32, device=device)
        dones = torch.as_tensor(self.dones[idx], dtype=torch.float32, device=device).unsqueeze(-1)

        return states, actions, rewards, next_states, dones


def soft_update(source: nn.Module, target: nn.Module, tau: float):
    with torch.no_grad():
        for param, target_param in zip(source.parameters(), target.parameters()):
            target_param.data.copy_(tau * param.data + (1.0 - tau) * target_param.data)


In [None]:
env_name = "HalfCheetah-v4"
seed = 0

# ==========================
#  Hyperparameters
# ==========================
replay_size = int(1e6)
gamma = 0.99
tau = 0.005
lr = 3e-4
batch_size = 256

# Number of steps purely using random actions at the beginning
start_steps = 10_000
# Total environment interaction steps
max_steps = 300_000
# Start learning only after some initial experience
update_after = 1_000
# Number of gradient updates per environment step
update_every = 1
# Interval to run evaluation
eval_interval = 5_000

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
env = gym.make(env_name)
np.random.seed(seed)
torch.manual_seed(seed)
env.reset(seed=seed)

state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
max_action = float(env.action_space.high[0])

print("State dim:", state_dim)
print("Action dim:", action_dim)
print("Max action:", max_action)

In [None]:
policy = GaussianPolicy(state_dim, action_dim, max_action).to(device)
critic = Critic(state_dim, action_dim).to(device)
critic_target = Critic(state_dim, action_dim).to(device)

critic_target.load_state_dict(critic.state_dict())

policy_optimizer = torch.optim.Adam(policy.parameters(), lr=lr)
critic_optimizer = torch.optim.Adam(critic.parameters(), lr=lr)

log_alpha = torch.zeros(1, requires_grad=True, device=device)
alpha_optimizer = torch.optim.Adam([log_alpha], lr=lr)

target_entropy = -float(action_dim)

replay_buffer = ReplayBuffer(state_dim, action_dim, replay_size)

state, _ = env.reset(seed=seed)
episode_return = 0.0
episode_length = 0
episode_idx = 0

## Hints for the SAC Update Block 


---


#### Sampling the Next Action from the Current Policy  
(used for the critic target)

Purpose:
- Generate the next action according to the **current stochastic policy**
- Also obtain its **log-probability**, which is required by the entropy term

Mathematically:
$$
a_{t+1} \sim \pi(\cdot | s_{t+1}), 
\quad \log \pi(a_{t+1} | s_{t+1})
$$

This corresponds to the line:
- `next_actions, next_log_pi, _ = ?`


---



#### Evaluating the Target Critic on the Next State–Action Pair

Purpose:
- Evaluate the value of the **next state and next action**
- Use the **target critic networks**, not the online critics

Mathematically:
$$
Q_1'(s_{t+1}, a_{t+1}), \quad Q_2'(s_{t+1}, a_{t+1})
$$

This corresponds to the line:
- `q1_next_target, q2_next_target = ?`


#### Taking the Minimum of the Twin Target Q-Values

Purpose:
- Prevent overestimation bias
- Follow the TD3-style minimum trick

$$
Q_{\text{min}}'(s_{t+1}, a_{t+1})
=
\min(Q_1', Q_2')
$$

This corresponds to the line:
- `q_next_target = ?`


---



#### Computing the Soft Bellman Target for the Critic

Purpose:
- Build the **entropy-regularized TD target**

$$
y_t
=
r_t
+
(1 - d_t)\,\gamma
\left(
Q_{\text{min}}'(s_{t+1}, a_{t+1})
-
\alpha \log \pi(a_{t+1} | s_{t+1})
\right)
$$

Key components:
- Immediate reward: $ r_t $
- Discount factor: $ \gamma $
- Done mask: $ (1 - d_t) $
- Entropy penalty: $\alpha \log \pi $

This corresponds to the line:
- `target_q = ?`


---



#### Computing the Current Critic Predictions

Purpose:
- Evaluate the current critic on **replay-buffer actions**

Mathematically:
$$
Q_1(s_t, a_t), \quad Q_2(s_t, a_t)
$$

These are used to regress toward the TD target.

This corresponds to the line:
- `q1_pred, q2_pred = ?`

#### Critic Loss: Mean Squared Bellman Error

Purpose:
- Force both critics to match the entropy-regularized TD target

$$
\mathcal{L}_{\text{critic}}
=
\| Q_1(s,a) - y \|^2
+
\| Q_2(s,a) - y \|^2
$$

This corresponds to the line:
- `critic_loss = ?`


---



#### Evaluating the Critic on Newly Sampled Actions (Actor Update)

Purpose:
- Evaluate **how good the current policy’s own actions are**
- These actions are **not from the replay buffer**

Mathematically:
$$
a \sim \pi(\cdot | s)
$$
$$
Q_1(s, a), \quad Q_2(s, a)
$$

This corresponds to the line:
- `q1_new, q2_new = ?`

---

#### Taking the Minimum Q-Value for Policy Optimization

Purpose:
- Follow the conservative evaluation principle

Mathematically:
$$
Q_{\text{min}}(s, a) = \min(Q_1, Q_2)
$$

This corresponds to the line:
- `q_new = ?`

---

#### Actor Loss: Maximum-Entropy Policy Objective

Purpose:
- Optimize the policy for **both high reward and high entropy**

Mathematically:
$$
\mathcal{L}_{\text{actor}}
=
\mathbb{E}
\left[
\alpha \log \pi(a|s) - Q_{\text{min}}(s,a)
\right]
$$

Interpretation:
- The $ -Q $ term encourages **high-value actions**
- The $ \alpha \log \pi $ term encourages **stochastic exploration**

This corresponds to the line:
- `actor_loss = ?`

---

#### Alpha Loss: Automatic Temperature Adjustment

Purpose:
- Automatically adjust the **exploration strength** so that:
$$
\mathcal{H}(\pi) \approx \mathcal{H}_{\text{target}}
$$

Mathematically:
$$
\mathcal{L}_{\alpha}
=
-\log \alpha \cdot
\left(
\log \pi(a|s) + \mathcal{H}_{\text{target}}
\right)
$$

Key idea:
- If the policy is **too deterministic**, α should increase
- If the policy is **too random**, α should decrease

This corresponds to the line:
- `alpha_loss = ?`

In [None]:
for t in range(1, max_steps + 1):

    if t < start_steps:
        action = env.action_space.sample()
    else:
        state_tensor = torch.as_tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
        with torch.no_grad():
            sampled_action, _, _ = policy.sample(state_tensor)
        action = sampled_action.cpu().numpy()[0]

    next_state, reward, terminated, truncated, _ = env.step(action)
    done = terminated or truncated

    episode_return += reward
    episode_length += 1

    replay_buffer.add(state, action, reward, next_state, float(done))

    state = next_state

    if done:
        print(
            f"Step {t} | "
            f"Episode {episode_idx} | "
            f"Return: {episode_return:.1f} | "
            f"Length: {episode_length}"
        )
        state, _ = env.reset()
        episode_return = 0.0
        episode_length = 0
        episode_idx += 1


    if t >= update_after and len(replay_buffer) >= batch_size:
        for _ in range(update_every):
            states, actions, rewards, next_states, dones = replay_buffer.sample(batch_size, device)
            alpha = log_alpha.exp()

            with torch.no_grad():
                
                next_actions, next_log_pi, _ = #
                q1_next_target, q2_next_target = #
                q_next_target = #
                target_q = #
                
            q1_pred, q2_pred = #
            critic_loss = #

            critic_optimizer.zero_grad()
            critic_loss.backward()
            critic_optimizer.step()

            # pick action again and optimize 
            new_actions, log_pi, _ = policy.sample(states)
            q1_new, q2_new = #
            q_new = #
            actor_loss = #
            
            policy_optimizer.zero_grad()
            actor_loss.backward()
            policy_optimizer.step()

            alpha_loss = #

            alpha_optimizer.zero_grad()
            alpha_loss.backward()
            alpha_optimizer.step()

            with torch.no_grad():
                for param, target_param in zip(critic.parameters(), critic_target.parameters()):
                    target_param.data.copy_(
                        tau * param.data + (1.0 - tau) * target_param.data
                    )