# SAC Humanoid, and etc 

* SAC - [T. Haarnoja et al., "Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor." 2018](https://arxiv.org/abs/1801.01290)
* SAC with Automatically Adjusted Temperature - [T. Haarnoja et al., "Soft Actor-Critic Algorithms and Applications." 2018](https://arxiv.org/abs/1812.05905)
* SAC spinningup OpenAI - https://spinningup.openai.com/en/latest/algorithms/sac.html 2019/2020
  

## SAC - Soft Actor-Critic

**Soft Actor-Critic (SAC)** [(Haarnoja et al. 2018)](https://arxiv.org/abs/1801.01290) incorporates the entropy measure of the policy into the reward to encourage exploration: we expect to learn a policy that acts as randomly as possible while it is still able to succeed at the task. It is an off-policy actor-critic model following the maximum entropy reinforcement learning framework. A precedent work is [Soft Q-learning](https://arxiv.org/abs/1702.08165).

Three key components in SAC:

* An **actor-critic** architecture with separate policy and value function networks;
* An **off-policy** formulation that enables reuse of previously collected data for efficiency;
* Entropy maximization to enable stability and exploration.

The policy is trained with the objective to maximize the expected return and the entropy at the same time:

$$
J(\theta) = \sum_{t=1}^T \mathbb{E}_{(s_t, a_t) \sim \rho_{\pi_\theta}} [r(s_t, a_t) + \alpha \mathcal{H}(\pi_\theta(.\vert s_t))]
$$

where $\mathcal{H}(.)$ is the entropy measure and $\alpha$ controls how important the entropy term is, known as **temperature** parameter. The entropy maximization leads to policies that can (1) explore more and (2) capture multiple modes of near-optimal strategies (i.e., if there exist multiple options that seem to be equally good, the policy should assign each with an equal probability to be chosen).

Precisely, SAC aims to learn three functions:

* The policy with parameter $\theta, \pi_\theta$
* Soft Q-value function parameterized by $w, Q_w$ 
* Soft state value function parameterized by $\psi, V_\psi$ *(first versions paper)*; **in modern versions we infer $V$ by knowing $Q$ and $\pi$ and we will do it here.**

Soft Q-value and soft state value are defined as:

$$
Q(s_t, a_t) = r(s_t, a_t) + \gamma \mathbb{E}_{s_{t+1} \sim \rho_{\pi}(s)} [V(s_{t+1})]  \text{; according to Bellman equation.}\\
\text{where }V(s_t) = \mathbb{E}_{a_t \sim \pi} [Q(s_t, a_t) - \alpha \log \pi(a_t \vert s_t)]  \text{; soft state value function.}
$$
$$
\text{Thus, } Q(s_t, a_t) = r(s_t, a_t) + \gamma \mathbb{E}_{(s_{t+1}, a_{t+1}) \sim \rho_{\pi}} [Q(s_{t+1}, a_{t+1}) - \alpha \log \pi(a_{t+1} \vert s_{t+1})]
$$

$\rho_{\pi}(s)$ and $\rho_{\pi}(s,a)$ denote the state and the state-action marginals of the state distribution induced by the policy 


-----
#### Older version of SAC including parameterized V

The soft state value function is trained to minimize the mean squared error:
$$
J_V(\psi) = \mathbb{E}_{s_t \sim \mathcal{D}} [\frac{1}{2} \big(V_\psi(s_t) - \mathbb{E}[Q_w(s_t, a_t) - \log \pi_\theta(a_t \vert s_t)] \big)^2] \\
\text{with gradient: }\nabla_\psi J_V(\psi) = \nabla_\psi V_\psi(s_t)\big( V_\psi(s_t) - Q_w(s_t, a_t) + \log \pi_\theta (a_t \vert s_t) \big)
$$

where $\mathcal{D}$ is the replay buffer.

The soft Q function is trained to minimize the soft Bellman residual:

$$
J_Q(w) = \mathbb{E}_{(s_t, a_t) \sim \mathcal{D}} [\frac{1}{2}\big( Q_w(s_t, a_t) - (r(s_t, a_t) + \gamma \mathbb{E}_{s_{t+1} \sim \rho_\pi(s)}[V_{\bar{\psi}}(s_{t+1})]) \big)^2] \\
\text{with gradient: } \nabla_w J_Q(w) = \nabla_w Q_w(s_t, a_t) \big( Q_w(s_t, a_t) - r(s_t, a_t) - \gamma V_{\bar{\psi}}(s_{t+1})\big) 
$$

where $\bar{\psi}$ is the target value function which is the exponential moving average (or only gets updated periodically in a “hard” way), just like how the parameter of the target Q network is treated in DQN to stabilize the training

----

#### Modern versions of SAC without parameterized V

The soft Q function is trained to minimize the soft Bellman residual:

$$
J_Q(w) = \mathbb{E}_{(s_t, a_t) \sim \mathcal{D}} [\frac{1}{2}\big( Q_w(s_t, a_t) - (r(s_t, a_t) + \gamma \mathbb{E}_{s_{t+1} \sim \rho_\pi(s)}[V_{\bar{w}}(s_{t+1})]) \big)^2] \\
\text{with gradient: } \nabla_w J_Q(w) = \nabla_w Q_w(s_t, a_t) \big( Q_w(s_t, a_t) - (r(s_t, a_t) + \gamma (Q_{\bar{w}}(s_{t+1},a_{t+1}) - \alpha \log \pi(a_{t+1} \vert s_{t+1})\big) 
$$

The update makes use of a target soft Q-function with parameters $\bar{w}$ which is the exponential moving average (or only gets updated periodically in a “hard” way), just like how the parameter of the target Q network is treated in DQN to stabilize the training

----

SAC updates the policy to minimize the **KL-divergence**:

$$
\pi_\text{new} 
= \arg\min_{\pi' \in \Pi} D_\text{KL} \Big( \pi'(.\vert s_t) \| \frac{\exp(\frac{1}{\alpha}Q^{\pi_\text{old}}(s_t, .))}{Z^{\pi_\text{old}}(s_t)} \Big) \\[6pt]
= \arg\min_{\pi' \in \Pi} D_\text{KL} \big( \pi'(.\vert s_t) \| \exp(\frac{1}{\alpha}Q^{\pi_\text{old}}(s_t, .) - \log Z^{\pi_\text{old}}(s_t)) \big) \\[6pt]
\text{objective for update: } J_\pi(\theta) = \nabla_\theta D_\text{KL} \big( \pi_\theta(. \vert s_t) \| \exp(\frac{1}{\alpha}Q_w(s_t, .) - \log Z_w(s_t)) \big) \\[6pt]
= \mathbb{E}_{a_t\sim\pi} \Big[ - \log \big( \frac{\exp(\frac{1}{\alpha}Q_w(s_t, a_t) - \log Z_w(s_t))}{\pi_\theta(a_t \vert s_t)} \big) \Big] \\[6pt]
= \mathbb{E}_{a_t\sim\pi} [ \alpha \log \pi_\theta(a_t \vert s_t) - Q_w(s_t, a_t) + \alpha \log Z_w(s_t) ]
$$

where $\Pi$ is the set of potential policies that we can model our policy as to keep them tractable; for example, $\Pi$ can be the family of Gaussian mixture distributions, expensive to model but highly expressive and still tractable. $Z^{\pi_\text{old}}(s_t)$ is the partition function to normalize the distribution. It is usually intractable but does not contribute to the gradient. How to minimize $J_\pi(\theta)$ depends our choice of $\Pi$.

----
#### Clipped double-Q trick

One thing to note is that the authors suggest to use two Q-functions to mitigate positive bias in the policy improvement step that is known to degrade performance of value based methods. In particular, we parameterize two Q-functions, with parameters $w_i$, and train them independently to optimize $J_Q(w)$. We then use the minimum of the Q-functions for the value and policy gradient. Two Q-functions can significantly speed up training, especially on harder task. 

To get the policy loss, the final step is that we need to substitute $Q^{\pi_{\theta}}$ with one of our function approximators. Unlike in TD3, which uses $Q_{w_1}$ (just the first Q approximator), SAC uses $\min_{j=1,2} Q_{w_j}$ (the minimum of the two Q approximators). The policy is thus optimized according to

$$
J_\pi(\theta) = \mathbb{E}_{a_t\sim\pi} [ \alpha \log \pi_\theta(a_t \vert s_t) - \min_{i=1,2} Q_{w_i}(s_t, a_t) ]
$$

Once we have defined the objective functions and gradients for soft action-state value, soft state value and the policy network, the soft actor-critic algorithm is straightforward. =)

____

## SAC with Automatically Adjusted Temperature

SAC is brittle with respect to the temperature parameter. Unfortunately it is difficult to adjust temperature, because the entropy can vary unpredictably both across tasks and during training as the policy becomes better. An improvement on SAC formulates a constrained optimization problem: while maximizing the expected return, the policy should satisfy a minimum entropy constraint:

$$
\max_{\pi_0, \dots, \pi_T} \mathbb{E} \Big[ \sum_{t=0}^T r(s_t, a_t)\Big] \text{s.t. } \forall t\text{, } \mathcal{H}(\pi_t) \geq \mathcal{H}_0
$$

where $\mathcal{H}_0$ is a predefined minimum policy entropy threshold. 

[**Great proof done by lilianweng**](https://lilianweng.github.io/posts/2018-04-08-policy-gradient/#sac-with-automatically-adjusted-temperature)

This constrained maximization becomes the following dual problem.

$$
\min_{\alpha_T  \ge 0} \max_{\pi_T} \mathbb{E} [r(s_T, a_T) - \alpha_T \log \pi(a_t|s_t)] - \alpha_T \mathcal{H}_0,
$$

where $\alpha_T$ is the dual variable. Furthermore, it can be rewrited as a optimization problem with regards to $\alpha$

$$
J(\alpha) = \mathbb{E}_{a_t \sim \pi_t} [-\alpha \log \pi_t(a_t \mid s_t) - \alpha \mathcal{H}_0]
$$

The final algorithm is same as SAC except for learning $\alpha$ explicitly with respect to the objective $J(\alpha)$

<img src="../assets/images/sac-algo.png" width="auto" height="auto">


## Import Modules 

In [1]:
import copy
import math
import io
from typing import List, Optional, Callable


import gymnasium as gym
import matplotlib.pyplot as plt
import numpy as np
from collections import OrderedDict


import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

### Define Model - Soft Actor Critic - Continuous Control Task

Important understand used trick for the enforcing action bounds, while calculating log probabilities. The actions need to be bounded to a finite interval. To that end, we apply an invertible squashing function $\tanh$ to the Gaussian samples(reparameterised for the gradients), and employ the change of variables formula to compute the likelihoods of the bounded actions. 
1. Sample the action using reparameterised Gaussian Distribution ($ a \sim \mathcal{N}(\mu, \sigma^2)$):
$$ a = \mu + \sigma* \mathcal{N}(0, 1) $$
2. Use the $\tanh$ to bound the action
$$ a_{\text{bounded}} = \tanh(a)$$
3. As now action is bounded to new range [-1,1], we use the **probability density conversion rule** to correct Formula for Log Probability Adjustment.
$$ 
    \text{change of variables formula: } p(a_{\text{bounded}}) = p(a) \cdot \left| \frac{d a}{d a_{\text{bounded}}} \right| \\[6pt]
    \text{derivative for tanh: }\frac{d a_{\text{bounded}}}{d a} = 1 - \tanh^2(a) \\[6pt]
    \frac{d a}{d a_{\text{bounded}}} = \frac{1}{1 - \tanh^2(a)} \\[6pt]
    \log p(a_{\text{bounded}}) = \log p(a) - \log \left|\frac{d a_{\text{bounded}}}{d a} \right| \\[6pt]
    \log p(a_{\text{bounded}}) = \log p(a) - \log \left(1 - \tanh^2(a)\right)
$$

where $ \log p(a) $: standard log probability of Gaussian distribution and $\log(1 - \tanh^2(a))$ is added to compensate the density conversion from $\tanh$

4. Scaling action to correct environment range
$$a_{\text{scaled}} = \text{action\_scale} \cdot a_{\text{bounded}} + \text{action\_bias}$$

P.S same way we should adjust actually scaled probability 
$$
 p(a_{\text{scaled}}) = p(a_{\text{bounded}}) \cdot \left| \frac{d a_{\text{bounded}}}{d a_{\text{scaled}}} \right| \\[6pt]
 \frac{d a_{\text{bounded}}}{d a_{\text{scaled}}} = \frac{1}{\text{action\_scale}} \\[6pt]
 \log p(a_{\text{scaled}}) = \log p(a_{\text{bounded}}) - \log (\text{action\_scale}) \\[6pt]
 \log p(a_{\text{scaled}}) = \log p(a) -  \log \left(\text{action\_scale} \cdot (1 - \tanh^2(a))\right)
$$

5. Smoothing better formula then $\tanh$, this is a more numerically-stable equivalent 
   $$ \log(1 - \tanh^2(a)) = 2 \cdot (\log(2) - a - \text{softplus}(-2a)) \text{ - lets prove it below}$$

$$
    \log(1-\tanh^2(a)) = \log(sech^2(a)) = 2 \cdot \log(sech(a)) = 2 \cdot \log(\frac{2}{e^{-a} + e^a}) = \\[6pt]
    = 2 \cdot \log(\frac{2e^{-a}}{e^{-2a} + 1}) = 2 \cdot \left( \log(2e^{-a}) - \log(e^{-2a} + 1) \right) = \\[6pt]
    = 2 \cdot \left( \log(2) - a - \text{softplus}(-2a) \right)
$$

| **Formula**                        | **Advantages**                             | **Disadvantages**                          |
|-------------------------------------|--------------------------------------------|--------------------------------------------|
| $ \log(1 - \tanh^2(a)) $          | Simple, direct, mathematically exact       | Prone to numerical instability (overflow/underflow) |
| $ 2 \cdot (\log(2) - a - \text{softplus}(-2a)) $ | Numerically stable, handles extreme values well | Slightly more computationally intensive   |
https://github.com/tensorflow/probability/commit/ef6bb176e0ebd1cf6e25c6b5cecdd2428c22963f#diff-5f9724ed00e3100f302f7e1f3cd306c91863e2cb3ce30b026ebcff6a269ac1beR64


In [2]:
# Prioritize device: CUDA > MPS > CPU
def set_device():
    if torch.cuda.is_available():
        DEVICE = torch.device("cuda")
        print("CUDA is available. Using CUDA.")
    elif torch.backends.mps.is_available():
        DEVICE = torch.device("mps")
        print("MPS backend is available. Using MPS.")
    else:
        DEVICE = torch.device("cpu")
        print("Neither CUDA nor MPS is available. Using CPU.")
    return DEVICE

DEVICE = set_device()
# lets see first on cpu
DEVICE = torch.device("cpu")

def make_mlp(input_dim: int, 
             output_dim: Optional[int] = None, 
             hidden_layers: List[int] = [256,256], 
             activation: Callable[[], nn.Module] = nn.ReLU) -> nn.Sequential:
    """
    Return an MLP (fully-connected) feature extractor.
    """
    layers = []
    for size in hidden_layers:
        layers.append(nn.Linear(input_dim, size))
        layers.append(activation())
        input_dim = size

    if output_dim is not None:
        # Add the final output layer
        layers.append(nn.Linear(input_dim, output_dim))
    return nn.Sequential(*layers)

def weights_init_(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight, gain=1)
        torch.nn.init.constant_(m.bias, 0)

LOG_STD_MAX = 2 # corresponding to std ~7.39
LOG_STD_MIN = -5 # std≈2.06×10^-9 based in original works

LOG2 = math.log(2.0)

class Actor(nn.Module):
    """Actor (Policy) Model."""

    def __init__(self, obs_shape, action_size, action_low, action_high):
        """"
        :param obs_shape: Tuple describing observation (e.g. (4,84,84) for images, or (24,) for vectors)
        :param action_size: Dimension of the action space (continuous)
        :param action_low:  Lower bound for the action (float or array)
        :param action_high: Upper bound for the action (float or array)
        """
        super(Actor, self).__init__()

        state_dim = obs_shape[0]
        # self.base = nn.Identity()
        self.fc_actor = make_mlp(state_dim, output_dim=None, hidden_layers=[256,256])
        feature_dim=256
            
        self.fc_mean = nn.Linear(feature_dim, action_size)
        self.fc_logstd = nn.Linear(feature_dim, action_size)
        
        # Action scaling
        self.register_buffer(
            "action_scale", torch.tensor((action_high - action_low) / 2, dtype=torch.float32)
        )
        self.register_buffer(
            "action_bias", torch.tensor((action_high + action_low) / 2, dtype=torch.float32)
        )
        self.register_buffer(
            "log_action_scale",  torch.log(self.action_scale)
        )

        self.apply(weights_init_)

    def forward(self, obs:torch.Tensor) -> torch.Tensor:
        """
        Forward pass: obs -> (CNN or Identity) -> MLP -> mean&log_std -> tanh -> scaled action
        """
        # features = self.base(obs)
        x = self.fc_actor(obs)
        
        # get mean
        mean = self.fc_mean(x)

        # get std 
        log_std = self.fc_logstd(x)
        log_std = torch.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
        std = log_std.exp()

        return mean, std
    
    def sample(self, obs: torch.Tensor, deterministic: bool = False, with_log_prob:bool = True):
        mean, std = self(obs)
        log_prob = None

        if deterministic:
            # used for evaluation of the policy 
            action = self.action_scale * torch.tanh(mean) + self.action_bias 

            return action, log_prob

        # sample actions
        distribution = torch.distributions.Normal(mean, std)

        # use reparametrisation trick on Normal distribution to calc the gradient base on actions
        #  (mean + std * N(0,1))
        action_raw = distribution.rsample()
        tanh_action = torch.tanh(action_raw)
        if with_log_prob:
            log_prob = distribution.log_prob(action_raw) #original gaussian distribution for action
            # Enforcing Action Bound, apply correction for Tanh squashing.
            # log_prob -= torch.log(self.action_scale * (1 - tanh_action.pow(2)) + 1e-6)
            # Enforcing Action Bound, use more numerically stable version for Tanh squashing, quite cool
            log_prob -= 2*(LOG2 - action_raw - F.softplus(-2 * action_raw)) 
            log_prob -= self.log_action_scale # Adjust for action scaling 
            log_prob = log_prob.sum(-1, keepdim=True)

        action = self.action_scale * tanh_action + self.action_bias

        return action, log_prob
    
    # Hypothetic way to rewrite above sample code using TransformedDistribution and Transforms
    # without 
    def _other_sample(self, obs: torch.Tensor, deterministic: bool = False):
        mean, std = self.forward(obs)

        if deterministic:
            # used for evaluation of the policy 
            action = mean
            action = self.action_scale *  torch.tanh(action) + self.action_bias
            return action, None
        
        base_distribution = torch.distributions.Normal(mean, std)
        # additional transforms to run on top
        # First tanh to bound between [-1, 1]
        tanh_transform = torch.distributions.transforms.TanhTransform(cache_size=1)
        # Then scale and shift to match action space
        scale_transform = torch.distributions.transforms.AffineTransform(self.action_bias, self.action_scale)
        squashed_and_scaled_dist = torch.distributions.TransformedDistribution(base_distribution, [tanh_transform, scale_transform])

        action = squashed_and_scaled_dist.rsample()
        
        log_prob = squashed_and_scaled_dist.log_prob(action).sum(-1, keepdim=True)
        # this equal to 
        # log_prob = base_distribution.log_prob(action)
        # log_prob -=  2*(np.log(2) - action - F.softplus(-2 * action))
        # log_prob -= torch.log(self.action_scale) 
        # log_prob = log_prob.sum(-1, keepdim=True)

        return action, log_prob


class Critic(nn.Module):
    """Double Critic (Q-Value) Model."""

    def __init__(self, obs_shape, action_size):
        """         
        :param obs_shape: Tuple describing observation (e.g. (4,84,84) for images, or (24,) for vectors)
        :param action_size: Dimension of the action space (continuous)
        """
        super(Critic, self).__init__()

        state_dim = obs_shape[0]
        self.fc_critic1 = make_mlp(state_dim + action_size, 1, [256,256])
        self.fc_critic2 = make_mlp(state_dim + action_size, 1, [256,256])

    def forward(self, obs: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
        """
        Forward pass: obs -> (CNN or Identity) -> MLP -> value1, value2
        """
        # features = self.base(obs)
        feature_actions = torch.cat([obs, action], dim=-1)
        value1 = self.fc_critic1(feature_actions)
        value2 = self.fc_critic2(feature_actions)
        return value1, value2 

MPS backend is available. Using MPS.


## Replay Buffer

The replay buffer storage collects and stores the following components for each step  $𝑡$ in a batch of $𝑁$ parallel environments:
- state $s_{t}$
- actions $a_{t}$
- next state $s_{t+1}$
- rewards $r_{t+}$
- dones/terminated $m_{t+1}$ - a binary mask to indicate if the environment is terminated (0 if active, 1 entered terminal state)
                    **important to have rewards to include next state value in case of truncates**

At each step $𝑡$, for environment $𝑏$, the collected data is:

$$ \{s_{t,b}, a_{t,b}, r_{t+1,b}, s_{t+1}, d_{t+1,b}, \} $$

In [4]:
class ReplayBufferNumpy:
    """    
    Replay buffer used in off-policy algorithms like DDPG/SAC/TD3.
    
    :param obs_dim: Observation dimensions
    :param action_dim: Actions dimensions
    :param n_envs: Number of parallel environments 
    :param size: Max number of elements in the buffer
    :param device: Device (cpu, cuda, ...) on which the code should be run. 
    """
    def __init__(self,
                 obs_dim,
                 action_dim,
                 n_envs: int = 1,
                 size: int = 1e6,
                 device: torch.device = torch.device("cpu")):
        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.device = device
        self.n_envs = n_envs

        self.pos = 0
        self.size = 0

        # Adjust buffer size 
        self.max_size = max(int(size // n_envs), 1)

        # Setup the data storage 
        
        self.obs = np.zeros((self.max_size, self.n_envs, *self.obs_dim), dtype=np.float32)
        self.next_obs = np.zeros((self.max_size, self.n_envs, *self.obs_dim), dtype=np.float32)
        self.actions = np.zeros((self.max_size, self.n_envs, *self.action_dim), dtype=np.float32)
        self.rewards = np.zeros((self.max_size, self.n_envs), dtype=np.float32)
        self.terminates = np.zeros((self.max_size, self.n_envs), dtype=np.float32)

    def add(self,
            obs: np.ndarray,
            next_obs: np.ndarray,
            action: np.ndarray,
            reward: np.ndarray,
            terminates: np.ndarray
            ):
        self.obs[self.pos] = obs
        self.next_obs[self.pos] = next_obs
        self.actions[self.pos] = action
        self.rewards[self.pos] = reward
        self.terminates[self.pos] = terminates

        self.pos = (self.pos + 1) % self.max_size
        self.size = min(self.size + 1, self.max_size)

    def sample(self,
               batch_size:int = 32):
        """
        Sample elements from the replay buffer.
        
        :param batch_size: Number of elements to sample
        """
        batch_indices = np.random.randint(0, self.size, size=batch_size)
        # Sample randomly the env idx
        env_indices = np.random.randint(0, high=self.n_envs, size=(len(batch_indices),))
        # in the end we return exactly batch_size transitions collected even from different agents
        # Gather indices for the first two dimensions
        indices = (batch_indices, env_indices)
        
        data = dict(
            obs=self.obs[indices],
            next_obs=self.next_obs[indices],
            actions=self.actions[indices],
            rewards=self.rewards[indices],
            # Only use dones that are not due to timeouts
            dones=self.terminates[indices]
        )
        return {k: self._to_torch(v) for k,v in data.items()}
        

    def _to_torch(self, data):
        return torch.tensor(data, dtype=torch.float32, device=self.device)
    
    def __len__(self):
        return self.size
    
    #
    # Internal helpers for dumping/loading via np.savez_compressed
    #
    def _dump_to_npz(self, file_obj):
        """Write the replay buffer data + metadata to a file-like object."""
        np.savez_compressed(
            file_obj,
            obs=self.obs,
            next_obs=self.next_obs,
            actions=self.actions,
            rewards=self.rewards,
            terminates=self.terminates,
            pos=self.pos,
            size=self.size,
            max_size=self.max_size,
            n_envs=self.n_envs
        )

    def _load_from_npz(self, file_obj):
        """Load the replay buffer data + metadata from a file-like object."""
        data = np.load(file_obj)
        self.obs = data["obs"]
        self.next_obs = data["next_obs"]
        self.actions = data["actions"]
        self.rewards = data["rewards"]
        self.terminates = data["terminates"]
        self.pos = int(data["pos"])
        self.size = int(data["size"])
        self.max_size = int(data["max_size"])
        self.n_envs = int(data["n_envs"])
    
    #
    # Public methods for file-based saving/loading
    #
    def save_as_numpy(self, file_path: str):
        """Save the replay buffer to a compressed .npz file on disk."""
        with open(file_path, "wb") as f:
            self._dump_to_npz(f)
        print(f"Replay buffer saved to {file_path}")

    def load_from_numpy(self, file_path: str):
        """Load the replay buffer from a compressed .npz file on disk."""
        with open(file_path, "rb") as f:
            self._load_from_npz(f)
        print(f"Replay buffer loaded from {file_path}")

    #
    # Public methods for in-memory (bytes) saving/loading
    #
    def save_as_bytes(self) -> bytes:
        """
        Serialize the replay buffer to an in-memory bytes object.
        Useful for storing in a single PyTorch checkpoint file.
        """
        buf = io.BytesIO()
        self._dump_to_npz(buf)
        return buf.getvalue()

    def load_from_bytes(self, replay_bytes: bytes):
        """
        Load the replay buffer data from a bytes object
        (the counterpart to save_as_bytes).
        """
        buf = io.BytesIO(replay_bytes)
        self._load_from_npz(buf)
        print("Replay buffer loaded from bytes")

## SAC Agent

In [8]:
from collections import deque
import time
import os

from helpers.utils import Logger

class SACAgent:
    """
    Soft Actor-Critic (SAC)
    Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor.

    :param env(gym.vector.VectorEnv): Vector Gym Environment to learn from, consists of nEnvs
    :param learning_rate: learning rate for adam optimizer,
        the same learning rate will be used for all networks (Q-Values, Actor)
    :param buffer_size: size of the replay buffer
    :param learning_starts: how many steps of the model to collect transitions for before learning starts
    :param batch_size: Minibatch size for each gradient update
    :param tau: the soft update coefficient ("Polyak update", between 0 and 1)
            θ_target = τ*θ_local + (1 - τ)*θ_target
    :param gamma: Discount factor
    :param alpha: Entropy regularization coefficient. (Equivalent to
        inverse of reward scale in the original SAC paper.)  Controlling exploration/exploitation trade-off.
    :param autotune: automatic tuning of the entropy coefficient
    :param device: Device (cpu, cuda, ...) on which the code should be run.
    """
    def __init__(self,
                env: gym.vector.VectorEnv,
                policy_lr: float = 3e-4,
                critic_lr: float = 1e-3,
                buffer_size: int = 1e6,
                learning_starts: int = 10000,
                batch_size: int = 128,
                tau: float = 0.005,
                gamma: float = 0.99,
                alpha: float = 0.2,
                autotune: bool = True,
                device: torch.device = torch.device("cpu"),
                reward_to_achieve: int = 2500):

        self.env = env
        self.n_envs = env.num_envs

        self.mini_batch_size = batch_size
        self.learning_starts = learning_starts
        
        self.tau = tau
        self.gamma = gamma
        self.autotune = autotune    
        
        self.device = device 

        action_shape = self.env.single_action_space.shape
        obs_shape = self.env.single_observation_space.shape 
        action_low = self.env.single_action_space.low 
        action_high = self.env.single_action_space.high

        # initialize networks
        # Actor Network 
        self.actor = Actor(obs_shape, action_shape[0], action_low, action_high).to(self.device)
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=policy_lr)
        
        # Critic Network (w/ Target Network)
        self.critic = Critic(obs_shape, action_shape[0]).to(self.device)
        self.critic_target = Critic(obs_shape, action_shape[0]).to(self.device)
        # hard copy the original actor to target actor
        self.critic_target.load_state_dict(self.critic.state_dict())
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=critic_lr)

        # automatic entropy tuning
        if self.autotune:            
            # Target Entropy = −dim(A) (e.g. , -6 for HalfCheetah-v2) as given in the paper
            self.target_entropy = -torch.prod(torch.Tensor(action_shape).to(self.device)).item()
            print(self.target_entropy)
            # https://github.com/rail-berkeley/softlearning/pull/142 - we should set boundaries for temperature >= 0,
            # this is easily done by training log_alpha, we make this way sure that alpha will be always >= 0
             # pass here as initial temperature as alpha, which is by default provided
            self.log_alpha = torch.log(alpha * torch.ones(1, device=device)).requires_grad_(True)
            # we just use alpha as constant updated time to time after training log_alpha
            self.alpha_optimizer = optim.Adam([self.log_alpha], lr=critic_lr)
            self.alpha = self.log_alpha.exp().item()
        else:
            self.alpha = alpha


        # define Replay Memory 
        self.replay_memory = ReplayBufferNumpy(
            obs_shape,
            action_shape,
            n_envs=self.n_envs,
            size=buffer_size,
            device=self.device)
         
        # --------- Additional Parameters to think about
        self.rollout_steps = 1

        # logger initialization 
        run_name =(
            f"actor_lr{policy_lr}"
            f"_critic_lr{critic_lr}"
            f"_gamma{gamma}"
            f"_nenvs{env.num_envs}"
            f"_batch{self.mini_batch_size}"
            f"_alpha{alpha}"
            f"_autotune{autotune}"
            f"_{int(time.time())}"
        )

        run_name = "".join(run_name)
        self.logger = Logger(run_name=run_name, env=env.envs[0].spec.id, algo="SAC")
        # Log hyperparameters table
        hyperparams = {
            "actor_lr": policy_lr,
            "critic_lr": critic_lr,
            "gamma": gamma,
            "tau": tau,
            "buffer_size": buffer_size,
            "mini_batch_size": self.mini_batch_size,
            "alpha": alpha,
            "autotune": autotune
        }
        self.logger.log_hyperparameters(hyperparams)
        


        self.episode_num = 1
        self.best_reward = float('-inf')
        self.reward_to_achieve = reward_to_achieve 

    def select_action(self, state) -> torch.Tensor:
        # if less then learning warm up number of steps, sample randomly, suggested 10000
        if self._num_timesteps < self.learning_starts:
            return self.env.action_space.sample()

        with torch.no_grad():
            action, _ = self.actor.sample(state, with_log_prob=False)
            action = action.cpu().numpy()

        return action
    
    def soft_update(self, local_model, target_model, tau):
        """Soft update model parameters.
        θ_target = τ*θ_local + (1 - τ)*θ_target

        :param local_model: PyTorch model (weights will be copied from)
        :param target_model: PyTorch model (weights will be copied to)
        :param tau (float): interpolation parameter 
        """
        for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
            target_param.data.copy_(tau*local_param.data + (1.0-tau)*target_param.data)
    
    def collect_rollouts(self):
        """
        Collect experiences and store them into a ``ReplayBuffer``.
        """
        obs = self._last_obs

        for _ in range(self.rollout_steps): 
            
            actions = self.select_action(torch.tensor(obs, dtype=torch.float32, device=self.device))

            next_obs, rewards, terminates, _, infos = self.env.step(actions) 
            
            # Updating global number of steps agent done, while learning
            self._num_timesteps += self.n_envs
            
            if "final_info" in infos:
                source_info = infos["final_info"]
                self._extract_episode_data(source_info.get('episode'), source_info.get('_episode'))
            

            self.replay_memory.add(
                obs,
                next_obs,
                actions,
                rewards,
                terminates
            )

            # update most recent 
            obs = next_obs

        # remember last obs for next rollout call
        self._last_obs = obs 

    def learn(self):
        """
        Update policy and value parameters using given batch of experience tuples.
        next_action, next_log_prob = actor_target(next_state)
        Q1_next_target, Q2_next_target = critic1_target(next_state, next_action), critic2_target(next_state, next_action)
        Q_next_target = min(Q1_next_target, Q2_next_target) - alpha * next_log_prob # max entropy target
        Q_targets = r + γ * Q_next_target
        where:
            actor_target(state) -> action, log_prob
            critic_target(state, action) -> Q-value
        """

        replay_data = self.replay_memory.sample(self.mini_batch_size)
        states = replay_data['obs']#[256,17]
        actions = replay_data['actions'] #[256,6]
        rewards = replay_data['rewards'].unsqueeze(-1)# [256,1]
        next_states = replay_data['next_obs'] # [256, 17]
        dones = replay_data['dones'].unsqueeze(-1)#[256,1]

        #---------------------------- update critic ---------------------------- #
        # Get predicted next-state actions and Q values from target models        
        with torch.no_grad():
            next_actions, next_log_prob = self.actor.sample(next_states)
            Q1_next_targets, Q2_next_targets = self.critic_target(next_states, next_actions)
            Q_next_targets = torch.min(Q1_next_targets, Q2_next_targets) 
            Q_next_targets -= self.alpha * next_log_prob
            # print(Q_next_targets.shape) #torch.Size([256, 1])
            # Compute Q targets for current states 
            Q_targets = rewards + (self.gamma * (1-dones) * Q_next_targets)            
         
        # Compute critic loss 
        Q1_expected, Q2_expected = self.critic(states, actions)
        Q1_loss = F.mse_loss(Q1_expected, Q_targets)
        Q2_loss = F.mse_loss(Q2_expected, Q_targets)
        critic_loss = Q1_loss + Q2_loss
        
        # Minimize the loss
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()
        
        #---------------------------- update actor (update policy weights) ------------------------------ #
        actions, log_probs = self.actor.sample(states)
        Q1_actions, Q2_actions = self.critic(states, actions)
        min_Q_actions = torch.min(Q1_actions, Q2_actions)
        self.actor_loss = ((self.alpha * log_probs) - min_Q_actions).mean() # Jπ = 𝔼st∼D,εt∼N[α * logπ(f(εt;st)|st) − Q(st,f(εt;st))]

        # Minimize the loss
        self.actor_optimizer.zero_grad()
        self.actor_loss.backward()
        self.actor_optimizer.step()

        #--------------------------- adjust temperature (entropy coefficient) ---------------------- #
        if self.autotune:
            with torch.no_grad(): 
                _, log_probs = self.actor.sample(states)
            # https://github.com/rail-berkeley/softlearning/issues/37 -> old issue, ok to use alpha, 
            # just need to bound temperature using log_alpha as the tensor we are training making sure alpha >= 0
            # https://github.com/rail-berkeley/softlearning/issues/136 & https://github.com/rail-berkeley/softlearning/pull/142
            # print(log_probs.shape) torch.Size([256, 1])
            self.alpha_loss = (self.log_alpha.exp() * (-log_probs.detach() - self.target_entropy)).mean()

            self.alpha_optimizer.zero_grad()
            self.alpha_loss.backward()
            self.alpha_optimizer.step()

            # update alpha after training 
            self.alpha = self.log_alpha.exp().item()
        else:
            self.alpha_loss = torch.tensor(0.0)
        
        #-------------------------- update target network  weights (for critic) ----------------------- # 
        with torch.no_grad():
            self.soft_update(self.critic, self.critic_target, self.tau)

        return( 
            Q1_loss.item(),
            Q2_loss.item(),
            Q1_expected,
            Q2_expected,
            critic_loss.item(),
            self.actor_loss.item(),
            self.alpha_loss.item()
        )     

    def train(self, total_timesteps: int = 1e6, eval_frequency: int = 5000):
        """
        Train the agent
        """
        obs, _ = self.env.reset()
        self._last_obs = obs

        self._num_timesteps = 0
        self._learn_iterations = 0 #used to control the policy frequency
        self.episode_rewards = deque(maxlen=20)
        self.episode_lengths = deque(maxlen=20)
        eval_frequency = (eval_frequency // self.n_envs) * self.n_envs  # Make divisible by n_envs
    
        q1_loss, q2_loss, critic_loss, actor_loss, alpha_loss = 0.0,0.0,0.0,0.0,0.0
        q1_values, q2_values = torch.tensor([]), torch.tensor([])
        self.actor_loss = torch.tensor(0.0)
        self.alpha_loss = torch.tensor(0.0)
        while self._num_timesteps < total_timesteps:
            
            # Collect rollouts
            self.collect_rollouts()

            if self._num_timesteps > self.learning_starts:
                # we assume learning_starts > batch_size
                q1_loss, q2_loss,q1_values, q2_values, critic_loss, actor_loss, alpha_loss = self.learn()

            # let's introduce step logging
            if self._num_timesteps % 100 == 0:
                self.logger.add_scalar("losses/q1_loss", q1_loss, self._num_timesteps)
                self.logger.add_scalar("losses/q2_loss", q2_loss, self._num_timesteps)
                self.logger.add_scalar("losses/q1_value", q1_values.mean().item(), self._num_timesteps)
                self.logger.add_scalar("losses/q2_value", q2_values.mean().item(), self._num_timesteps)
                self.logger.add_scalar("losses/qf_loss", critic_loss / 2.0, self._num_timesteps)
                self.logger.add_scalar("losses/actor_loss", actor_loss, self._num_timesteps)
                if self.autotune:
                    self.logger.add_scalar("losses/alpha", self.alpha, self._num_timesteps)
                    self.logger.add_scalar("losses/alpha_loss", alpha_loss,  self._num_timesteps)

            # Do some logging each log_interval
            if self._num_timesteps % eval_frequency == 0:
                mean_rewards = np.mean(self.episode_rewards)
                median_rewards = np.median(self.episode_rewards)
                min_rewards = np.min(self.episode_rewards)
                max_rewards = np.max(self.episode_rewards)
                
                print(
                    "Num timesteps {}/{} Mean Length {:.1f}\n"
                    "Losses: Actor {:.4f} Critic {:.4} Alpha {:.4f}\n"
                    "Last {} training episodes: mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}"
                    .format(
                        self._num_timesteps, total_timesteps, np.mean(self.episode_lengths),
                        actor_loss, critic_loss / 2.0, alpha_loss,
                        len(self.episode_rewards), mean_rewards,
                        median_rewards, min_rewards, max_rewards
                    )
                )

                if  mean_rewards > self.reward_to_achieve and mean_rewards > self.best_reward:
                    self.best_reward = mean_rewards
                    self._save_model('best-torch.model')
                    print(f"Saved best model with min rewards {mean_rewards:.2f}")
             
        print('Saving last...')
        self._save_model('torch.model')
        
    def _save_model(self, filename: str):
        """Save the current model state dictionaries and normalization info to a file in logger.dir_name."""
        model_path = os.path.join(self.logger.dir_name, filename)
        
        save_data = {
            "actor": self.actor.state_dict(),
            "critic": self.critic.state_dict(),
        }

        torch.save(save_data, model_path)

    def _extract_episode_data(self, episode_data, episode_flags):
        """
        Extract data for environments where '_episode' is True and append it to 
        self.episode_info_buffer deque
            {'r': -21.0, '_r': True, 'l': 944, '_l': True, 't': 5.089006, '_t': True}- example
            r - cumulative reward
            l - episode length
            t - elapsed time since beginning of episode

        :param episode_data: dict, data from environments
        :param episode_flags: np.ndarray, boolean array indicating done environments
        """
        done_envs = np.where(episode_flags)[0]  # Get indices of done environments
    
        for env_index in done_envs:
            env_specific_data = {}
            for key, value in episode_data.items():
                if isinstance(value, np.ndarray):  # Ensure it's an array
                    env_specific_data[key] = value[env_index]
            reward = env_specific_data['r']
            length = env_specific_data['l']
            self.logger.add_scalar('charts/reward', reward, self._num_timesteps)
            self.logger.add_scalar('charts/length', length, self._num_timesteps)        
            self.episode_rewards.append(reward)
            self.episode_lengths.append(length)
            self.episode_num += 1
   

### Agent heckpoint Logic 

I would like to have some logic to save agent at some step and then be able to load it from this step and continue to learn 

What we would need to store to be able to do this?
1. Agent (Policy) Network
2. Critic (Q) networks
3. Target critic (Q) networks
4. Entropy (temperature) parameter (α) — if auto-tuning is used
5. Optimizers (for actor, critic, and α) — if we want to continue training with the same optimizer states
6. Replay buffer (if we want to continue training from the same experience)

Wow, quite a lot lol, I just would like to be able run humanoid from some checkpoint for additional 1mln steps, ok, lets do it then.

In [9]:
import io

def _dump_to_npz(replay_memory, file_obj):
        """Write the replay buffer data + metadata to a file-like object."""
        np.savez_compressed(
            file_obj,
            obs=replay_memory.obs,
            next_obs=replay_memory.next_obs,
            actions=replay_memory.actions,
            rewards=replay_memory.rewards,
            terminates=replay_memory.terminates,
            pos=replay_memory.pos,
            size=replay_memory.size,
            max_size=replay_memory.max_size,
            n_envs=replay_memory.n_envs
        )
def save_as_bytes(replay_memory) -> bytes:
        """
        Serialize the replay buffer to an in-memory bytes object.
        Useful for storing in a single PyTorch checkpoint file.
        """
        buf = io.BytesIO()
        _dump_to_npz(replay_memory, buf)
        return buf.getvalue()

def _load_from_npz(replay_memory, file_obj):
        """Load the replay buffer data + metadata from a file-like object."""
        data = np.load(file_obj)
        replay_memory.obs = data["obs"]
        replay_memory.next_obs = data["next_obs"]
        replay_memory.actions = data["actions"]
        replay_memory.rewards = data["rewards"]
        replay_memory.terminates = data["terminates"]
        replay_memory.pos = int(data["pos"])
        replay_memory.size = int(data["size"])
        replay_memory.max_size = int(data["max_size"])
        replay_memory.n_envs = int(data["n_envs"])
    

def load_from_bytes(replay_memory, replay_bytes: bytes):
        """
        Load the replay buffer data from a bytes object
        (the counterpart to save_as_bytes).
        """
        buf = io.BytesIO(replay_bytes)
        _load_from_npz(replay_memory, buf)
        print("Replay buffer loaded from bytes")



def save_checkpoint(agent:SACAgent, filename="checkpoint.pth"):
    # 1) Serialize replay buffer in memory
    replay_bytes = save_as_bytes(agent.replay_memory)

    # 2) Save everything in one PyTorch file
    torch.save({
        "actor_state_dict": agent.actor.state_dict(),
        "actor_optimizer_state_dict": agent.actor_optimizer.state_dict(),
        
        "critic_state_dict": agent.critic.state_dict(),
        "critic_optimizer_state_dict": agent.critic_optimizer.state_dict(),

        "critic_target_state_dict": agent.critic_target.state_dict(),
        # If using alpha auto-tuning
        "log_alpha": agent.log_alpha.item() if agent.autotune else 0.0, 
        "alpha_optimizer_state_dict": agent.alpha_optimizer.state_dict() if agent.autotune else 0.0,

        "num_timesteps": agent._num_timesteps,

        # The replay buffer, now as in-memory bytes
        "replay_buffer_bytes": replay_bytes 
    }, filename)

    print(f"Checkpoint saved to {filename}")

def load_checkpoint(agent:SACAgent, filename="checkpoint.pth"):
    checkpoint = torch.load(filename, map_location="cpu")  # or map_location=agent.device
    
    # Restore agent networks
    agent.actor.load_state_dict(checkpoint["actor_state_dict"])
    agent.critic.load_state_dict(checkpoint["critic_state_dict"])
    agent.critic_target.load_state_dict(checkpoint["critic_target_state_dict"])
    agent.actor_optimizer.load_state_dict(checkpoint["actor_optimizer_state_dict"])
    agent.critic_optimizer.load_state_dict(checkpoint["critic_optimizer_state_dict"])
    
    agent._num_timesteps =checkpoint["num_timesteps"]

    # If using alpha auto-tuning
    if "log_alpha" in checkpoint and checkpoint["log_alpha"] != 0.0:
        agent.log_alpha = torch.tensor(checkpoint["log_alpha"], requires_grad=True, device=agent.device)
        agent.alpha = agent.log_alpha.exp().item()
        agent.alpha_optimizer.load_state_dict(checkpoint["alpha_optimizer_state_dict"])

    # Restore the replay buffer bytes
    if "replay_buffer_bytes" in checkpoint:
        replay_bytes = checkpoint["replay_buffer_bytes"]
        load_from_bytes(agent.replay_memory, replay_bytes)

    print(f"Checkpoint loaded from {filename}")

### Train Example - Ant-v5

Setup envs

In [10]:
from helpers.envs import make_sync_vec, AutoresetMode

num_envs = 1
env_id = 'Ant-v5'

envs = make_sync_vec(env_id, 
                    num_envs=num_envs, 
                    wrappers=(gym.wrappers.RecordEpisodeStatistics,),
                    autoreset_mode=AutoresetMode.SAME_STEP)


agent = SACAgent(envs,  
                  policy_lr=3e-4,
                  critic_lr=3e-4,
                  buffer_size=1e6,
                  learning_starts=25000,
                  batch_size=256,
                  tau=0.005,
                  gamma=0.99,
                  alpha=1.0,
                  autotune=True,
                  reward_to_achieve=4000) 
    

agent.train(total_timesteps=1e6) 

-8.0
----------------------------------------------------------------------------
| Hyperparams                    | Values                                  |
----------------------------------------------------------------------------
| actor_lr                       | 0.0003                                  |
| critic_lr                      | 0.0003                                  |
| gamma                          | 0.99                                    |
| tau                            | 0.005                                   |
| buffer_size                    | 1000000.0                               |
| mini_batch_size                | 256                                     |
| alpha                          | 1.0                                     |
| autotune                       | True                                    |
----------------------------------------------------------------------------
Num timesteps 5000/1000000.0 Mean Length 116.0
Losses: Actor 0.0000 Cri

### Train Humanoid-v5

In [12]:
from helpers.envs import make_sync_vec, AutoresetMode

num_envs = 1
env_id = 'Humanoid-v5'

envs = make_sync_vec(env_id, 
                    num_envs=num_envs, 
                    wrappers=(gym.wrappers.RecordEpisodeStatistics,),
                    autoreset_mode=AutoresetMode.SAME_STEP)


agent = SACAgent(envs,  
                policy_lr=3e-4,
                critic_lr=3e-4,
                buffer_size=1e6,
                learning_starts=25000,
                batch_size=256,
                tau=0.005,
                gamma=0.99,
                alpha=0.1,
                autotune=True,
                device=DEVICE) 

agent.train(total_timesteps=1e6, eval_frequency=10000) 

-17.0
----------------------------------------------------------------------------
| Hyperparams                    | Values                                  |
----------------------------------------------------------------------------
| actor_lr                       | 0.0003                                  |
| critic_lr                      | 0.0003                                  |
| gamma                          | 0.99                                    |
| tau                            | 0.005                                   |
| buffer_size                    | 1000000.0                               |
| mini_batch_size                | 256                                     |
| alpha                          | 0.1                                     |
| autotune                       | True                                    |
----------------------------------------------------------------------------
Num timesteps 10000/1000000.0 Mean Length 23.8
Losses: Actor 0.0000 Cr

### Evaluate, make video

In [25]:
# Import required helpers for the evaluation and video making 
from helpers.utils import create_evaluation_env_model

import os
import torch
import imageio
from IPython.display import Video, display

def eval_policy(env, policy, num_episodes=10):
    """Evaluate the policy over a number of episodes."""
    # Store rewards for each episode
    episode_rewards = []
    for episode in range(num_episodes):
        obs, _ = env.reset()
        done = False
        total_reward = 0
        while not done:
            state = torch.tensor(obs, dtype=torch.float32)
            if len(state) > 2:
                state = state.unsqueeze(0)
            # Select the action using the trained model
            action, _ = policy.sample(state, deterministic=True)

            # Step the environment
            obs, reward, terminated, truncated, _ = env.step(action.squeeze(0).detach().numpy())
            total_reward += reward
            done = terminated or truncated
           
        episode_rewards.append(total_reward) # Store the total reward for the episode
        # print(f"Episode {episode + 1}: Total Reward = {total_reward}")

    # Calculate mean and standard deviation
    mean_reward = np.mean(episode_rewards)
    std_reward = np.std(episode_rewards)
    print(f"Mean Reward: {mean_reward}, Standard Deviation: {std_reward}")

    return mean_reward, std_reward


def record_and_display(env, policy, out_directory, out_name, fps=30, min_reward=None):
    """
    Generate a replay video of the agent and save it.
    """
    images = []
    obs, _ = env.reset()
    total_reward = 0
    episode_length = 0
    num_saved_episodes = 0

    while True:
        state = torch.tensor(obs, dtype=torch.float32)
        if len(state) > 2:
            state = state.unsqueeze(0)
        
        # Select the action using the trained model
        action, _ = policy.sample(state, deterministic=True)
        
        obs, reward, terminated, truncated, _ = env.step(action.squeeze(0).detach().numpy())
        total_reward += reward
        images.append(env.render())
        episode_length += 1

        if terminated or truncated:
            obs, _ = env.reset()
            print(total_reward)
            if min_reward is None or total_reward >= min_reward:
                num_saved_episodes += 1
            else:
                images = images[:-episode_length]  # Remove unsatisfactory episode
            total_reward = 0
            episode_length = 0
            if num_saved_episodes == 2:  # Save 2 episodes
                break
    
    # Save the video
    video_path = os.path.join(out_directory, out_name)
    imageio.mimsave(video_path, images, fps=fps)

    display(Video(video_path, embed=True, width=640, height=480))

### Result Ant-v5

In [21]:
env, eval_model = create_evaluation_env_model('Ant-v5', 
                                              Actor, 
                                              './runs/Ant-v5/SAC/actor_lr0.0003_critic_lr0.0003_gamma0.99_nenvs1_batch256_alpha1.0_autotuneTrue_1741005896/torch.model') 

eval_policy(env, eval_model, num_episodes=10)

Episode 1: Total Reward = 4296.546073970611
Episode 2: Total Reward = 4012.826499535744
Episode 3: Total Reward = 4135.276634500188
Episode 4: Total Reward = 4224.3128726363375
Episode 5: Total Reward = 4274.455572229548
Episode 6: Total Reward = 1243.3679581188594
Episode 7: Total Reward = 4259.198290767898
Episode 8: Total Reward = 3278.8083138924812
Episode 9: Total Reward = 4319.6981346522125
Episode 10: Total Reward = 4041.298753650055
Mean Reward: 3808.578910395394, Standard Deviation: 903.0979765313413


(np.float64(3808.578910395394), np.float64(903.0979765313413))

In [22]:
record_and_display(env, 
             eval_model, 
             '../assets/videos', 
             'sac_antv5.mp4', 
             fps=60, min_reward=4000)

4211.855005612604
4126.5183026612685


<video width="640" height="480" controls>
  <source src="../assets/videos/sac_antv5.mp4" type="video/mp4">
</video>


### Result Humanoid-v5

In [28]:
env, eval_model = create_evaluation_env_model('Humanoid-v5', 
                                              Actor, 
                                              './runs/Humanoid-v5/SAC/actor_lr0.0003_critic_lr0.0003_gamma0.99_nenvs1_batch256_alpha0.1_autotuneTrue_1741020294/torch.model') 

eval_policy(env, eval_model, num_episodes=10)

Mean Reward: 4214.804822152573, Standard Deviation: 1075.3726589196633


(np.float64(4214.804822152573), np.float64(1075.3726589196633))

In [29]:
record_and_display(env, 
             eval_model, 
             '../assets/videos', 
             'sac_humanoidv5.mp4', 
             fps=60, min_reward=4000)

5032.008779290945
2590.341345289942
3754.9915283064843
3637.3431113344745
4498.389892091903


<video width="640" height="480" controls>
  <source src="../assets/videos/sac_humanoidv5.mp4" type="video/mp4">
</video>

```bash
Performance
^
|    /\      /\/\  
|   /  \____/    \_/\___  SAC w/auto-tune
|  /               
| /         _____________  TD3
|/________/
+-----------------------> Training Steps
   Fast    Stable phase
   gains
```

Looks like SAC with autotune is faster to learn, but TD3 is more stable. 
This pattern is actually expected - SAC often achieves higher peaks faster but shows more variance and sometimes deteriorates. TD3 typically shows steadier improvement and better stability in the long run.

For the best results, consider:

* Using SAC with auto-tune for faster initial learning
* Potentially switching to TD3 or manual SAC temperature for the final policy refinement
* Lowering the learning rate for SAC after 1M steps

# OFFTOPIC

#### Experiments on the torch distributions with transforms

In [11]:
import torch
from torch.distributions import Normal, TransformedDistribution
from torch.distributions.transforms import TanhTransform, AffineTransform

# Parameters for the Gaussian distribution
mean = torch.zeros(32, 4)
std = torch.ones(32, 4)
base_dist = Normal(mean, std)

# Define the scaling factor and offset for actions
action_scale = torch.tensor(2.0)  # Scale to [-2, 2]
action_loc = torch.tensor(0.0)    # Centered around 0

# Combine Tanh and Affine transforms
tanh_transform = TanhTransform(cache_size=1)  # Tanh squashing
scale_transform = AffineTransform(loc=action_loc, scale=action_scale)  # Scaling
transforms = [tanh_transform, scale_transform]

# Create the transformed distribution
squashed_and_scaled_dist = TransformedDistribution(base_dist, transforms)

# Sampling and log probabilities
sample = squashed_and_scaled_dist.rsample()  # Squashed and scaled sample
log_prob = squashed_and_scaled_dist.log_prob(sample).sum(dim=-1, keepdim=True)  # Corrected log prob

print("Sampled action:", sample)
print("Log probability:", log_prob)

Sampled action: tensor([[ 0.5074,  1.5703,  1.4810, -1.4949],
        [ 1.5509, -1.8622, -0.5262,  0.0907],
        [ 1.2739,  0.3555, -0.2102,  0.4862],
        [ 0.6198,  1.5352, -0.8541,  1.5408],
        [-0.4747,  0.3214,  0.7239,  1.7032],
        [-1.2132,  1.1641,  0.5863, -0.2658],
        [-0.5931,  0.8703,  0.9703,  0.8083],
        [-1.2249,  1.7981,  0.4451, -1.2591],
        [ 0.7131,  0.5960,  1.0656,  0.8743],
        [ 0.0526, -1.0802,  0.9369, -0.4025],
        [-0.5002, -1.1977,  1.9006,  1.8159],
        [ 1.4157, -1.0389,  1.8482, -0.2980],
        [ 1.1150, -0.7821,  1.8687,  1.0501],
        [ 1.4550, -0.3507, -1.1966, -0.7126],
        [-1.3134, -0.6636,  1.4307, -0.2071],
        [ 0.9355, -1.7647,  1.6452, -1.8502],
        [-0.2547,  1.1819,  1.6270, -1.4446],
        [-1.5273,  0.4018,  1.8404, -0.9995],
        [-1.7603,  0.8722, -0.6005, -1.2477],
        [ 1.8503, -0.3827,  0.3971,  1.2566],
        [-1.7159, -1.8221, -0.0553,  1.6546],
        [ 0.7808, 

### Proof by Lilian Weng
<img src="../assets/images/sac-autotune-proof1.png" width="auto" height="auto">

<img src="../assets/images/sac-autotune-proof2.png" width="auto" height="auto">