# 02 - Datasets

For any AI task, we need a dataset, that maps inputs to labeled outputs. With this information, we can train a model on this dataset to predict results on new, unseen data.

We will use PyTorch to work on this dataset.

(If you want a comprehensive, executable reference for Pytorch, you can go to the [PyTorch Tutorial](A0-PyTorch%20Tutorial.ipynb) notebook.)

### Understanding the Fashion MNIST Dataset

The Fashion MNIST dataset is a collection of grayscale images of 10 fashion categories, each of size 28x28 pixels. It's used as a drop-in replacement for the classic MNIST dataset. It serves as a more challenging classification problem than the regular MNIST digit dataset due to the similarities in clothing items.

![](../media/datasets/FashionMNIST.png)

Each image in the dataset corresponds to a label from 0-9, representing the ten categories:
| Label | Description |
| --- | --- |
| 0 | T-shirt/top |
| 1 | Trouser |
| 2 | Pullover |
| 3 | Dress |
| 4 | Coat |
| 5 | Sandal |
| 6 | Shirt |
| 7 | Sneaker |
| 8 | Bag |
| 9 | Ankle boot |

In this tutorial, we are primarily using `torchvision` to access the Fashion MNIST dataset and apply transformations to the images.

### Loading a Dataset

We need to import some libraries. `torch` for PyTorch, and `matplotlib` for plotting figures. These were installed for you with Docker.

In [None]:
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt

Download the Fashion dataset, if it hasn't been downloaded already. Separate  into training and test data. We don't use a dev set, because we won't be modifying hyperparameters.

In [None]:
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

In [None]:
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

### Iterating and Visualizing the Dataset

The dataset is a 10-class classifier. Let's plot some of its images.

In [None]:
labels_map = {
    0: "T-Shirt",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(training_data), size=(1,)).item()
    img, label = training_data[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.title(labels_map[label])
    plt.axis("off")
    plt.imshow(img.squeeze(), cmap="gray")
plt.show()

### Preparing your data for training with DataLoaders

To simplify loading data from the dataset and into the processing stage, we use dataloaders. We can specify batch sizes here. Small batch sizes help fitting portions of the dataset into limited hardware (for example, a graphics card with limited VRAM). The selected batch size is 64. This is a hyperparameter.

In [None]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

In [None]:
# Iterate through the DataLoader

# Display image and label.
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}: {labels_map[int(label)]}")

### Iterate

Remember we told you you can execute a cell multiple times? Well, you can execute the cell above to go display each entry of the training dataset batch if you want.

### Shapes

Most AI models expect a specific input shape. In this case, `[64, 1, 28, 28]` means that we have 64 images in this batch, each image has one channel (B/W, color images usually have 3), and the size of the image is 28 pixels high by 28 pixels wide.

**Next Notebook: [03-Neural Networks](03-Neural%20Networks.ipynb)**