# DataLoader

The DataLoader is an iterator. At each iteration, it yields a batch of training examples from a dataset. It stops when it exhausts the training examples in the dataset.

## Inputs

It takes as input:
* dataset: A dictionary-like object that we can use to look up a training example by key. This look up typically requires reading data from disk (I/O) and processing that data (compute).
* batch_size: The number of training examples in each batch (except maybe the last batch).
* shuffle: If True, then randomly sample a batch without replacement at each iteration. Otherwise, sample each batch preserving the order of the training examples in the dataset.
* collate_fn: A function that takes a list of training examples as input and outputs a batch.
* num_workers: If 0, then just use the single main process to prepare each batch. Otherwise, use `num_workers` concurrent workers to prepare batches.

We define a Dataset class:

In [1]:
import time

class Dataset:
    
    def __init__(self, features, labels):
        self.examples = list(zip(features, labels))
       
    def __getitem__(self, i):
        time.sleep(0.1) # simulate work
        return self.examples[i]

    def __len__(self):
        return len(self.examples)

We also define a collate_fn:

In [2]:
def collate_fn(examples):
    """[(x1,y1),...,(xn,yn)] -> X, y"""
    z = list(zip(*examples))
    data = np.stack(z[0], axis=0)
    target = np.stack(z[1], axis=0)
    return (data, target)

## Single process


When `num_workers` == 0, the implementation is relatively straightforward:

In [3]:
import numpy as np

def _process_batch_idxs(batch_idxs, dataset, collate_fn):
    return collate_fn([dataset[idx] for idx in batch_idxs])

class SingleProcessDataLoader:

    def __init__(self, dataset, batch_size, shuffle, collate_fn, num_workers=0):
        self.dataset = dataset
        self.batch_size = batch_size
        self.collate_fn = collate_fn
        self.num_workers = num_workers

        num_examples = len(self.dataset)
        self.num_remaining_examples = num_examples
        
        if shuffle:
            self.sampler = iter(np.random.permutation(num_examples))
        else:
            self.sampler = iter(range(num_examples))            

    def __iter__(self):
        # PyTorch's DataLoader will restart
        # every time you use it in a new loop.[1,2]
        # E.g., the first and second loops will yield
        # the same batches:
        #
        # dataloader = DataLoader(...)
        # for batch in dataloader:
        #    ...
        # for batch in dataloader:
        #    ...
        #
        # This is not the case for e.g. the code below:
        #
        # it = iter(range(10))
        # for entry in it:
        #   ...
        # for entry in it:
        #   ...
        #
        # If we do not break out of the first
        # loop early, then the second loop will
        # immediately break, because the iterator
        # will be exhausted.
        #
        # We implement the behavior of iter(range(10)).
        # In order to implement PyTorch's behavior,
        # we would move most of the logic of the class
        # to DataLoaderIter and replace "self" here
        # with "DataLoaderIter(self)".
        #
        # [1]: https://stackoverflow.com/questions/60311307/how-does-one-reset-the-dataloader-in-pytorch
        # [2]: "When __iter__ is a generator it's returned as a new closure every time. 
        #       It'd be the best not to modify the internal state (i.e. assign to self),
        #       but to have these as local variables. Now, if you start iterating over the dataloader,
        #       stop in the middle, take a new iterator, and continue iterating over the first one,
        #       you'll start from the beginning."
        #      (https://github.com/pytorch/pytorch/pull/44/commits/\
        #       55afe5137eb57d27de76208827652dfc192896df#discussion_r79853797)
        return self
    
    def _next_batch_idxs(self):
        # Note that the `next` call can raise a StopIteration
        # and stop the `self` iterator, but it shouldn't happen
        # in practice, because sampler is the size of the dataset
        # and we keep track of self.num_remaining_examples.
        n = min(self.num_remaining_examples, self.batch_size)
        batch_idxs = [next(self.sampler) for _ in range(n)]
        self.num_remaining_examples -= len(batch_idxs)
        return batch_idxs
    
    def __next__(self):
        if self.num_workers == 0:
            if self.num_remaining_examples == 0:
                raise StopIteration
            batch_idxs = self._next_batch_idxs()
            batch = _process_batch_idxs(batch_idxs, self.dataset, self.collate_fn)
            return batch
            
        raise NotImplementedError

