# Let's play around with prioritized replay

In [1]:
import gymnasium as gym
import math
import random
from collections import namedtuple
import heapq
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
device = torch.device("mps") 

In [2]:
env = gym.make("Blackjack-v1") # is the stochasticity going to help in highlighting the usefulness of a distributional approach? let's hope so...

## DQN, Replay Memory and Scheduler

In [3]:
class DQN(nn.Module):
    def __init__(self, n_observations, n_actions):
        super(DQN, self).__init__()
        self.layer1 = nn.Linear(n_observations, 128)
        self.layer2 = nn.Linear(128, 64)
        self.layer3 = nn.Linear(64, n_actions)

    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        return self.layer3(x)

In [4]:
# https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/rl/dqn/replay_buffer.py
class ReplayMemory:
    def __init__(self, capacity, alpha, obs_shape, action_shape):
        self.capacity = capacity
        self.alpha = alpha

        self.priority_sum = torch.zeros(2 * self.capacity)
        self.priority_min = torch.full((2 * self.capacity,), float('inf'))

        self.max_priority = 1.

        self.data = {
            'state': torch.zeros((capacity, *obs_shape)),
            'action': torch.zeros((capacity, *action_shape), dtype=torch.int32),
            'reward': torch.zeros(capacity, dtype=torch.float32),
            'next_state': torch.zeros((capacity, *obs_shape)),
            'done': torch.zeros(capacity, dtype=torch.bool)
        }
        self.next_idx = 0

        self.size = 0

    def add(self, obs, action, reward, next_obs, done):
        idx = self.next_idx

        self.data['state'][idx] = obs
        self.data['action'][idx] = action
        self.data['reward'][idx] = reward
        self.data['done'][idx] = done

        if next_obs is None:
            self.data['next_state'][idx] = torch.zeros_like(next_obs)
        else:
            self.data['next_state'][idx] = next_obs

        self.next_idx = (idx + 1) % self.capacity
        self.size = min(self.capacity, self.size + 1)

        priority_alpha = self.max_priority ** self.alpha
        self._set_priority_min(idx, priority_alpha)
        self._set_priority_sum(idx, priority_alpha)

    def _set_priority_min(self, idx, priority_alpha):
        idx += self.capacity
        self.priority_min[idx] = priority_alpha

        while idx >= 2:
            idx //= 2
            self.priority_min[idx] = min(self.priority_min[2 * idx], self.priority_min[2 * idx + 1])

    def _set_priority_sum(self, idx, priority):
        idx += self.capacity
        self.priority_sum[idx] = priority

        while idx >= 2:
            idx //= 2
            self.priority_sum[idx] = self.priority_sum[2 * idx] + self.priority_sum[2 * idx + 1]

    def _sum(self):
        return self.priority_sum[1]

    def _min(self):
        return self.priority_min[1]

    def find_prefix_sum_idx(self, prefix_sum):
        idx = 1
        while idx < self.capacity:
            if self.priority_sum[idx * 2] > prefix_sum:
                idx = 2 * idx
            else:
                prefix_sum -= self.priority_sum[idx * 2]
                idx = 2 * idx + 1

        return idx - self.capacity

    def sample(self, batch_size, beta):
            samples = {
                'weights': torch.zeros(size=(batch_size,), dtype=torch.float32),
                'indexes': torch.zeros(size=(batch_size,), dtype=torch.int32)
            }

            for i in range(batch_size):
                p = random.random() * self._sum()
                idx = self.find_prefix_sum_idx(p)
                samples['indexes'][i] = idx

            prob_min = self._min() / self._sum()
            max_weight = (prob_min * self.size) ** (-beta)

            for i in range(batch_size):
                idx = int(samples['indexes'][i])
                prob = self.priority_sum[idx + self.capacity] / self._sum()
                weight = (prob * self.size) ** (-beta)
                samples['weights'][i] = weight / max_weight

            for k, v in self.data.items():
                samples[k] = v[samples['indexes'].long()]

            return samples

    def update_priorities(self, indexes, priorities):
        for idx, priority in zip(indexes, priorities):
            self.max_priority = max(self.max_priority, priority)

            priority_alpha = priority ** self.alpha
            self._set_priority_min(idx, priority_alpha)
            self._set_priority_sum(idx, priority_alpha)

    def is_full(self):
        return self.capacity == self.size

In [5]:
class EpsilonScheduler:
    def __init__(self, eps_start, eps_end, eps_decay):
        self.eps_start = eps_start
        self.eps_end = eps_end
        self.eps_decay = eps_decay
        self.steps_done = 0

    def get_epsilon(self):
        epsilon = self.eps_end + (self.eps_start - self.eps_end) * math.exp(-1. * self.steps_done / self.eps_decay)
        self.steps_done += 1
        return epsilon

    def reset(self):
        self.steps_done = 0

