<a rel="license" href="http://creativecommons.org/licenses/by-nc-sa/4.0/"><img alt="Creative Commons License" align="left" src="https://i.creativecommons.org/l/by-nc-sa/4.0/80x15.png" /></a>&nbsp;| [Emmanuel Rachelson](https://personnel.isae-supaero.fr/emmanuel-rachelson?lang=en) | <a href="https://erachelson.github.io/RLclass_MVA/">https://erachelson.github.io/RLclass_MVA/</a>

<div style="font-size:22pt; line-height:25pt; font-weight:bold; text-align:center;">Chapter 5: Continuous actions in DQN algorithms</div>

In previous chapters, we introduced function approximation within value iteration methods, which enabled learning optimal value functions in very large state spaces. As we were deriving the corresponding algorithms, we always retained an implicit policy, as that which was greedy with respect to the value function. This was easy since all the challenges we faced had only a small number of discrete actions. When the action space turns continuous (or when there are too many actions for an efficient enumeration), writing $\pi(s) \in \arg\max_a Q(s,a)$ does not translate as easily into an implementable policy. Deciding how to choose actions boils down to solving a separate optimization problem in every state, which can quickly become quite impractical. In the present chapter, we make the policy explicit, and build upon the actor-critic architecture to derive approximate value iteration algorithms for continuous action spaces.

<div class="alert alert-success">

**Learning outcomes**   
By the end of this chapter, you should be able to:
- explain the deterministic policy gradient theorem
- implement a DDPG algorithm and discuss its properties
- implement TD3, motivate it, and discuss its properties
- explain the derivation of soft policy iteration and soft actor-critic
- implement a SAC algorithm with fixed temperature
</div>

# Playground

We will be playing with different continuous environments in this chapter. Gymnasium provides a number of such environments, for example through the [Box2D](https://gymnasium.farama.org/environments/box2d/) and [MuJoCo](https://gymnasium.farama.org/environments/mujoco/) families.

To keep things computationally light, we will mostly play with the [inverted pendulum](https://gymnasium.farama.org/environments/mujoco/inverted_pendulum/) environment for the MuJoCo suite. Some exercises along the chapter will open towards more difficult benchmarks. Feel free to try them (or others), to build a better sense of how algorithms work.

In [None]:
import gymnasium as gym
from gymnasium.utils.save_video import save_video
from tqdm import trange

#test_env = gym.make("BipedalWalker-v3", render_mode="rgb_array_list")
test_env = gym.make("InvertedPendulum-v4", render_mode="rgb_array_list")
s,_ = test_env.reset()
for t in range(1000):
    a = test_env.action_space.sample()
    s2,r,d,trunc,_ = test_env.step(a)
    s = s2
    if d:
        break

save_video(test_env.render(), "videos", fps=test_env.metadata["render_fps"], name_prefix="random_policy",)

In [None]:
from IPython.display import Video
Video("videos/random_policy-episode-0.mp4")

In [None]:
import gymnasium as gym
import numpy as np
from tqdm import trange

#test_env = gym.make("BipedalWalker-v3", render_mode="rgb_array")
test_env = gym.make("InvertedPendulum-v4", render_mode="rgb_array")
s,_ = test_env.reset()
returns = []
for _ in trange(50):
    cumulated_reward = 0
    s,_ = test_env.reset()
    for t in range(1000):
        a = test_env.action_space.sample()
        s2,r,d,trunc,_ = test_env.step(a)
        cumulated_reward += r
        s = s2
        if d:
            break
    returns.append(cumulated_reward)

print(np.mean(returns))
print(np.std(returns))

In [None]:
print(test_env.action_space)
print(test_env.observation_space)

# Deep deterministic policy gradients

## The deterministic policy gradient theorem

Let us, once again, restart from approximate value iteration (AVI) as a sequence of risk minimization problems.

<div class="alert alert-success">

**Approximate value iteration as a sequence of risk minimization problems**  
$$\pi_n \in \mathcal{G} Q_n,$$
$$L_n(\theta) = \frac{1}{2} \mathbb{E}_{(s,a) \sim \rho}\left[ \left( Q(s,a;\theta) - G^{\pi_n}_1(s,a,Q_n) \right)^2 \right],$$
$$\theta_{n+1} \in \arg\min_{\theta} L_n(\theta),$$
$$Q_{n+1}(s,a) = Q(s,a;\theta_{n+1}).$$
</div>

For deterministic policies, we have
$$\pi \in \mathcal{G} Q, \Leftrightarrow \pi(s) \in \arg\max_{a \in A} \left[Q(s,a)\right], \forall s\in S.$$

Finding $\pi \in \mathcal{G}Q$ was relatively easy as long as there were few, discrete actions. But when actions are continuous, solving a $\max_a$ problem is a continuous optimization problem on which we have little knowledge.

Let us turn to what we called "weak optimality" in a previous chapter, that is, given a distribution $\rho_0(s)$ on starting states, the search for a policy that maximizes $J(\pi) = \mathbb{E}_{s\sim \rho_0} [V^\pi(s)]$. As indicated in previous chapters, finding a policy which maximizes this *average value across states* is not necessarily the same as finding a policy which *dominates any other policy in every state*. But in most practical cases, $J(\pi)$ is a very reasonable and interesting proxy for optimality.

Recall that $V^\pi(s) = Q^\pi(s,\pi(s))$. 
So, for a given function $Q$, instead of looking for $\pi(s) \in \arg\max_{a \in A} \left[Q(s,a)\right], \forall s\in S$, we can redefine *greediness* and look for 
$$\pi \in \arg\max_{a \in A} \mathbb{E}_{s\sim \rho_0} [Q(s,\pi(s))] = J_Q(\pi).$$
If $Q=Q^*$, then a maximizer of this quantity is a maximizer of $J(\pi)$.

Now if $\pi$ is a parameterized function $\pi_w$, then one can try to approximate $\pi_n \in \mathcal{G}Q_n$ by taking gradient steps on $J_{Q_n}(\pi_w)$. This is the key idea behind deterministic policy gradient algorithms.

It relies on the deterministic policy gradient theorem, introduced by Silver at al (2014) in the **[Deterministic Policy Gradient Algorithms](https://proceedings.mlr.press/v32/silver14.html)** paper.

Let us write $J(w) = J(\pi_w)$ for a parametric policy $\pi_w$.  
Write also $\rho^{\pi_w}(s) = \sum\limits_{t = 0}^\infty \gamma^t p(S_t=s|\rho_0,\pi_w)$ for all $s \in S$ the state occupancy measure of $\pi_w$ given $\rho_0$.

<div class="alert alert-success">

**Deterministic policy gradient theorem**  
Consider a deterministic policy $\pi_w: S\rightarrow A$ interacting with an MDP $(S, A, p, r)$ with a starting state distribution $\rho_0$.  
We will drop the $w$ subscripts wherever unambiguous, to improve readability.    
If $p(s,a)$, $\nabla_a p(s'|s,a)$, $r(s,a)$, $\nabla_a r(s,a)$, $\rho_0(s)$, $\pi_w(s)$, and $\nabla_w\pi_w(s)$ all exist and are continuous in $(s,a,s')$, then 
$$\nabla_w J(w) = \mathbb{E}_{s\sim \rho^{\pi}} \left[ \nabla_a Q^{\pi}(s,a)|_{a=\pi(s)} \cdot \nabla_w \pi_w(s) \right].$$
</div>

Note that $\nabla_w \pi_w(s)$ is a Jacobian matrix where each column is the derivative of an action variable with respect to the parameters $w$.

Rewriting this theorem with partial derivatives, we have:
$$\nabla_w J(w) = \mathbb{E}_{s\sim \rho^{\pi}} \left[ \frac{\partial Q^{\pi}(s,a)}{\partial a}(s,\pi(s)) \cdot \frac{\partial \pi(s)}{\partial w}(s) \right].$$

This theorem looks like the chain rule applied to $J(\pi_w)$ but it's actually a bit more than that.

Let us write $J(w)=J(\pi_w)$ again.
\begin{align*}
J(w) &= \mathbb{E}_{s\sim \rho_0} [V^{\pi_w}(s)]\\
 &= \mathbb{E}_{s\sim \rho_0} [Q^{\pi_w}(s,\pi_w(s))]
\end{align*}

Let us take the gradient of this term with respect to $w$.
$$\nabla_w J(w) = \mathbb{E}_{s\sim \rho_0} \left[ \frac{\partial Q^{\pi_w}(s,\pi_w(s))}{\partial w} \right].$$

If we had had a fixed $Q$ instead of $Q^{\pi_w}$ in the expression above, we could have used the chain rule and we could have written:
$$\frac{\partial Q(s,\pi_w(s))}{\partial w}  = \frac{\partial Q(s,a)}{\partial a}(s,\pi_w(s)) \frac{\partial \pi_w(s)}{\partial w}(s).$$

But we don't have a fixed $Q$, and as soon as $w$ changes infinitesimally, $Q^{\pi_w}$ changes too, so this chain rule is not so straightforward.

The full proof of the deterministic policy gradient theorem is in the appendix of the aforementionned paper and we will not recall it here. In short, the derivation implies unfolding the sum over times steps of reward random variables in $Q^\pi$, which leads to the introduction of $\rho^\pi$.  

Instead, we will try to provide intuition as to why this gradient ascent direction makes sense in an AVI context.

## Connecting the DPG and AVI

**The intuition.**  
Suppose we have a current $Q_n$ in the AVI sequence, and we are searching for $\pi_n$, with the intention to define $Q_{n+1}$. Then, for each visited state $s$, $\pi_n(s)$ is a $Q_n$-greedy action $\pi_n(s) \in \arg\max_{a} Q_n(s,a)$. So, in each state, given a current $\pi_w$, the policy parameters should move in the direction of $\nabla_w Q_n(s,\pi_w(s))$. Since this should be true in all states visited by $\pi_w$, averaging these gradients according to $\rho^{\pi_w}$ makes sense.

**Understanding the deterministic policy gradient theorem: where $\infty$-horizon improvement and one-step greediness coincide.**  
Let $\pi_w$ be the current policy in the AVI process. Then if there is no approximation error, $Q_n = Q^{\pi_w}$. We want to find a policy that dominates over $\pi_w$. What the policy gradient theorem tells us is that the overall improvement step $\nabla_w J(w)$ of $\pi_w$ is found using gradients derived from the 1-step-lookahead value function $Q_n$ with respect to $\pi_w$. In other words, in $w$ specifically, the deterministic policy gradient $\nabla_w J(w)$ coincides with $\mathbb{E}_{s\sim \rho^{\pi_w}} [\nabla_w Q_n(s,\pi_w(s))]$. But this is only true in $w$, because of the tight coupling between $\rho^{\pi_w}$, $Q_n$, and $\pi_w$. Once the gradient step is taken, both $\rho^{\pi_w}$ and $Q^{\pi_w}$ change and need to be updated to perform future gradient steps: the gradient estimate can only use $Q_n$ because it was an estimate of $Q^{\pi_w}$. 

**Nuts and bolts of DPG.**   
Yet, provided we trust $Q_n$ to be close enough to $Q^{\pi_w}$ and provided we can draw from a distribution close enough to $\rho^{\pi_w}$, this enables building a Monte Carlo estimator of $\nabla_w J(w)$ by drawing samples according to $\rho^{\pi_w}$, summing the corresponding values of $Q_n(s,\pi_w(s))$ and taking the gradient with respect to $w$.
In practice, we often have a replay buffer which is representative not only of $\rho^{\pi_w}$, but rather of a mix of successive $\rho^{\pi_w}$. Nonetheless, this replay buffer will enable defining a loss function on the policy. Consequently, given a replay buffer distribution $\rho_n$ at iteration $n$, we can redefine the AVI sequence as:

$$L^\pi_n(w) = \mathbb{E}_{\rho_n} \left[ Q_n(s,\pi_w(s)) \right]$$
$$w_n = w_{n-1} +\alpha \nabla_w L^\pi_n(w_{n-1})$$
$$\pi_n = \pi_{w_n}$$
$$L^Q_n(\theta) = \frac{1}{2} \mathbb{E}_{(s,a) \sim \rho_n}\left[ \left( Q(s,a;\theta) - G^{\pi_n}_1(s,a,Q_n) \right)^2 \right]$$
$$\theta_{n+1} \in \arg\min_\theta L^Q_n(\theta)$$
$$Q_{n+1} = Q_{\theta_n}$$

Recall that the loss $L^Q_n$ on the Q-function's parameters is actually an off-policy loss, requiring only that $\rho_n$ covers states and actions which are likely to be encountered by $\pi_n$.

In turn, this provides us with a direct way to implement a general **[Deep Deterministic Policy Gradient](https://arxiv.org/abs/1509.02971)** or DDPG (Lillicrap et al., 2016) algorithm. 

## Deep Deterministic Policy Gradient

Recall that (as in the previous chapter), a single gradient step makes for a poor function approximator. To compensate for this, we previously introduced two mechanisms: either take several gradient steps with respect to a given *target* network (we did it for $Q$ but the same idea applies to $\pi$), or implement a moving average. DDPG implements the latter. Define $\theta'$ and $w'$ the target networks' parameters, then $G^{\pi_n}_1(s,a,Q_n)$ in the Q-function loss becomes $G^{\pi_{w'}}_1(s,a,Q_{\theta'})$. 

So, after each drawn mini-batch:
$$w \leftarrow w + \alpha_w \nabla_w L^\pi_n(w),$$
$$\theta \leftarrow \theta - \alpha_\theta \nabla_\theta L^Q_n(\theta).$$
And the target networks are updated according to:
$$w' \leftarrow \tau w + (1-\tau) w',$$
$$\theta' \leftarrow \tau \theta + (1-\tau) \theta'.$$

In order to use libraries like `pytorch` which perform stochastic gradient descent (and not ascent), we redefine the policy update loss as $-Q(s,\pi(s))$.

Finally, we need to define the behavior policy. The policy update requires states to be drawn according to $\rho^\pi$, but the Q-function update requires to test all actions in the encountered states. Since we work with continuous actions, we turn to an exploration policy consisting of adding noise to the deterministic policy's action. Historically, this noise was a time-correlated Ornstein-Uhlenbeck noise but this later appeared to be unnecessary in practice and a simple normal distribution suffices.

Overall, the pseudo-code of DDPG is:

```
Initialize theta=theta' and w=w' 
Initialize replay buffer RB
s = env.init()
loop:
   Pick a = pi_w(s) + noise
   s',r = env.step(a)
   RB.append(s,a,r,s')
   minibatch = RB.sample()
   actor_loss(w) = MSE(Q(s,pi(s,w),theta)
   actor_loss.gradient_ascent_step()
   target_value(s,a) = r + gamma Q(s',pi(s',w'),theta')
   critic_loss(theta) = MSE(Q(s,a,theta), target_value(s,a))
   critic_loss.gradient_descent_step()
   s=s'
```

<div class="alert alert-warning">
    
**Exercise:**  
Declare a class for a Q-function neural network. Use two hidden layers with 256 neurons each.
</div>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Warning, this class only works for vector-shaped inputs (for images, it requires adjustments).
class QNetwork(nn.Module):
    def __init__(self, env):
        super().__init__()
        state_dim = env.observation_space.shape[0]
        action_dim = env.action_space.shape[0]
        self.fc1 = nn.Linear(state_dim + action_dim, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 1)
    def forward(self, x, a):
        x = torch.cat([x, a], 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

<div class="alert alert-warning">
    
**Exercise:**  
Declare a class for a policy neural network. Use two hidden layers with 256 neurons each.  
\[Optional\] Include an option for inputing upper and lower bounds to enable action spaces not centered on zero, and store and `action_scale` and an `action_bias` parameter.
</div>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Warning, this class only works for vector-shaped inputs (for images, it requires adjustments).
class policyNetwork(nn.Module):
    def __init__(self, env):
        super().__init__()
        state_dim = env.observation_space.shape[0]
        action_dim = env.action_space.shape[0]
        self.fc1 = nn.Linear(state_dim, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc_mu = nn.Linear(256, action_dim)
        # action rescaling
        self.register_buffer(
            "action_scale", torch.tensor((env.action_space.high - env.action_space.low) / 2.0, dtype=torch.float32)
        )
        self.register_buffer(
            "action_bias", torch.tensor((env.action_space.high + env.action_space.low) / 2.0, dtype=torch.float32)
        )

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = torch.tanh(self.fc_mu(x))
        return x * self.action_scale + self.action_bias

<div class="alert alert-warning">
    
**Exercise:**  
Write a class that implements the DDPG pseudo-code.
</div>

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from copy import deepcopy
from solutions.replay_buffer2 import ReplayBuffer
from tqdm import trange

class ddpg_agent:
    def __init__(self, config, value_network, policy_network):
        # networks
        device = "cuda" if next(value_network.parameters()).is_cuda else "cpu"
        self.scalar_dtype = next(value_network.parameters()).dtype
        self.Qfunction = value_network
        self.Q_target = deepcopy(self.Qfunction).to(device)
        self.pi = policy_network
        self.pi_target = deepcopy(self.pi).to(device)
        # parameters
        self.gamma = config['gamma'] if 'gamma' in config.keys() else 0.95
        buffer_size = config['buffer_size'] if 'buffer_size' in config.keys() else int(1e5)
        self.memory = ReplayBuffer(buffer_size, device)
        self.batch_size = config['batch_size'] if 'batch_size' in config.keys() else 100
        lr = config['learning_rate'] if 'learning_rate' in config.keys() else 0.001
        self.Q_optimizer = torch.optim.Adam(list(self.Qfunction.parameters()), lr=lr)
        self.pi_optimizer = torch.optim.Adam(list(self.pi.parameters()), lr=lr)
        self.tau = config['tau'] if 'tau' in config.keys() else 0.005
        self.exploration_noise = config['exploration_noise'] if 'exploration_noise' in config.keys() else 0.005
        self.delay_learning = config['delay_learning'] if 'delay_learning' in config.keys() else 1e4
        self.tqdm_disable = config['tqdm_disable'] if 'tqdm_disable' in config.keys() else True
        self.disable_episode_report = config['disable_episode_report'] if 'disable_episode_report' in config.keys() else True

    def hello(self):
        print("hello world")
    def train(self, env, max_steps):
        x,_ = env.reset()
        episode = 0
        episode_cum_reward = 0
        episode_return = []

        for time_step in trange(int(max_steps), disable=self.tqdm_disable):
            # step (policy + noise), add to rb
            if time_step > self.delay_learning:
                with torch.no_grad():
                    a = self.pi(torch.tensor(x,dtype=self.scalar_dtype))
                    a += torch.normal(0, self.pi.action_scale * self.exploration_noise)
                    a = a.cpu().numpy().clip(env.action_space.low, env.action_space.high)
            else:
                a = env.action_space.sample()
            y, r, done, trunc, _ = env.step(a)
            self.memory.append(x,a,r,y,done)
            episode_cum_reward += r
            
            # gradient step
            if time_step > self.delay_learning:
                X, A, R, Y, D = self.memory.sample(self.batch_size)
                ## Qfunction update
                with torch.no_grad():
                    next_actions = self.pi_target(Y)
                    QYA = self.Q_target(Y, next_actions)
                    #target = torch.addcmul(R, 1-D, QY, value=self.gamma)
                    target = R + self.gamma * (1-D) * QYA.view(-1)
                QXA = self.Qfunction(X, A).view(-1)
                Qloss = F.mse_loss(QXA,target)
                self.Q_optimizer.zero_grad()
                Qloss.backward()
                self.Q_optimizer.step()
                ## policy update
                pi_loss = -self.Qfunction(X, self.pi(X)).mean()
                self.pi_optimizer.zero_grad()
                pi_loss.backward()
                self.pi_optimizer.step()
                
                # target networks update
                for param, target_param in zip(self.pi.parameters(), self.pi_target.parameters()):
                    target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
                for param, target_param in zip(self.Qfunction.parameters(), self.Q_target.parameters()):
                    target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
                
            # if done, print episode info
            if done or trunc:
                x, _ = env.reset()
                episode_return.append(episode_cum_reward)
                if not self.disable_episode_report:
                    print("Episode ", '{:2d}'.format(episode), 
                          ", buffer size ", '{:4d}'.format(len(self.memory)), 
                          ", episode return ", '{:4.1f}'.format(episode_cum_reward), 
                          sep='')
                episode += 1
                episode_cum_reward = 0
            else:
                x=y
        return episode_return

<div class="alert alert-warning">
    
**Exercise:**  
Test your code on Gymnasium's [MuJoCo inverted pendulum](https://gymnasium.farama.org/environments/mujoco/inverted_pendulum/) or [bipedal walker](https://gymnasium.farama.org/environments/box2d/bipedal_walker/) environments.  
Caveat: training might be long!
</div>

In [None]:
import gymnasium as gym
import torch

#env = gym.make("BipedalWalker-v3", render_mode="rgb_array")
env = gym.make("InvertedPendulum-v4", render_mode="rgb_array")
config = {'gamma': .99,
          'buffer_size': 1e6,
          'learning_rate': 3e-4,
          'batch_size': 256,
          'tau': 0.005,
          'delay_learning': 1e4,
          'exploration_noise': .1,
          'tqdm_disable': False
         }

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Qfunction = QNetwork(env).to(device)
policy = policyNetwork(env).to(device)

agent = ddpg_agent(config, Qfunction, policy)
episode_returns = agent.train(env, 2e4)

In [None]:
import matplotlib.pyplot as plt
plt.plot(episode_returns);

In [None]:
import gymnasium as gym

#test_env = gym.make("BipedalWalker-v3", render_mode="rgb_array_list")
test_env = gym.make("InvertedPendulum-v4", render_mode="rgb_array_list")
scalar_dtype = next(policy.parameters()).dtype
s,_ = test_env.reset()
with torch.no_grad():
    for t in range(1000):
        a = policy(torch.tensor(s,dtype=scalar_dtype)).numpy()
        s2,r,d,trunc,_ = test_env.step(a)
        s = s2
        if d:
            break

save_video(test_env.render(), "videos", fps=test_env.metadata["render_fps"], name_prefix="ddpg_policy")

In [None]:
from IPython.display import Video
Video("videos/ddpg_policy-episode-0.mp4")

In [None]:
import gymnasium as gym
import numpy as np
from tqdm import trange

#test_env = gym.make("BipedalWalker-v3", render_mode="rgb_array")
test_env = gym.make("InvertedPendulum-v4", render_mode="rgb_array")
scalar_dtype = next(policy.parameters()).dtype
returns = []
for _ in trange(20):
    cumulated_reward = 0
    s,_ = test_env.reset()
    with torch.no_grad():
        for t in range(1000):
            a = policy(torch.tensor(s,dtype=scalar_dtype)).numpy()
            s2,r,d,trunc,_ = test_env.step(a)
            cumulated_reward += r
            s = s2
            if d:
                break
    returns.append(cumulated_reward)

print(np.mean(returns))
print(np.std(returns))

<div class="alert alert-warning">
    
**Exercise:**  
Modify the `Qnetwork` and `policyNetwork` classes above to take images as inputs and run DDPG on Gymnasium's [car racing environment](https://gymnasium.farama.org/environments/box2d/car_racing/).
</div>

In [None]:
#%load solutions/no_solution_yet.py

<div class="alert alert-warning">
    
**Exercise (discussion):**  
Is DDPG off-policy?
</div>

# TD3: improving value function approximation in DDPG

Just like DQN and other VI-based algorihtms, DDPG is prone to overestimation of the Q-function. This was improved in the discrete actions domain by [Double Q-learning](https://arxiv.org/abs/1509.06461) and its extension [Double DQN](https://ojs.aaai.org/index.php/AAAI/article/view/10295), taking the greedy action with respect to the current value network, while its value is taken using the target network.

Building on this idea, Fujimoto et al. introduced [Twin Delayed DDPG](https://arxiv.org/pdf/1802.09477.pdf) also known as TD3, as an improvement on DDPG which:
- defines double critics (as in DDQN),
- defines a new update target called *clipped Q-learning*,
- enables better fitting of $Q^\pi$ by $Q_n$, by performing two gradient steps on $\theta$ before each gradient step on $w$ and each target network update.
These three items motivate the name of the algorithm.

The clipped Q-learning target is:
$$y = r + \gamma \min_{i\in\{1,2\}} Q(s', \pi_w(s'); \theta'_i)$$

Additionally, TD3 improves generalization across actions by adding a clipped noise term to the actions in the target. This can be seen as a way to promote value function smoothness with respect to continuous actions (it is a very naive form of data augmentation). Hence, the target value above is changed to:
$$y = r + \gamma \min_{i\in\{1,2\}} Q(s', \pi_w(s') + \epsilon; \theta'_i)$$
where $\epsilon$ is drawn according to a clipped, centered Gaussian noise.

<div class="alert alert-warning">
    
**Exercise (discussion):**  
Does this make TD3 close to SARSA?
</div>

Overall, the pseudo-code of TD3 is:

```
Initialize theta1=theta1', theta2=theta2' and w=w' 
Initialize replay buffer RB
s = env.init()
loop:
   Pick a = pi_w(s) + noise
   s',r = env.step(a)
   RB.append(s,a,r,s')
   minibatch = RB.sample()
   next_action = pi(s',w') + clipped_noise
   target_value(s,a) = r + gamma min [ Q(s',next_action,theta1'), Q(s',next_action,theta2') ]
   critic1_loss(theta) = MSE(Q(s,a,theta1), target_value(s,a))
   critic2_loss(theta) = MSE(Q(s,a,theta2), target_value(s,a))
   critic_loss = critic1_loss + critic2_loss
   critic_loss.gradient_descent_step()
   only one step out of two:
      actor_loss(w) = MSE(Q1(s,pi(s,w),theta)
      actor_loss.gradient_ascent_step()
      target networks update
   s=s'
```

<div class="alert alert-warning">
    
**Exercise (discussion):**  
Implement the pseudo-code above.
</div>

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from copy import deepcopy
from solutions.replay_buffer2 import ReplayBuffer
from tqdm import trange

class td3_agent:
    def __init__(self, config, value_network1, value_network2, policy_network):
        # networks
        self.device = "cuda" if next(value_network1.parameters()).is_cuda else "cpu"
        self.scalar_dtype = next(value_network1.parameters()).dtype
        self.Qfunction1 = value_network1
        self.Qfunction2 = value_network2
        self.Q1_target = deepcopy(self.Qfunction1).to(device)
        self.Q2_target = deepcopy(self.Qfunction2).to(device)
        self.pi = policy_network
        self.pi_target = deepcopy(self.pi).to(device)
        # parameters
        self.gamma = config['gamma'] if 'gamma' in config.keys() else 0.95
        buffer_size = config['buffer_size'] if 'buffer_size' in config.keys() else int(1e5)
        self.memory = ReplayBuffer(buffer_size, device)
        self.batch_size = config['batch_size'] if 'batch_size' in config.keys() else 100
        lr = config['learning_rate'] if 'learning_rate' in config.keys() else 0.001
        self.Q_optimizer = torch.optim.Adam(list(Qfunction1.parameters()) + list(Qfunction2.parameters()), lr=lr)
        self.pi_optimizer = torch.optim.Adam(list(self.pi.parameters()), lr=lr)
        self.tau = config['tau'] if 'tau' in config.keys() else 0.005
        self.exploration_noise = config['exploration_noise'] if 'exploration_noise' in config.keys() else 0.005
        self.delay_learning = config['delay_learning'] if 'delay_learning' in config.keys() else 1e4
        self.action_noise_scale = config['action_noise_scale'] if 'action_noise_scale' in config.keys() else 0.2
        self.action_noise_clip = config['action_noise_clip'] if 'action_noise_clip' in config.keys() else 0.5
        self.policy_update_freq = config['policy_update_freq'] if 'policy_update_freq' in config.keys() else 2
        self.tqdm_disable = config['tqdm_disable'] if 'tqdm_disable' in config.keys() else True
        self.disable_episode_report = config['disable_episode_report'] if 'disable_episode_report' in config.keys() else True

    def hello(self):
        print("hello world")
    def train(self, env, max_steps):
        x,_ = env.reset()
        episode = 0
        episode_cum_reward = 0
        episode_return = []

        for time_step in trange(int(max_steps), disable=self.tqdm_disable):
            # step (policy + noise), add to rb
            if time_step > self.delay_learning:
                with torch.no_grad():
                    a = self.pi(torch.tensor(x,dtype=self.scalar_dtype))
                    a += torch.normal(0, self.pi.action_scale * self.exploration_noise)
                    a = a.cpu().numpy().clip(env.action_space.low, env.action_space.high)
            else:
                a = env.action_space.sample()
            y, r, done, trunc, _ = env.step(a)
            self.memory.append(x,a,r,y,done)
            episode_cum_reward += r
            
            # gradient step
            if time_step > self.delay_learning:
                X, A, R, Y, D = self.memory.sample(self.batch_size)
                ## Qfunction update
                with torch.no_grad():
                    # next action with noise
                    noise = torch.randn_like(A, device=self.device) * self.action_noise_scale
                    clipped_noise = noise.clamp(-self.action_noise_clip, self.action_noise_clip) * self.pi_target.action_scale
                    next_actions = self.pi_target(Y) + clipped_noise
                    next_actions=next_actions.clamp(env.action_space.low[0], env.action_space.high[0])
                    # clipped q-learning target
                    Q1YA = self.Q1_target(Y, next_actions)
                    Q2YA = self.Q2_target(Y, next_actions)
                    min_QYA = torch.min(Q1YA, Q2YA)
                    target = R + self.gamma * (1-D) * min_QYA.view(-1)
                # double q-network update
                Q1XA = self.Qfunction1(X, A).view(-1)
                Q2XA = self.Qfunction2(X, A).view(-1)
                Q1loss = F.mse_loss(Q1XA,target)
                Q2loss = F.mse_loss(Q2XA,target)
                Qloss = Q1loss + Q2loss
                self.Q_optimizer.zero_grad()
                Qloss.backward()
                self.Q_optimizer.step()
                ## policy update
                if time_step % self.policy_update_freq ==0:
                    pi_loss = -self.Qfunction1(X, self.pi(X)).mean()
                    self.pi_optimizer.zero_grad()
                    pi_loss.backward()
                    self.pi_optimizer.step()
                    # target networks update
                    for param, target_param in zip(self.pi.parameters(), self.pi_target.parameters()):
                        target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
                    for param, target_param in zip(self.Qfunction1.parameters(), self.Q1_target.parameters()):
                        target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
                    for param, target_param in zip(self.Qfunction2.parameters(), self.Q2_target.parameters()):
                        target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
            
            # if done, print episode info
            if done or trunc:
                x, _ = env.reset()
                episode_return.append(episode_cum_reward)
                if not self.disable_episode_report:
                    print("Episode ", '{:2d}'.format(episode), 
                          ", buffer size ", '{:4d}'.format(len(self.memory)), 
                          ", episode return ", '{:4.1f}'.format(episode_cum_reward), 
                          sep='')
                episode += 1
                episode_cum_reward = 0
            else:
                x=y
        return episode_return

<div class="alert alert-warning">
    
**Exercise:**  
Test your code on Gymnasium's [MuJoCo inverted pendulum](https://gymnasium.farama.org/environments/mujoco/inverted_pendulum/) or [bipedal walker](https://gymnasium.farama.org/environments/box2d/bipedal_walker/) environments.  
Caveat: training might be long!
</div>

In [None]:
import gymnasium as gym

# env = gym.make("BipedalWalker-v3", render_mode="rgb_array")
env = gym.make("InvertedPendulum-v4", render_mode="rgb_array")
config = {'gamma': .99,
          'buffer_size': 1e6,
          'learning_rate': 3e-4,
          'batch_size': 256,
          'tau': 0.005,
          'delay_learning': 1e4,
          'exploration_noise': .1,
          'action_noise_scale': 0.2,
          'action_noise_clip"': 0.5,
          'policy_update_freq': 2,
          'tqdm_disable': False
         }

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Qfunction1 = QNetwork(env).to(device)
Qfunction2 = QNetwork(env).to(device)
policy = policyNetwork(env).to(device)

agent = td3_agent(config, Qfunction1, Qfunction2, policy)
episode_returns = agent.train(env, 2e4)

In [None]:
import matplotlib.pyplot as plt
plt.plot(episode_returns);

In [None]:
import gymnasium as gym

#test_env = gym.make("BipedalWalker-v3", render_mode="rgb_array_list")
test_env = gym.make("InvertedPendulum-v4", render_mode="rgb_array_list")
scalar_dtype = next(policy.parameters()).dtype
s,_ = test_env.reset()
with torch.no_grad():
    for t in range(1000):
        a = policy(torch.tensor(s,dtype=scalar_dtype)).numpy()
        s2,r,d,trunc,_ = test_env.step(a)
        s = s2
        if d:
            break

save_video(test_env.render(), "videos", fps=test_env.metadata["render_fps"], name_prefix="ddpg_policy")

In [None]:
from IPython.display import Video
Video("videos/ddpg_policy-episode-0.mp4")

In [None]:
import gymnasium as gym
import numpy as np
from tqdm import trange

#test_env = gym.make("BipedalWalker-v3", render_mode="rgb_array")
test_env = gym.make("InvertedPendulum-v4", render_mode="rgb_array")
scalar_dtype = next(policy.parameters()).dtype
returns = []
for _ in trange(50):
    cumulated_reward = 0
    s,_ = test_env.reset()
    with torch.no_grad():
        for t in range(1000):
            a = policy(torch.tensor(s,dtype=scalar_dtype)).numpy()
            s2,r,d,trunc,_ = test_env.step(a)
            cumulated_reward += r
            s = s2
            if d:
                break
    returns.append(cumulated_reward)

print(np.mean(returns))
print(np.std(returns))

# Soft Actor-Critic (SAC): the maximum entropy principle within AVI

We now depart from the deterministic policies we have used so far. We still consider AVI sequences, but now with stochastic policies for continuous actions.

Soft Actor-Critic algorithms correspond to a series of papers:  
[Reinforcement Learning with Deep Energy-Based Policies](https://arxiv.org/abs/1702.08165) (ICML 2017)  
[Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor](https://arxiv.org/abs/1801.01290) (ICML 2018)  
[Soft Actor-Critic Algorithms and Applications](https://arxiv.org/abs/1812.05905) (arXiv only, 2019)  
One should also mention the complementary paper:  
[Soft Actor-Critic for Discrete Action Settings](https://arxiv.org/abs/1910.07207) (arXiv only, 2019)  

The most comprehensive presentation is given in [Soft Actor-Critic Algorithms and Applications](https://arxiv.org/abs/1812.05905), which forms the basis for this section.

## Why maximize the policy's entropy?

As we saw, DDPG (and TD3) is a straightforward way to extend DQN to continuous actions. In a nutshell, it uses an off-policy critic to learn $Q^\pi$ and takes gradient steps in the direction of the deterministic policy gradient. 

In large state and action spaces, DDPG-like algorithms are somehow inefficient as they rely on this fixed exploration policy, and a delicate interplay between the actor and the critic(s). Optimizing a deterministic policy in an actor-critic architecture with continuous actions is prone to instabilities and errors. For example, suppose that in a given state, two actions are deemed equally good at a certain stage of the AVI process. DDPG's actor will "fall" towards one or the other, mostly depending on noise in the minibatches' distribution. If we are unlucky and it turns out the optimal action actually was the other one, it might be difficult to escape the local minimum originally found and converge to the true optimal policy.

Stochastic policies offer an alternative, by dispatching probability mass across the different actions in each state. Intuitively, instead of encoding the knowledge of an $\arg\max$ in a given state, they can encode the knowledge of a ranking among actions. But for this, we need to impose that they retain a certain level of entropy so that all probability mass does not end up concentrated on a single state. 
Transferring probability mass between actions is not as abrupt as jumping from one deterministic value to the other, so we expect the optimization process on such policies to be smoother. Hence we are interested in stochastic policies that maintain a certain level of action distribution entropy in each state.

Such policies have been shown to be beneficial to robustness to model mispecifications, to improve gradient descent convergence by smoothing out the loss landscape, and to promote exploration.

The maximum entropy objective for policies is:
$$\pi^* = \arg\max_\pi \sum_t \gamma^t \mathbb{E}_{s_t,a_t} \left[r \left(s_t,a_t\right) + \alpha \mathcal{H}\left(\pi\left(s_t\right)\right) \right]$$
where $\alpha$ is a temperature parameter, trading-off rewards for action distribution entropy.

## Soft Policy Iteration

From the maximum entropy objective function, one can define the soft value function:
$$V^\pi(s) = \sum_t \gamma^t \mathbb{E}_{s_t,a_t} \left[r \left(s_t,a_t\right) + \alpha \mathcal{H}\left(\pi\left(s_t\right)\right) | s_0=s, \pi \right]$$

One can derive the relationship between $V$ and $Q$ functions and, in turn, the policy evaluation operator:
$$T^\pi Q(s,a) = r(s,a) + \gamma \mathbb{E}_{s'\sim p(\cdot|s,a)} \left[ V(s') \right],$$
$$V(s) = \mathbb{E}_{a\sim \pi(s)} \left[ Q(s,a) - \alpha \log \pi(a|s) \right].$$

So $$T^\pi Q(s,a) = r(s,a) + \gamma \mathbb{E}_{s'\sim p(\cdot|s,a)} \mathbb{E}_{a'\sim \pi(s')} \left[ Q(s',a') - \alpha \log \pi(a'|s') \right].$$

This operator, like the classical one, is a contraction mapping. So repeatedly applying it defines a sequence $Q_{k+1} = T^\pi Q_k$ which converges to the **soft Q-function** of $\pi$.

Given the soft Q-function $Q^{\pi_n}$ of policy $\pi_n$, one can consider the stochastic policy $\pi_{n+1}(s)$ that puts probability mass on each action in proportion to $\exp(\frac{1}{\alpha}Q^{\pi_n}(s,a))$. Then, given a family $\Pi$ of stochastic policies (including $\pi_n$) one can define:
$$\pi_{n+1}(s) = \arg\min_{\pi \in \Pi} D_{KL} \left( \pi(s) \Bigg|\Bigg| \frac{\exp(\frac{1}{\alpha}Q^{\pi_n}(s,a))}{Z^{\pi_n}(s)} \right),$$
where $Z^{\pi_n}(s)$ is a normalizing term so that the right-hand side of the divergence is indeed a probability density function. Then one can prove that $Q^{\pi_{n+1}} \geq Q^{\pi_n}$. The proof is very similar to that of the policy improvement phase in policy iteration. 

One could wonder why it is necessary to take a temperature equal to the entropy trade-off coefficient $\alpha$ when defining $\pi_{n+1}$ (after all, $\pi_{n+1}$ is "just" the projection on $\Pi$ of $Q^{\pi_n}$'s softmax, with a temperature). Interestingly, the proof that $Q^{\pi_{n+1}} \geq Q^{\pi_n}$ builds on this assumption.

In turn, this enables defining a soft policy iteration algorithm which converges to a policy $\pi^*$ whose soft Q-function $Q^{\pi^*}$ dominates over any other policy in $\Pi$.

## Soft Actor-Critic

As for previous algorithms, we wish to introduce function approximators for the value function and the policy. We will write $Q(s,a;\theta)$ for Q-function approximators, and $\pi_w(s)$ for parameterized policies. We will also define target network parameters $\theta'$.

We can write an off-policy loss on the value function's parameters:
$$L^Q(\theta) = \frac{1}{2} \mathbb{E}_{(s,a)\sim \rho} \left[ \left( r(s,a) + \gamma \mathbb{E}_{\substack{s'\sim p(\cdot|s,a)\\a'\sim\pi_{w}(s')}}\left[ Q(s',a';\theta') - \alpha \log \pi_{w}(a'|s') \right] - Q(s,a;\theta) \right)^2 \right].$$

And the loss on the policy is the KL divergence defined earlier, averaged across states in the replay buffer:
$$L^\pi(w) = \mathbb{E}_{\substack{s\sim \rho\\a\sim\pi_w(s)}} \left[ \alpha \log \pi_w(a|s) - Q(s,a;\theta) \right]. $$

Taking the gradient $\nabla_w L^\pi(w)$ is feasible thanks to the reparametrization trick. Let us take a family of distributions $\Pi$ such that drawing $a$ from $\pi_w(s)$ is equivalent to drawing $\epsilon$ from a fixed distribution (eg. Gaussian) and running $\epsilon$ through function $f_w(\epsilon,s)$ to obtain $a$. Then one has:
$$L^\pi(w) = \mathbb{E}_{\substack{s\sim \rho\\ \epsilon\sim\mathcal{N}}} \left[ \alpha \log \pi_w(f_w(\epsilon,s)|s) - Q(s,f_w(\epsilon,s);\theta) \right], $$
where each $\log \pi_w(f_w(\epsilon,s)|s)$ is really the log-probability of each value of $\epsilon$.

When one looks closely at these loss functions, they directly extend DDPG's losses to the case of stochastic policies.

This provides a direct algorithm, which can be implemented using double critics as:

```
Initialize theta1=theta1', theta2=theta2' and w
Initialize replay buffer RB
s = env.init()
loop:
   Sample a from pi(s,w) using the reparametrization trick
   s',r = env.step(a)
   RB.append(s,a,r,s')
   minibatch = RB.sample()
   next_action, next_log_prob = sample from pi(s',w)
   next_value(s,a) = min [ Q(s',next_action,theta1'), Q(s',next_action,theta2') ] - alpha * next_log_prob
   target_value(s,a) = r + gamma * next_value(s,a)
   critic1_loss(theta) = MSE(Q(s,a,theta1), target_value(s,a))
   critic2_loss(theta) = MSE(Q(s,a,theta2), target_value(s,a))
   critic_loss = critic1_loss + critic2_loss
   critic_loss.gradient_descent_step()
   action, log_prob = sample from pi(s,w)
   actor_loss(w) = alpha * log_prob - min [ Q(s,action,theta1'), Q(s,action,theta2') ] 
   actor_loss.gradient_descent_step()
   target networks update
   s=s'
```

<div class="alert alert-warning">
    
**Exercise:**  
Adjust the previous `policyNetwork` class to encode a policy that relies on an internal Gaussian distribution. It predicts a mean and a log standard deviation. To draw actions, it uses the reparametrization trick and passes its samples through a tanh function to ensure the actions remain bounded.  
Use the following tricks to ease your implementation.
</div>

**Trick 1**

Consider a random variable $u$, of probability density function $\mu(u|s)$.  
Let $a = \tanh(u)$.  
Then $$\log\pi(a|s) = \log\mu(u|s) - \sum_{i=1}^D \log(1-\tanh^2(u_i)),$$
where $u_i$ is the $i$th element of $u$.

**Trick 2**

Standard deviations shouldn't grow or shrink unbounded. To keep them in check, we will assume no standard deviation is smaller than $10^-5$ and larger than $10^2$. To make sure the predicted $\log\sigma$ is within this interval, we run it through a $\tanh$ function which we scale between $-5$ and $2$.

The implementation below is freely adapted from the [CleanRL](https://docs.cleanrl.dev/rl-algorithms/sac/) library.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

LOG_STD_MAX = 2
LOG_STD_MIN = -5

class policy_network(nn.Module):
    def __init__(self, env):
        super().__init__()
        state_dim = env.observation_space.shape[0]
        action_dim = env.action_space.shape[0]
        self.fc1 = nn.Linear(np.array(state_dim, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc_mean = nn.Linear(256, action_dim)
        self.fc_logstd = nn.Linear(256, action_dim)
        # action rescaling
        self.register_buffer(
            "action_scale", torch.tensor((env.action_space.high - env.action_space.low) / 2.0, dtype=torch.float32)
        )
        self.register_buffer(
            "action_bias", torch.tensor((env.action_space.high + env.action_space.low) / 2.0, dtype=torch.float32)
        )

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        mean = self.fc_mean(x)
        log_std = self.fc_logstd(x)
        log_std = torch.tanh(log_std)
        log_std = LOG_STD_MIN + 0.5 * (LOG_STD_MAX - LOG_STD_MIN) * (log_std + 1)  # From SpinUp / Denis Yarats
        return mean, log_std

    def get_action(self, x):
        mean, log_std = self(x)
        std = log_std.exp()
        normal = torch.distributions.Normal(mean, std)
        x_t = normal.rsample()  # for reparameterization trick (mean + std * N(0,1))
        y_t = torch.tanh(x_t)
        action = y_t * self.action_scale + self.action_bias
        log_prob = normal.log_prob(x_t)
        # Enforcing Action Bound
        log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + 1e-6)
        log_prob = log_prob.sum(1, keepdim=True)
        mean = torch.tanh(mean) * self.action_scale + self.action_bias
        return action, log_prob, mean

Unfinished, please check [CleanRL](https://docs.cleanrl.dev/rl-algorithms/sac/) for a full implementation.

## Going further: AVI under a minimal entropy constraint

One can derive a version of SAC with adjustable temperature by casting the optimization problem as a contrained optimization one, instead of a regularized one.

More on this later.