In [1]:
import numpy as np
from collections import defaultdict
from RingBuffer import RingBuf, WeightedRingBuf

In [2]:
def make_hist(iterable):
    hist = defaultdict(int)
    for i in iterable:
        hist[i] += 1
    return dict(hist)

### Test plain ring buffer.

Test that the RingBuf overwrites entries; 'd' should replace 'a'. 



In [3]:
rbuf = RingBuf(3)
rbuf.append('a')
rbuf.append('b')
rbuf.append('c')
rbuf.append('d')

(0, 'a')

Check basic interactions with the RingBuf:
- iteration
- getitem with: list, tuple, set, slice, index

In [4]:
# Iteration
for i, ele in enumerate(rbuf):
    assert ele == ('d', 'b', 'c')[i],\
        "At index: " + str(i) + " Got ele: " + ele
    
# single index getitem
assert rbuf[0] == 'd', "Failed single index lookup"

# getitem with list
assert rbuf[[0, 1]] == ['d', 'b'], "Gailed list of indices lookup"

# getitem with tuple
assert rbuf[(1, 2)] == ['b', 'c'], "Failed tuple of indices lookup"

# getitem with set
assert rbuf[{1, 1, 1}] == ['b'], "Failed set lookup"

# getitem with slice
assert rbuf[::2] == ['d', 'c'], "Failed slice lookup"

Check list accessing with wraparound indexing.

In [5]:
res = rbuf[[0, 1, 2, 3]]
print(res)
assert res == ['d', 'b', 'c', 'd'], 'ERROR'

['d', 'b', 'c', 'd']


Try to sample from the buffer. Do this many times to see that each element appears about the same number of times.

In [6]:
eles = list()
for i in range(900):
    eles += rbuf.sample(1)
hist = make_hist(eles)
print(hist)
for k,v in hist.items():
    assert abs(v - 300) < 50, "ERROR: Sampling seems non-uniform."

{'d': 334, 'c': 278, 'b': 288}


### Test weighted ring buffer.

Requires a class that has some writeable property weight.

In [7]:
class W():
    def __init__(self, weight, name):
        self.weight = weight
        self.name = name
    def __str__(self):
        return 'weight=' + str(self.weight) + ' name=' + str(self.name)

w_list = (W(.01, 'a'), W(0, 'b'), W(.5, 'c'), W(3.4, 'd'), W(.4, 'e'),
          W(2, 'f'), W(1.1, 'g'), W(.1, 'h'), W(0, 'i'), W(5, 'j'))

Test that the WeightedRingBuf overwrites entries; 'j' should replace 'a'.

Also show that can handle list accessing with wraparound indexing.

In [8]:
wbuf = WeightedRingBuf(9)
for ele in w_list:
    wbuf.append(ele)
    
assert wbuf.min_weight == .1, wbuf.min_weight

# Check that the wbuf contains ['j', 'b', 'c', ..., 'i']
expected_eles = [w_list[-1], *w_list[1:(len(w_list)-1)]]
assert wbuf[:] == expected_eles,\
    'expected: ' + [ele.name for ele in expected_eles] +\
    "actual: " + [ele.name for ele in wbuf[:]]

# Check that the total_weight property is correct. Float math so
# tolerance of 1 trillionth
net_weight = sum([ele.weight for ele in wbuf])
assert abs(net_weight - wbuf.total_weight) < 1e-12,\
    'total_weight=' + str(wbuf.total_weight) +\
    ' net_weight=' + str(net_weight)
expected_net_weight = sum([ele.weight for ele in expected_eles])
assert abs(expected_net_weight - wbuf.total_weight) < 1e-12,\
    'total_weight=' + str(wbuf.total_weight) +\
    ' expected_net_weight=' + str(expected_net_weight)
print('total_weight =', wbuf.total_weight)

total_weight = 12.5


Test that sampling is done in proportion to the weights of the elements. Sampling returns a unique set of ids, so we repeatedly sample.

In [9]:
ids = list()
n_samples = 1000
batch_size = 2
assert n_samples % batch_size == 0,\
    "Not a general requirement for wbuf"
