# Notes

- What we need for a good replay buffer: fixed-size, FIFO behavior, O(1) insertion at the end, O(1) sampling. Limited memory footprint.
- [deque](https://docs.python.org/3/library/collections.html#collections.deque) has O(1) insertion time at the end, but O(n) access time (which made me doubt its ability to make a good replay buffer and try an np.array-based solution)
- When we draw a mini-batch for DQN, it would be best to receive (separately) an array of states only, an array of actions, an array of rewards, an array of next states and a last array of "done", that we can pass these to the Q-network. What's the best way of doing that? Store them separately?

# Setting the frame

In [3]:
import gym
from gym import logger
import numpy as np
logger.set_level(gym.logger.DISABLED)
import torch

In [4]:
cartpole = gym.make('CartPole-v1')

In [5]:
state = cartpole.reset()
action = cartpole.action_space.sample()
next_state, reward, done, _ = cartpole.step(action)

In [6]:
replay_buffer_size = int(1e6)
nb_samples = int(2e6)
nb_batches = int(1e4)
batch_size = 50

# Testing functions

In [7]:
from tqdm import trange

def test_insertion_tqdm(buffer, nb_samples):
    state = cartpole.reset()
    for _ in trange(nb_samples):
        buffer.append(state, action, reward, next_state, done)

def test_sampling_tqdm(buffer, nb_batches):
    for _ in trange(nb_batches):
        buffer.sample(batch_size)

In [8]:
import timeit
import gc

def test_insertion_timeit(buffer, nb_samples):
    print("Insertion of", nb_samples, "samples:", 
      timeit.timeit('memory.append(state,action,reward,next_state,done)', 
                    globals=globals(), 
                    setup='gc.enable()', 
                    number=nb_samples))

def test_sampling_timeit(buffer, nb_batches):
    print("Sampling of", nb_batches, "batches:",
          timeit.timeit('memory.sample(batch_size)', 
                        globals=globals(), 
                        setup='gc.enable()', 
                        number=nb_batches))

# Replay buffer classes

In [9]:
from collections import deque, namedtuple
Transition = namedtuple('Transition', 
                        ('state', 'action', 'reward', 'next_state', 'done'))

# But for the sake of the exercise, we will wrap this in a dedicated class.

import random
    
class ReplayBuffer1(object):
    def __init__(self, capacity):
        self.memory = deque(maxlen=capacity)
    def append(self, *args):
        self.memory.append(Transition(*args))
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)
    def __len__(self):
        return len(self.memory)
    def capacity(self):
        return self.memory.maxlen
    
class ReplayBuffer2(object):
    def __init__(self, capacity):
        self.memory = deque(maxlen=capacity)
    def append(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)
    def __len__(self):
        return len(self.memory)
    def capacity(self):
        return self.memory.maxlen
    
class ReplayBuffer3(deque):
    def __init__(self, capacity):
        super().__init__(maxlen=capacity)
    def append(self, state, action, reward, next_state, done):
        super().append((state, action, reward, next_state, done))
    def sample(self, batch_size):
        return random.sample(self, batch_size)
    
class ReplayBuffer4(deque):
    def __init__(self, capacity):
        super().__init__(maxlen=capacity)
    def append(self, state, action, reward, next_state, done):
        super().append(Transition(state, action, reward, next_state, done))
    def sample(self, batch_size):
        return random.sample(self, batch_size)

class ReplayBuffer5(object):
    def __init__(self, capacity):
        self.capacity = capacity # capacity of the buffer
        self.data = np.empty(capacity, dtype=Transition)
        self.index = 0 # index of the next cell to be filled
        self.size = 0 # number of elements in the buffer
        
    def append(self, *args):
        self.data[self.index] = Transition(*args)
        self.index = (self.index + 1) % self.capacity
        if self.size < self.capacity:
            self.size+=1
        
    def sample(self, batch_size):
        #indices = np.random.choice(self.size, size=batch_size, replace=False)
        #return self.memory[indices]
        return np.random.choice(self.data[:self.size], size=batch_size, replace=False)
    
    def __len__(self):
        return self.size
    
