# Datasets & DataLoaders

* We want our dataset code to be decoupled from our model training code for better readability and modularity.
* PyTorch allow using preloaded datasets as well as user's own dataset.
* Data primitives:
  * `torch.utils.data.DataLoader`: wraps an iterable around the `Dataset` to enable easy access to the samples.  
  * `torch.utils.data.Dataset`: stores the samples and their corresponding labels.
* There are a number of preloaded datasets (e.g. `FashionMNIST`) that subclass `torch.utils.data.Dataset` and implements functions specific to the particular data. More details: [Image Datasets](https://pytorch.org/vision/stable/datasets.html), [Audio Datasets](https://pytorch.org/audio/stable/datasets.html), [Text Datasets](https://pytorch.org/text/stable/datasets.html).

## Loading a Dataset

## Creating a Custom Dataset

A custom Dataset class must implement three functions: `__init__`, `__len__`, and `__getitem__`.

### `__init__`

The `__init__` function is run once when instantiating the `Dataset` object.
Example usage: Initialize the directory containing the images, the annotations file, and both transforms: `transform` and `target_transform`.

```python
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
    self.img_labels = pd.read_csv(annotations_file)
    self.img_dir = img_dir
    self.transform = transform
    self.target_transform = target_transform
```

### `__len__`

The `__len__` function returns the number of samples in the dataset.

```python
def __len__(self):
    return len(self.img_labels)
```

### `__getitem__`

The `__getitem__` function loads and returns a sample from the dataset at the given index `idx`.

```python
def __getitem__(self, idx):
    img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
    image = read_image(img_path)
    label = self.img_labels.iloc[idx, 1]
    if self.transform:
        image = self.transform(image)
    if self.target_transform:
        label = self.target_transform(label)
    return image, label
```

## Preparing data for training with `DataLoader`

While training a model, we typically want to pass samples in “minibatches”, reshuffle the data at every epoch to reduce model overfitting, and use Python’s multiprocessing to speed up data retrieval.

`DataLoader` is an iterable that abstracts this complexity for us in an easy API.

```python
from torch.utils.data import DataLoader

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
```

## Iterate through the `DataLoader`

Each iteration below returns a batch of `train_features` and `train_labels` (containing `batch_size=64` features and labels respectively). Because we specified `shuffle=True`, after we iterate over all batches the data is shuffled (for finer-grained control over the data loading order, take a look at Samplers).

```python
# Display image and label.
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")
```