# Soft Actor-Critic (SAC) Homework

Soft Actor-Critic (SAC) is an off-policy, maximum entropy algorithm used in Deep Reinforcement Learning. It follows the principles of policy iteration and consists of two main steps:

1. **Policy Evaluation**: Learning the soft Q-values under the current policy. This is done by performing a SARSA-like temporal difference update on the critic.
2. **Policy Improvement**: Improving the policy based on the soft Q-values. This involves backpropagating through the critic to change the policy such that the critic's output is maximized.

Despite being developed a few years ago, SAC remains an effective algorithm, especially when using a high Replay Ratio (RR), which is the number of gradient update steps per environment step. However, increasing RR can lead to training instability, negatively affecting the training process.

## Homework Tasks

In this homework, you will address the issue of value overestimation associated with temporal difference learning, which is particularly noticeable in high RR settings. Your tasks are:

1. **Fill-in the implementation of the SAC Algorithm**
   - Develop the code for the policy evaluation and policy improvement steps of the SAC algorithm.
   - Implement a simple remedy to mitigate training instabilities.

2. **Train SAC Agent on the Ant Environment**
   - Use the Ant environment to train your SAC agent and observe its performance.

3. **Investigate Training Instabilities in High RR**
   - Analyze the training instabilities that arise when using a high Replay Ratio.

## References

