In [1]:
from ExperienceBuffer import ExpBuf, WeightedExpBuf
from collections import defaultdict
import random

In [2]:
def make_hist(iterable, hasher = hash):
    hist = defaultdict(int)
    for i in iterable:
        hist[hasher(i)] += 1
    return dict(hist)

In [3]:
states = ['a', 'b', 'c', 'd']
actions = [1, 2, 3, 4]
rewards = [0, 1, 0, -1]
next_states = ['b', 'c', 'd', 'e']
is_terminals = [False, False, False, True]

### Test the plain experience buffer.

In [4]:
ebuf = ExpBuf(3)  # keep small since lots of hand written stuff.
for s, a, r, ns, it in zip(states, actions, rewards, next_states, is_terminals):
    ebuf.append(s, a, r, ns, it)
    
eles = list()
for i in range(900):
    eles.append(ebuf.sample(1))
    
hist = make_hist(eles, lambda exp: str([e[0] for e in exp]))

exp_keys = {"['b', 2, 1, 'c', True]",
            "['c', 3, 0, 'd', True]",
            "['d', 4, -1, 'e', False]"}
assert set(hist.keys()) == exp_keys, \
    'Unexpected experiences in sample ' + str(hist.keys())

for k, count in hist.items():
    kv_str = k + " : " + str(count)
    assert abs(count - 300) < 50, "Non uniform sampling: " + kv_str
    print(kv_str)

['b', 2, 1, 'c', True] : 338
['d', 4, -1, 'e', False] : 272
['c', 3, 0, 'd', True] : 290


### Weighted Experience Buffer

In [5]:
wbuf = WeightedExpBuf(3)  # keep small since lots of hand written stuff.
for s, a, r, ns, it in zip(states, actions, rewards, next_states, is_terminals):
    wbuf.append(s, a, r, ns, it)
    
weights = [0, .5, 3.4]  # b, c, d
wbuf.update_weights([1, 2, 3], weights)  # indices wrap around for ring
net_weight = sum(weights)

eles = list()
for i in range(1000):
    eles.append(wbuf.sample(1))
hist = make_hist(eles, lambda exp: exp[1][0])  # Use the state as hash

assert set(hist.keys()) == {'c', 'd'}, str(hist.keys())
assert abs(hist['c']/1000 - .5/net_weight) < .1, \
    'Improper sampling: ' + str(hist)
assert abs(hist['d']/1000 - 3.4/net_weight) < .1, \
    'Improper sampling: ' + str(hist)
print(hist)

{'c': 123, 'd': 877}


Change the weight of 'b' to be 1.3 and retest sampling 

In [6]:
wbuf.update_weights([1], [1.3]) # b's weight goes from 0 ==> 1.3
net_weight += 1.3

eles = list()
for i in range(1000):
    eles.append(wbuf.sample(1))
hist = make_hist(eles, lambda exp: exp[1][0])  # Use the state as hash

assert set(hist.keys()) == {'b', 'c', 'd'}, str(hist.keys())
assert abs(hist['b']/1000 - 1.3/net_weight) < .1, \
    'Improper sampling: ' + str(hist)
assert abs(hist['c']/1000 - .5/net_weight) < .1, \
    'Improper sampling: ' + str(hist)
assert abs(hist['d']/1000 - 3.4/net_weight) < .1, \
    'Improper sampling: ' + str(hist)
print(hist)

{'c': 105, 'b': 234, 'd': 661}