class ReplayBuffer6(object):
    def __init__(self, capacity):
        self.data = deque(maxlen=capacity)
    def append(self, state, action, reward, next_state, done):
        self.data.append((state, action, reward, next_state, done))
    def sample(self, batch_size):
        batch = random.sample(self.data, batch_size)
        return list(map(np.array, list(zip(*batch))))
    def __len__(self):
        return len(self.data)
    def capacity(self):
        return self.data.maxlen

class ReplayBuffer7(deque):
    def __init__(self, capacity):
        super().__init__(maxlen=capacity)
    def append(self, state, action, reward, next_state, done):
        super().append(Transition(state, action, reward, next_state, done))
    def sample(self, batch_size):
        batch = random.sample(self, batch_size)
        return list(map(np.array, list(zip(*batch))))
    def capacity(self):
        return self.maxlen
    
class ReplayBuffer8(object):
    def __init__(self, capacity):
        self.data = deque(maxlen=capacity)
    def append(self, state, action, reward, next_state, done):
        self.data.append((state, action, reward, next_state, done))
    def sample(self, batch_size):
        batch = random.sample(self.data, batch_size)
        return list(map(torch.Tensor, list(zip(*batch))))
    def __len__(self):
        return len(self.data)
    def capacity(self):
        return self.data.maxlen

In [20]:
class ReplayBuffer9:
    def __init__(self, capacity):
        self.capacity = capacity # capacity of the buffer
        self.data = []
        self.index = 0 # index of the next cell to be filled
        
    def append(self, s, a, r, s_, d):
        if len(self.data) < self.capacity:
            self.data.append(None)
        self.data[self.index] = (s, a, r, s_, d)
        self.index = (self.index + 1) % self.capacity
        
    def sample(self, batch_size):
        return random.sample(self.data, batch_size)
    
    def __len__(self):
        return len(self.data)

# Pseudo-unit testing

In [10]:
# init
memory = ReplayBuffer4(replay_buffer_size)
print(memory)
# len
print(len(memory))
# append
memory.append(state, action, reward, next_state, done)
print(memory)
print(len(memory))

ReplayBuffer4([], maxlen=1000000)
0
ReplayBuffer4([Transition(state=array([-0.01975132, -0.04026621,  0.01258357,  0.00629829]), action=1, reward=1.0, next_state=array([-0.02055665,  0.15467303,  0.01270954, -0.28238796]), done=False)], maxlen=1000000)
1


In [11]:
# init
memory = ReplayBuffer5(replay_buffer_size)
print(memory.data)
# len
print(len(memory))
# append
memory.append(state, action, reward, next_state, done)
print(memory.data)
print(len(memory))

[None None None ... None None None]
0
[Transition(state=array([-0.01975132, -0.04026621,  0.01258357,  0.00629829]), action=1, reward=1.0, next_state=array([-0.02055665,  0.15467303,  0.01270954, -0.28238796]), done=False)
 None None ... None None None]
1


# Time testing

In [12]:
memory = ReplayBuffer1(replay_buffer_size)
test_insertion_tqdm(memory, nb_samples)
test_sampling_tqdm(memory, nb_batches)
test_insertion_timeit(memory, nb_samples)
test_sampling_timeit(memory, nb_batches)

100%|██████████| 2000000/2000000 [00:01<00:00, 1195226.05it/s]
100%|██████████| 10000/10000 [00:15<00:00, 644.42it/s]


Insertion of 2000000 samples: 0.8660166679983377
Sampling of 10000 batches: 15.171046824994846


In [13]:
memory = ReplayBuffer2(replay_buffer_size)
test_insertion_tqdm(memory, nb_samples)
test_sampling_tqdm(memory, nb_batches)
test_insertion_timeit(memory, nb_samples)
test_sampling_timeit(memory, nb_batches)

100%|██████████| 2000000/2000000 [00:00<00:00, 2490685.99it/s]
100%|██████████| 10000/10000 [00:15<00:00, 640.37it/s]


