In [None]:
import time

import matplotlib.pyplot as plt

from torch.utils.data import DataLoader

from torchvision.datasets import CIFAR10
from torchvision import transforms

In [None]:
ds = CIFAR10(
    root="../../assets/cifar10", 
    train=True, 
    download=True,
    transform=transforms.Compose([transforms.ToTensor()])
)

In [None]:
times = {}
batch_sizes = [16, 32, 64, 128, 256]
num_epochs = 6

for batch_size in batch_sizes:
    loader = DataLoader(
        ds,
        batch_size=batch_size,
        shuffle=True
    )

    epoch_times = []

    for epoch in range(num_epochs):
        start_time = time.perf_counter()

        for _ in loader:
            pass

        epoch_time = (time.perf_counter() - start_time) * 1000

        # skip warm-up epoch
        if epoch > 0:
            epoch_times.append(epoch_time)

    avg_time = sum(epoch_times) / len(epoch_times)
    print(f"Batch Size: {batch_size}, Average Epoch Time: {avg_time:.2f} milliseconds")
    times[batch_size] = avg_time

In [None]:
batch_sizes = sorted(times.keys())
epoch_times = [times[b] for b in batch_sizes]

plt.figure(figsize=(6, 4))
plt.plot(batch_sizes, epoch_times, marker="o")
plt.xlabel("Batch Size")
plt.ylabel("Average Time per Epoch (milliseconds)")
plt.grid(True)
plt.tight_layout()
plt.show()

**CAUTION!!** Behavior may vary depending on hardware and operating system (especially on macOS). Jupyter can behave unpredictably with multiple processes, so we execute this experiment in a separate Python process.

In [None]:
!python num_workers_experiment.py

In [None]:
img = plt.imread("num_workers_vs_epoch_time.png")
plt.imshow(img)
plt.axis("off")
plt.show()

A general rule of thumb is to assign `num_workers` to your number of cores; take this with a grain of salt...

In [None]:
import os

cpu_cores = os.cpu_count()
print(f"Number of CPU cores available: {cpu_cores}")

## Further Optimizations

When `pin_memory=True`, batches are stored in page-locked (pinned) CPU memory enabling faster and non-blocking transfers to the GPU.

```python
DataLoader(..., pin_memory=True)
```

`prefetch_factor` controls how many batches each worker prepares in advance. The default is 2 and the total prefetched batches is num_workers * prefetch_factor.

```python
DataLoader(..., num_workers=4, prefetch_factor=4)
```