Here's the basic usage:

In [4]:
np.random.seed(0)
n = 20
features = np.random.randn(n, 100)
labels = np.random.permutation(n)
dataset = Dataset(features, labels)
dataloader = SingleProcessDataLoader(dataset, batch_size=2, shuffle=False, collate_fn=collate_fn)
for i, (X, y) in enumerate(dataloader):
    print(i)

0
1
2
3
4
5
6
7
8
9


## Multiple processes (out-of-order)

Implementing the multiprocess version is a trickier.

The idea is to create multiple child processes, where each child process is a worker that continuously tries to get a batch of indices from an input queue. Once it gets a batch of indices, it looks up the training example for each index using the dataset and then collates those training examples to construct a batch. It puts this batch onto an output queue. It only exits when it gets an exit signal from the input queue instead of a batch of indices.

At each iteration, the parent process puts a batch of indices on the input queue (provided there are training examples remaining) and then gets a batch from the output queue to return at that iteration.

Every time it puts a batch of indices on the input queue, it increments a counter by 1. Every time it gets a batch from the output queue, it decrements the counter by 1. In this way, the value of the counter at the end of the iteration is the same as the value of the counter at the start of the iteration unless we have exhausted the remaining training examples. If we have exhausted the remaining training examples, then the value of the counter at the end of the iteration is 1 less than the value of the counter at the start of the iteration. 

We place at least one batch of indices on the input queue before any of the child processes are created, so the counter is positive at the start of the first iteration. Eventually, we exhaust the training examples and the counter decrements by 1 at each iteration until it hits 0. When it hits 0, we know that we have processed all the batches and we can send the exit signal to all the workers and raise a StopIteration.

In [5]:
# https://stackoverflow.com/questions/41385708/multiprocessing-example-giving-attributeerror
!pip install multiprocess



In [6]:
import numpy as np
import multiprocess as multiprocessing

def _process_batch_idxs(batch_idxs, dataset, collate_fn):
    return collate_fn([dataset[idx] for idx in batch_idxs])

def _worker(dataset, input_queue, output_queue, collate_fn):
    while True:
        batch_idxs = input_queue.get()
        
        if batch_idxs is None:
            break
            
        batch = _process_batch_idxs(batch_idxs, dataset, collate_fn)
        
        output_queue.put(batch)

