# Prioritized replay buffer

To implement this method, we have to introduce several changes in our code.
First of all, we need a new replay buffer that will track priorities, sample a batch
according to them, calculate weights, and let us update priorities after the loss has
become known. The second change will be the loss function itself. Now we not only
need to incorporate weights for every sample, but we need to pass loss values back
to the replay buffer to adjust the priorities of the sampled transitions.
In the example file Chapter08/05_dqn_prio_replay.py, we have all those changes
implemented. For the sake of simplicity, the new priority replay buffer class uses
a very similar storage scheme to our previous replay buffer. Unfortunately, new
requirements for prioritization make it impossible to implement sampling in O(1)
time to buffer size. If we are using simple lists, every time that we sample a new
batch, we need to process all the priorities, which makes our sampling have O(N)
time complexity in proportion to the buffer size. It's not a big deal if our buffer is
small, such as 100k samples, but may become an issue for real-life large buffers
of millions of transitions. There are other storage schemes that support efficient
sampling in O(log N) time, for example, using the segment tree data structure. You
can find such implementation in the OpenAI Baselines project: https://github.com/openai/baselines. The PTAN library also provides an efficient prioritized
replay buffer in the class ptan.experience.PrioritizedReplayBuffer. You can
update the example to use the more efficient version and check the effect on training
performance.
But, for now, let's take a look at the naïve version, whose source code is in lib/dqn_extra.py.

``` python
class PrioReplayBuffer:
    def __init__(self, exp_source, buf_size, prob_alpha=0.6):
        self.exp_source_iter = iter(exp_source)
        self.prob_alpha = prob_alpha
        self.capacity = buf_size
        self.pos = 0
        self.buffer = []
        self.priorities = np.zeros(
            (buf_size, ), dtype=np.float32)
        self.beta = BETA_START

    def update_beta(self, idx):
        v = BETA_START + idx * (1.0 - BETA_START) / \
            BETA_FRAMES
        self.beta = min(1.0, v)
        return self.beta

    def __len__(self):
        return len(self.buffer)

    def populate(self, count):
        max_prio = self.priorities.max() if \
            self.buffer else 1.0
        for _ in range(count):
            sample = next(self.exp_source_iter)
            if len(self.buffer) < self.capacity:
                self.buffer.append(sample)
            else:
                self.buffer[self.pos] = sample
            self.priorities[self.pos] = max_prio
            self.pos = (self.pos + 1) % self.capacity

    def sample(self, batch_size):
        if len(self.buffer) == self.capacity:
            prios = self.priorities
        else:
            prios = self.priorities[:self.pos]
        probs = prios ** self.prob_alpha

        probs /= probs.sum()
        indices = np.random.choice(len(self.buffer),
                                   batch_size, p=probs)
        samples = [self.buffer[idx] for idx in indices]
        total = len(self.buffer)
        weights = (total * probs[indices]) ** (-self.beta)
        weights /= weights.max()
        return samples, indices, \
               np.array(weights, dtype=np.float32)

    def update_priorities(self, batch_indices,
                          batch_priorities):
        for idx, prio in zip(batch_indices,
                             batch_priorities):
            self.priorities[idx] = prio
```

In the beginning, we define parameters for the 𝛽 increase rate. Our beta will be changed
from 0.4 to 1.0 during the first 100k frames.

The class for the priority replay buffer stores samples in a circular buffer (it allows
us to keep a fixed amount of entries without reallocating the list) and a NumPy array
to keep priorities. We also store the iterator to the experience source object to pull
the samples from the environment.

The populate() method needs to pull the given number of transitions from the
ExperienceSource object and store them in the buffer. As our storage for the
transitions is implemented as a circular buffer, we have two different situations
with this buffer:
- When our buffer hasn't reached the maximum capacity, we just need to
append a new transition to the buffer.
- If the buffer is already full, we need to overwrite the oldest transition, which
is tracked by the pos class field, and adjust this position modulo buffer's size.
The method update_beta needs to be called periodically to increase beta according
to schedule.

In the sample method, we need to convert priorities to probabilities using our 𝛼𝛼
hyperparameter.

Then, using those probabilities, we sample our buffer to obtain a batch of samples.

As the last step, we calculate weights for samples in the batch and return three
objects: the batch, indices, and weights. Indices for batch samples are required to
update priorities for sampled items.

The last function of the priority replay buffer allows us to update new priorities for
the processed batch. It's the responsibility of the caller to use this function with the
calculated losses for the batch.
The next custom function that we have in our example is the loss calculation. As
the MSELoss class in PyTorch doesn't support weights (which is understandable, as
MSE is loss used in regression problems, but weighting of the samples is commonly
utilized in classification losses), we need to calculate the MSE and explicitly multiply
the result on the weights:

``` python
def calc_loss(batch, batch_weights, net, tgt_net,
              gamma, device="cpu"):
    states, actions, rewards, dones, next_states = \
        common.unpack_batch(batch)

    states_v = torch.tensor(states).to(device)
    actions_v = torch.tensor(actions).to(device)
    rewards_v = torch.tensor(rewards).to(device)
    done_mask = torch.BoolTensor(dones).to(device)
    batch_weights_v = torch.tensor(batch_weights).to(device)

    actions_v = actions_v.unsqueeze(-1)
    state_action_vals = net(states_v).gather(1, actions_v)
    state_action_vals = state_action_vals.squeeze(-1)
    with torch.no_grad():
        next_states_v = torch.tensor(next_states).to(device)
        next_s_vals = tgt_net(next_states_v).max(1)[0]
        next_s_vals[done_mask] = 0.0
        exp_sa_vals = next_s_vals.detach() * gamma + rewards_v
    l = (state_action_vals - exp_sa_vals) ** 2
    losses_v = batch_weights_v * l
    return losses_v.mean(), \
           (losses_v + 1e-5).data.cpu().numpy()
```

