# Soft Actor Critic

Some background: https://spinningup.openai.com/en/latest/algorithms/sac.html
Nice explanation: https://towardsdatascience.com/soft-actor-critic-demystified-b8427df61665
Nice video: https://www.youtube.com/watch?v=LN29DDlHp1U&ab_channel=YannBouteiller

off-policy:

actor-critic:

encourages exploration - maximizes entropy: No maximize reward of action, but maximize entropy to randomize actions and encourage exploration



Issues:
- Not a lot of exploration - Local optima are an issue
- Unseen scenarios yield poor performance

![SAC logic](img/SAC.png)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import gym
import torch.nn.functional as F

## Hyperparams


### ALPHA 
Tradeoff between exploration and exploitation. 

Balances the return vs the entropy of the policy

Participates in
- Actor loss: 
$$  \text{Loss} = - \alpha \times \text{entropy} - Q $$
with $\text{entropy} = -\log p$. This way we optimize both the entropy (more exploration), and the Q value (reward). 
- reparametrization trick temperature parameter

It is learned iteratively - We only control the initial value. Learning is done by trying to set the log proba of actions to a target entropy value


### GAMMA
Discount factor, for relative importance between current (0) and future (1) rewards. The relevant formula is

$$ \text{target } Q (s_t, a_t)= r(s_t, a_t) +  \gamma \times (1-\text{done}) \times \left( \text{min } Q_{\text{critics}} + \alpha \times \text{entropy}   \right) $$
with $r = \text{reward from replay buffer}$, 


### TAU
Polyak coefficient. Determines how much the target networks update. In this case there is target actor and two target critics.

$$ \text{target} = \tau \times \text{data} + (1 - \tau) \times \text{target} $$

- TAU = 0 means targets do not update
- TAU = 1 means targets update exactly. Better for convergence, but introduces instabilities


In [None]:
# Hyperparameters
GAMMA = 0.99
TAU = 0.005
ALPHA = 0.2
BATCH_SIZE = 256
BUFFER_SIZE = 1e6
LEARNING_RATE = 3e-4
EPISODES = 1000


# Create an environment
env = gym.make('Pendulum-v0')

## Actor

In [None]:
from torch.distributions import Normal

# Define the actor network
class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, max_action):
        super(Actor, self).__init__()
        self.layer_1 = nn.Linear(state_dim, 256)
        self.layer_2 = nn.Linear(256, 256)
        self.layer_3 = nn.Linear(256, action_dim)
        self.log_std_layer = nn.Linear(256, action_dim)
        self.log_std_min = -20
        self.log_std_max = 2
        self.max_action = max_action

    def forward(self, state):
        x = F.relu(self.layer_1(state))
        x = F.relu(self.layer_2(x))
        mu = self.layer_3(x)
        log_std = self.log_std_layer(x)
        log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
        return mu, log_std

    def sample(self, state):
        mu, log_std = self.forward(state)
        std = log_std.exp()
        normal_distribution = Normal(mu, std)
        z = normal_distribution.rsample()
        action = torch.tanh(z) * self.max_action
        log_prob = normal_distribution.log_prob(z) - torch.log(1 - action.pow(2) + 1e-7)
        log_prob = log_prob.sum(1, keepdim=True)
        return action, log_prob

## Critic

In [None]:
class Critic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(Critic, self).__init__()

        self.layer_1 = nn.Linear(state_dim + action_dim, 256)
        self.layer_2 = nn.Linear(256, 256)
        self.layer_3 = nn.Linear(256, 1)

    def forward(self, x, u):
        x = torch.relu(self.layer_1(torch.cat([x, u], 1)))
        x = torch.relu(self.layer_2(x))
        x = self.layer_3(x)
        return x

## Replay Buffer


Stores experiences experienced by the agent in the form of a tuple  (state, action, next_state, reward, done)

Used during training to break the correlation between consecutive samples.

Allows efficient reuse of experiences in training

Allows to go off-policy, as the experiences recorded do not correspond to the current policy during training

In [None]:
class ReplayBuffer:
    def __init__(self, max_size):
        self.storage = []
        self.max_size = max_size
        self.ptr = 0

    def add(self, data):
        if len(self.storage) == self.max_size:
            self.storage[int(self.ptr)] = data
            self.ptr = (self.ptr + 1) % self.max_size
        else:
            self.storage.append(data)

    def sample(self, batch_size):
        ind = np.random.randint(0, len(self.storage), size=batch_size)
        batch_states, batch_actions, batch_next_states, batch_rewards, batch_dones = [], [], [], [], []

        for i in ind:
            state, action, next_state, reward, done = self.storage[i]
            batch_states.append(state)
            batch_actions.append(action)
            batch_next_states.append(next_state)
            batch_rewards.append(reward)
            batch_dones.append(done)

        return (
            torch.FloatTensor(batch_states),
            torch.FloatTensor(batch_actions),
            torch.FloatTensor(batch_next_states),
            torch.FloatTensor(batch_rewards),
            torch.FloatTensor(batch_dones),
        )

