In [1]:
from collections import defaultdict
import random, os, sys
import numpy as np

# Not a python script so can't use __file__ like elsewhere.
parent_dir = os.path.dirname(os.getcwd())
if parent_dir not in sys.path:
    sys.path.insert(1, parent_dir)
from utils.ExperienceBuffer import ExpBuf, WeightedExpBuf


In [2]:
def make_hist(iterable, hasher = hash):
    hist = defaultdict(int)
    for i in iterable:
        hist[hasher(i)] += 1
    return dict(hist)

In [3]:
states = ['a', 'b', 'c', 'd']
actions = [1, 2, 3, 4]
rewards = [0, 1, 0, -1]
next_states = ['b', 'c', 'd', 'e']
is_terminals = [False, False, False, True]

### Test the plain experience buffer.

In [4]:
ebuf = ExpBuf(3)  # keep small since lots of hand written stuff.
for s, a, r, ns, it in zip(states, actions, rewards, next_states, is_terminals):
    ebuf.append(s, a, r, ns, it)
    
eles = list()
for i in range(900):
    eles.append(ebuf.sample(1))
    
hist = make_hist(eles, lambda exp: str([e[0] for e in exp]))

exp_keys = {"['b', 2, 1, 'c', True]",
            "['c', 3, 0, 'd', True]",
            "['d', 4, -1, 'e', False]"}
assert set(hist.keys()) == exp_keys, \
    'Unexpected experiences in sample ' + str(hist.keys())

for k, count in hist.items():
    kv_str = k + " : " + str(count)
    assert abs(count - 300) < 50, "Non uniform sampling: " + kv_str
    print(kv_str)

['b', 2, 1, 'c', True] : 307
['d', 4, -1, 'e', False] : 290
['c', 3, 0, 'd', True] : 303


### Weighted Experience Buffer

In [5]:
a = .4
b = .5
b_f = 1
b_anneal = 5
e = .01
wbuf = WeightedExpBuf(capacity=3, alpha=a, beta_i=b, beta_f=b_f,
                      beta_anneal=b_anneal, weight_offset=e)
effective_weight = lambda raw_weight: (raw_weight + e) ** a
P_select = lambda raw_weight: effective_weight(raw_weight) / wbuf.total_weight

for s, act, r, ns, term in zip(states, actions, rewards, next_states, is_terminals):
    wbuf.append(s, act, r, ns, term)
    
raw_weights = np.array([0, .5, 3.4])  # b, c, d
tot_weight = sum(effective_weight(raw_weights))
wbuf.update_weights([1, 2, 3], raw_weights)  # indices wrap around for ring
assert abs(tot_weight - wbuf.total_weight) < 1e-12,\
    'expected=' + str(tot_weight) + ' actual=' + str(wbuf.total_weight)

The first 3 samples should be 1 from each experience.

In [6]:
index, state, action, reward, next_state, not_terminal, IS_weight = [], [], [], [], [], [], []
ids, s, acts, r, ns, nt, IS = wbuf.sample(3)
index += list(ids)
state += list(s)
action += list(acts)
reward += list(r)
next_state += list(ns)
not_terminal += list(nt)
IS_weight += list(IS)

assert set(ids) == {0, 1, 2}, ids
assert set(state) == {'b', 'c', 'd'}, state
assert set(action) == {2, 3, 4}, action
assert set(reward) == {1, 0, -1}, reward
assert set(next_state) == {'c', 'd', 'e'}, next_state
assert set(not_terminal) == {True, True, False}, not_terminal
assert set(IS_weight) == {1, 1, 1},\
    'weights should be 1 for first sampling: actual=' + str(IS_weight)

Use batch_size=1, since the way we sample, sample_n_subsets, breaks down the elements in the sum tree. We want this effect for actual sampling when experiences are commoditized (num_exp >> batch_size and weigh_exp << total_weight) but here that would throw off the numbers.

In [7]:
batch_size = 1
n_samples = 1000
assert n_samples % batch_size == 0,\
    "Not a general requirement for wbuf"