Insertion of 2000000 samples: 0.40381351800169796
Sampling of 10000 batches: 15.266013815999031


In [14]:
memory = ReplayBuffer3(replay_buffer_size)
test_insertion_tqdm(memory, nb_samples)
test_sampling_tqdm(memory, nb_batches)
test_insertion_timeit(memory, nb_samples)
test_sampling_timeit(memory, nb_batches)

100%|██████████| 2000000/2000000 [00:00<00:00, 2021437.05it/s]
100%|██████████| 10000/10000 [00:15<00:00, 633.41it/s]


Insertion of 2000000 samples: 0.5816797989973566
Sampling of 10000 batches: 14.921481103003316


In [15]:
memory = ReplayBuffer4(replay_buffer_size)
test_insertion_tqdm(memory, nb_samples)
test_sampling_tqdm(memory, nb_batches)
test_insertion_timeit(memory, nb_samples)
test_sampling_timeit(memory, nb_batches)

100%|██████████| 2000000/2000000 [00:01<00:00, 1045744.91it/s]
100%|██████████| 10000/10000 [00:15<00:00, 633.04it/s]


Insertion of 2000000 samples: 1.1129553169957944
Sampling of 10000 batches: 15.083680347001064


In [16]:
memory = ReplayBuffer5(replay_buffer_size)
test_insertion_tqdm(memory, nb_samples)
test_sampling_tqdm(memory, nb_batches)
test_insertion_timeit(memory, nb_samples)
test_sampling_timeit(memory, nb_batches)

100%|██████████| 2000000/2000000 [00:02<00:00, 946996.43it/s] 
100%|██████████| 10000/10000 [02:37<00:00, 63.40it/s]


Insertion of 2000000 samples: 1.2960227980001946
Sampling of 10000 batches: 149.89062911699875


In [17]:
memory = ReplayBuffer6(replay_buffer_size)
test_insertion_tqdm(memory, nb_samples)
test_sampling_tqdm(memory, nb_batches)
test_insertion_timeit(memory, nb_samples)
test_sampling_timeit(memory, nb_batches)

100%|██████████| 2000000/2000000 [00:00<00:00, 2453348.81it/s]
100%|██████████| 10000/10000 [00:16<00:00, 624.97it/s]


Insertion of 2000000 samples: 0.3957371369979228
Sampling of 10000 batches: 15.264003774005687


In [18]:
memory = ReplayBuffer7(replay_buffer_size)
test_insertion_tqdm(memory, nb_samples)
test_sampling_tqdm(memory, nb_batches)
test_insertion_timeit(memory, nb_samples)
test_sampling_timeit(memory, nb_batches)

100%|██████████| 2000000/2000000 [00:01<00:00, 1044286.72it/s]
100%|██████████| 10000/10000 [00:16<00:00, 624.28it/s]


Insertion of 2000000 samples: 1.096159065993561
Sampling of 10000 batches: 15.458037463999062


In [19]:
memory = ReplayBuffer8(replay_buffer_size)
test_insertion_tqdm(memory, nb_samples)
test_sampling_tqdm(memory, nb_batches)
test_insertion_timeit(memory, nb_samples)
test_sampling_timeit(memory, nb_batches)

100%|██████████| 2000000/2000000 [00:00<00:00, 2547823.82it/s]
100%|██████████| 10000/10000 [00:16<00:00, 597.04it/s]


Insertion of 2000000 samples: 0.4225345800005016
Sampling of 10000 batches: 16.56810394200147


In [21]:
memory = ReplayBuffer9(replay_buffer_size)
test_insertion_tqdm(memory, nb_samples)
test_sampling_tqdm(memory, nb_batches)
test_insertion_timeit(memory, nb_samples)
test_sampling_timeit(memory, nb_batches)

100%|██████████| 2000000/2000000 [00:01<00:00, 1597173.58it/s]
100%|██████████| 10000/10000 [00:00<00:00, 27997.68it/s]


Insertion of 2000000 samples: 0.8166492579985061
Sampling of 10000 batches: 0.32161630100017646
