# Notes

- What we need for a good replay buffer: fixed-size, FIFO behavior, O(1) insertion at the end, O(1) sampling. Limited memory footprint.
- [deque](https://docs.python.org/3/library/collections.html#collections.deque) has O(1) insertion time at the end, but O(n) access time (which made me doubt its ability to make a good replay buffer and try an np.array-based solution)
- When we draw a mini-batch for DQN, it would be best to receive (separately) an array of states only, an array of actions, an array of rewards, an array of next states and a last array of "done", that we can pass these to the Q-network. What's the best way of doing that? Store them separately?

# Setting the frame

In [1]:
import gymnasium as gym
from gym import logger
import numpy as np
logger.set_level(gym.logger.DISABLED)
import torch

In [2]:
cartpole = gym.make('CartPole-v1')

In [4]:
state,_ = cartpole.reset()
action = cartpole.action_space.sample()
next_state, reward, done, trunc, _ = cartpole.step(action)

In [5]:
replay_buffer_size = int(1e6)
nb_samples = int(2e6)
nb_batches = int(1e4)
batch_size = 50

# Testing functions

In [6]:
from tqdm import trange

def test_insertion_tqdm(buffer, nb_samples):
    state, _ = cartpole.reset()
    for _ in trange(nb_samples):
        buffer.append(state, action, reward, next_state, done)

def test_sampling_tqdm(buffer, nb_batches):
    for _ in trange(nb_batches):
        buffer.sample(batch_size)

In [7]:
import timeit
import gc

def test_insertion_timeit(buffer, nb_samples):
    print("Insertion of", nb_samples, "samples:", 
      timeit.timeit('memory.append(state,action,reward,next_state,done)', 
                    globals=globals(), 
                    setup='gc.enable()', 
                    number=nb_samples))

def test_sampling_timeit(buffer, nb_batches):
    print("Sampling of", nb_batches, "batches:",
          timeit.timeit('memory.sample(batch_size)', 
                        globals=globals(), 
                        setup='gc.enable()', 
                        number=nb_batches))

# Replay buffer classes

In [8]:
from collections import deque, namedtuple
Transition = namedtuple('Transition', 
                        ('state', 'action', 'reward', 'next_state', 'done'))

# But for the sake of the exercise, we will wrap this in a dedicated class.

import random
    
class ReplayBuffer1(object):
    def __init__(self, capacity):
        self.memory = deque(maxlen=capacity)
    def append(self, *args):
        self.memory.append(Transition(*args))
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)
    def __len__(self):
        return len(self.memory)
    def capacity(self):
        return self.memory.maxlen
    
class ReplayBuffer2(object):
    def __init__(self, capacity):
        self.memory = deque(maxlen=capacity)
    def append(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)
    def __len__(self):
        return len(self.memory)
    def capacity(self):
        return self.memory.maxlen
    
class ReplayBuffer3(deque):
    def __init__(self, capacity):
        super().__init__(maxlen=capacity)
    def append(self, state, action, reward, next_state, done):
        super().append((state, action, reward, next_state, done))
    def sample(self, batch_size):
        return random.sample(self, batch_size)
    
class ReplayBuffer4(deque):
    def __init__(self, capacity):
        super().__init__(maxlen=capacity)
    def append(self, state, action, reward, next_state, done):
        super().append(Transition(state, action, reward, next_state, done))
    def sample(self, batch_size):
        return random.sample(self, batch_size)

class ReplayBuffer5(object):
    def __init__(self, capacity):
        self.capacity = capacity # capacity of the buffer
        self.data = np.empty(capacity, dtype=Transition)
        self.index = 0 # index of the next cell to be filled
        self.size = 0 # number of elements in the buffer
    def append(self, *args):
        self.data[self.index] = Transition(*args)
        self.index = (self.index + 1) % self.capacity
        if self.size < self.capacity:
            self.size+=1
    def sample(self, batch_size):
        #indices = np.random.choice(self.size, size=batch_size, replace=False)
        #return self.memory[indices]
        return np.random.choice(self.data[:self.size], size=batch_size, replace=False)
    def __len__(self):
        return self.size
    
class ReplayBuffer6(object):
    def __init__(self, capacity):
        self.data = deque(maxlen=capacity)
    def append(self, state, action, reward, next_state, done):
        self.data.append((state, action, reward, next_state, done))
    def sample(self, batch_size):
        batch = random.sample(self.data, batch_size)
        return list(map(np.array, list(zip(*batch))))
    def __len__(self):
        return len(self.data)
    def capacity(self):
        return self.data.maxlen