def check_count(hist, key, raw_weight):
    exp_rate = P_select(raw_weight)
    assert abs(hist[key] / n_samples - exp_rate) < .1,\
        'count=' + str(hist[key]) + ' expected=' +\
        str(int(exp_rate * n_samples))

In [8]:
index, state, action, reward, next_state, not_terminal, IS_weight = [], [], [], [], [], [], []
for i in range(n_samples // batch_size):
    ids, s, acts, r, ns, nt, IS = wbuf.sample(batch_size)
    index += list(ids)
    state += list(s)
    action += list(acts)
    reward += list(r)
    next_state += list(ns)
    IS_weight += list(IS)
    
hist = make_hist(state, lambda exp: exp)  # Use the state as hash
print(hist)

assert set(hist.keys()) == {'b', 'c', 'd'}, str(hist.keys())

check_count(hist, 'b', 0)
check_count(hist, 'c', .5)
check_count(hist, 'd', 3.4)

{'c': 281, 'd': 666, 'b': 53}


In [9]:
selection_rates = P_select(raw_weights)
raw_IS_weights = ((wbuf.capacity * selection_rates) ** -wbuf.beta)
exp_IS_weights = set(raw_IS_weights / max(raw_IS_weights))

assert set(IS_weight) == exp_IS_weights,\
    'expected=' + str(exp_IS_weights) +\
    ' actual=' + str(set(IS_weight))

Use batch_size=2, and simply confirm that we see the smoothing effect

In [10]:
batch_size = 2
n_samples = 1000
assert n_samples % batch_size == 0,\
    "Not a general requirement for wbuf"

index, state, action, reward, next_state, not_terminal, IS_weight = [], [], [], [], [], [], []
for i in range(n_samples // batch_size):
    ids, s, acts, r, ns, nt, IS = wbuf.sample(batch_size)
    index += list(ids)
    state += list(s)
    action += list(acts)
    reward += list(r)
    next_state += list(ns)
    IS_weight += list(IS)
    
hist = make_hist(state, lambda exp: exp)  # Use the state as hash
print(hist)

assert set(hist.keys()) == {'b', 'c', 'd'}, str(hist.keys())

expected = lambda weight: ((weight + e) ** a) / wbuf.total_weight

assert hist['b']/n_samples > expected(0),\
    'count=' + str(hist['b']) + ' expected=' +\
    str(int(expected(0) * n_samples))
assert hist['c']/n_samples > expected(.5),\
    'count=' + str(hist['c']) + ' expected=' +\
    str(int(expected(.5) * n_samples))
assert hist['d']/n_samples < expected(3.4),\
    'count=' + str(hist['d']) + ' expected=' +\
    str(int(expected(3.4) * n_samples))

{'c': 420, 'd': 500, 'b': 80}


Change the weight of 'b' to be 1.3 and retest sampling. Now that 'd' doesn't dominate as much check that the sample_n_subset is spread out nicely.

In [11]:
wbuf.update_weights(np.array([1]), np.array([1.3])) # b's weight goes from 0 ==> 1.3

eles = list()
for i in range(1000):
    eles.append(wbuf.sample(1))
hist = make_hist(eles, lambda exp: exp[1][0])  # Use the state as hash
print(hist)

assert set(hist.keys()) == {'b', 'c', 'd'}, str(hist.keys())

assert abs(hist['b']/n_samples - P_select(1.3)) < .1,\
    'count=' + str(hist['b']) + ' expected=' +\
    str(int(expected(1.3) * n_samples))
assert abs(hist['c']/n_samples - P_select(.5)) < .1,\
    'count=' + str(hist['b']) + ' expected=' +\
    str(int(expected(.5) * n_samples))
assert abs(hist['d']/n_samples - P_select(3.4)) < .1,\
    'count=' + str(hist['b']) + ' expected=' +\
    str(int(expected(3.4) * n_samples))

{'b': 314, 'd': 457, 'c': 229}


Update weights a whole bunch and check that the beta stays at the final value.

In [12]:
for i in range(10):
    wbuf.update_weights([1], np.array([1.3]))
assert abs(b_f - wbuf.beta_f) < 1e-12,\
    'expected=' + str(b_f) + ' actual=' + str(wbuf.beta_f)