# Multi-process data loading
To increase the sampling speed, specifically for datasets with CPU-bound retrieval operations (e.g. file reading), parallelized data loading can be used.

In [1]:
from buffers import ReplayBuffer
from buffers.dataset import IndividualFileDataset

NUM_WORKERS = (0, 2, 4, 8, 16)
BATCH_SIZE = 512
NUM_CALLS = 16

replay_buffers = [ReplayBuffer(IndividualFileDataset(10000), batch_size=BATCH_SIZE, sequence_length=8, num_workers=nw) for nw in NUM_WORKERS]

In [2]:
from utils import add_sample_data

for rb in replay_buffers:
    add_sample_data(rb, num_steps=200 * max(NUM_WORKERS)) # Have at least as many episodes as num_workers.

# Initialize dataloader on first sampling operation.
for rb in replay_buffers:
    _ = rb.sample()

  if not isinstance(terminated, (bool, np.bool8)):


In [3]:
from utils import StopWatch

stopwatch = StopWatch()

for nw, rb in zip(NUM_WORKERS, replay_buffers):
    @stopwatch.get_duration(name=f"{nw}_workers")
    def sample():
        _ = rb.sample()

    for _ in range(NUM_CALLS):
        sample()

In [4]:
print(stopwatch)

+------------+------------------------+-----------------------+------------------------+---------+
|   Function |    Avg Duration (s)    |    Min Duration (s)   |    Max Duration (s)    | Calls # |
+------------+------------------------+-----------------------+------------------------+---------+
|  0_workers |  0.31592419743537903   |   0.289478063583374   |   0.3535332679748535   |    16   |
|  2_workers |  0.16262350976467133   |  5.7220458984375e-05  |   0.5968260765075684   |    16   |
|  4_workers |   0.0892651379108429   |  8.0108642578125e-05  |   0.7194643020629883   |    16   |
|  8_workers | 0.0001093447208404541  |  5.7220458984375e-05  | 0.0001842975616455078  |    16   |
| 16_workers | 0.00010056793689727783 | 6.937980651855469e-05 | 0.00016188621520996094 |    16   |
+------------+------------------------+-----------------------+------------------------+---------+
