In [1]:
import torch, torchvision
import matplotlib.pyplot as plt

In [2]:
# download mnist dataset
mnist_train = torchvision.datasets.MNIST(root='./data', train=True, download=True)
mnist_test = torchvision.datasets.MNIST(root='./data', train=False, download=True)
x_train, y_train = mnist_train.data, mnist_train.targets
x_test, y_test = mnist_test.data, mnist_test.targets

In [3]:
print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)

torch.Size([60000, 28, 28]) torch.Size([60000])
torch.Size([10000, 28, 28]) torch.Size([10000])


### Basic dataset object

- From https://pytorch.org/tutorials/beginner/data_loading_tutorial.html#dataset-class, to define a custom dataset object in PyTorch, one simply inherits from `torch.utils.data.Dataset` and overwrites the `__len__` and `__getitem__` methods.
- The code below provides a skeleton framework of what a basic dataset object should look like.
- **Note:** the `__getitem__` is analogous to a "getter" method that allows you to fetch items via `dataset[5]` instead of `dataset.getitem(5)`

In [4]:
class Dataset():
    def __init__(self, x, y): self.x,self.y = x,y
    def __len__(self): return len(self.x)
    def __getitem__(self, i): return self.x[i], self.y[i]

In [5]:
ds = Dataset(x_train, y_train)
type(ds[0])

tuple

In [6]:
print(ds[:5][0].shape, ds[:5][1].shape)

torch.Size([5, 28, 28]) torch.Size([5])


### Basic dataloader object

- with the dataset object defined previously, we can iterate through our data samples via a `for` loop structure 
- but this is not very clean code-wise (i.e. batching samples): 
    - it does not provide functionalities like **shuffling of data**; and 
    - it does not load data using multiprocessing
- the code below shows the skeleton of a very basic dataflow class

In [7]:
class BasicDataLoader():
    def __init__(self, ds, bs): self.ds,self.bs = ds,bs
    def __iter__(self):
        for i in range(0, len(self.ds), self.bs): yield self.ds[i:i+self.bs]

In [9]:
bs = 64
dl = BasicDataLoader(ds, bs)
xb, yb = next(iter(dl))  # wrap in next(iter(dataloader)) to grab one batch
print(xb.shape, yb.shape)

torch.Size([64, 28, 28]) torch.Size([64])


### Random sample

- below, we add code for a basic random sampler

In [10]:
class Sampler():
    def __init__(self, ds, bs, shuffle=False):
        self.n, self.bs, self.shuffle = len(ds), bs, shuffle
        
    def __iter__(self):
        # return random permutation of range of numbers up to n if shuffle is True
        self.idxs = torch.randperm(self.n) if self.shuffle else torch.arange(self.n)
        # generate a batch of indices from the shuffled indices list
        for i in range(0, self.n, self.bs): yield self.idxs[i:i+self.bs]

In [11]:
small_ds = Dataset(*ds[:10])
s = Sampler(small_ds, 3, True)

In [12]:
[o for o in s]

[tensor([1, 7, 6]), tensor([9, 3, 8]), tensor([2, 5, 0]), tensor([4])]

- so all the sampler really does is to return the batches of shuffled indices
- we also need a way to collect the samples based on these shuffled indices via the collate function

In [13]:
def collate(b):
    xs, ys = zip(*b)
    return torch.stack(xs),torch.stack(ys)

- lets modify our basic DataLoader to account for the shuffle & collate functions

In [14]:
class BasicDataLoader():
    def __init__(self, ds, sampler, collate_fn=collate):
        self.ds, self.sampler, self.collate_fn = ds, sampler, collate_fn
        
    def __iter__(self):
        for s in self.sampler: yield self.collate_fn([self.ds[i] for i in s])

In [15]:
s = Sampler(ds, 64, True)
dl = BasicDataLoader(ds, sampler=s, collate_fn=collate)
x, y = next(iter(dl))
print(x.shape, y.shape)

torch.Size([64, 28, 28]) torch.Size([64])


### PyTorch DataLoader

- the PyTorch dataloader is superior to the previous basic dataloader since it provides various sampler methodologies, as well it provides multi-processing capabilities to load the data

In [28]:
import os
from torch.utils.data import DataLoader, SequentialSampler, RandomSampler

In [29]:
bs = 64
s = RandomSampler(ds)
dl = DataLoader(ds, bs, sampler=s, num_workers=os.cpu_count())
db = next(iter(dl))
print(db[0].shape, db[1].shape)

torch.Size([64, 28, 28]) torch.Size([64])


### Validation dataset

- almost all the time, you also want to have a validation dataset to evaluate your model to ensure it is not overfitting (during the training loop process, which will be covered later)
- note that the batch_size in the validattiion dataloader is set to twice the training batch size - this is because the training loop does not require background gradient computation, so it only uses half the memory, therefore we can be more aggressive with the validatiion dataloader batch size

In [30]:
def get_dls(train_ds, valid_ds, bs, **kwargs):
    return (DataLoader(train_ds, batch_size=bs, shuffle=True, **kwargs),
            DataLoader(valid_ds, batch_size=bs*2, **kwargs))

In [None]:
train_ds = Dataset(x_train, y_train)
valid_ds = Dataset(x_valid, y_valid)
train_dl,valid_dl = get_dls(train_ds, valid_ds, bs)