1. Haarnoja, T., et al. (2018). "Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor." [arXiv:1812.05905](https://arxiv.org/abs/1812.05905)
2. [OpenReview: Replay Ratio in Off-Policy Algorithms](https://openreview.net/pdf?id=OpC-9aBBVJe)
3. Fujimoto, S., et al. (2018). "Addressing Function Approximation Error in Actor-Critic Methods." [arXiv:1802.09477](https://arxiv.org/abs/1802.09477)
4. [Training Instabilities in High Replay Ratio](https://arxiv.org/pdf/2403.05996)


Ensure your implementation is clear and well-documented, and analyze the results thoroughly. Make sure you use gpu runtime to speed up computations.

## Provided code

In [None]:
!pip install gymnasium[mujoco]

In [None]:
import torch
import torch.nn as nn
import numpy as np
import random
import os
import copy
from typing import OrderedDict
import gymnasium as gym
from google.colab import files

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
class ReplayBuffer:
    def __init__(self, env, capacity: int):
        state_dim = env.observation_space.shape[0]
        action_dim = env.action_space.shape[0]
        self.states = np.empty((capacity, state_dim), dtype=np.float32)
        self.actions = np.empty((capacity, action_dim), dtype=np.float32)
        self.rewards = np.empty((capacity, 1), dtype=np.float32)
        self.masks = np.empty((capacity, 1), dtype=np.float32)
        self.next_states = np.empty((capacity, state_dim), dtype=np.float32)
        self.size = 0
        self.insert_index = 0
        self.capacity = capacity

    def add(self, state: np.ndarray, action: np.ndarray, reward: float, next_state: np.ndarray, mask: float):
        self.states[self.insert_index] = state
        self.actions[self.insert_index] = action
        self.rewards[self.insert_index] = reward
        self.masks[self.insert_index] = mask
        self.next_states[self.insert_index] = next_state
        self.insert_index = (self.insert_index + 1) % self.capacity
        self.size = min(self.size + 1, self.capacity)

    def to_torch(self, array):
        return torch.from_numpy(array).float()

    def sample(self, batch_size: int, num_batches: int):
        indxs = np.random.randint(self.size, size=(num_batches, batch_size))
        states = self.to_torch(self.states[indxs])
        actions = self.to_torch(self.actions[indxs])
        rewards = self.to_torch(self.rewards[indxs])
        next_states = self.to_torch(self.next_states[indxs])
        masks = self.to_torch(self.masks[indxs])
        return states, actions, rewards, next_states, masks

In [None]:
def set_seed(seed: int):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)

def get_masks(terminates, truncates):
    masks = 0.0
    if not terminates or truncates:
        masks = 1.0
    return masks

In [None]:
def weight_init(model):
    if isinstance(model, nn.Linear):
        nn.init.orthogonal_(model.weight.data)
        model.bias.data.fill_(0.0)

class Actor(nn.Module):
    def __init__(self, state_dim: int, action_dim: int, hidden_dim: int = 256):
        super(Actor, self).__init__()
        self.log_std_min = -10.0
        self.log_std_max = 2.0
        self.activation = nn.ReLU()
        self.layers = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            self.activation,
            nn.Linear(hidden_dim, hidden_dim),
            self.activation,
            nn.Linear(hidden_dim, 2 * action_dim))
        self.apply(weight_init)

    def forward(self, state):
        mu, log_std = self.layers(state).chunk(2, dim=-1)
        # cap log_std between log_std_min and log_std_max - good for stability
        log_std = self.log_std_min + 0.5 * (self.log_std_max - self.log_std_min) * (torch.tanh(log_std) + 1)
        return mu, log_std.exp()

# Task 1 - Fill-in the implementation of the SAC Algorithm

## Task 1a - Implement a simple remedy to mitigate training instabilities

Implement two versions of the critic network - with layer normalization and without.

Common elements of the architecture:
* Input: state and action
* Output: value of the action-state pair - float
* Depth: 2 hidden layers of size `hidden_dim`
* Activation: `self.activation`

Layer normalization should be applied after each hidden layer.


In [None]:
class Critic(nn.Module):
    def __init__(self, state_dim: int, action_dim: int, hidden_dim: int = 256, use_layernorm: bool = False):
        super(Critic, self).__init__()
        self.activation = nn.ReLU()
        if use_layernorm:
            ####### TODO #######
            # self.layers = ...
            ####################
        else:
            ####### TODO #######
            # self.layers = ...
            ####################

        self.apply(weight_init)

    def forward(self, state, action):
        state_action = torch.concat((state, action), dim=1)
        q_value = self.layers(state_action)
        return q_value

class DoubleCritic(nn.Module):
    def __init__(self, state_dim: int, action_dim: int, hidden_dim: int = 256, use_layernorm: bool = False):
        super(DoubleCritic, self).__init__()
        self.critic1 = Critic(state_dim=state_dim, action_dim=action_dim, hidden_dim=hidden_dim, use_layernorm=use_layernorm)
        self.critic2 = Critic(state_dim=state_dim, action_dim=action_dim, hidden_dim=hidden_dim, use_layernorm=use_layernorm)

    def forward(self, state, action):
        q1 = self.critic1(state, action)
        q2 = self.critic2(state, action)
        return q1, q2

## Task 1b - Implement Update Steps of the SAC Algorithm
Please fill missing parts of code in the following key sections of the SAC algorithm:
1. `target_Q` computation in method `update_critic` - compute q-backup according to equations 3, 5, and 6 from [1]. Remeber you are using double q-learning setup and max entropy setting.
2. `actor_loss` computation in method `update_actor` - compute actor loss according to equation 7 from [1]. Remeber you are using double q-learning setup.
3. `alpha_loss` computation in method `update_actor` - compute alpha loss according to equation 18 from [1].

**References**
 1. Haarnoja, T., et al. (2018). "Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor." [arXiv:1812.05905](https://arxiv.org/abs/1812.05905)

In [None]:
def gaussian_logprob(noise, log_std):
    residual = (-0.5 * noise.pow(2) - log_std).sum(-1, keepdim=True)
    return residual - 0.5 * np.log(2 * np.pi) * noise.size(-1)

def squash(pi, log_prob):
    if pi is not None:
        pi = torch.tanh(pi)
    if log_prob is not None:
        log_prob -= torch.log(nn.functional.relu(1 - pi.pow(2)) + 1e-6).sum(-1, keepdim=True)
    return pi, log_prob

class SAC(nn.Module):
    def __init__(self, env, use_layernorm: bool = False):
        super(SAC, self).__init__()
        self.device = DEVICE
        self.state_dim = env.observation_space.shape[0]
        self.action_dim = env.action_space.shape[0]
        self.hidden_dim = 256
        self.learning_rate = 3e-3
        self.discount = 0.99
        self.target_entropy = -np.prod(self.action_dim) / 2
        self.tau = 0.005
        self.use_layernorm = use_layernorm
        self.reset()
        self.logger = {'q_value': [], 'temperature': [], 'entropy': [], 'critic_loss': [], 'actor_loss': [], 'returns': []}

    def reset(self):
        self.actor = Actor(self.state_dim, self.action_dim, self.hidden_dim).to(self.device)
        self.optimizer_actor = torch.optim.Adam(self.actor.parameters(), lr=self.learning_rate)
        self.critic = DoubleCritic(self.state_dim, self.action_dim, self.hidden_dim, self.use_layernorm).to(self.device)
        self.target_critic = copy.deepcopy(self.critic).to(self.device)
        self.optimizer_critic = torch.optim.Adam(self.critic.parameters(), lr=self.learning_rate)
        self.log_alpha = torch.tensor(np.log(1.0)).to(self.device)
        self.log_alpha.requires_grad = True
        self.optimizer_log_alpha = torch.optim.Adam([self.log_alpha], lr=self.learning_rate, )

    @property
    def alpha(self):
        return self.log_alpha.exp()

    def get_action(self, state, return_logprob=False, temperature=1.0):
        mu, std = self.actor(state)
        noise = torch.randn_like(mu)
        std = std * temperature
        action = mu + noise * std
        if return_logprob:
            log_prob = gaussian_logprob(noise, std.log())
            action, log_prob = squash(action, log_prob)
            return action, log_prob
        else:
            action, _ = squash(action, None)
            return action

    def update(self, step, states, actions, rewards, next_states, masks):
        for i in range(states.shape[0]):
            self.update_critic(step, states[i], actions[i], rewards[i], next_states[i], masks[i])
            self.update_actor(step, states[i])
            self.update_target_critic(self.tau)

    def update_target_critic(self, tau):
        for param, target_param in zip(self.critic.parameters(), self.target_critic.parameters()):
            target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)

    def update_critic(self, step, states, actions, rewards, next_states, masks):
        with torch.no_grad():
            ####### TODO #######

            # target_Q = ...
            ####################
        current_Q1, current_Q2 = self.critic(states, actions)
        critic_loss = nn.functional.mse_loss(current_Q1, target_Q) + nn.functional.mse_loss(current_Q2, target_Q)
        self.optimizer_critic.zero_grad()
        critic_loss.backward()
        self.optimizer_critic.step()
        if step == 1:
            self.logger['q_value'].append(current_Q1.mean().detach().item())
            self.logger['critic_loss'].append(critic_loss.mean().detach().item())

    def update_actor(self, step, states):
        actions, log_probs = self.get_action(states, return_logprob=True, temperature=1.0)
        ####### TODO #######

        # actor_loss = ...
        ####################
        self.optimizer_actor.zero_grad()
        actor_loss.backward()
        self.optimizer_actor.step()
        self.optimizer_log_alpha.zero_grad()
        ####### TODO #######

        # alpha_loss = ...
        ####################
        alpha_loss.backward()
        self.optimizer_log_alpha.step()
        if step == 1:
            self.logger['entropy'].append(entropy.detach().item())
            self.logger['temperature'].append(self.alpha.detach().item())
            self.logger['actor_loss'].append(actor_loss.detach().item())



# Task 2 - Train SAC Agent on the Ant Environment
Train the following models:

1. use_layernorm = False, replay_ratio = 1
2. use_layernorm = False, replay_ratio = 8
3. use_layernorm = True, replay_ratio = 8

In [None]:
def train(init_steps: int = 1000,
          training_steps: int = 50000,
          replay_ratio: int = 8,
          use_layernorm: bool = False):

    set_seed(0)
    env = gym.make('Ant-v4')

    agent = SAC(env, use_layernorm=use_layernorm).to(DEVICE)
    agent = torch.compile(agent)
    buffer = ReplayBuffer(env, training_steps)

    state, _ = env.reset(seed=np.random.randint(0,1e7))
    state = torch.from_numpy(state).float().to(DEVICE)

    returns = 0.0
    episode_step = 0

    for step in range(1, training_steps + 1):
        if step >= init_steps:
            action = agent.get_action(state).detach().cpu().numpy()
        else:
            action = env.action_space.sample()

        next_state, reward, terminal, truncate, _ = env.step(action)
        next_state = torch.from_numpy(next_state).float().to(DEVICE)

        mask = get_masks(terminal, truncate)

        returns += reward
        episode_step += 1

        buffer.add(state.cpu().numpy(), action, reward, next_state.cpu().numpy(), mask)

        if step >= init_steps:
            states, actions, rewards, next_states, masks = buffer.sample(256, replay_ratio)
            states = states.to(DEVICE)
            actions = actions.to(DEVICE)
            rewards = rewards.to(DEVICE)
            next_states = next_states.to(DEVICE)
            masks = masks.to(DEVICE)

            agent.update(episode_step, states, actions, rewards, next_states, masks)

        if terminal or truncate:
            state, _ = env.reset(seed=np.random.randint(0, 1e7))
            state = torch.from_numpy(state).float().to(DEVICE)

            if len(agent.logger['q_value']) > 0:
                agent.logger['returns'].append(returns)
                print(f"TrainStep: {step} Returns: {np.round(returns, 2)} Q-values: {np.round(agent.logger['q_value'][-1], 2)} Temperature: {np.round(agent.logger['temperature'][-1], 2)} Entropy: {np.round(agent.logger['entropy'][-1], 2)}")

            returns = 0.0
            episode_step = 0
        else:
            state = next_state

    return agent.logger

In [None]:
# train low rr agent ~ 10 min
history_lowrr = train(replay_ratio=1, use_layernorm=False)
np.savez("history_lowrr.npz", history_lowrr)
files.download("history_lowrr.npz")

In [None]:
# train high rr agent ~ 1h
history_noln = train(use_layernorm=False)
np.savez("history_noln.npz", history_noln)
files.download("history_noln.npz")

In [None]:
# train high rr LN agent ~ 1h
history_ln = train(use_layernorm=True)
np.savez("history_ln.npz", history_ln)
files.download("history_ln.npz")

In [None]:
# print average performance

print('Low RR, No LayerNorm: ', np.asarray(history_lowrr['returns']).mean())
print('High RR, No LayerNorm: ', np.asarray(history_noln['returns']).mean())
print('High RR, LayerNorm: ', np.asarray(history_ln['returns']).mean())

# Task 3 - Investigate Training Instabilities in High RR

## Task 3a (coding)
Graph the following timeseries for all tested models:

1. Returns (Y-axis), Episode (X-axis)
2. Q-values (Y-axis), Episode (X-axis)
3. Temperature (Y-axis), Episode (X-axis)
4. Entropy (Y-axis), Episode (X-axis)

The task will be graded by **clarity of the presentation** and its **esthetic values**. The graph should provide all the information needed to observe properites and answer the quesions described below.



## Task 3b (writing)
Investigate the data and describe the issues that you observe. In particular, focus on the contrast between tested model variations in the following contexts:

1. Does the critic learn correct Q-values? Does it overestimate or underestimate?
2. Is the temperature mechanism enough to stabilize the entropy?
3. What happens to the entropy temperature parameter? Why would it take "large" values (ie. > 10)?


## Task 3c (writing)
In previous questions we compare sum of episodic rewards to critic output at starting state. Is this approach correct given that our critic learns soft Q-values? What are the problems with this approach?



## Task 3d (writing - bonus question)
Why would using layer normalization help with value overestimation in high RR?