class ReplayBuffer7(deque):
    def __init__(self, capacity):
        super().__init__(maxlen=capacity)
    def append(self, state, action, reward, next_state, done):
        super().append(Transition(state, action, reward, next_state, done))
    def sample(self, batch_size):
        batch = random.sample(self, batch_size)
        return list(map(np.array, list(zip(*batch))))
    def capacity(self):
        return self.maxlen
    
class ReplayBuffer8(object):
    def __init__(self, capacity):
        self.data = deque(maxlen=capacity)
    def append(self, state, action, reward, next_state, done):
        self.data.append((state, action, reward, next_state, done))
    def sample(self, batch_size):
        batch = random.sample(self.data, batch_size)
        return list(map(torch.Tensor, list(zip(*batch))))
    def __len__(self):
        return len(self.data)
    def capacity(self):
        return self.data.maxlen

In [9]:
class ReplayBuffer9:
    def __init__(self, capacity):
        self.capacity = capacity # capacity of the buffer
        self.data = []
        self.index = 0 # index of the next cell to be filled
    def append(self, s, a, r, s_, d):
        if len(self.data) < self.capacity:
            self.data.append(None)
        self.data[self.index] = (s, a, r, s_, d)
        self.index = (self.index + 1) % self.capacity
    def sample(self, batch_size):
        return random.sample(self.data, batch_size)
    def __len__(self):
        return len(self.data)
    
class ReplayBuffer10:
    def __init__(self, capacity):
        self.capacity = capacity # capacity of the buffer
        self.data = []
        self.index = 0 # index of the next cell to be filled
    def append(self, s, a, r, s_, d):
        if len(self.data) < self.capacity:
            self.data.append(None)
        self.data[self.index] = (s, a, r, s_, d)
        self.index = (self.index + 1) % self.capacity
    def sample(self, batch_size):
        batch = random.sample(self.data, batch_size)
        return list(map(torch.Tensor, list(zip(*batch))))
    def __len__(self):
        return len(self.data)

# Pseudo-unit testing

In [10]:
# init
memory = ReplayBuffer4(replay_buffer_size)
print(memory)
# len
print(len(memory))
# append
memory.append(state, action, reward, next_state, done)
print(memory)
print(len(memory))

ReplayBuffer4([], maxlen=1000000)
0
ReplayBuffer4([Transition(state=array([ 0.00410543,  0.03742514, -0.0147697 ,  0.02450623], dtype=float32), action=1, reward=1.0, next_state=array([ 0.00485393,  0.23275575, -0.01427957, -0.27279985], dtype=float32), done=False)], maxlen=1000000)
1


In [11]:
# init
memory = ReplayBuffer5(replay_buffer_size)
print(memory.data)
# len
print(len(memory))
# append
memory.append(state, action, reward, next_state, done)
print(memory.data)
print(len(memory))

[None None None ... None None None]
0
[Transition(state=array([ 0.00410543,  0.03742514, -0.0147697 ,  0.02450623], dtype=float32), action=1, reward=1.0, next_state=array([ 0.00485393,  0.23275575, -0.01427957, -0.27279985], dtype=float32), done=False)
 None None ... None None None]
1


# Time testing

In [12]:
memory = ReplayBuffer1(replay_buffer_size)
test_insertion_tqdm(memory, nb_samples)
test_sampling_tqdm(memory, nb_batches)
test_insertion_timeit(memory, nb_samples)
test_sampling_timeit(memory, nb_batches)

100%|██████████| 2000000/2000000 [00:01<00:00, 1147464.96it/s]
100%|██████████| 10000/10000 [00:16<00:00, 611.18it/s]


Insertion of 2000000 samples: 0.8236342830005015
Sampling of 10000 batches: 16.067819274000612


In [13]:
memory = ReplayBuffer2(replay_buffer_size)
test_insertion_tqdm(memory, nb_samples)
test_sampling_tqdm(memory, nb_batches)
test_insertion_timeit(memory, nb_samples)
test_sampling_timeit(memory, nb_batches)

100%|██████████| 2000000/2000000 [00:00<00:00, 2682985.16it/s]
100%|██████████| 10000/10000 [00:16<00:00, 619.53it/s]


Insertion of 2000000 samples: 0.41561654299948714
Sampling of 10000 batches: 15.2426448719998


