In [11]:
import numpy as np
from collections import defaultdict
from random import sample
from ExperienceBuffer import RingBuf, WeightedRingBuf

In [12]:
def simple_hist(iterable):
    hist = defaultdict(int)
    for i in iterable:
        hist[i] += 1
    return dict(hist)

In [13]:
rbuf = RingBuf(3)
rbuf.append('a')
rbuf.append('b')
rbuf.append('c')
rbuf.append('d')
rbuf[[0, 1, 2, 3]]

['d', 'b', 'c', 'd']

In [14]:
# 'a' should never appear since it should be overwritten.
# All others should be around 300 since sampling is uniform
eles = list()
for i in range(900):
    eles += rbuf.sample(1)
simple_hist(eles)

{'b': 325, 'c': 282, 'd': 293}

In [15]:
class W():
    def __init__(self, weight, name):
        self.weight = weight
        self.name = name
    def __str__(self):
        return 'weight=' + str(self.weight) + ' name=' + str(self.name)

In [17]:
wbuf = WeightedRingBuf(3)
wbuf.append(W(10, 'a'))
wbuf.append(W(0, 'b'))
wbuf.append(W(1, 'c'))
wbuf.append(W(10, 'd'))
print(wbuf.tree)
[str(w) for w in wbuf[[0, 1, 2, 3]]]

[array([11.]), array([10.,  1.]), <ExperienceBuffer.RingBuf object at 0x7f82442f1ac8>]


['weight=10 name=d', 'weight=0 name=b', 'weight=1 name=c', 'weight=10 name=d']

In [18]:
# 'a' should never appear since it should be overwritten.
# All others should appear in proportion to their weights (d ~ 1000, c ~ 100)
ids = list()
for i in range(1100):
    ids += wbuf.sample(1)
simple_hist([w.name for w in wbuf[ids]])

{'c': 93, 'd': 1007}

In [19]:
wbuf.update_weight(0, 3)
print(wbuf.tree)
[str(w) for w in wbuf[[0, 1, 2, 3]]]

[array([4.]), array([3., 1.]), <ExperienceBuffer.RingBuf object at 0x7f82442f1ac8>]


['weight=3 name=d', 'weight=0 name=b', 'weight=1 name=c', 'weight=3 name=d']

In [20]:
# d ~ 300 c ~ 100
ids = list()
for i in range(400):
    ids += wbuf.sample(1)
simple_hist([w.name for w in wbuf[ids]])

{'c': 94, 'd': 306}