In [2]:
%reload_ext autoreload
%autoreload 2
import torch
import torch.nn.functional as F 
import random
import numpy as np
from EXITrl.helpers import print_weight_size, copy_params, update_params, ExperienceReplay, convert_to_tensor
import gym

### Experience Replay

In [25]:
memory = ExperienceReplay(3, 3)
memory.remember([1,10],2,3,False)
memory.remember([11,110],22,33, True)
memory.remember([111,1110],222,333, False)
memory.remember([1111,11110],2222,3333, True)
a, b, c, d = memory.recall()
a, b, c, d

(tensor([[   11.,   110.],
         [  111.,  1110.],
         [ 1111., 11110.]]),
 tensor([  22.,  222., 2222.]),
 tensor([  33.,  333., 3333.]),
 tensor([1., 0., 1.]))

### Prioritized Experience Replay
https://github.com/rlcode/per<br>
[medium](https://medium.freecodecamp.org/improvements-in-deep-q-learning-dueling-double-dqn-prioritized-experience-replay-and-fixed-58b130cc5682#8dd3)<br>
[paper](https://arxiv.org/pdf/1511.05952.pdf)<br>
[more explaination on stackexchange](https://datascience.stackexchange.com/questions/32873/prioritized-replay-what-does-importance-sampling-really-do)

For the rank-based variant
- we can approximate the cumulative density function with a piecewise
linear function with k segments of equal probability
- we sample a segment, and then sample uniformly among the transitions within it
- choose k to be the size of the minibatch, and sample exactly one transition
from each segment


In [11]:
import numpy


# SumTree
# a binary tree data structure where the parent’s value is the sum of its children
class SumTree:
    write = 0

    def __init__(self, capacity):
        self.capacity = capacity
        self.tree = numpy.zeros(2 * capacity - 1)
        self.data = numpy.zeros(capacity, dtype=object)
        self.n_entries = 0

    # update to the root node
    def _propagate(self, idx, change):
        parent = (idx - 1) // 2

        self.tree[parent] += change

        if parent != 0:
            self._propagate(parent, change)

    # find sample on leaf node
    def _retrieve(self, idx, s):
        left = 2 * idx + 1
        right = left + 1

        if left >= len(self.tree):
            return idx

        if s <= self.tree[left]:
            return self._retrieve(left, s)
        else:
            return self._retrieve(right, s - self.tree[left])

    def total(self):
        return self.tree[0]

    # store priority and sample
    def add(self, p, data):
        idx = self.write + self.capacity - 1

        self.data[self.write] = data
        self.update(idx, p)

        self.write += 1
        if self.write >= self.capacity:
            self.write = 0

        if self.n_entries < self.capacity:
            self.n_entries += 1
        print(' data:', self.data)
        print('t:', self.tree)

    # update priority
    def update(self, idx, p):
        change = p - self.tree[idx]

        self.tree[idx] = p
        self._propagate(idx, change)

    # get priority and sample
    def get(self, s):
        idx = self._retrieve(0, s)
        dataIdx = idx - self.capacity + 1

        return (idx, self.tree[idx], self.data[dataIdx])

In [20]:
capacity = 3
tree = SumTree(capacity)
for i in range(1,5):
    tree.add(10*i, {'a':i})
tree.get(20)

 data: [{'a': 1} 0 0]
t: [10.  0. 10.  0.  0.]
 data: [{'a': 1} {'a': 2} 0]
t: [30. 20. 10. 20.  0.]
 data: [{'a': 1} {'a': 2} {'a': 3}]
t: [60. 50. 10. 20. 30.]
 data: [{'a': 4} {'a': 2} {'a': 3}]
t: [90. 50. 40. 20. 30.]


(3, 20.0, {'a': 2})