In [1]:
from torch.utils.data import DataLoader, Dataset, IterableDataset
from torchvision import transforms

import torch

import numpy as np

from time import sleep, time

# Dataset
A PyTorch `DataLoader` expects a `Dataset`. The default `Dataset` is essentially a `map`, which allows random access. There exists also `IterableDataset` which can be used with generators.

For more information checkout the PyTorch-Documentation](https://pytorch.org/docs/stable/data.html#dataset-types).

In [2]:
class MyDataset(Dataset):
    def __init__(self):
        super(MyDataset).__init__()
        self.data = np.arange(44)
        self.data.shape = (self.data.shape[0], 1, 1, 1)
        self.transform = transforms.Compose([
            transforms.ToTensor()
        ])
        
    def __getitem__(self, index):
        # small data-prep. simulation delay
        sleep(0.5)
        return self.transform(self.data[index]), self.transform(-self.data[index])
    
    def __len__(self):
        return len(self.data)

In [3]:
def _gen(data):
    for d in data:
        # small data-prep. simulation delay
        sleep(0.5)
        yield d, -d

class MyIterableDataset(IterableDataset):
    
    def __init__(self):
        self.data = None 
            
        
    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is None:
            data = np.arange(44)
            data.shape = (data.shape[0], 1, 1, 1)
            data = torch.from_numpy(data)
            return _gen(data)
        else:
            start = worker_info.id
            step = worker_info.num_workers
            data = np.arange(start, 44, step)
            data.shape = (data.shape[0], 1, 1, 1)
            data = torch.from_numpy(data)
            return _gen(data)

## DataLoader
We can now use `DataLoader` to get batches from our dataset. By default the `DataLoader` runs on the main-thread and returns all data-points in batches. The last batch is smaller, if the dataset is not divisible by `batch_size`.

### Map Dataset

In [4]:
ds = MyDataset()

In [5]:
dl = DataLoader(ds, batch_size=8)

In [6]:
start_time = time()
for x, y in dl:
    # Processing time of one batch
    sleep(0.2)
    print(x.shape, y.shape)
    print('-----------------------------------------------')
    print('x: {}'.format(x[:,0,0,0]))
    print('y: {}'.format(y[:,0,0,0]))
    print('-----------------------------------------------')

end_time= time()
print(end_time - start_time)

torch.Size([8, 1, 1, 1]) torch.Size([8, 1, 1, 1])
-----------------------------------------------
x: tensor([0, 1, 2, 3, 4, 5, 6, 7])
y: tensor([ 0, -1, -2, -3, -4, -5, -6, -7])
-----------------------------------------------
torch.Size([8, 1, 1, 1]) torch.Size([8, 1, 1, 1])
-----------------------------------------------
x: tensor([ 8,  9, 10, 11, 12, 13, 14, 15])
y: tensor([ -8,  -9, -10, -11, -12, -13, -14, -15])
-----------------------------------------------
torch.Size([8, 1, 1, 1]) torch.Size([8, 1, 1, 1])
-----------------------------------------------
x: tensor([16, 17, 18, 19, 20, 21, 22, 23])
y: tensor([-16, -17, -18, -19, -20, -21, -22, -23])
-----------------------------------------------
torch.Size([8, 1, 1, 1]) torch.Size([8, 1, 1, 1])
-----------------------------------------------
x: tensor([24, 25, 26, 27, 28, 29, 30, 31])
y: tensor([-24, -25, -26, -27, -28, -29, -30, -31])
-----------------------------------------------
torch.Size([8, 1, 1, 1]) torch.Size([8, 1, 1, 1]

### Iterable Dataset

In [7]:
ds = MyIterableDataset()

In [8]:
dl = DataLoader(ds, batch_size=8)

In [9]:
start_time = time()
for x, y in dl:
    # Processing time of one batch
    sleep(0.2)
    print(x.shape, y.shape)
    print('-----------------------------------------------')
    print('x: {}'.format(x[:,0,0,0]))
    print('y: {}'.format(y[:,0,0,0]))
    print('-----------------------------------------------')

end_time= time()
print(end_time - start_time)

torch.Size([8, 1, 1, 1]) torch.Size([8, 1, 1, 1])
-----------------------------------------------
x: tensor([0, 1, 2, 3, 4, 5, 6, 7])
y: tensor([ 0, -1, -2, -3, -4, -5, -6, -7])
-----------------------------------------------
torch.Size([8, 1, 1, 1]) torch.Size([8, 1, 1, 1])
-----------------------------------------------
x: tensor([ 8,  9, 10, 11, 12, 13, 14, 15])
y: tensor([ -8,  -9, -10, -11, -12, -13, -14, -15])
-----------------------------------------------
torch.Size([8, 1, 1, 1]) torch.Size([8, 1, 1, 1])
-----------------------------------------------
x: tensor([16, 17, 18, 19, 20, 21, 22, 23])
y: tensor([-16, -17, -18, -19, -20, -21, -22, -23])
-----------------------------------------------
torch.Size([8, 1, 1, 1]) torch.Size([8, 1, 1, 1])
-----------------------------------------------
x: tensor([24, 25, 26, 27, 28, 29, 30, 31])
y: tensor([-24, -25, -26, -27, -28, -29, -30, -31])
-----------------------------------------------
torch.Size([8, 1, 1, 1]) torch.Size([8, 1, 1, 1]

## Parallel Loading
We can speed-up dataloading by using `num_workers`. If we set `num_workers=1` dataloading is spawned in a new process i.e. the main-thread is idle, but no speed-up can be observed. With `num_workers > 1` data is loaded in parallel. More precicesly each `num_worker` `DataLoader`s are created and create batches i.e. each `DataLoader` creates a full batch and then waits until the batch is used.

For more information consult the [documentation](https://pytorch.org/docs/stable/data.html#multi-process-data-loading).

### Map Dataset

In [10]:
ds = MyDataset()

In [11]:
dl = DataLoader(ds, batch_size=8, num_workers=1)

In [12]:
start_time = time()
for x, y in dl:
    # Processing time of one batch
    sleep(0.2)
    print(x.shape, y.shape)
    print('-----------------------------------------------')
    print('x: {}'.format(x[:,0,0,0]))
    print('y: {}'.format(y[:,0,0,0]))
    print('-----------------------------------------------')
    
end_time= time()
print(end_time - start_time)

torch.Size([8, 1, 1, 1]) torch.Size([8, 1, 1, 1])
-----------------------------------------------
x: tensor([0, 1, 2, 3, 4, 5, 6, 7])
y: tensor([ 0, -1, -2, -3, -4, -5, -6, -7])
-----------------------------------------------
torch.Size([8, 1, 1, 1]) torch.Size([8, 1, 1, 1])
-----------------------------------------------
x: tensor([ 8,  9, 10, 11, 12, 13, 14, 15])
y: tensor([ -8,  -9, -10, -11, -12, -13, -14, -15])
-----------------------------------------------
torch.Size([8, 1, 1, 1]) torch.Size([8, 1, 1, 1])
-----------------------------------------------
x: tensor([16, 17, 18, 19, 20, 21, 22, 23])
y: tensor([-16, -17, -18, -19, -20, -21, -22, -23])
-----------------------------------------------
torch.Size([8, 1, 1, 1]) torch.Size([8, 1, 1, 1])
-----------------------------------------------
x: tensor([24, 25, 26, 27, 28, 29, 30, 31])
y: tensor([-24, -25, -26, -27, -28, -29, -30, -31])
-----------------------------------------------
torch.Size([8, 1, 1, 1]) torch.Size([8, 1, 1, 1]

In [13]:
dl = DataLoader(ds, batch_size=8, num_workers=4)

In [14]:
start_time = time()
for x, y in dl:
    # Processing time of one batch
    sleep(0.2)
    print(x.shape, y.shape)
    print('-----------------------------------------------')
    print('x: {}'.format(x[:,0,0,0]))
    print('y: {}'.format(y[:,0,0,0]))
    print('-----------------------------------------------')
    
end_time= time()
print(end_time - start_time)

torch.Size([8, 1, 1, 1]) torch.Size([8, 1, 1, 1])
-----------------------------------------------
x: tensor([0, 1, 2, 3, 4, 5, 6, 7])
y: tensor([ 0, -1, -2, -3, -4, -5, -6, -7])
-----------------------------------------------
torch.Size([8, 1, 1, 1]) torch.Size([8, 1, 1, 1])
-----------------------------------------------
x: tensor([ 8,  9, 10, 11, 12, 13, 14, 15])
y: tensor([ -8,  -9, -10, -11, -12, -13, -14, -15])
-----------------------------------------------
torch.Size([8, 1, 1, 1]) torch.Size([8, 1, 1, 1])
-----------------------------------------------
x: tensor([16, 17, 18, 19, 20, 21, 22, 23])
y: tensor([-16, -17, -18, -19, -20, -21, -22, -23])
-----------------------------------------------
torch.Size([8, 1, 1, 1]) torch.Size([8, 1, 1, 1])
-----------------------------------------------
x: tensor([24, 25, 26, 27, 28, 29, 30, 31])
y: tensor([-24, -25, -26, -27, -28, -29, -30, -31])
-----------------------------------------------
torch.Size([8, 1, 1, 1]) torch.Size([8, 1, 1, 1]

### Iterable Dataset

In [15]:
ds = MyIterableDataset()

In [16]:
dl = DataLoader(ds, batch_size=8, num_workers=1)

In [17]:
start_time = time()
for x, y in dl:
    # Processing time of one batch
    sleep(0.2)
    print(x.shape, y.shape)
    print('-----------------------------------------------')
    print('x: {}'.format(x[:,0,0,0]))
    print('y: {}'.format(y[:,0,0,0]))
    print('-----------------------------------------------')
    
end_time= time()
print(end_time - start_time)

torch.Size([8, 1, 1, 1]) torch.Size([8, 1, 1, 1])
-----------------------------------------------
x: tensor([0, 1, 2, 3, 4, 5, 6, 7])
y: tensor([ 0, -1, -2, -3, -4, -5, -6, -7])
-----------------------------------------------
torch.Size([8, 1, 1, 1]) torch.Size([8, 1, 1, 1])
-----------------------------------------------
x: tensor([ 8,  9, 10, 11, 12, 13, 14, 15])
y: tensor([ -8,  -9, -10, -11, -12, -13, -14, -15])
-----------------------------------------------
torch.Size([8, 1, 1, 1]) torch.Size([8, 1, 1, 1])
-----------------------------------------------
x: tensor([16, 17, 18, 19, 20, 21, 22, 23])
y: tensor([-16, -17, -18, -19, -20, -21, -22, -23])
-----------------------------------------------
torch.Size([8, 1, 1, 1]) torch.Size([8, 1, 1, 1])
-----------------------------------------------
x: tensor([24, 25, 26, 27, 28, 29, 30, 31])
y: tensor([-24, -25, -26, -27, -28, -29, -30, -31])
-----------------------------------------------
torch.Size([8, 1, 1, 1]) torch.Size([8, 1, 1, 1]

In [18]:
dl = DataLoader(ds, batch_size=8, num_workers=4)

In [19]:
start_time = time()
for x, y in dl:
    # Processing time of one batch
    sleep(0.2)
    print(x.shape, y.shape)
    print('-----------------------------------------------')
    print('x: {}'.format(x[:,0,0,0]))
    print('y: {}'.format(y[:,0,0,0]))
    print('-----------------------------------------------')
    
end_time= time()
print(end_time - start_time)

torch.Size([8, 1, 1, 1]) torch.Size([8, 1, 1, 1])
-----------------------------------------------
x: tensor([ 0,  4,  8, 12, 16, 20, 24, 28])
y: tensor([  0,  -4,  -8, -12, -16, -20, -24, -28])
-----------------------------------------------
torch.Size([8, 1, 1, 1]) torch.Size([8, 1, 1, 1])
-----------------------------------------------
x: tensor([ 1,  5,  9, 13, 17, 21, 25, 29])
y: tensor([ -1,  -5,  -9, -13, -17, -21, -25, -29])
-----------------------------------------------
torch.Size([8, 1, 1, 1]) torch.Size([8, 1, 1, 1])
-----------------------------------------------
x: tensor([ 2,  6, 10, 14, 18, 22, 26, 30])
y: tensor([ -2,  -6, -10, -14, -18, -22, -26, -30])
-----------------------------------------------
torch.Size([8, 1, 1, 1]) torch.Size([8, 1, 1, 1])
-----------------------------------------------
x: tensor([ 3,  7, 11, 15, 19, 23, 27, 31])
y: tensor([ -3,  -7, -11, -15, -19, -23, -27, -31])
-----------------------------------------------
torch.Size([3, 1, 1, 1]) torch.S

## Shuffle
Usually it is a good idea to shuffle the data after each epoch. `DataLoader` takes care of this as well!

### Map Dataset

In [20]:
ds = MyDataset()

In [21]:
dl = DataLoader(ds, batch_size=8, num_workers=4, shuffle=True)

In [22]:
start_time = time()
for x, y in dl:
    # Processing time of one batch
    sleep(0.2)
    print(x.shape, y.shape)
    print('-----------------------------------------------')
    print('x: {}'.format(x[:,0,0,0]))
    print('y: {}'.format(y[:,0,0,0]))
    print('-----------------------------------------------')
    
end_time= time()
print(end_time - start_time)

torch.Size([8, 1, 1, 1]) torch.Size([8, 1, 1, 1])
-----------------------------------------------
x: tensor([37, 43, 15, 12, 42, 41, 16, 17])
y: tensor([-37, -43, -15, -12, -42, -41, -16, -17])
-----------------------------------------------
torch.Size([8, 1, 1, 1]) torch.Size([8, 1, 1, 1])
-----------------------------------------------
x: tensor([ 8, 25,  5,  2,  6, 36, 21,  3])
y: tensor([ -8, -25,  -5,  -2,  -6, -36, -21,  -3])
-----------------------------------------------
torch.Size([8, 1, 1, 1]) torch.Size([8, 1, 1, 1])
-----------------------------------------------
x: tensor([31, 14, 33, 13, 26, 35, 18, 22])
y: tensor([-31, -14, -33, -13, -26, -35, -18, -22])
-----------------------------------------------
torch.Size([8, 1, 1, 1]) torch.Size([8, 1, 1, 1])
-----------------------------------------------
x: tensor([24,  1, 40,  7, 38, 30, 39, 20])
y: tensor([-24,  -1, -40,  -7, -38, -30, -39, -20])
-----------------------------------------------
torch.Size([8, 1, 1, 1]) torch.S

### Iterable Dataset
Shuffle does not work on an iterator. You would have to take care of it inside the iterator if it is possible.

## CUDA
If you want to use the GPU, just move each batch to GPU with `.cuda()`.

__Note:__ Sometimes it is faster to use [pinned-memory](https://pytorch.org/docs/stable/data.html#memory-pinning).

### Map Dataset

In [23]:
ds = MyDataset()

In [24]:
dl = DataLoader(ds, batch_size=8, num_workers=4, shuffle=True, pin_memory=True)

In [25]:
start_time = time()
for x, y in dl:
    # Processing time of one batch
    sleep(0.2)
    x = x.cuda()
    y = y.cuda()
    print(x.shape, y.shape)
    print('-----------------------------------------------')
    print('x: {}'.format(x[:,0,0,0]))
    print('y: {}'.format(y[:,0,0,0]))
    print('-----------------------------------------------')
    
end_time= time()
print(end_time - start_time)

torch.Size([8, 1, 1, 1]) torch.Size([8, 1, 1, 1])
-----------------------------------------------
x: tensor([23,  4, 34,  1, 43, 14, 13, 31], device='cuda:0')
y: tensor([-23,  -4, -34,  -1, -43, -14, -13, -31], device='cuda:0')
-----------------------------------------------
torch.Size([8, 1, 1, 1]) torch.Size([8, 1, 1, 1])
-----------------------------------------------
x: tensor([ 0,  8, 29, 41, 22, 40,  9,  2], device='cuda:0')
y: tensor([  0,  -8, -29, -41, -22, -40,  -9,  -2], device='cuda:0')
-----------------------------------------------
torch.Size([8, 1, 1, 1]) torch.Size([8, 1, 1, 1])
-----------------------------------------------
x: tensor([19, 39, 12, 42,  5, 18, 16, 36], device='cuda:0')
y: tensor([-19, -39, -12, -42,  -5, -18, -16, -36], device='cuda:0')
-----------------------------------------------
torch.Size([8, 1, 1, 1]) torch.Size([8, 1, 1, 1])
-----------------------------------------------
x: tensor([25, 17, 10, 26, 11, 20, 21, 15], device='cuda:0')
y: tensor([-2

## Augmentation

We can apply transformations to the data on the fly.
For images also look into [torchvision.transforms](https://pytorch.org/vision/stable/transforms.html) which provides many augmentation methods.

### Map Dataset

In [26]:
def gaussian_noise(data):
    return data + torch.randn(*data.shape)

class MyDataset(Dataset):
    def __init__(self):
        super(MyDataset).__init__()
        self.data = np.arange(44)
        self.data.shape = (self.data.shape[0], 1, 1, 1)
        self.transform_x = transforms.Compose([
            transforms.ToTensor(),
            gaussian_noise
        ])
        self.transform_y = transforms.Compose([
            transforms.ToTensor(),
        ])
        
    def __getitem__(self, index):
        # small data-prep. simulation delay
        sleep(0.5)
        return self.transform_x(self.data[index]), self.transform_y(-self.data[index])
    
    def __len__(self):
        return len(self.data)

In [27]:
ds = MyDataset()

In [28]:
dl = DataLoader(ds, batch_size=8, num_workers=4, shuffle=True, pin_memory=True, )

In [29]:
start_time = time()
for x, y in dl:
    # Processing time of one batch
    sleep(0.2)
    x = x.cuda()
    y = y.cuda()
    print(x.shape, y.shape)
    print('-----------------------------------------------')
    print('x: {}'.format(x[:,0,0,0]))
    print('y: {}'.format(y[:,0,0,0]))
    print('-----------------------------------------------')
    
end_time= time()
print(end_time - start_time)

torch.Size([8, 1, 1, 1]) torch.Size([8, 1, 1, 1])
-----------------------------------------------
x: tensor([38.9348,  3.1185, 12.3086, 26.7476, 34.9752, 33.4604, 21.5414, 24.7577],
       device='cuda:0')
y: tensor([-38,  -2, -13, -27, -35, -32, -21, -23], device='cuda:0')
-----------------------------------------------
torch.Size([8, 1, 1, 1]) torch.Size([8, 1, 1, 1])
-----------------------------------------------
x: tensor([20.1446, 36.4844, 43.4108, 29.8149, 28.4759,  6.6531, 39.4273, 11.8008],
       device='cuda:0')
y: tensor([-20, -34, -42, -31, -30,  -7, -41, -12], device='cuda:0')
-----------------------------------------------
torch.Size([8, 1, 1, 1]) torch.Size([8, 1, 1, 1])
-----------------------------------------------
x: tensor([ 4.8565, -0.7846, 19.9808,  0.2828, 22.1500, 26.4676, 20.8791, 29.4759],
       device='cuda:0')
y: tensor([ -4,  -1, -19,   0, -24, -25, -22, -28], device='cuda:0')
-----------------------------------------------
torch.Size([8, 1, 1, 1]) torch.