In [131]:
import numpy as np
from collections import defaultdict
from random import sample

In [132]:
def simple_hist(iterable):
    hist = defaultdict(int)
    for i in iterable:
        hist[i] += 1
    return dict(hist)

In [158]:
class RingBuf:
    def __init__(self, capacity):
        """
        Initialize an empty list of capacity elements to be filled up.
        :param capacity: Maximum number of elements to store before wrapping.
        """
        self.capacity = int(capacity)
        self.index = 0
        self.data = [None] * self.capacity

    def __getitem__(self, items):
        """
        Returns the items inside of the buffer. Handles wraparound.
        """
        if isinstance(items, list):
            return [self.data[idx % self.capacity] for idx in items]
        return self.data[items % self.capacity]

    def append(self, element):
        """
        Internal append that will be fronted by the specific buffer that will
        provide a nice interface. Add a new element to the list.
        Overwrites in the order that elements
        were placed on buffer.
        :param element: new element to append.
        :return: the element replaced if wrapping around or None.
        """
        idx = self.index % self.capacity
        old_ele = self.data[idx]
        self.data[idx] = element
        self.index += 1
        return old_ele

    def sample(self, num):
        """
        Randomly sample a set of the elements in the buffer.
        :param num: Number of elements to sample.
        :return: a list of elements.
        """
        ids = sample(range(self.size), num)
        return [self.data[i] for i in ids]

    @property
    def size(self):
        """
        While the buffer is still growing self.index. Once the buffer is full
        the length stays stagnant at capacity while new elements simple overwrite
        the old ones.
        :return:
        """
        return min(self.index, self.capacity)

    def __len__(self):
        return self.size



In [159]:
rbuf = RingBuf(3)
rbuf.append('a')
rbuf.append('b')
rbuf.append('c')
rbuf.append('d')
rbuf[[0, 1, 2, 3]]

['d', 'b', 'c', 'd']

In [160]:
# 'a' should never appear since it should be overwritten.
# All others should be around 1/3 since sampling is uniform
eles = list()
for i in range(900):
    eles += rbuf.sample(1)
simple_hist(eles)

{'b': 313, 'c': 301, 'd': 286}

In [183]:
class WeightedBuf():
    """
    This class is a buffer which can hold elements, but with weighted selection.
    Can ontain any element so long as it has a weight property which is a
    numeric value (ele.weight).

    Implemented as a binary tree where each node holds the sum of the weights
    of its children.
    """

    def __init__(self, capacity):
        self.capacity = capacity
        self.index = 0
        self.tree = self.make_tree(capacity)

    def __getitem__(self, idx):
        """
        Retrieve an item from the underlying buffer.
        :param idx: idx of the element.
        :return:
        """
        return self.tree[-1][idx]

    def append(self, ele):
        """
        Appends an element to the buffer and updates the weights of its
        parent nodes.
        :param ele:
        :return: old element
        """
        weight = ele.weight
        ele.weight = 0  # Will be set in update_weight
        old_ele = self.tree[-1].append(ele)
        
        self.update_weight(self.index, weight)
        
        self.index += 1
        return old_ele

    def make_tree(self, capacity):
        """
        Create a tree with all weights initialized to 0.
        """
        c = 1
        tree = []
        while c < capacity:
            tree.append(np.zeros(c))
            c *= 2
        tree.append(RingBuf(capacity))
        return tree

    def get_leaf(self):
        """
        Randomly select a leaf from the tree based on the weights in the tree.
        Returns the id of the leaf (the index).
        """
        val = np.random.randint(0, self.tree[0][0])
        idx = 0
        for depth in range(1, len(self.tree) - 1):
            left_weight = self.tree[depth][idx]
            if val >= left_weight:
                val -= left_weight
                idx = (idx + 1) * 2
            else:
                idx *= 2
        left_weight = self[idx].weight
        return idx + (val >= left_weight)

    def update_weight(self, index, weight):
        """
        Resets the weight of the element identified by index, then
        goes up each row updating the parent nodes with the new weight.
        :param index: index of the changed experience.
        :param weight: new weight to give the element.
        """
        idx = index % self.capacity
        delta = weight - self[idx].weight
        self[idx].weight = weight
        idx //= 2
        for depth in range(-2, -len(self.tree) - 1, -1):
            self.tree[depth][idx] += delta
            idx //= 2

    def sample(self, num):
        """
        Sample a number of unique experiences. Shouldn't cause too much change
        in the effective weights since we are assuming the batch_size is much
        smaller than the experience buffer.
        :param num: how many experiences to sample.
        Returns a unique set of indices for the leaves in the tree.
        """
        idxs = set()
        while len(idxs) < num:
            idxs.add(self.get_leaf())
        return list(idxs)

In [187]:
class W():
    def __init__(self, weight, name):
        self.weight = weight
        self.name = name
    def __str__(self):
        return 'weight=' + str(self.weight) + ' name=' + str(self.name)

In [188]:
wbuf = WeightedBuf(3)
wbuf.append(W(10, 'a'))
wbuf.append(W(0, 'b'))
wbuf.append(W(1, 'c'))
wbuf.append(W(10, 'd'))
print(wbuf.tree)
[str(w) for w in wbuf[[0, 1, 2, 3]]]

[array([11.]), array([10.,  1.]), <__main__.RingBuf object at 0x7f716681c9b0>]


['weight=10 name=d', 'weight=0 name=b', 'weight=1 name=c', 'weight=10 name=d']

In [189]:
# 'a' should never appear since it should be overwritten.
# All others should appear in proportion to their weights
ids = list()
for i in range(1100):
    ids += wbuf.sample(1)
simple_hist([w.name for w in wbuf[ids]])

{'c': 99, 'd': 1001}

In [190]:
wbuf.update_weight(0, 3)
ids = list()
for i in range(1100):
    ids += wbuf.sample(1)
simple_hist([w.name for w in wbuf[ids]])

{'c': 73, 'd': 1027}