class OutOfOrderDataLoader:

    def __init__(self, dataset, batch_size, shuffle, collate_fn, num_workers=0):
        self.dataset = dataset
        self.batch_size = batch_size
        self.collate_fn = collate_fn
        self.num_workers = num_workers

        num_examples = len(self.dataset)
        self.num_remaining_examples = num_examples
        
        if shuffle:
            self.sampler = iter(np.random.permutation(num_examples))
        else:
            self.sampler = iter(range(num_examples))

        self.unfinished_tasks = 0
        self.all_tasks_done = False
        self.workers = []
            
        if num_workers > 0:
            self.input_queue = multiprocessing.Queue()
            self.output_queue = multiprocessing.Queue()
            
            # We start by putting at least 1 batch of indices on
            # the `input_queue`. This increments `unfinished_tasks`.
            # We also want to avoid idle workers by adding multiple
            # batches to the input queue before the workers start.
            # Adding some multiple of `num_workers` would also
            # work.
            for _ in range(num_workers):
                self._put_batch_idxs()
                
            for _ in range(num_workers):
                w = multiprocessing.Process(
                    target=_worker,
                    args=(
                        self.dataset,
                        self.input_queue,
                        self.output_queue,
                        self.collate_fn),
                    daemon=True)
                w.start()
                self.workers.append(w)

    def __iter__(self):
        return self
    
    def _next_batch_idxs(self):
        n = min(self.num_remaining_examples, self.batch_size)
        batch_idxs = [next(self.sampler) for _ in range(n)]
        self.num_remaining_examples -= len(batch_idxs)
        return batch_idxs
    
    def _put_batch_idxs(self):
        if self.num_remaining_examples == 0:
            return
        batch_idxs = self._next_batch_idxs()
        self.input_queue.put(batch_idxs)
        self.unfinished_tasks += 1
    
    def __next__(self):
        if self.num_workers == 0:
            if self.num_remaining_examples == 0:
                raise StopIteration
            batch_idxs = self._next_batch_idxs()
            return _process_batch_idxs(batch_idxs, self.dataset, self.collate_fn)
            
        if self.unfinished_tasks == 0:
            self._join()
            raise StopIteration
            
        self._put_batch_idxs()
        batch = self.output_queue.get()
        self.unfinished_tasks -= 1
        
        return batch
    
    def _join(self):
        if self.all_tasks_done:
            return
        for _ in range(len(self.workers)):
            self.input_queue.put(None)
        for w in self.workers:
            w.join()
        self.all_tasks_done = True
            
    def __del__(self):
        self._join()

In [7]:
dataloader = OutOfOrderDataLoader(dataset, batch_size=2, shuffle=False, collate_fn=collate_fn, num_workers=4)
for i, (X, y) in enumerate(dataloader):
    print(i)

0
1
2
3
4
5
6
7
8
9


## Multiple processes (in-order)

In [8]:
import numpy as np
import time
import multiprocess as multiprocessing


class Dataset:
    
    def __init__(self, features, labels):
        self.examples = list(zip(features, labels))
       
    def __getitem__(self, i):
        time.sleep(0.1) # simulate work
        return self.examples[i]

    def __len__(self):
        return len(self.examples)


def collate_fn(examples):
    """[(x1,y1),...,(xn,yn)] -> X, y"""
    z = list(zip(*examples))
    data = np.stack(z[0], axis=0)
    target = np.stack(z[1], axis=0)
    return (data, target)


def _process_batch_idxs(batch_idxs, dataset, collate_fn):
    return collate_fn([dataset[idx] for idx in batch_idxs])


def _worker(dataset, input_queue, output_queue, collate_fn):
    while True:
        r = input_queue.get()

        if r is None:
            break

        idx, batch_idxs = r

        batch = _process_batch_idxs(batch_idxs, dataset, collate_fn)
        output_queue.put((idx, batch))


