In [1]:
from collections import defaultdict
import random, os, sys
import numpy as np

# Not a python script so can't use __file__ like elsewhere.
parent_dir = os.path.dirname(os.getcwd())
if parent_dir not in sys.path:
    sys.path.insert(1, parent_dir)
from utils.ExperienceBuffer import ExpBuf, WeightedExpBuf


In [2]:
def make_hist(iterable, hasher = hash):
    hist = defaultdict(int)
    for i in iterable:
        hist[hasher(i)] += 1
    return dict(hist)

In [3]:
states = ['a', 'b', 'c', 'd']
actions = [1, 2, 3, 4]
rewards = [0, 1, 0, -1]
next_states = ['b', 'c', 'd', 'e']
is_terminals = [False, False, False, True]

### Test the plain experience buffer.

In [4]:
ebuf = ExpBuf(3)  # keep small since lots of hand written stuff.
for s, a, r, ns, it in zip(states, actions, rewards, next_states, is_terminals):
    ebuf.append(s, a, r, ns, it)
    
eles = list()
for i in range(900):
    eles.append(ebuf.sample(1))
    
hist = make_hist(eles, lambda exp: str([e[0] for e in exp]))

exp_keys = {"['b', 2, 1, 'c', True]",
            "['c', 3, 0, 'd', True]",
            "['d', 4, -1, 'e', False]"}
assert set(hist.keys()) == exp_keys, \
    'Unexpected experiences in sample ' + str(hist.keys())

for k, count in hist.items():
    kv_str = k + " : " + str(count)
    assert abs(count - 300) < 50, "Non uniform sampling: " + kv_str
    print(kv_str)

['b', 2, 1, 'c', True] : 313
['c', 3, 0, 'd', True] : 261
['d', 4, -1, 'e', False] : 326


### Weighted Experience Buffer

In [9]:
a = .4
b = .5
b_f = 1
b_anneal = 5
e = .01
wbuf = WeightedExpBuf(capacity=3, alpha=a, beta_i=b, beta_f=b_f,
                      beta_anneal=b_anneal, weight_offset=e)
for s, act, r, ns, term in zip(states, actions, rewards, next_states, is_terminals):
    wbuf.append(s, act, r, ns, term)
    
raw_weights = np.array([0, .5, 3.4])  # b, c, d
tot_weight = sum((raw_weights + e) ** a)
wbuf.update_weights([1, 2, 3], raw_weights)  # indices wrap around for ring
assert abs(tot_weight - wbuf.total_weight) < 1e-12,\
    'expected=' + str(tot_weight) + ' actual=' + str(wbuf.total_weight)

Use batch_size=1, since the way we sample, sample_n_subsets, breaks down the elements in the sum tree. We want this effect for actual sampling when experiences are commoditized (num_exp >> batch_size and weigh_exp << total_weight) but here that would throw off the numbers.

In [8]:
batch_size = 1
n_samples = 1000
assert n_samples % batch_size == 0,\
    "Not a general requirement for wbuf"

index, state, action, reward, next_state, not_terminal, IS_weight = [], [], [], [], [], [], []
for i in range(n_samples // batch_size):
    ids, s, acts, r, ns, nt, IS = wbuf.sample(batch_size)
    index += list(ids)
    state += list(s)
    action += list(acts)
    reward += list(r)
    next_state += list(ns)
    IS_weight += list(IS)
    
hist = make_hist(state, lambda exp: exp)  # Use the state as hash

assert set(hist.keys()) == {'b', 'c', 'd'}, str(hist.keys())

expected = lambda weight: ((weight + e) ** a) / wbuf.total_weight

assert abs(hist['b']/n_samples - expected(0)) < .1,\
    'count=' + str(hist['b']) + ' expected=' +\
    str(int(expected(0) * n_samples))
assert abs(hist['c']/n_samples - expected(.5)) < .1,\
    'count=' + str(hist['c']) + ' expected=' +\
    str(int(expected(.5) * n_samples))
assert abs(hist['d']/n_samples - expected(3.4)) < .1,\
    'count=' + str(hist['d']) + ' expected=' +\
    str(int(expected(3.4) * n_samples))

print(hist)

{'d': 637, 'c': 306, 'b': 57}


Take 4 samples to get the 3 unique unexperiences transitions, and check that the 4th starts to repeat. Then inspect the values returned.

In [18]:
ids, s, acts, r, ns, nt, IS = wbuf.sample(3)
index += list(ids)
state += list(s)
action += list(acts)
reward += list(r)
next_state += list(ns)
IS_weight += list(IS)

assert wbuf.sample(1)[1] in state
    
print(index)
print(state)
print(action)
print(reward)
print(next_state)
print(IS_weight)

[0, 1, 2, 1, 0, 1, 2]
['d', 'b', 'c', 'b', 'd', 'b', 'c']
[4, 2, 3, 2, 4, 2, 3]
[-1, 1, 0, 1, -1, 1, 0]
['e', 'c', 'd', 'c', 'e', 'c', 'd']
[0.04700444715149313, 0.19054607179632474, 0.07416229446542934, 0.19054607179632474, 0.04700444715149313, 0.19054607179632474, 0.07416229446542934]


Use batch_size=2, and simply confirm that we see the smoothing effect

In [15]:
batch_size = 2
n_samples = 1000
assert n_samples % batch_size == 0,\
    "Not a general requirement for wbuf"

index, state, action, reward, next_state, not_terminal, IS_weight = [], [], [], [], [], [], []
for i in range(n_samples // batch_size):
    ids, s, acts, r, ns, nt, IS = wbuf.sample(batch_size)
    index += list(ids)
    state += list(s)
    action += list(acts)
    reward += list(r)
    next_state += list(ns)
    IS_weight += list(IS)
    
hist = make_hist(state, lambda exp: exp)  # Use the state as hash

assert set(hist.keys()) == {'b', 'c', 'd'}, str(hist.keys())

expected = lambda weight: ((weight + e) ** a) / wbuf.total_weight

assert hist['b']/n_samples > expected(0),\
    'count=' + str(hist['b']) + ' expected=' +\
    str(int(expected(0) * n_samples))
assert hist['c']/n_samples > expected(.5),\
    'count=' + str(hist['c']) + ' expected=' +\
    str(int(expected(.5) * n_samples))
assert hist['d']/n_samples < expected(3.4),\
    'count=' + str(hist['d']) + ' expected=' +\
    str(int(expected(3.4) * n_samples))

print(hist)

{'d': 500, 'c': 424, 'b': 76}


Change the weight of 'b' to be 1.3 and retest sampling 

In [None]:
wbuf.update_losses([1], [1.3]) # b's weight goes from 0 ==> 1.3
net_weight += 1.3

eles = list()
for i in range(1000):
    eles.append(wbuf.sample(1))
hist = make_hist(eles, lambda exp: exp[1][0])  # Use the state as hash

assert set(hist.keys()) == {'b', 'c', 'd'}, str(hist.keys())
assert abs(hist['b']/1000 - 1.3/net_weight) < .1, \
    'Improper sampling: ' + str(hist)
assert abs(hist['c']/1000 - .5/net_weight) < .1, \
    'Improper sampling: ' + str(hist)
assert abs(hist['d']/1000 - 3.4/net_weight) < .1, \
    'Improper sampling: ' + str(hist)
print(hist)