# Intro 
In this notebook, we will be looking at an extension to the DQN-based model that we have been using so far. This extension is called Prioritized Experience Replay (PER). Main paper for PER is [Prioritized Experience Replay](https://arxiv.org/abs/1511.05952).

The general idea is to  sample experiences from the replay buffer in a non-uniform way. Instead of sampling experiences uniformly at random, we will sample experiences with a probability that is proportional to the absolute TD error. This way, we will be more likely to sample experiences that are more surprising or unexpected. This can help the agent learn faster and more effectively.

As discussed previously, TD error is defined as:

$$
\delta = r + \gamma \max_{a'} Q(s', a') - Q(s, a)
$$

# Simulation Environment

In [7]:
# Game of Pong Simulation environment
import gymnasium as gym
import gymnasium.utils.seeding as seeding
from gymnasium.wrappers import AtariPreprocessing, RecordVideo
import ale_py



import numpy as np
import random
from collections import namedtuple, deque
import torch
import torch.nn.functional as F
import torch.optim as optim

BUFFER_SIZE = int(1e5)  # replay buffer size
BATCH_SIZE = 64         # minibatch size
GAMMA = 0.99            # discount factor
TAU = 1e-3              # for soft update of target parameters
LR = 5e-4               # learning rate 
UPDATE_EVERY = 4        # how often to update the network

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


DefaultRandomSeed = 10 
# Create the Pong environment
env = gym.make("ALE/Pong-v5",frameskip=1)
env.np_random, _ = seeding.np_random(DefaultRandomSeed)
env.reset(seed=DefaultRandomSeed)
env = AtariPreprocessing(env) # Frame skipping, Grayscale, Resize (To 84*84), Stack 4 frames




# Example interaction with the environment
for _ in range(1000):
    action = env.action_space.sample()  # Take a random action
    observation, reward , terminated, truncated, info = env.step(action)  # Apply the action
    

    if terminated or truncated:
        state = env.reset()  # Reset the environment if done

print(terminated)
print (truncated)
print(info)
print(observation.shape)


A.L.E: Arcade Learning Environment (version 0.9.0+750d7f9)
[Powered by Stella]


False
False
{'lives': 0, 'episode_frame_number': 291, 'frame_number': 4026}
(84, 84)


# PER Algorithm walkthrough

This section provides a high-level overview of the PER algorithm. The algorithm is based on the DQN algorithm, with some modifications to the replay buffer and sampling process. The following steps are written based on the the pseduo code provided in the paper.

1. Initialize Replay Buffer ($\mathcal{H}$) with capacity $N$.
2. In each episode
    -  In each step (Until episode is done or terminated)
        - Run action selection policy
        - Store transition ($s_t$, $a_t$, $r_{t+1}$, $s_{t+1}$, $p_{t}$) in $\mathcal{H}$
        - for each transition in mininbatch (1:$\mathcal{K}$)
            - Sample transition j with the probability of $P({j}) = \frac{p_{j}^{\alpha}}{\sum_{i}p_{i}^{\alpha}}$
            - Compute the importance sampling weight $w_{j} = \left( \frac{1}{N} \cdot \frac{1}{P(j)} \right)^{\beta}$
            - Compute the TD error $\delta_{j} = r_{j} + \gamma \max_{a^{\prime}} Q(s_{j+1},a^{\prime}, \theta) - Q(s_{j},a_{j}, \theta)$ (TD error deponds on the type of DQN algorithm)
            - update the priority of transition $j$ in $\mathcal{H}$ as $p_{j} = |\delta_{j}| + \epsilon$
            - Accumulate the weight change $\Delta \ = \Delta  + w_{j} \delta_{j} \nabla_{\theta} Q(s_{j},a_{j}, \theta)$
        - Update the Q network weights using the accumulated weight change $\theta = \theta + \eta \Delta$
        - - every C steps update $\hat{Q}$ network weights using the following equation
            $\bar{\theta} = \tau*\theta + (1 - \tau)*\bar{\theta}$

# Model

In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class QNetwork(nn.Module):
    def __init__(self, input_shape, num_actions,seed):
        super(QNetwork, self).__init__()
        print(input_shape)
        self.conv1 = nn.Conv2d(in_channels=input_shape[0], out_channels=16, kernel_size=8, stride=4)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=4, stride=2)
        self.fc1 = nn.Linear(in_features=self._feature_size(input_shape), out_features=256)
        self.fc2 = nn.Linear(in_features=256, out_features=num_actions)
        self.seed = torch.manual_seed(seed)

    def forward(self, x):
        # input : Observations 
        # Ouput : Q value of different actions
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

    def _feature_size(self, input_shape):
        return nn.Sequential(
            nn.Conv2d(input_shape[0], 16, 8, 4),
            nn.ReLU(),
            nn.Conv2d(16, 32, 4, 2),
            nn.ReLU()
        ).forward(torch.zeros(1, *input_shape)).view(1, -1).size(1)

net = QNetwork((2, 84, 84), 4,10) 
print(net)


(2, 84, 84)
QNetwork(
  (conv1): Conv2d(2, 16, kernel_size=(8, 8), stride=(4, 4))
  (conv2): Conv2d(16, 32, kernel_size=(4, 4), stride=(2, 2))
  (fc1): Linear(in_features=2592, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=4, bias=True)
)


# Agent

## Prioritized Experience Replay

In [9]:
class Prioritized_ReplayBuffer:
    def __init__(self, buffer_size, batch_size, seed,alpha=0.6,beta=0.4,beta_increment=1e-5):
        self.buffer_size = buffer_size
        self.memory = deque(maxlen=self.buffer_size)
        self.priorities = deque(maxlen=self.buffer_size)
        self.batch_size = batch_size
        self.experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"])
        self.priorities = namedtuple("Priorities", field_names=["priority"])
        self.seed = random.seed(seed)

        self._alpha = alpha # alpha is the importance sampling factor. alpha = 0 means uniform sampling
        self._beta = beta
        self._max_priority = 1.0
    
    def __len__(self):
        return len(self.memory)
    
    def add(self, state, action, reward, next_state, done):
        e = self.experience(state, action, reward, next_state, done)
        self.memory.append(e)
        self.priorities.append(self._max_priority)

    def sample(self):
        # Sample the batch
        probs = np.array(self.priorities**self._alpha)/np.sum(self.priorities**self._alpha)
        idx = np.random.choice(np.arange(len(self.memory)), size=self.batch_size, p=probs, replace=False)
        experiences = [self.memory[i] for i in idx]
        weights = (self.buffer_size*probs[idx])**(-self._beta)
        weights = weights/np.max(weights)
        return experiences, idx , weights
    
    def update_priority(self, idx, priority):
        self.priorities[idx] = priority + 1e-5
        self._max_priority = max(self._max_priority, priority)


In [10]:
class DQNAgent():
    def __init__(self, state_size, action_size, seed ,random_policy=False,beta=0.4):
        self.state_size = state_size
        self.action_size = action_size
        self.seed = random.seed(seed)
        self.random = random_policy
        self.beta = beta

        # Q-Network
        self.qnetwork_local = QNetwork(state_size, action_size, seed).to(device)
        self.qnetwork_target = QNetwork(state_size, action_size, seed).to(device)
        self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=LR)

        # Replay memory
        self.memory = Prioritized_ReplayBuffer(BUFFER_SIZE, BATCH_SIZE, seed)
        self.t_step = 0

    def load_weights(self, model_weights):
        self.qnetwork_local.load_state_dict(torch.load('models/{}'.format(model_weights)))
    
    def save_weights(self, model_weights):
        torch.save(self.qnetwork_local.state_dict(), 'models/{}'.format(model_weights))

    def step(self, state, action, reward, next_state, done):
        if self.random:
            return
        # Save experience in replay memory
        self.memory.add(state, action, reward, next_state, done)
        self.t_step = (self.t_step + 1) % UPDATE_EVERY

        if self.t_step == 0:
            # If enough samples are available in memory, get random subset and learn
            if len(self.memory) > BATCH_SIZE:
                experiences, idx,weights = self.memory.sample()
                self.learn(experiences, idx,weights,GAMMA)
    
    def act(self, state, eps=0.):
        if self.random:
            return np.random.randint(self.action_size)
        state = torch.from_numpy(state).float().unsqueeze(0).to(device)
        self.qnetwork_local.eval()
        with torch.no_grad():
            action_values = self.qnetwork_local(state)
        self.qnetwork_local.train()

        # Epsilon-greedy action selection
        if random.random() > eps:
            return np.argmax(action_values.cpu().data.numpy())
        else:
            return random.choice(np.arange(self.action_size))
    
    def learn(self, experiences, idx,weights, gamma):
        states, actions, rewards, next_states, dones = experiences
        current = self.qnetwork_local(states).gather(1, actions)
        next_qvalues = self.qnetwork_target(next_states).detach().max(1)[0].unsqueeze(1)
        target = rewards + (gamma * next_qvalues * (1 - dones))
        # calculate the TD error
        td_error = target - current
        # update the priority
        for i in range(len(idx)):
            self.memory.update_priority(idx[i], td_error[i].item())
        
        loss = (td_error* weights).pow(2)   

        self.qnetwork_local.zero_grad()
        loss.backward()
        self.optimizer.step()
        self.soft_update(self.qnetwork_local, self.qnetwork_target, TAU)

        def soft_update(self, local_model, target_model, tau):
            for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
                target_param.data.copy_(tau*local_param.data + (1.0-tau)*target_param.data)
        




# Training

In [11]:
class DQN_Training():
    def __init__(self, env, agent, n_episodes=2000, max_t=1000, eps_start=1.0, eps_end=0.01, eps_decay=0.995):
        self.env = env
        self.agent = agent
        self.n_episodes = n_episodes
        self.max_t = max_t
        self.eps_start = eps_start
        self.eps_end = eps_end
        self.eps_decay = eps_decay
    
    def train(self):
        
        for ep_number in range(1, self.n_episodes+1):
            state,_ = self.env.reset()
            score = 0
            eps = max(self.eps_end, self.eps_start*self.eps_decay)
            episode_step = 0
            episode_reward = 0
            while not done:
                episode_step += 1
                action = self.agent.act(state, eps)
                next_state, reward, done, terminated, truncated, info = self.env.step(action)
                if terminated or truncated or episode_step >= self.max_t:
                    done = True
                self.agent.step(state, action, reward, next_state, done)
                episode_reward += reward
                state = next_state
            print('Episode: {}\tScore: {:.2f}'.format(ep_number, score))