class DataLoader:

    def __init__(self, dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=0):
        self.dataset = dataset
        self.batch_size = batch_size
        self.collate_fn = collate_fn
        self.num_workers = num_workers

        n_examples = len(self.dataset)

        if shuffle:
            self.sampler = iter(np.random.permutation(n_examples))
        else:
            self.sampler = iter(range(n_examples))

        self.n_remaining_examples = n_examples
        self.unfinished_tasks = 0
        self.all_tasks_done = False
        self.workers = []

        if num_workers > 0:
            self.input_queue = multiprocessing.Queue()
            self.output_queue = multiprocessing.Queue()

            # Every time we put an item in `input_queue`,
            # we increment `send_idx`. This tracks the 
            # order in which we put items in the `input_queue`
            # in the main process.
            self.send_idx = 0

            # Every time we return a batch, we increment
            # `recv_idx`.
            self.recv_idx = 0

            self.cache = {}

            for _ in range(self.num_workers):
                self._put_batch_idxs()

            for _ in range(self.num_workers):
                w = multiprocessing.Process(
                    target=_worker,
                    args=(self.dataset, self.input_queue, self.output_queue, self.collate_fn),
                    daemon=True)
                w.start()
                self.workers.append(w)

    def _next_batch_idxs(self):
        n = min(self.n_remaining_examples, self.batch_size)
        batch_idxs = [next(self.sampler) for _ in range(n)]
        self.n_remaining_examples -= len(batch_idxs)
        return batch_idxs

    def _put_batch_idxs(self):
        if self.n_remaining_examples == 0:
            return
        batch_idxs = self._next_batch_idxs()
        self.input_queue.put((self.send_idx, batch_idxs))
        self.unfinished_tasks += 1
        self.send_idx += 1

    def __iter__(self):
        return self

    def __next__(self):
        if self.num_workers == 0:
            if self.n_remaining_examples == 0:
                raise StopIteration
            batch_idxs = self._next_batch_idxs()
            return _process_batch_idxs(batch_idxs, self.dataset, self.collate_fn)

        if self.recv_idx in self.cache:
            batch = self.cache.pop(self.recv_idx)
            self.recv_idx += 1
            return batch

        # Should check after checking the cache
        # to make the sure the cache is empty.
        if self.unfinished_tasks == 0:
            self._join_workers()
            raise StopIteration

        while True:
            self._put_batch_idxs()
            idx, batch = self.output_queue.get()
            self.unfinished_tasks -= 1

            if idx != self.recv_idx:
                self.cache[idx] = batch
                continue

            self.recv_idx += 1
            return batch

    def _join_workers(self):
        if self.all_tasks_done:
            return
        self.all_tasks_done = True
        for _ in range(len(self.workers)):
            self.input_queue.put(None)
        for w in self.workers:
            w.join()

    def __del__(self):
        if self.all_tasks_done:
            return
        self._join_workers()

## Tests

In [9]:
import unittest
import math
import sys
import traceback

class ErrorDataset:

    def __init__(self, size):
        self.size = size

    def __len__(self):
        return self.size

