In [1]:
from collections import namedtuple
import numpy as np

In [31]:
class RingBuf:
    def __init__(self, capacity):
        """
        Initialize an empty list of capacity elements to be filled up.
        :param capacity: Maximum number of elements to store before wrapping.
        """
        self.capacity = int(capacity)
        self.index = 0
        self.data = [None] * self.capacity

    def __getitem__(self, items):
        """
        Returns the items inside of the buffer. Handles wraparound.
        """
        if isinstance(items, list):
            return [self.data[idx % self.capacity] for idx in items]
        return self.data[items % self.capacity]

    def append(self, element):
        """
        Internal append that will be fronted by the specific buffer that will
        provide a nice interface. Add a new element to the list.
        Overwrites in the order that elements
        were placed on buffer.
        :param element: new element to append.
        :return: the element replaced if wrapping around or None.
        """
        idx = self.index % self.capacity
        old_ele = self.data[idx]
        self.data[idx] = element
        self.index += 1
        return old_ele

    def sample(self, num):
        """
        Randomly sample a set of the elements in the buffer.
        :param num: Number of elements to sample.
        :return: a list of elements.
        """
        ids = [np.random.randint(0, self.size) for i in range(num)]
        return [self.data[i] for i in ids]

    @property
    def size(self):
        """
        While the buffer is still growing self.index. Once the buffer is full
        the length stays stagnant at capacity while new elements simple overwrite
        the old ones.
        :return:
        """
        return min(self.index, self.capacity)

    def __len__(self):
        return self.size



In [32]:
class WeightedBuf():
    """
    This class is a buffer which can hold elements, but with weighted selection.
    Can ontain any element so long as it has a weight property which is a
    numeric value (ele.weight).

    Implemented as a binary tree where each node holds the sum of the weights
    of its children.
    """

    def __init__(self, capacity):
        self.capacity = capacity
        self.index = 0
        self.tree = self.make_tree(capacity)

    def __getitem__(self, idx):
        """
        Retrieve an item from the underlying buffer.
        :param idx: idx of the element.
        :return:
        """
        return self.tree[-1][idx]

    def append(self, ele):
        """
        Appends an element to the buffer and updates the weights of its
        parent nodes.
        :param ele:
        :return: old element
        """
        old_ele = self.tree[-1].append(ele)
        if old_ele is not None:
            delta = ele.weight - old_ele.weight
            self.update_weight(self.index % self.capacity, delta)
            self.index += 1
        return old_ele

    def make_tree(self, capacity):
        """
        Create a tree with all weights initialized to 0.
        """
        c = 1
        tree = []
        while c < capacity:
            tree.append(np.zeros(c))
            c *= 2
        tree.append(RingBuf(capacity))
        return tree

    def get_leaf(self):
        """
        Randomly select a leaf from the tree based on the weights in the tree.
        Returns the id of the leaf (the index).
        """
        val = np.random.randint(0, self.tree[0][0])
        idx = 0
        for depth in range(1, len(self.tree) - 1):
            left_weight = self.tree[depth][idx]
            if val >= left_weight:
                val -= left_weight
                idx = (idx + 1) * 2
            else:
                idx *= 2
        left_weight = self[idx].weight
        return idx + (val >= left_weight)

    def update_weight(self, idx, delta):
        """
        Go up each row updating the parent nodes with the new weight.
        :param idx: index of the changed experience.
        :param delta: change in weights for that index.
        """
        self[idx].weight += delta
        idx //= 2
        for depth in range(-2, -len(self.tree) - 1, -1):
            self.tree[depth][idx] += delta
            idx //= 2

    def sample(self, num):
        """
        Sample a number of unique experiences. Shouldn't cause too much change
        in the effective weights since we are assuming the batch_size is much
        smaller than the experience buffer.
        :param num: how many experiences to sample.
        Returns a unique set of indices for the leaves in the tree.
        """
        idxs = set()
        while len(idxs) < num:
            idxs |= {self._get_leaf()}
        return list(idxs)


In [37]:
rbuf = RingBuf(3)
rbuf.append(0)
rbuf.append(1)
rbuf.append(2)
rbuf.append(3)
rbuf[[0, 1, 2, 3]]

[3, 1, 2, 3]

In [45]:
class W():
    def __init__(self, weight, name):
        self.weight = weight
        self.name = name
    def __str__(self):
        return 'weight=' + str(weight) + ' name=' + str(name)

In [47]:
wbuf = WeightedBuf(3)
wbuf.append(W(10, 'a'))
wbuf.append(W(0, 'b'))
wbuf.append(W(1, 'c'))
wbuf.append(W(10, 'd'))
wbuf[[0, 1, 2, 3]]

[<__main__.W object at 0x7f7166895d68>, <__main__.W object at 0x7f7166895cf8>, <__main__.W object at 0x7f71945e0cc0>, <__main__.W object at 0x7f7166895d68>]


True