for i in range(n_samples // batch_size):
    ids += wbuf.sample(batch_size)
hist = make_hist([w.name for w in wbuf[ids]])

assert 'a' not in hist.keys(), "'a' should be overwritten"

# In the above cell we confirmed that wbuf has the elements we want
for ele in wbuf:
    if ele.name in {'b', 'i'}:
        assert ele.name not in hist.keys(),\
            "'b' and 'i' have 0 weight so shouldn't be sampled." +\
            str(hist)
        continue
        
    name = hist[ele.name]
    weight = ele.weight
    assert abs(name/n_samples - weight/wbuf.total_weight) < .1,\
        'Improper sampling for W(' + ele.name + ', ' +\
        str(ele.weight) + ') with total_weight=' +\
        str(wbuf.total_weight) + ': ' + str(hist)

print(hist)

{'g': 111, 'h': 9, 'e': 47, 'd': 269, 'j': 342, 'c': 47, 'f': 175}


sample_n_subsets should generally look the same if no single weight represents a significant portion of total_weight

In [10]:
ids = list()
n_samples = 1000
batch_size = 2
assert n_samples % batch_size == 0,\
    "Not a general requirement for wbuf"
for i in range(n_samples // batch_size):
    ids += wbuf.sample_n_subsets(batch_size)
hist = make_hist([w.name for w in wbuf[ids]])

assert 'a' not in hist.keys(), "'a' should be overwritten"

# In the above cell we confirmed that wbuf has the elements we want
for ele in wbuf:
    if ele.name in {'b', 'i'}:
        assert ele.name not in hist.keys(),\
            "'b' and 'i' have 0 weight so shouldn't be sampled." +\
            str(hist)
        continue
        
    name = hist[ele.name]
    weight = ele.weight
    assert abs(name/n_samples - weight/wbuf.total_weight) < .1,\
        'Improper sampling for W(' + ele.name + ', ' +\
        str(ele.weight) + ') with total_weight=' +\
        str(wbuf.total_weight) + ': ' + str(hist)

print(hist)

{'c': 48, 'h': 5, 'e': 30, 'd': 236, 'j': 400, 'g': 101, 'f': 180}


Test Exclusion

In [11]:
ids = list()
for i in range(n_samples):
    # 'b' and 'i' have 0 weight
    # Manually exclude 'd', j', 'h', 
    ids += wbuf.sample(1, exclude = {0, 3, 7})
hist = make_hist([w.name for w in wbuf[ids]])

assert set(hist.keys()) == {'c', 'e', 'f', 'g'},\
    'Exclusion failure: ' + str(hist)
print(hist)

{'c': 111, 'g': 276, 'f': 510, 'e': 103}


Change the weight of 'j' to be over 50% of the total_weight and check that the tree gets updated. This mutates the tree and throws off other tests since it has an extreme outlier.

In [12]:
old_total_weight = wbuf.total_weight
wbuf.update_weight(0, wbuf[0].weight + wbuf.total_weight)
assert abs(2*old_total_weight - wbuf.total_weight) < 1e-12,\
    'total_weight=' + str(wbuf.total_weight) +\
    ' old_total_weight=' + str(old_total_weight)
print('total_weight =', wbuf.total_weight)

total_weight = 25.0


To test sample_n_subset no check that if sampling with batch_size of 2 'j' will be selected every time, since it represents the entire first 1/2 of the weights. I feel like I may be testing implementation details and not public interface...

In [13]:
ids = list()
for i in range(n_samples // 2):
    ids += wbuf.sample_n_subsets(2)
hist = make_hist([w.name for w in wbuf[ids]])

assert hist['j'] == 500
print(hist)

{'g': 84, 'h': 6, 'e': 25, 'd': 207, 'j': 500, 'c': 27, 'f': 151}


For regular sampling it is possible for neither element to be 'j', so cannot possible be more than 500, but can be less. (This could fail but with large enough numbers it probably won't be selected once...)

In [14]:
ids = list()
for i in range(n_samples // 2):
    ids += wbuf.sample(2)
hist = make_hist([w.name for w in wbuf[ids]])

assert hist['j'] <= n_samples // 2
print(hist)

{'g': 72, 'h': 6, 'e': 37, 'd': 233, 'j': 467, 'c': 45, 'f': 140}


In [15]:
wbuf.update_weight(1, 3)
wbuf.update_weight(8, 4)
assert wbuf.min_weight == .1,\
    'min_weight=' + str(wbug.min_weight) +\
    'all_weights=' + str([ele.weight for ele in wbuf])
wbuf.min_weight

0.1