In [14]:
memory = ReplayBuffer3(replay_buffer_size)
test_insertion_tqdm(memory, nb_samples)
test_sampling_tqdm(memory, nb_batches)
test_insertion_timeit(memory, nb_samples)
test_sampling_timeit(memory, nb_batches)

100%|██████████| 2000000/2000000 [00:00<00:00, 2198896.59it/s]
100%|██████████| 10000/10000 [00:15<00:00, 627.69it/s]


Insertion of 2000000 samples: 0.5717196929999773
Sampling of 10000 batches: 15.4775857730001


In [15]:
memory = ReplayBuffer4(replay_buffer_size)
test_insertion_tqdm(memory, nb_samples)
test_sampling_tqdm(memory, nb_batches)
test_insertion_timeit(memory, nb_samples)
test_sampling_timeit(memory, nb_batches)

100%|██████████| 2000000/2000000 [00:02<00:00, 997902.27it/s] 
100%|██████████| 10000/10000 [00:16<00:00, 620.28it/s]


Insertion of 2000000 samples: 1.0013630290004585
Sampling of 10000 batches: 15.992759149000449


In [16]:
memory = ReplayBuffer5(replay_buffer_size)
test_insertion_tqdm(memory, nb_samples)
test_sampling_tqdm(memory, nb_batches)
test_insertion_timeit(memory, nb_samples)
test_sampling_timeit(memory, nb_batches)

100%|██████████| 2000000/2000000 [00:01<00:00, 1013634.74it/s]
100%|██████████| 10000/10000 [02:38<00:00, 62.94it/s]


Insertion of 2000000 samples: 1.1802581740003006
Sampling of 10000 batches: 159.70133984299991


In [17]:
memory = ReplayBuffer6(replay_buffer_size)
test_insertion_tqdm(memory, nb_samples)
test_sampling_tqdm(memory, nb_batches)
test_insertion_timeit(memory, nb_samples)
test_sampling_timeit(memory, nb_batches)

100%|██████████| 2000000/2000000 [00:00<00:00, 2626853.78it/s]
100%|██████████| 10000/10000 [00:17<00:00, 584.65it/s]


Insertion of 2000000 samples: 0.4241142740002033
Sampling of 10000 batches: 16.318693213999723


In [18]:
memory = ReplayBuffer7(replay_buffer_size)
test_insertion_tqdm(memory, nb_samples)
test_sampling_tqdm(memory, nb_batches)
test_insertion_timeit(memory, nb_samples)
test_sampling_timeit(memory, nb_batches)

100%|██████████| 2000000/2000000 [00:02<00:00, 985910.89it/s] 
100%|██████████| 10000/10000 [00:16<00:00, 596.94it/s]


Insertion of 2000000 samples: 1.0350060809996648
Sampling of 10000 batches: 16.589127244999872


In [19]:
memory = ReplayBuffer8(replay_buffer_size)
test_insertion_tqdm(memory, nb_samples)
test_sampling_tqdm(memory, nb_batches)
test_insertion_timeit(memory, nb_samples)
test_sampling_timeit(memory, nb_batches)

100%|██████████| 2000000/2000000 [00:00<00:00, 2744802.58it/s]
  return list(map(torch.Tensor, list(zip(*batch))))
100%|██████████| 10000/10000 [00:17<00:00, 564.80it/s]


Insertion of 2000000 samples: 0.3841596349993779
Sampling of 10000 batches: 16.850195649999478


In [20]:
memory = ReplayBuffer9(replay_buffer_size)
test_insertion_tqdm(memory, nb_samples)
test_sampling_tqdm(memory, nb_batches)
test_insertion_timeit(memory, nb_samples)
test_sampling_timeit(memory, nb_batches)

100%|██████████| 2000000/2000000 [00:01<00:00, 1836881.10it/s]
100%|██████████| 10000/10000 [00:00<00:00, 42478.31it/s]


Insertion of 2000000 samples: 0.7609448889998021
Sampling of 10000 batches: 0.23952373599968269


In [21]:
memory = ReplayBuffer10(replay_buffer_size)
test_insertion_tqdm(memory, nb_samples)
test_sampling_tqdm(memory, nb_batches)
test_insertion_timeit(memory, nb_samples)
test_sampling_timeit(memory, nb_batches)

100%|██████████| 2000000/2000000 [00:01<00:00, 1792246.89it/s]
100%|██████████| 10000/10000 [00:01<00:00, 6025.21it/s]


Insertion of 2000000 samples: 0.7767183719997774
Sampling of 10000 batches: 1.6790301529999851