class TestDataLoader(unittest.TestCase):

    def setUp(self):
        n = 20
        np.random.seed(0)
        self.data = np.random.randn(n, 2, 3, 5)
        self.labels = np.array(np.random.permutation(n // 2).tolist() * 2)
        self.dataset = Dataset(self.data, self.labels)

    def _test_sequential(self, loader):
        batch_size = loader.batch_size
        i = 0
        for i, (sample, target) in enumerate(loader):
            idx = i * batch_size
            np.testing.assert_almost_equal(sample, self.data[idx:idx+batch_size])
            np.testing.assert_almost_equal(target, self.labels[idx:idx+batch_size])
        self.assertEqual(i, math.floor((len(self.dataset)-1) / batch_size))

    def _test_shuffle(self, loader):
        found_data = {i: 0 for i in range(self.data.shape[0])}
        found_labels = {i: 0 for i in range(self.labels.shape[0])}
        batch_size = loader.batch_size
        for i, (batch_samples, batch_targets) in enumerate(loader):
            for sample, target in zip(batch_samples, batch_targets):
                for data_point_idx, data_point in enumerate(self.data):
                    if (data_point == sample).all():
                        self.assertFalse(found_data[data_point_idx])
                        found_data[data_point_idx] += 1
                        break
                self.assertEqual(target, self.labels[data_point_idx:(data_point_idx + 1)])
                found_labels[data_point_idx] += 1
            self.assertEqual(sum(found_data.values()), (i+1) * batch_size)
            self.assertEqual(sum(found_labels.values()), (i+1) * batch_size)
        self.assertEqual(i, math.floor((len(self.dataset)-1) / batch_size))

    def _test_error(self, loader):
        it = iter(loader)
        errors = 0
        while True:
            try:
                it.next()
            except NotImplementedError:
                msg = "".join(traceback.format_exception(*sys.exc_info()))
                self.assertTrue("collate_fn" in msg)
                errors += 1
            except StopIteration:
                self.assertEqual(errors,
                    math.ceil(float(len(loader.dataset))/loader.batch_size))
                return

    def test_sequential(self):
        self._test_sequential(DataLoader(self.dataset))

    def test_sequential_batch(self):
        self._test_sequential(DataLoader(self.dataset, batch_size=2))

    def test_shuffle(self):
        self._test_shuffle(DataLoader(self.dataset, shuffle=True))

    def test_shuffle_batch(self):
        self._test_shuffle(DataLoader(self.dataset, batch_size=2, shuffle=True))

    def test_sequential_workers(self):
        self._test_sequential(DataLoader(self.dataset, num_workers=4))

    def test_sequential_batch_workers(self):
        self._test_sequential(DataLoader(self.dataset, batch_size=2, num_workers=4))

    def test_shuffle_workers(self):
        self._test_shuffle(DataLoader(self.dataset, shuffle=True, num_workers=4))

    def test_shuffle_batch_workers(self):
        self._test_shuffle(DataLoader(self.dataset, batch_size=2, shuffle=True, num_workers=4))

    def test_partial_workers(self):
        "check that workers exit even if the iterator is not exhausted"
        loader = iter(DataLoader(self.dataset, batch_size=2, num_workers=4))
        workers = loader.workers
        for i, sample in enumerate(loader):
            if i == 3:
                break
        del loader
        for w in workers:
            w.join(1.0)  # timeout of one second
            self.assertFalse(w.is_alive(), 'subprocess not terminated')
            self.assertEqual(w.exitcode, 0)

In [10]:
unittest.main(argv=[''], verbosity=2, exit=False)

test_partial_workers (__main__.TestDataLoader)
check that workers exit even if the iterator is not exhausted ... ok
test_sequential (__main__.TestDataLoader) ... ok
test_sequential_batch (__main__.TestDataLoader) ... ok
test_sequential_batch_workers (__main__.TestDataLoader) ... ok
test_sequential_workers (__main__.TestDataLoader) ... ok
test_shuffle (__main__.TestDataLoader) ... ok
test_shuffle_batch (__main__.TestDataLoader) ... ok
test_shuffle_batch_workers (__main__.TestDataLoader) ... ok
test_shuffle_workers (__main__.TestDataLoader) ... ok

----------------------------------------------------------------------
Ran 9 tests in 11.336s

OK


<unittest.main.TestProgram at 0x7f79500ff370>

## Sources

* "History for pytorch/torch/utils/data/dataloader.py" [(link)](https://github.com/pytorch/pytorch/commits/master?after=7dd7dde0332c6582082c9a5475d25668652db83d+139&branch=master&path%5B%5D=torch&path%5B%5D=utils&path%5B%5D=data&path%5B%5D=dataloader.py&qualified_name=refs%2Fheads%2Fmaster)
* "Add multiprocess data loader + improvements to torch.utils.data"
    * https://github.com/pytorch/pytorch/pull/44
    * https://github.com/pytorch/pytorch/commit/a1f5fe6a8f47ddb3d79c8492e248762883e80214
    * https://github.com/pytorch/pytorch/blob/a1f5fe6a8f47ddb3d79c8492e248762883e80214/torch/utils/data/dataloader.py
    * https://github.com/pytorch/pytorch/blob/a1f5fe6a8f47ddb3d79c8492e248762883e80214/test/test_utils.py
* "Make DataLoader preserve the ordering of the dataset"
    * https://github.com/pytorch/pytorch/pull/135
    * https://github.com/pytorch/pytorch/commit/6db721b5dda11638aa2eaf6aaea3af341274ef21
    * https://github.com/pytorch/pytorch/blob/6db721b5dda11638aa2eaf6aaea3af341274ef21/torch/utils/data/dataloader.py
    * https://github.com/pytorch/pytorch/blob/6db721b5dda11638aa2eaf6aaea3af341274ef21/test/test_dataloader.py
