# 3.5. The Image Classification Dataset

In [1]:
%matplotlib inline
import torch
import torchvision
from torch.utils import data
from torchvision import transforms

## 3.5.1. Reading the Dataset

In [5]:
# `ToTensor` converts the image data from PIL type to 32-bit floating point tensors. 
# It divides all numbers by 255 so that all pixel values are between 0 and 1.
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)

In [6]:
len(mnist_train), len(mnist_test)

(60000, 10000)

In [16]:
mnist_train[0][0].shape

torch.Size([1, 28, 28])

In [17]:
def get_fashion_mnist_labels(labels):  #@save
    """Return text labels for the Fashion-MNIST dataset."""
    text_labels = [ 't-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 
                   'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]

## 3.5.2. Reading a Minibatch

In [18]:
batch_size = 256

def get_dataloader_workers():  #@save
    """Use 4 processes to read the data."""
    return 4

train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True, num_workers=get_dataloader_workers())

## 3.5.3. Putting All Things Together

In [19]:
def load_data_fashion_mnist(batch_size, resize=None):  #@save
    """Download the Fashion-MNIST dataset and then load it into memory."""
    trans = [transforms.ToTensor()]
    if resize:
        trans.insert(0, transforms.Resize(resize))
    trans = transforms.Compose(trans)
    mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)
    mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)
    return (data.DataLoader(mnist_train, batch_size, shuffle=True, num_workers=get_dataloader_workers()),
            data.DataLoader(mnist_test, batch_size, shuffle=False, num_workers=get_dataloader_workers()))

In [20]:
train_iter, test_iter = load_data_fashion_mnist(32, resize=64)
for X, y in train_iter:
    print(X.shape, X.dtype, y.shape, y.dtype)
    break

torch.Size([32, 1, 64, 64]) torch.float32 torch.Size([32]) torch.int64
