# Policy Gradients
<table align="left">
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/matyama/deep-rl-hands-on/blob/main/11_policy_gradients.ipynb">
    <img src="https://www.tensorflow.org/images/colab_logo_32px.png" />
        Run in Google Colab
    </a>
  </td>
</table>

In [4]:
%%bash
!(stat -t /usr/local/lib/*/dist-packages/google/colab > /dev/null 2>&1) && exit

echo "Running on Google Colab, therefore installing dependencies..."
pip install ptan>=0.7 tensorboardX

Running on Google Colab, therefore installing dependencies...


In [None]:
%load_ext tensorboard
%tensorboard --logdir runs

## Common Imports

In [17]:
# flake8: noqa: E402,I001

import sys
import time
from collections import deque
from dataclasses import dataclass
from typing import Any, Iterable, List, Optional, Tuple

import gym
import numpy as np
import ptan
import torch
import torch.nn as nn
from ptan.experience import ExperienceFirstLast
from tensorboardX import SummaryWriter

## Values and Policy
Contrary to the value iteration methods (Q-Learning) which try to estimate the state values (state-action values), the *policy gradient* technique focus directly on the policy $\pi(s)$. 

Direct policy modeling has several advantages:
* From certain point of view, we don't care that much about the expected discounted rewards but rather the decision/action $\pi(s)$ to take in each state $s$
* As we saw earlier with the *Categorical DQN*, learning a distribution helps to better capture the underlying MDP (especially in stochastic environments)
* It becomes quite a hard to determine the best action to take when the action space is large or even continuous. The DQN model of $Q(s, a)$ is highly non-linear and the optimization problem $a^* = argmax_a Q(s, a)$ can be hard to solve.

In the value iteration case our DQN parametrized the state-action values as $DQN(s) \to Q_\mathbf{w}(s, \cdot)$. Similarly, we will represent the policy as a probability distribution over actions $\pi_\mathbf{w}(s)$ parametrized by the NN.

*Modelling the output as action (class) probabilities is a typical technique in classification tasks that gives us a smooth representation (intuitively, changing NN weights $\mathbf{w}$ a bit changes $\pi$ a bit as well - compared to the case with discrete action labels which would change in steps).*

In [7]:
class PGN(nn.Module):
    """
    Policy Gradient Network that consumes states (observations)
    and outputs action logits (scores).

    Note: Logits should be manually converted to probabilities with
    `log_softmax` for better numerical stability and optimization.
    """

    def __init__(self, input_shape: Tuple[int, ...], n_actions: int) -> None:
        super().__init__()

        # Simple, not really deep, forward network that outputs action logits
        self.net = nn.Sequential(
            nn.Linear(input_shape, 128),
            nn.ReLU(),
            nn.Linear(128, n_actions),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)

## Gradients of the Policy

*Policy Gradient* methods are closely related to the *Cross-Entropy Method* introduced earlier. The gradient is a direction in which we want to change NN weights to maximize the accumulated reward and is proportional in scale to the $Q$ state-action value and in the direction to the log of action probabilities:
$$
\nabla J \approx \mathbb{E}[Q(s, a) \nabla \log(\pi(a | s))]
$$
where the expectation means that we average the gradient over several steps.

Equivalently we can say that we optimize the loss function $\mathcal{L} = -Q(s, a) \log(\pi(a | s))$ (Note: SGD minimizes the loss function but we want to maximize the gradient, therefore the minus sign).

Recall that in the *Cross-Entropy Method* we sampled the environment for few episodes and trained only on transitions from the above-average ones. This corresponds to having $Q(s, a) = 1$ for the good transitions and $Q(s, a) = 0$ otherwise. In general, policy gradient methods differ in the way how $Q$ values are treated but in any case we want to use $Q(s, a) \in [0, 1]$:
1. for better separation of episode
1. to incorporate the discount factor and thus the uncertainty about future rewards

## The REINFORCE method
The outline of the *REINFORCE* methods is the following:
1. Initialize NN weights randomly
1. Play $N$ full episode and collect experiences $(s, a, r, s')$
1. Compute actual $Q$ values for every played episode $k$ and step $t$: $Q_{k, t} = \sum_{i=0}^t \gamma^t r_t$
1. Compute the loss for all transitions: $\mathcal{L} = - \sum_{k, t} Q_{k, t} \log(\pi(s_{k, t}, a_{k, t}))$
1. Do one SGD step by minimizing the loss and update NN weights
1. Repeat from step 2. until convergence

Properties of the REINFORCE method:
* We **don't need an explicit exploration policy** because we explore automatically using the policy our NN outputs.
* **On-policy** method, therefore no ER buffer is needed because we can't train on the data from old policies. On the other hand, value methods typically need less interations with the environment.
* We train on actual Q values and not estimated ones so we **don't need a target NN** to break experience correlations either.

### CartPole REINFORCE

In [8]:
def compute_q_values(rewards: List[float], gamma: float) -> Iterable[float]:
    qs = []
    sum_r = 0.0

    for r in reversed(rewards):
        sum_r *= gamma
        sum_r += r
        qs.append(sum_r)

    return reversed(qs)


def train_reinforce(
    env_name: str,
    gamma: float = 0.99,
    learning_rate: float = 0.01,
    n_played_episodes: int = 4,
    reward_bound: int = 195,
    log_period: int = 10,
) -> None:

    # Crate the environment
    env = gym.make(env_name)

    # Create PG network
    net = PGN(
        input_shape=env.observation_space.shape[0],
        n_actions=env.action_space.n,
    )
    print(net)

    # Initialize an agent
    #  - Notice: We instruct it to apply softmax to the PGN output
    agent = ptan.agent.PolicyAgent(
        net,
        preprocessor=ptan.agent.float32_preprocessor,
        apply_softmax=True,
    )

    # Create experience source and optimizer

    exp_source = ptan.experience.ExperienceSourceFirstLast(
        env=env,
        agent=agent,
        gamma=gamma,
    )

    optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)

    with SummaryWriter(comment=f"-{env_name}-reinforce") as writer:

        done_episodes = 0
        batch_episodes = 0

        batch_states, batch_actions, batch_q_values = [], [], []

        episode_rewards = []
        total_rewards = []

        # Interact with the environment and consume experiences
        for i, exp in enumerate(exp_source):

            # Add the new experience to current batch
            batch_states.append(exp.state)
            batch_actions.append(int(exp.action))

            # Buffer immedieate rewards during each episode
            episode_rewards.append(exp.reward)

            # Compute Q values from immediate rewards when episode ends
            if exp.last_state is None:
                batch_q_values += compute_q_values(episode_rewards, gamma)
                episode_rewards.clear()
                batch_episodes += 1

            # Handle new rewards
            new_rewards = exp_source.pop_total_rewards()
            if new_rewards:

                done_episodes += 1

                # Collect total rewards
                reward = new_rewards[0]
                total_rewards.append(reward)

                # Compute the mean reward over last 100 episodes
                mean_rewards = float(np.mean(total_rewards[-100:]))

                # Log training progress
                if done_episodes % log_period == 0:
                    print(
                        f"{i}: reward: {reward:.2}, "
                        f"mean_100: {mean_rewards:.2}, "
                        f"episodes: {done_episodes}"
                    )

                # Record metrics for TensorBoard
                writer.add_scalar("reward", reward, i)
                writer.add_scalar("reward_100", mean_rewards, i)
                writer.add_scalar("episodes", done_episodes, i)

                # Check if the learned policy is good enough
                if mean_rewards > reward_bound:
                    print(f"Solved in {i} steps and {done_episodes} episodes!")
                    break

            # Play N episodes to accumulate Q values before training step
            if batch_episodes < n_played_episodes:
                continue

            n_states = len(batch_states)

            # Reset gradients
            optimizer.zero_grad()

            # Convert batch parts to tensors
            states = torch.FloatTensor(batch_states)
            actions = torch.LongTensor(batch_actions)
            q_values = torch.FloatTensor(batch_q_values)

            # Compute action scores (logits)
            #  - Note: There's just single pass through the PGN (DQN has 2)
            logits = net(states)

            # Compute the loss funciton defiend in previous section
            log_action_prob = nn.functional.log_softmax(logits, dim=1)
            exp_values = q_values * log_action_prob[range(n_states), actions]
            loss = -exp_values.mean()

            # Compute gradient of the loss function and make one SGD step
            loss.backward()
            optimizer.step()

            # Reset current batch
            batch_episodes = 0
            batch_states.clear()
            batch_actions.clear()
            batch_q_values.clear()


# Run REINFORCE to solve the CartPole environment
train_reinforce(env_name="CartPole-v0")

PGN(
  (net): Sequential(
    (0): Linear(in_features=4, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=2, bias=True)
  )
)
291: reward: 1.1e+01, mean_100: 2.9e+01, episodes: 10
631: reward: 3.5e+01, mean_100: 3.2e+01, episodes: 20
916: reward: 2.4e+01, mean_100: 3.1e+01, episodes: 30
1335: reward: 4.6e+01, mean_100: 3.3e+01, episodes: 40
1824: reward: 1.1e+02, mean_100: 3.6e+01, episodes: 50
2503: reward: 5.3e+01, mean_100: 4.2e+01, episodes: 60
3004: reward: 7.4e+01, mean_100: 4.3e+01, episodes: 70
3908: reward: 1.6e+02, mean_100: 4.9e+01, episodes: 80
4719: reward: 1.9e+01, mean_100: 5.2e+01, episodes: 90
5634: reward: 8.3e+01, mean_100: 5.6e+01, episodes: 100
6491: reward: 3.1e+01, mean_100: 6.2e+01, episodes: 110
7345: reward: 6.5e+01, mean_100: 6.7e+01, episodes: 120
8312: reward: 1.1e+02, mean_100: 7.4e+01, episodes: 130
9956: reward: 2e+02, mean_100: 8.6e+01, episodes: 140
11851: reward: 2e+02, mean_100: 1e+02, episodes: 150
13513: re

### REINFORCE issues

#### Complete episodes
First drawback of REINFORCE and PG methods in general is that it is way **less sample efficient**. In order to estimate Q values as well as possible we need quite a lot of interations with the environment from full episodes. Moreover, the length of the episodes we must play only inreases in complex environments (e.g. episodes in Atari Pong might have thousands of steps).

In the DQN scenario we used our own $Q(s, a)$ to estimate $V(s)$ in the one-step Bellman update: $Q(s, a) = r_a + \gamma V(s')$. But in PG we don't have Q values - these are approximated from episodes completed in the environment.

There are two approaches dealing with this issue:
* We use the NN to estimate $V(s)$ as well as action logits and use these state values to obtain Qs. This approach implements the *actor-critic method* which will be described later.
* The other way is to unroll the Bellman Eq. N steps ahead which will implicitly exploit the fact that the future value contribution is discounted by $\gamma < 1$.

#### High variance of gradients
Recall that the policy gradient $\nabla J$ is proportinal to $Q(s, a)$. The problem with rewards (and thus Q values) is that these are heavily environment-dependent. In other words, the gradient has high variance - one lucky episode will dominate in the final gradient.

To prevent training instabilities due to high variance one can subtract a *baseline* value from the $Q$:
* Constant value, typically the mean of the discounted rewards
* Moving average of the discounted rewards
* The state value $V(s)$

#### Exploration
Even though we can get rid of exploration strategies (e.g. epsilon-greedy) because we can sample from current policy, the agent can still converge to a sub-optimal policy. Fortunately, we can benefit from the fact that we have represented the policy as a probability distribution and add an *entropy bonus* to the loss funciton.

The entropy of a policy is
$$
H(\pi(\cdot | s)) = - \sum_a \pi(a | s) \log(\pi(a | s))
$$
and we add it to the loss (or rather the mean over batch states $s$) in order to push the agent from local optima by promoting more uniform distribution over actions (in local optima some action $a$ will have $\pi(a | \cdot) = 1$ which corresponds to $H(\pi) = 0$, here we constrain it by maximizing over $H$ as well).

#### Correlation between samples
As mentioned before, we cannot use an experience replay buffer as we did in DQN to break correlations between experiences from one episode because PG is an *on-policy* method. If we did use old experiences, we'd compute gradient of an old policy, not the current one.

A typical trick to solve this problem for the PG methods is to sample from multiple environments (independent but the same) at once. This will give us an i.i.d. set of experiences for the SGD step (or close enough).

## Policy Gradient Method

### CartPole PG

In [9]:
def smooth(old: Optional[float], val: float, alpha: float = 0.95) -> float:
    return val if old is None else old * alpha + (1 - alpha) * val


# Hyperparameters
GAMMA = 0.99
LEARNING_RATE = 0.001
ENTROPY_BETA = 0.01
BATCH_SIZE = 8
REWARD_STEPS = 10
REWARD_BOUND = 195
LOG_PERIOD = 10

# Initialize entironment, PGN, the agent, exp. source and optimizer
#  - We pass gamma directly to the exp. source to discount the rewards
#  - We also use `REWARD_STEPS`-ahead technique with our exp. source
#    instead of playing full episodes to approximate Q values

env = gym.make("CartPole-v0")

net = PGN(
    input_shape=env.observation_space.shape[0],
    n_actions=env.action_space.n,
)
print(net)

agent = ptan.agent.PolicyAgent(
    net,
    preprocessor=ptan.agent.float32_preprocessor,
    apply_softmax=True,
)

exp_source = ptan.experience.ExperienceSourceFirstLast(
    env=env,
    agent=agent,
    gamma=GAMMA,
    steps_count=REWARD_STEPS,
)

optimizer = torch.optim.Adam(net.parameters(), lr=LEARNING_RATE)

with SummaryWriter(comment="-cartpole-pg") as writer:

    done_episodes = 0

    bs_smoothed = entropy = l_entropy = l_policy = l_total = None

    batch_states, batch_actions, batch_scales = [], [], []

    total_rewards = []
    step_rewards = []
    reward_sum = 0.0

    # Run the training loop
    for i, exp in enumerate(exp_source):

        # Accumulate discounted rewards over `REWARD_STEPS`-ahead
        reward_sum += exp.reward

        # Compute current baseline value as the mean reward up until now
        baseline = reward_sum / (i + 1)

        # Track the baseline value
        writer.add_scalar("baseline", baseline, i)

        # Add new experience to the batch
        #  - Notice: We subtract the baseline value from the reward to
        #    reduce the variance of the gradient scales (Q values)
        batch_states.append(exp.state)
        batch_actions.append(int(exp.action))
        batch_scales.append(exp.reward - baseline)

        # Handle new rewards as before
        #  - Logs training progress and metrics in previous example
        #  - The termination condition is also the same

        new_rewards = exp_source.pop_total_rewards()
        if new_rewards:

            done_episodes += 1

            reward = new_rewards[0]
            total_rewards.append(reward)

            mean_rewards = float(np.mean(total_rewards[-100:]))

            if done_episodes % LOG_PERIOD == 0:
                print(
                    f"{i}: reward: {reward:.2}, "
                    f"mean_100: {mean_rewards:.2}, "
                    f"episodes: {done_episodes}"
                )

            writer.add_scalar("reward", reward, i)
            writer.add_scalar("reward_100", mean_rewards, i)
            writer.add_scalar("episodes", done_episodes, i)

            if mean_rewards > REWARD_BOUND:
                print(f"Solved in {i} steps and {done_episodes} episodes!")
                break

        # Wait for the batch to fill up
        if len(batch_states) < BATCH_SIZE:
            continue

        # Convert batch to tensors
        states = torch.FloatTensor(batch_states)
        batch_actions_t = torch.LongTensor(batch_actions)
        batch_scale = torch.FloatTensor(batch_scales)

        # Clear gradients
        optimizer.zero_grad()

        # Compute the policy part of the loss function
        logits = net(states)
        log_action_prob = nn.functional.log_softmax(logits, dim=1)
        exp_values = (
            batch_scale * log_action_prob[range(BATCH_SIZE), batch_actions_t]
        )
        policy_loss = -exp_values.mean()

        # Compute the entropy bonus to the loss function
        action_prob = nn.functional.softmax(logits, dim=1)
        entropy = -(action_prob * log_action_prob).sum(dim=1).mean()
        entropy_loss = -ENTROPY_BETA * entropy
        loss = policy_loss + entropy_loss

        # Compute gradient of the loss function and apply one optimization step
        loss.backward()
        optimizer.step()

        # Compute KL divergence: D(previous policy || new policy)

        with torch.no_grad():

            new_logits = net(states)
            new_action_prob = nn.functional.softmax(new_logits, dim=1)

            kl_div = (
                -(action_prob * (new_action_prob / action_prob).log())
                .sum(dim=1)
                .mean()
            )

        # Record KL divergence in TensorBoard
        writer.add_scalar("kl_div", kl_div.item(), i)

        # Compute additional gradient metrics: max and l2 norms

        grad_max = 0.0
        grad_means = 0.0
        grad_count = 0

        for p in net.parameters():
            grad_max = max(grad_max, p.grad.abs().max().item())
            grad_means += (p.grad ** 2).mean().sqrt().item()
            grad_count += 1

        # Do smooth updates to tracked metrics
        #  - Note: We use mixing hyperparameter alpha = 0.95
        bs_smoothed = smooth(bs_smoothed, np.mean(batch_scales))
        entropy = smooth(entropy, entropy.item())
        l_entropy = smooth(l_entropy, entropy_loss.item())
        l_policy = smooth(l_policy, policy_loss.item())
        l_total = smooth(l_total, loss.item())

        # Record metrics for TensorBoard
        writer.add_scalar("baseline", baseline, i)
        writer.add_scalar("entropy", entropy, i)
        writer.add_scalar("loss_entropy", l_entropy, i)
        writer.add_scalar("loss_policy", l_policy, i)
        writer.add_scalar("loss_total", l_total, i)
        writer.add_scalar("grad_l2", grad_means / grad_count, i)
        writer.add_scalar("grad_max", grad_max, i)
        writer.add_scalar("batch_scales", bs_smoothed, i)

        # Batch cleanup
        batch_states.clear()
        batch_actions.clear()
        batch_scales.clear()

PGN(
  (net): Sequential(
    (0): Linear(in_features=4, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=2, bias=True)
  )
)
259: reward: 1.7e+01, mean_100: 2.6e+01, episodes: 10
499: reward: 3e+01, mean_100: 2.5e+01, episodes: 20
694: reward: 1.4e+01, mean_100: 2.3e+01, episodes: 30
1011: reward: 2.3e+01, mean_100: 2.5e+01, episodes: 40
1350: reward: 3.9e+01, mean_100: 2.7e+01, episodes: 50
1826: reward: 3.4e+01, mean_100: 3e+01, episodes: 60
2508: reward: 5.1e+01, mean_100: 3.6e+01, episodes: 70
2999: reward: 9.8e+01, mean_100: 3.7e+01, episodes: 80
3569: reward: 2.5e+01, mean_100: 4e+01, episodes: 90
4205: reward: 4e+01, mean_100: 4.2e+01, episodes: 100
4778: reward: 4.5e+01, mean_100: 4.5e+01, episodes: 110
5837: reward: 1.2e+02, mean_100: 5.3e+01, episodes: 120
6574: reward: 7e+01, mean_100: 5.9e+01, episodes: 130
7997: reward: 1.2e+02, mean_100: 7e+01, episodes: 140
9379: reward: 1.3e+02, mean_100: 8e+01, episodes: 150
10830: reward: 9.7

### Atari Pong PG

#### PNG

In [10]:
class AtariPGN(nn.Module):
    def __init__(self, input_shape: Tuple[int, ...], n_actions: int) -> None:
        super().__init__()

        # Convolutional layers
        self.conv = nn.Sequential(
            nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU(),
        )

        fc_input_size = self._get_fc_inputs(input_shape)

        # Dense layers
        self.fc = nn.Sequential(
            nn.Linear(fc_input_size, 512),
            nn.ReLU(),
            nn.Linear(512, n_actions),
        )

    def _get_fc_inputs(self, shape: Tuple[int, ...]) -> int:
        o = self.conv(torch.zeros(1, *shape))
        return int(np.prod(o.size()))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        inputs = x.float() / 256
        conv_out = self.conv(inputs).view(inputs.size()[0], -1)
        return self.fc(conv_out)

#### Reward Tracker

In [24]:
class RewardTracker:
    def __init__(
        self,
        writer: SummaryWriter,
        stop_reward: float,
        window_size: int = 100,
    ) -> None:
        self.writer = writer
        self.stop_reward = stop_reward
        self.window_size = window_size

    def __enter__(self) -> "RewardTracker":
        self.ts = time.time()
        self.ts_frame = 0
        self.total_rewards = []
        return self

    def __exit__(self, *args: Any) -> None:
        self.writer.close()

    def add_reward(self, reward: float, frame: int) -> bool:
        """
        Returns an indication of whether a termination contition was reached.
        """

        self.total_rewards.append(reward)

        fps = (frame - self.ts_frame) / (time.time() - self.ts)

        self.ts_frame = frame
        self.ts = time.time()

        mean_reward = np.mean(self.total_rewards[-self.window_size :])

        if frame % self.window_size == 0:
            print(
                f"{frame}: done {len(self.total_rewards)} games, "
                f"mean reward {mean_reward:.3}, speed {fps:.2} fps"
            )

        sys.stdout.flush()

        self.writer.add_scalar("fps", fps, frame)
        self.writer.add_scalar("reward_100", mean_reward, frame)
        self.writer.add_scalar("reward", reward, frame)

        return mean_reward > self.stop_reward

#### Environment & Baseline Mean Buffer

In [25]:
def make_env(name: str) -> gym.Env:
    return ptan.common.wrappers.wrap_dqn(gym.make(name))


class MeanBuffer:
    def __init__(self, capacity: int) -> None:
        self.capacity = capacity
        self.buffer = deque(maxlen=capacity)
        self.sum = 0.0

    def __iadd__(self, value: float) -> "MeanBuffer":
        if len(self.buffer) == self.capacity:
            self.sum -= self.buffer[0]

        self.sum += value
        self.buffer.append(value)
        return self

    def mean(self) -> float:
        return self.sum / len(self.buffer) if self.buffer else 0.0

#### Training Loop

In [None]:
# Hyperparameters
GAMMA = 0.99
LEARNING_RATE = 0.0001
ENTROPY_BETA = 0.01
BATCH_SIZE = 128
REWARD_STEPS = 10
BASELINE_STEPS = 1000000
GRAD_L2_CLIP = 0.1
ENV_NAME = "PongNoFrameskip-v4"
ENV_COUNT = 32
STOP_REWARD = 18

# Determine where the computations will take place
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize multiple instances of the Pong environment
envs = [make_env(ENV_NAME) for _ in range(ENV_COUNT)]

# Initialize TensorBoard writer
writer = SummaryWriter(comment="-pong-pg-" + ENV_NAME)

# Create the PGN
net = AtariPGN(
    input_shape=envs[0].observation_space.shape,
    n_actions=envs[0].action_space.n,
).to(device)
print(net)

# Create the agent and experience source
agent = ptan.agent.PolicyAgent(net, apply_softmax=True, device=device)
exp_source = ptan.experience.ExperienceSourceFirstLast(
    env=envs,
    agent=agent,
    gamma=GAMMA,
    steps_count=REWARD_STEPS,
)

optimizer = torch.optim.Adam(net.parameters(), lr=LEARNING_RATE, eps=1e-3)

# Create a sliding window buffer for computing moving average to the baseline
baseline_buffer = MeanBuffer(BASELINE_STEPS)

total_rewards = []
batch_states = []
batch_actions = []
batch_scales = []
m_baseline = []
m_batch_scales = []
m_loss_entropy = []
m_loss_policy = []
m_loss_total = []
m_grad_max = []
m_grad_mean = []

sum_reward = 0.0

done_episodes = 0
train_step = 0

# Initialize reward tracker
with RewardTracker(writer, stop_reward=STOP_REWARD) as tracker:

    # Run the training loop
    for frame, exp in enumerate(exp_source):

        # Compute the baseline as the moving average of discounted rewards
        baseline_buffer += exp.reward
        baseline = baseline_buffer.mean()

        # Add new experience to current batch
        batch_states.append(np.array(exp.state, copy=False))
        batch_actions.append(int(exp.action))
        batch_scales.append(exp.reward - baseline)

        new_rewards = exp_source.pop_total_rewards()
        if new_rewards:

            # Record new reward and check for termination
            solved = tracker.add_reward(reward=new_rewards[0], frame=frame)

            # Stop if the mean reward of the last 100 episodes is good enough
            if solved:
                print(f"Solved in {frame} frames!")
                break

        # Let the batch fill up
        if len(batch_states) < BATCH_SIZE:
            continue

        train_step += 1

        # Convert batch data to tensors
        states = np.array(batch_states, copy=False)
        states = torch.FloatTensor(states).to(device)
        batch_actions_t = torch.LongTensor(batch_actions).to(device)
        batch_scale = torch.FloatTensor(batch_scales).to(device)

        # Clear gradients
        optimizer.zero_grad()

        # Compute the policy part of the loss function
        logits = net(states)
        log_action_prob = nn.functional.log_softmax(logits, dim=1)
        exp_values = (
            batch_scale * log_action_prob[range(BATCH_SIZE), batch_actions_t]
        )
        policy_loss = -exp_values.mean()

        # Entropy bonus to the loss function
        action_prob = nn.functional.softmax(logits, dim=1)
        entropy = -(action_prob * log_action_prob).sum(dim=1).mean()
        entropy_loss = -ENTROPY_BETA * entropy

        # Compute total loss and its gradient
        loss = policy_loss + entropy_loss
        loss.backward()

        # Use gradient clipping (by l2 norm) before making next gradient step
        nn.utils.clip_grad_norm_(net.parameters(), GRAD_L2_CLIP)
        optimizer.step()

        # Compute KL divergence between old and new policy
        new_logits = net(states)
        new_action_prob = nn.functional.softmax(new_logits, dim=1)
        kl_div = (
            -(action_prob * (new_action_prob / action_prob).log())
            .sum(dim=1)
            .mean()
        )

        writer.add_scalar("kl_div", kl_div.item(), frame)

        # Compute gradient metrics: max and l2 norms
        grad_max = 0.0
        grad_means = 0.0
        grad_count = 0
        for p in net.parameters():
            grad_max = max(grad_max, p.grad.abs().max().item())
            grad_means += (p.grad ** 2).mean().sqrt().item()
            grad_count += 1

        # Metrics tracking
        writer.add_scalar("baseline", baseline, frame)
        writer.add_scalar("entropy", entropy.item(), frame)
        writer.add_scalar("batch_scales", np.mean(batch_scales), frame)
        writer.add_scalar("batch_scales_std", np.std(batch_scales), frame)
        writer.add_scalar("loss_entropy", entropy_loss.item(), frame)
        writer.add_scalar("loss_policy", policy_loss.item(), frame)
        writer.add_scalar("loss_total", loss.item(), frame)
        writer.add_scalar("grad_l2", grad_means / grad_count, frame)
        writer.add_scalar("grad_max", grad_max, frame)

        # Batch cleanup
        batch_states.clear()
        batch_actions.clear()
        batch_scales.clear()