In [1]:
# now I need to implement prioritized experience replay. The original implementation uses a sum tree.

# From ChatGPT
class SumTree:
    def __init__(self, capacity):
        self.capacity = capacity
        self.tree = [0] * (2 * capacity - 1)
        self.data = [None] * capacity
        self.write_index = 0
        self.sum = 0

    def _propagate(self, index, diff):
        self.tree[index] += diff
        while index != 0:
            index = (index - 1) // 2
            self.tree[index] += diff

    def _retrieve(self, index, value):
        left = 2 * index + 1
        right = left + 1
        if left >= len(self.tree):
            return index
        if value <= self.tree[left]:
            return self._retrieve(left, value)
        else:
            return self._retrieve(right, value - self.tree[left])

    def insert(self, priority, data):
        index = self.write_index + self.capacity - 1
        self.data[self.write_index] = data
        self.write_index += 1
        self._propagate(index, priority - self.tree[index])


    def query(self, value):
        index = self._retrieve(0, value)
        data_index = index - self.capacity + 1
        data = self.data[data_index]
        return data, data_index


In [2]:
# I need to implement my own sumtree if I want to not use 2x the entries and also be able to remove entries. 
class SumTree:
    def __init__(self, capacity):
        self.capacity = capacity
        self.tree = [None] * capacity
        self.size = 0
        self.sum = 0
        
        # nodes have a tuple of (value, data, sum)
        # sum is value of both sides plus current value.
    
    def _propagate(self, index, diff):
        # I don't think we need this because we're starting from a leaf node.
#         self.tree[index] += diff
        while index != 0:
            index = (index - 1) // 2
            self.tree[index][2] += diff
    
    def remove(self, index):
        # get the index value
        # reverse propagate the value
        # get the last index value
        # depropagate the value
        # move it to the old spot


In [3]:
def test_sum_tree_edge_cases():
    capacity = 10
    sum_tree = SumTree(capacity)

    # Test empty tree
    assert sum_tree.query(0) is None

    # Test adding values
    for i in range(capacity):
        sum_tree.insert(0.1, i)
    print(sum_tree.tree)


    # Test querying values
    assert sum_tree.query(0) == 0
    assert sum_tree.query(1) == 1
    assert sum_tree.query(3) == 3
    assert sum_tree.query(6) == 6
    assert sum_tree.query(7) == 7

    # Test removing values

    # Test inserting more values than capacity


In [4]:
# Ok, I think I need to implement this as a basic binary tree with the sum as an extra target.
# good diagrams on rotation here: https://betsybaileyy.github.io/AVL_Tree/
class BinaryTreeNode:
    def __init__(self, data):
        self.left = None
        self.right = None
        self.parent = None
        self.data = data
        self.height = 1

In [5]:
def height(node):
    if node is None:
        return 0
    return node.height

In [None]:
def rotate_left(node):
    """
      A
    D   B
       E C
    
       B
     A   C
    D E
    
    return the new root so we don't have to deal with the parent.
    """
    a = node
    b = node.right
    

In [6]:
# Or I can just implement a basic sum tree, then update the values when we get new data!
# store the data in a regular replay buffer, store that index in the sum tree, then when we get to overwriting,
# we can just change the index and update the value back through the tree!
# this even works for weighting the priority of previous states upwards!
# We can just index into the tree with the right index and increase the priority!


In [39]:
"""
From https://github.com/rlcode/per/blob/master/SumTree.py
"""
import numpy


# SumTree
# a binary tree data structure where the parent’s value is the sum of its children
class SumTree:
    write = 0

    def __init__(self, capacity):
        self.capacity = capacity
        self.tree = numpy.zeros(2 * capacity - 1)
        self.data = numpy.zeros(capacity, dtype=object)
        self.n_entries = 0

    # update to the root node
    def _propagate(self, idx, change):
        parent = (idx - 1) // 2

        self.tree[parent] += change

        if parent != 0:
            self._propagate(parent, change)

    # find sample on leaf node
    def _retrieve(self, idx, s):
        left = 2 * idx + 1
        right = left + 1

        if left >= len(self.tree):
            return idx

        if s <= self.tree[left]:
            return self._retrieve(left, s)
        else:
            return self._retrieve(right, s - self.tree[left])

    def total(self):
        return self.tree[0]
    
    def data_idx(self, idx):
        return (idx + self.capacity + 1) % self.capacity
    
    def tree_idx(self, d_idx):
        return d_idx + self.capacity - 1

    # store priority and sample
    def add(self, p, data):
        idx = self.write + self.capacity - 1

        self.data[self.write] = data
        self.update(idx, p)

        self.write += 1
        if self.write >= self.capacity:
            self.write = 0

        if self.n_entries < self.capacity:
            self.n_entries += 1
        return idx

    # update priority
    def update(self, idx, p):
        change = p - self.tree[idx]

        self.tree[idx] = p
        self._propagate(idx, change)

    # get priority and sample
    def get(self, s):
        idx = self._retrieve(0, s)
        dataIdx = idx - self.capacity + 1

        return (idx, self.tree[idx], self.data[dataIdx])

In [45]:
tree = SumTree(64)
total = 0
for i in range(64):
    assert i == tree.data_idx(tree.tree_idx(i))
    assert 0 <= tree.data_idx(tree.tree_idx(i) + 1) < tree.capacity
    assert 0 <= tree.data_idx(tree.tree_idx(i) - 1) < tree.capacity
    tree.add(i, i)
    total += i
    assert tree.total() == total
total

2016

In [34]:
tree.tree[-1]

63.0

In [35]:
tree.get(1), tree.get(2), tree.get(3), tree.get(4), 

((64, 1.0, 1), (65, 2.0, 2), (65, 2.0, 2), (66, 3.0, 3))

In [47]:
tree.total()

64.0

In [48]:
for i in range(64):
    tree.update(tree.tree_idx(i), 1)
tree.total()

64.0

In [14]:
# actually the default implementation just kinda works. How did I doubt them?
# I think I just need to figure out the default value and then it'll be fine.

In [37]:
# replay buffer. Store (s, a, r, s_n, d) tuples
class PrioritizedReplayBuffer:
    def __init__(self, max_size=1000000):
        self.tree = SumTree(max_size)
    
    def add(self, s, a, r, s_n, d):
        self.tree.add(1, (s, a, r, s_n, d))
    
    def sample_batch(self, batch_size):
        batch = []
        for i in range(batch_size):
            batch.append(self.tree.get(random.random() * self.tree.total()))
        return batch
    
    def add_all(self, sarsd):
        for i in sarsd:
            buffer.add(*i)
    
    def update(self, idx, p):
        self.tree.update(idx, p)