In the last part of the loss calculation, we implement the same MSE loss but write
our expression explicitly, rather than using the library. This allows us to take into
account the weights of samples and keep individual loss values for every sample.
Those values will be passed to the priority replay buffer to update priorities. A small
value is added to every loss to handle the situation of zero loss value, which will lead
to zero priority for an entry in the replay buffer.

In the main section of the utility, we have only two updates: the creation of the
replay buffer and our processing function. Buffer creation is straightforward,
so we will take a look at only a new processing function, there are several changes here:

- Our batch now contains three entities: the batch of data, indices of sampled
items, and samples' weights.
- We call our new loss function, which accepts weights and returns the
additional items' priorities. They are passed to the buffer.update_
priorities function to reprioritize items that we have sampled.
- We call the update_beta method of the buffer to change the beta parameter
according to schedule.

In [1]:
import sys
sys.path.append("../Chapter08/")

In [None]:
import gym
import ptan
import argparse
import random

import torch
import torch.optim as optim

from ignite.engine import Engine

from lib import dqn_model, common, dqn_extra

NAME = "05_prio_replay"
PRIO_REPLAY_ALPHA = 0.6


def calc_loss(batch, batch_weights, net, tgt_net,
              gamma, device="cpu"):
    states, actions, rewards, dones, next_states = \
        common.unpack_batch(batch)

    states_v = torch.tensor(states).to(device)
    actions_v = torch.tensor(actions).to(device)
    rewards_v = torch.tensor(rewards).to(device)
    done_mask = torch.BoolTensor(dones).to(device)
    batch_weights_v = torch.tensor(batch_weights).to(device)

    actions_v = actions_v.unsqueeze(-1)
    state_action_vals = net(states_v).gather(1, actions_v)
    state_action_vals = state_action_vals.squeeze(-1)
    with torch.no_grad():
        next_states_v = torch.tensor(next_states).to(device)
        next_s_vals = tgt_net(next_states_v).max(1)[0]
        next_s_vals[done_mask] = 0.0
        exp_sa_vals = next_s_vals.detach() * gamma + rewards_v
    l = (state_action_vals - exp_sa_vals) ** 2
    losses_v = batch_weights_v * l
    return losses_v.mean(), \
           (losses_v + 1e-5).data.cpu().numpy()


random.seed(common.SEED)
torch.manual_seed(common.SEED)
params = common.HYPERPARAMS['pong']
device = torch.device("cuda")

env = gym.make(params.env_name)
env = ptan.common.wrappers.wrap_dqn(env)
env.seed(common.SEED)

net = dqn_model.DQN(env.observation_space.shape, env.action_space.n).to(device)

tgt_net = ptan.agent.TargetNet(net)
selector = ptan.actions.EpsilonGreedyActionSelector(epsilon=params.epsilon_start)
epsilon_tracker = common.EpsilonTracker(selector, params)
agent = ptan.agent.DQNAgent(net, selector, device=device)

exp_source = ptan.experience.ExperienceSourceFirstLast(
    env, agent, gamma=params.gamma)
buffer = dqn_extra.PrioReplayBuffer(
    exp_source, params.replay_size, PRIO_REPLAY_ALPHA)
optimizer = optim.Adam(net.parameters(), lr=params.learning_rate)

def process_batch(engine, batch_data):
    batch, batch_indices, batch_weights = batch_data
    optimizer.zero_grad()
    loss_v, sample_prios = calc_loss(
        batch, batch_weights, net, tgt_net.target_model,
        gamma=params.gamma, device=device)
    loss_v.backward()
    optimizer.step()
    buffer.update_priorities(batch_indices, sample_prios)
    epsilon_tracker.frame(engine.state.iteration)
    if engine.state.iteration % params.target_net_sync == 0:
        tgt_net.sync()
    return {
        "loss": loss_v.item(),
        "epsilon": selector.epsilon,
        "beta": buffer.update_beta(engine.state.iteration),
    }

engine = Engine(process_batch)
common.setup_ignite(engine, params, exp_source, NAME)
engine.run(common.batch_generator(buffer, params.replay_initial, params.batch_size))

Episode 1: reward=-21, steps=908, speed=0.0 f/s, elapsed=0:00:27
Episode 2: reward=-19, steps=1130, speed=0.0 f/s, elapsed=0:00:27
Episode 3: reward=-20, steps=969, speed=0.0 f/s, elapsed=0:00:27
Episode 4: reward=-20, steps=1064, speed=0.0 f/s, elapsed=0:00:27
Episode 5: reward=-21, steps=838, speed=0.0 f/s, elapsed=0:00:27
Episode 6: reward=-20, steps=955, speed=0.0 f/s, elapsed=0:00:27
Episode 7: reward=-20, steps=894, speed=0.0 f/s, elapsed=0:00:27
Episode 8: reward=-21, steps=780, speed=0.0 f/s, elapsed=0:00:27
Episode 9: reward=-19, steps=976, speed=0.0 f/s, elapsed=0:00:27
Episode 10: reward=-20, steps=925, speed=0.0 f/s, elapsed=0:00:27
Episode 11: reward=-21, steps=882, speed=61.1 f/s, elapsed=0:00:32
Episode 12: reward=-20, steps=935, speed=61.1 f/s, elapsed=0:00:48
Episode 13: reward=-20, steps=1020, speed=61.1 f/s, elapsed=0:01:04
Episode 14: reward=-19, steps=993, speed=61.1 f/s, elapsed=0:01:21
Episode 15: reward=-21, steps=897, speed=61.1 f/s, elapsed=0:01:36
Episode 16:

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.