## SAC algorithm

In [None]:
class SAC:
    def __init__(self, env):
        # Environment
        self.env = env
        self.state_dim = env.observation_space.shape[0]
        self.action_dim = env.action_space.shape[0]
        self.max_action = float(env.action_space.high[0])

        # Actor object
        self.actor = Actor(self.state_dim, self.action_dim, self.max_action)

        # Target is an actor that is updated less often than the Actor - Initially it must be initialized as a copy of actor. 
        self.actor_target = Actor(self.state_dim, self.action_dim, self.max_action)
        self.actor_target.load_state_dict(self.actor.state_dict())

        # Optimized
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=LEARNING_RATE)

        # Critics
        self.critic_1 = Critic(self.state_dim, self.action_dim)
        self.critic_2 = Critic(self.state_dim, self.action_dim)
        self.critic_target_1 = Critic(self.state_dim, self.action_dim)
        self.critic_target_2 = Critic(self.state_dim, self.action_dim)
        self.critic_target_1.load_state_dict(self.critic_1.state_dict())
        self.critic_target_2.load_state_dict(self.critic_2.state_dict())
        self.critic_optimizer_1 = optim.Adam(self.critic_1.parameters(), lr=LEARNING_RATE)
        self.critic_optimizer_2 = optim.Adam(self.critic_2.parameters(), lr=LEARNING_RATE)

        self.log_alpha = torch.tensor(np.log(ALPHA), requires_grad=True)
        self.alpha_optimizer = optim.Adam([self.log_alpha], lr=LEARNING_RATE)

        self.replay_buffer = ReplayBuffer(BUFFER_SIZE)

    def train(self, iterations):
        # Main training loop
        for _ in range(iterations):

            # Sample a batch of experiences from the replay buffer
            state, action, next_state, reward, done = self.replay_buffer.sample(BATCH_SIZE)

            # Compute the TARGET Q value (expected reward)
            # It is the minimum reward computed from the critic response to a sampled action from the actor target
            with torch.no_grad():
                target_action, log_prob = self.actor_target.sample(next_state)
                target_q1 = self.critic_target_1(next_state, target_action)
                target_q2 = self.critic_target_2(next_state, target_action)
                # Alpha here encourages exploration by introducing an entropy term, which is the log of the prob of the sampled action
                target_q = torch.min(target_q1, target_q2) - self.alpha * log_prob
                target_q = reward + (1 - done) * GAMMA * target_q

            # Update the critics
            # Critics are networks who should estimate target Q correctly!
            current_q1 = self.critic_1(state, action)
            current_q2 = self.critic_2(state, action)
            critic_loss_1 = nn.MSELoss()(current_q1, target_q)
            critic_loss_2 = nn.MSELoss()(current_q2, target_q)
            self.critic_optimizer_1.zero_grad()
            critic_loss_1.backward()
            self.critic_optimizer_1.step()
            self.critic_optimizer_2.zero_grad()
            critic_loss_2.backward()
            self.critic_optimizer_2.step()

            # Update the actor and alpha
            # Actor update seeks to improve the returns (regularized by the entropy term)
            new_action, log_prob = self.actor.sample(state)
            q1_new = self.critic_1(state, new_action)
            q2_new = self.critic_2(state, new_action)
            q_new = torch.min(q1_new, q2_new)  # q-value determined by the critics
            # Actor must optimize the q-value, but there is an entropy term to encourage exploration
            actor_loss = (self.alpha * log_prob - q_new).mean() 
            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            self.actor_optimizer.step()

            # Alpha is trained to be the parameter that makes the entropy (log_prob) equal to the target entropy (hyperparam)
            alpha_loss = -(self.log_alpha * (log_prob + self.target_entropy).detach()).mean()
            self.alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self.alpha_optimizer.step()
            self.alpha = self.log_alpha.exp()
            
            # Soft update on target parameters
            for param, target_param in zip(self.critic_1.parameters(), self.critic_target_1.parameters()):
                target_param.data.copy_(TAU * param.data + (1 - TAU) * target_param.data)

            for param, target_param in zip(self.critic_2.parameters(), self.critic_target_2.parameters()):
                target_param.data.copy_(TAU * param.data + (1 - TAU) * target_param.data)

            for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
                target_param.data.copy_(TAU * param.data + (1 - TAU) * target_param.data)

    def run(self):
        for episode in range(EPISODES):
            state = self.env.reset()
            episode_reward = 0

            while True:
                action = self.actor.select_action(state)
                next_state, reward, done, _ = self.env.step(action)
                self.replay_buffer.add((state, action, next_state, reward, float(done)))
                state = next_state
                episode_reward += reward

                if len(self.replay_buffer.storage) >= BATCH_SIZE:
                    self.train(1)

                if done:
                    print(f"Episode: {episode + 1}, Reward: {episode_reward}")
                    break

