In [1]:
%load_ext cython

In [2]:
%matplotlib inline
import matplotlib.pyplot as plot
import numpy as np
import seaborn as sns

In [3]:
import collections

In [4]:
class ReplayDeque(object):
    def __init__(self, maxlen):
        self.data = collections.deque(maxlen=maxlen)

    def __len__(self):
        return len(self.data)

    def add(self, exp):
        self.data.append(exp)

    def sample(self, n):
        indices = np.random.choice(len(self.data), n)
        samples = [self.data[idx] for idx in indices]
        return indices, samples

In [5]:
class ReplayArray(object):
    def __init__(self, maxlen):
        self.maxlen = maxlen
        self.write = 0
        self.count = 0
        self.data = np.empty((maxlen,), dtype=object)

    def __len__(self):
        return self.count

    def add(self, exp):
        self.data[self.write] = exp
        self.write = (self.write + 1) % self.maxlen
        self.count = min(self.maxlen, self.count + 1)

    def sample(self, n):
        indices = np.random.choice(self.count, n)
        samples = np.empty((n,), dtype=object)
        for i in range(n):
            samples[i] = self.data[indices[i]]
        return indices, samples

In [6]:
%%cython -c=-O2
import numpy as np
cimport numpy as np

cdef class cython_bufffer:
    cdef np.ndarray np_data
    cdef object[:] data
    cdef size_t maxlen
    cdef size_t write
    cdef size_t count
    
    def __init__(self, maxlen):
        self.maxlen = maxlen
        self.write = 0
        self.count = 0
        self.np_data = np.empty((maxlen,), dtype=object)
        self.data = self.np_data
    
    cpdef size_t get_count(self):
        return self.count
    
    cpdef add(self, exp : object):
        self.data[self.write] = exp
        self.write = (self.write + 1) % self.maxlen
        self.count = min(self.maxlen, self.count + 1)
    
    cpdef sample(self, n : size_t):
        cdef np.ndarray indices = np.random.choice(self.count, n).astype(np.int32)
        cdef np.ndarray samples = np.empty((n,), dtype=object)
        cdef size_t i
        for i in range(n):
            samples[i] = self.data[indices[i]]
        return indices, samples

class ReplayCython(object):
    def __init__(self, maxlen):
        self.buffer = cython_bufffer(maxlen)

    def __len__(self):
        return self.buffer.get_count()

    def add(self, exp):
        self.buffer.add(exp)

    def sample(self, n):
        return self.buffer.sample(n)



In [7]:
def run_bench(factory, count, batch_size):
    buf = factory()
    for i in range(count):
        buf.add((i, f'rec_{i}'))
        if len(buf) > batch_size:
            i, s = buf.sample(batch_size)

In [8]:
ff = lambda: ReplayCython(10000)
%timeit run_bench(ff, 100000, 32)
%timeit run_bench(ff, 1000000, 32)

1.95 s ± 29.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
20.2 s ± 663 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [9]:
ff = lambda: ReplayDeque(10000)
%timeit run_bench(ff, 100000, 32)
%timeit run_bench(ff, 1000000, 32)

2.3 s ± 26.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
23.7 s ± 774 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [10]:
ff = lambda: ReplayArray(10000)
%timeit run_bench(ff, 100000, 32)
%timeit run_bench(ff, 1000000, 32)

2.57 s ± 46.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
25.3 s ± 184 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [11]:
ff = lambda: ReplayCython(100000)
%timeit run_bench(ff, 1000000, 32)
%timeit run_bench(ff, 10000000, 32)

22.5 s ± 376 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
4min 6s ± 26.4 s per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [12]:
ff = lambda: ReplayDeque(100000)
%timeit run_bench(ff, 1000000, 32)
%timeit run_bench(ff, 10000000, 32)

49.6 s ± 665 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
8min 23s ± 6.78 s per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [13]:
ff = lambda: ReplayArray(100000)
%timeit run_bench(ff, 1000000, 32)
%timeit run_bench(ff, 10000000, 32)

30 s ± 389 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
5min 1s ± 462 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
