Prioritised experience replay (PER) implementation from paper: https://arxiv.org/pdf/1511.05952.pdf

Nice PER blog post: https://adventuresinmachinelearning.com/prioritised-experience-replay/
SumTree blog post: https://adventuresinmachinelearning.com/sumtree-introduction-python/

In [None]:
import networkx as nx
import matplotlib.pyplot as plt
from networkx.drawing.nx_pydot import graphviz_layout
import seaborn as sns
import copy
import numpy as np

First lets try to implement a SumTree data structure for weighted sampling and sorting.

Leaf nodes -> cumulative weight of each possible sample which can be sampled. I.e. if we have 10 possible samples, then will have 10 leaf nodes.

Each leaf node has a represents a unique sample, and its value is the sample's weight. Each parent node value is the cumulative sum of its child leaf nods. Therefore, the value of the root node will be the cumulative sum of all sample (leaf node) weights.

In [None]:
class SumTree:
    def __init__(self, leaf_values, leaf_weights, debug_mode=False):
        self.debug_mode = debug_mode
        self.init_sum_tree(leaf_values, leaf_weights)
                
    def init_sum_tree(self, leaf_values, leaf_weights):
        self.tree = nx.DiGraph()
        
        # add leaf nodes
        leaf_ids = range(len(leaf_values))
        for _id, v, w in zip(leaf_ids, leaf_values, leaf_weights):
            self.tree.add_node(_id, value=v, weight=w, is_leaf=True)
            
        child_ids = copy.deepcopy(leaf_ids)
        
        last_idx = child_ids[-1]
        while len(child_ids) > 1:
            inodes = iter(child_ids)
            if len(child_ids) % 2 != 0:
                # last node will be left over
                left_over_node = child_ids[-1]
            else:
                left_over_node = None
            child_ids = [] # track child ids to add parents for in next level of tree
            for pair in zip(inodes, inodes):
                parent_id = last_idx + 1
                child_ids.append(parent_id)
                self.tree.add_node(parent_id,
                                   value=None,
                                   weight=self.tree.nodes[pair[0]]['weight']+self.tree.nodes[pair[1]]['weight'],
                                   is_leaf=False)
                self.tree.add_edge(parent_id, pair[0])
                self.tree.add_edge(parent_id, pair[1])
                last_idx += 1
            if left_over_node is not None:
                child_ids.append(left_over_node)
            
        self.tree.graph['root_id'] = child_ids[-1]
        self.tree.graph['leaf_ids'] = list(leaf_ids)
            
    def retrieve(self, weight, node_id):
        '''Recursive traversal from node_id to a leaf node.'''
        if self.tree.nodes[node_id]['is_leaf']:
            # reached leaf node
            if self.debug_mode:
                print(f'reached leaf node {node_id}')
            return node_id
        else:
            # not yet reached leaf node, keep traversing
            if self.debug_mode:
                print(f'curr node: {node_id} | weight to consider: {weight}')
        
            # get children of current node
            children = list(self.tree.successors(node_id))
            if self.debug_mode:
                print(f'children: {children}')

            # choose if traverse to left or right child
            if self.tree.nodes[children[0]]['weight'] >= weight:
                # keep weight same, traverse to left-hand child
                if self.debug_mode:
                    print(f'traverse to LHS child w/ weight to consider: {weight}')
                return self.retrieve(weight, children[0])
            else:
                # subtract left-hand child's weight from weight, traverse to right-hand child
                if self.debug_mode:
                    print(f'traverse to RHS child w/ updated weight to consider: {weight - self.tree.nodes[children[0]]["weight"]}')
                return self.retrieve(weight - self.tree.nodes[children[0]]['weight'], children[1])
            
    def update(self, node_id, new_weight):
        change = new_weight - self.tree.nodes[node_id]['weight']
        self.tree.nodes[node_id]['weight'] = new_weight
        self.propagate_changes(change, list(self.tree.predecessors(node_id))[0])
        
    def propagate_changes(self, change, node_id):
        self.tree.nodes[node_id]['weight'] += change
        if list(self.tree.predecessors(node_id))[0] is not None:
            self.propagate_changes(change, list(self.tree.predecessors(node_id))[0])
            
        predecessor = list(self.tree.predecessors(node_id))
        if len(predecessor) > 0:
            self.propagate_changes(change, predecessor)
        else:
            # node_id is root node, no further change to propagate
            pass
            
    def sample(self):
        # get total tree weight
        total_tree_sum = self.tree.nodes[self.tree.graph['root_id']]['weight']
        if self.debug_mode:
            print(f'tree sum: {total_tree_sum}')
        
        # choose random weight to sample
        random_weight = np.random.uniform(0, total_tree_sum)
#         random_weight = 4.5
        if self.debug_mode:
            print(f'randomly sampled weight: {random_weight}')
        
        # traverse tree from root node to retrieve node id of this randomly sampled weight in O(log(n))
        return self.retrieve(random_weight, self.tree.graph['root_id'])
        
    def render(self):
        fig = plt.figure()
        
        pos = graphviz_layout(self.tree, prog='dot')
        node_labels = {node: node for node in self.tree.nodes}
        nx.draw_networkx_nodes(self.tree,
                               pos,
                               label=node_labels)
        nx.draw_networkx_edges(self.tree,
                               pos)
        
        nx.draw_networkx_labels(self.tree, pos, labels=node_labels)
        
        plt.show()

Consider we have 4 possible samples with weights 1, 2, 3, and 4:

In [None]:
samples = [214, 342, 42, 123]
weights = [1, 4, 2, 3]
sample_to_weight = {sample: weight for sample, weight in zip(samples, weights)}
print(f'Possible samples and their corresponding sampling weight: {sample_to_weight}')

In [None]:
sum_tree = SumTree(samples, weights)
sum_tree.render()
for node in sum_tree.tree.nodes:
    print(f'Node {node} value: {sum_tree.tree.nodes[node]["value"]} | weight: {sum_tree.tree.nodes[node]["weight"]}')
print(f'Root node id: {sum_tree.tree.graph["root_id"]}')
print(f'Leaf node ids: {sum_tree.tree.graph["leaf_ids"]}')

In [None]:
# sample tree
selected_nodes = [sum_tree.sample() for _ in range(10000)]
# print(selected_nodes)

fig = sns.histplot(selected_nodes, stat='probability')
plt.xlabel('Leaf node id')
plt.ylabel('Probability')
plt.show()

In [None]:
samples = range(3)
weights = [np.random.uniform() for _ in range(len(samples))]
sample_to_weight = {sample: weight for sample, weight in zip(samples, weights)}
print(f'Possible samples and their corresponding sampling weight: {sample_to_weight}')

sum_tree = SumTree(samples, weights, debug_mode=False)
sum_tree.render()
for node in sum_tree.tree.nodes:
    print(f'Node {node} value: {sum_tree.tree.nodes[node]["value"]} | weight: {sum_tree.tree.nodes[node]["weight"]}')
print(f'Root node id: {sum_tree.tree.graph["root_id"]}')
print(f'Leaf node ids: {sum_tree.tree.graph["leaf_ids"]}')

# sample tree
selected_nodes = [sum_tree.sample() for _ in range(10000)]
# print(selected_nodes)

fig = sns.histplot(selected_nodes, stat='probability')
plt.xlabel('Leaf node id')
plt.ylabel('Probability')
plt.show()