In [3]:
# prioritized experience replay vs replay memory

from collections import deque
import random
import numpy as np

In [2]:
class ReplayMemory():
    def __init__(self, maxlen):
        self.memory = deque([], maxlen=maxlen)

    def append(self, transition):
        self.memory.append(transition)

    def sample(self, sample_size):
        return random.sample(self.memory, sample_size)

    def __len__(self):
        return len(self.memory)

In [None]:
r_memory = ReplayMemory(1000)
r_memory.append(1)
r_memory.append(2)
r_memory.append(3)
# Sample a single element from the replay memory
r_memory.sample(1)

[1]

In [35]:
class SumTree():
    def __init__(self, capacity):
        self.capacity = capacity
        self.tree = np.zeros(2 * capacity - 1, dtype=np.float32)
        self.data_pointer = 0
    
    def update(self, idx, priority):
        tree_index = idx + self.capacity - 1
        change = priority - self.tree[tree_index]
        self.tree[tree_index] = priority
        while tree_index != 0:
            tree_index = (tree_index - 1) // 2
            self.tree[tree_index] += change
        
    def get_leaf(self, idx):
        parent_index = 0
        print(self.tree)
        while True:
            print(f"Searching for leaf: idx={idx}, parent_index={parent_index}")
            left_child = 2 * parent_index + 1
            right_child = left_child + 1

            if left_child >= len(self.tree):
                leaf_index = parent_index
                break
            if idx <= self.tree[left_child]:
                parent_index = left_child
            else:
                idx -= self.tree[left_child]
                parent_index = right_child
        data_index = leaf_index - self.capacity + 1
        return data_index, self.tree[leaf_index]
    
    def total_sum(self):
        return self.tree[0]
    
    def print_tree(self):
        print(self.tree)

In [36]:
sum_tree = SumTree(4)
sum_tree.update(0, 1)
sum_tree.update(1, 2)

sum_tree.get_leaf(4)

[3. 3. 0. 1. 2. 0. 0.]
Searching for leaf: idx=4, parent_index=0
Searching for leaf: idx=1.0, parent_index=2
Searching for leaf: idx=1.0, parent_index=6


(3, np.float32(0.0))

In [37]:
sum_tree.total_sum()

np.float32(3.0)

In [38]:
sum_tree.print_tree()

[3. 3. 0. 1. 2. 0. 0.]


In [None]:
class PrioritizedReplayMemory():
    def __init__(self, capacity, alpha=0.6):
        self.capacity = capacity
        self.alpha = alpha  # alpha = 0.6 (controls how much prioritization is used) 0 = no prioritization, 1 = full prioritization
        self.priorities = SumTree(capacity)
        self.buffer = []
        self.position = 0
        self.max_priority = 1.0
        
    def store(self, state, action, reward, next_state, done, priority):
        experience = (state, action, reward, next_state, done)
        
        if len(self.buffer) < self.capacity:
            self.buffer.append(experience)
        else:
            self.buffer[self.position] = experience
        
        self.priorities.update(self.position, priority ** self.alpha)
        self.max_priority = max(self.max_priority, priority)
        self.position = (self.position + 1) % self.capacity

    def update_priority(self, index, priority):
        self.priorities.update(index, priority ** self.alpha)
        self.max_priority = max(self.max_priority, priority)
    
    def sample(self, batch_size, beta=0.4):
        indices = []
        priorities = []
        experiences = []

        # sample based on priority
        total_priority = self.priorities.total_sum()
        segment_size = total_priority / batch_size

        for i in range(batch_size):
            a = segment_size * i
            b = segment_size * (i + 1)
            value = random.uniform(a, b)

            index, priority = self.priorities.get_leaf(value)
            indices.append(index)
            priorities.append(priority)
            experiences.append(self.buffer[index])

        # calculate importance sampling weights
        weights = []
        min_prob = min(priorities) / total_priority
        max_weight = (min_prob * len(self.buffer)) ** (-beta)

        for priority in priorities:
            prob = priority / total_priority
            weight = (prob * len(self.buffer)) ** (-beta)
            weights.append(weight / max_weight)
        
        return experiences, indices, weights

    def get_max_priority(self):
        return self.max_priority

In [72]:
prm = PrioritizedReplayMemory(3)
state = 0
new_state = 1
action = 2
reward = 1
done = True
priority = 1.0
prm.store(state, action, reward, new_state, done, priority)

state = 1
new_state = 2
action = 2
reward = 0
done = False
priority = 0.1
prm.store(state, action, reward, new_state, done, priority)

prm.sample(1, beta=0.4)  # will sample first element with priority 1.0 most of the time


[1.2511886  0.25118864 1.         0.25118864 0.        ]
Searching for leaf: idx=0.9921334981918335, parent_index=0
Searching for leaf: idx=0.7409448623657227, parent_index=2


([(0, 2, 1, 1, True)], [0], [np.float32(1.0)])

## Why Alpha = 0.6 is Common
This value was found empirically to provide a good balance:

Strong enough to accelerate learning by focusing on important experiences
Not too strong to completely ignore experiences with small TD-errors
Maintains diversity in the training batch
Proven effective across many different environments in the original PER paper

### Trade-offs
Higher Alpha (closer to 1.0):

- ✅ Faster initial learning
- ✅ Focus on most informative experiences
- ❌ Risk of overfitting to high-error experiences
- ❌ May ignore important but "boring" experiences

Lower Alpha (closer to 0.0):

- ✅ More diverse sampling
- ✅ Less risk of overfitting
- ❌ Slower learning
- ❌ Less benefit from prioritization

## The Problem Beta Solves
When we use prioritized sampling instead of uniform sampling, we introduce bias into our learning algorithm. Beta corrects this bias to maintain theoretical convergence guarantees.

## What is Importance Sampling?
Imagine you're trying to estimate the average height of people in a city, but your sampling method accidentally selects tall people more often. To get the correct average, you need to down-weight the tall people's contributions. That's importance sampling!


### Beta Values and Their Effects
#### Beta = 0 (No Bias Correction)
- all samples get equal weight in loss calculation
- Fastest learning but biased
- Essentially ignores the sampling bias

#### Beta = 1 (Full Bias Correction)
- Fully corrects the sampling bias
- Theoretically sound but slower learning
- High-priority samples get heavily down-weighted

#### Beta = 0.4 → 1.0 (Annealed)

- Starts with fast, biased learning
- Gradually becomes unbiased
- Best of both worlds approach