In [6]:
import numpy as np
from collections import deque, namedtuple
import random

In [7]:
d = deque(maxlen=10000)

In [10]:
%%timeit

for i in range(100000):
    d.append((i, i+1, i+2, i+3))

15.3 ms ± 122 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [9]:
%%timeit

Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))

for i in range(100000):
    t = Transition(i, i+1, i+2, i+3)
    d.append(t)


42.3 ms ± 177 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [11]:
import numpy

class SumTree:
    write = 0

    def __init__(self, capacity):
        self.capacity = capacity
        self.tree = numpy.zeros( 2*capacity - 1 )
        self.data = numpy.zeros( capacity, dtype=object )

    def _propagate(self, idx, change):
        parent = (idx - 1) // 2

        self.tree[parent] += change

        if parent != 0:
            self._propagate(parent, change)

    def _retrieve(self, idx, s):
        left = 2 * idx + 1
        right = left + 1

        if left >= len(self.tree):
            return idx

        if s <= self.tree[left]:
            return self._retrieve(left, s)
        else:
            return self._retrieve(right, s-self.tree[left])

    def total(self):
        return self.tree[0]

    def add(self, p, data):
        idx = self.write + self.capacity - 1

        self.data[self.write] = data
        self.update(idx, p)

        self.write += 1
        if self.write >= self.capacity:
            self.write = 0

    def update(self, idx, p):
        change = p - self.tree[idx]

        self.tree[idx] = p
        self._propagate(idx, change)

    def get(self, s):
        idx = self._retrieve(0, s)
        dataIdx = idx - self.capacity + 1

        return (idx, self.tree[idx], self.data[dataIdx])

In [245]:
replay_memory = SumTree(10)

In [18]:
for i in range(10):
    rb.add(i,None)

In [246]:
P = [6, 48, 31, 26, 49, 43, 93, 74, 79, 13]

for p in P:
    replay_memory.add(p, None)

In [247]:
from collections import defaultdict
p_sum = defaultdict(int)

for i in range(10000):
    s = random.randint(0, replay_memory.total() - 1)
    _, p, _ = replay_memory.get(s)
    p_sum[int(p)] += 1

for p in P:
    print("{}, {}, {}".format(p, p / sum(P), p_sum[p] / 10000))

6, 0.012987012987012988, 0.014
48, 0.1038961038961039, 0.1071
31, 0.0670995670995671, 0.0684
26, 0.05627705627705628, 0.0596
49, 0.10606060606060606, 0.1
43, 0.09307359307359307, 0.0859
93, 0.2012987012987013, 0.2049
74, 0.16017316017316016, 0.1583
79, 0.170995670995671, 0.1737
13, 0.02813852813852814, 0.0281


In [177]:
replay_memory.get(10000)

(14, 43.0, None)

In [241]:
replay_memory = SumTree(3)
replay_memory.add(1,None)
replay_memory.add(2,None)
replay_memory.add(3,None)



In [242]:
from collections import defaultdict
p_sum = defaultdict(int)

for i in range(10000):
    s = random.randint(0, replay_memory.total() - 1)
    _, p, _ = replay_memory.get(s)
    p_sum[int(p)] += 1

for p in [1,2,3]:
    print("{}, {}, {}".format(p, p / sum([1,2,3]), p_sum[p] / 10000))

1, 0.16666666666666666, 0.0
2, 0.3333333333333333, 0.4966
3, 0.5, 0.5034


In [213]:
replay_memory.get()

(2, 7.0, None)

In [6]:
%%timeit
for index in range(100000):
    if len(l) < 10000:
        l.append(None)  # メモリが満タンでないときは足す
        
    index = (index) % 10000  # 保存するindexを1つずらす

    l[index] = 10

    

12.7 ms ± 61.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [9]:
%%timeit

for _ in range(100000):
    random.sample(d, 256)

15.7 s ± 143 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [15]:
%%timeit

for _ in range(100000):
    np.random.choice(l, 256)

44.4 s ± 173 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
d

In [14]:
np.random.(l,2)

TypeError: random_sample() takes at most 1 positional argument (2 given)