# Hyperparameters and initialisations

In [6]:
BATCH_SIZE = 32
GAMMA = 0.99
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 200
LR = 0.001
TARGET_UPDATE = 10
ALPHA = 0.6  # Prioritization level
BETA_START = 0.4  # Importance-sampling weight
BETA_FRAMES = 10_000  # Number of frames over which beta will be annealed to 1
MEMORY_CAPACITY = 10_000

n_actions = env.action_space.n
n_observations = len(env.reset()[0])

policy_net = DQN(n_observations, n_actions).to(device)
target_net = DQN(n_observations, n_actions).to(device)
target_net.load_state_dict(policy_net.state_dict())

optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)
epsilon_scheduler = EpsilonScheduler(EPS_START, EPS_END, EPS_DECAY)

### Action selection and single optimisation step

In [7]:
def select_action(state, policy_net, eps_scheduler):
    sample = random.random()
    epsilon = eps_scheduler.get_epsilon()

    if sample < epsilon:
        return torch.tensor([[env.action_space.sample()]], device=device, dtype=torch.long)
    else:
        with torch.no_grad():
            return policy_net(state).max(1)[1].view(1, 1)

Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))
def optimize_model(memory, batch_size=BATCH_SIZE, beta=BETA_START):
    if len(memory) < batch_size:
        return

    transitions, weights, indexes = memory.sample(batch_size, BETA_START)
    batch = Transition(*zip(*transitions))

    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)
    weights = torch.tensor(weights, dtype=torch.float32)

    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), dtype=torch.bool)
    non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])

    state_action_values = policy_net(state_batch).gather(1, action_batch)

    next_state_values = torch.zeros(batch_size)
    next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0].detach()
    
    expected_state_action_values = (next_state_values * GAMMA) + reward_batch

    loss = F.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1), reduction='none')
    loss = (loss * weights).mean()

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    new_priorities = torch.abs(state_action_values - expected_state_action_values.unsqueeze(1)).detach().numpy()
    memory.update_priorities(indexes, new_priorities + 1e-5)

In [8]:
state, _ = env.reset()
state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
action = select_action(state, policy_net, epsilon_scheduler)

# init memory here because i need select_action()
memory = ReplayMemory(capacity=MEMORY_CAPACITY, alpha=0.6, obs_shape=state.shape, action_shape=action.shape)

## Training loop

In [9]:
num_episodes = 100
for episode in range(num_episodes):
    state, _ = env.reset()
    state = torch.tensor([state], dtype=torch.float32)

    while True:
        action = select_action(state, policy_net, epsilon_scheduler)
        next_state, reward, done, _, _ = env.step(action.item())
        reward = torch.tensor([reward], dtype=torch.float32)

        if not done:
            next_state = torch.tensor([next_state], dtype=torch.float32)
        else:
            next_state = None

        # TODO something is wrong with next state handling here...
        memory.add(state, action, next_state, reward, done)
        state = next_state

        optimize_model(memory)

        if done:
            break

    if episode % TARGET_UPDATE == 0:
        target_net.load_state_dict(policy_net.state_dict())

TypeError: can't assign a NoneType to a torch.FloatTensor

## Evaluation

In [None]:
def evaluate_agent(env, num_episodes, device, policy_net=None):
    agent = "Random" if policy_net is None else "Trained"
    if policy_net is None:
        action_selection = lambda state: torch.tensor([[env.action_space.sample()]], device=device, dtype=torch.long)
    else:
        action_selection = lambda state: policy_net(state).max(1)[1].view(1, 1)

    wins = 0
    total_earnings = 0
    for _ in range(num_episodes):
        state, info = env.reset()
        state = torch.tensor([state], dtype=torch.float32, device=device)

        while True:
            with torch.no_grad():
                action = action_selection(state)
            next_state, reward, done, _, _ = env.step(action.item())
            total_earnings += reward

            state = torch.tensor([next_state], dtype=torch.float32, device=device)

            if done:
                if reward > 0:
                    wins += 1
                break

    win_rate = wins / num_episodes
    average_earnings = total_earnings / num_episodes
    print(f'{agent} agent => Evaluation over {num_episodes} episodes. Win Rate: {win_rate:.2f}; Avg. Earnings: {average_earnings:.2f}')

    return win_rate

In [None]:
num_evaluation_episodes = 1000
evaluate_agent(env, num_evaluation_episodes, device, policy_net=policy_net);
evaluate_agent(env, num_evaluation_episodes, device);

Trained agent => Evaluation over 1000 episodes. Win Rate: 0.00; Avg. Earnings: -1.00
Random agent => Evaluation over 1000 episodes. Win Rate: 0.28; Avg. Earnings: -0.39


Not too shabby.