# 4.2 The Image Classification Dataset

## 1. Does reducing the `batch_size` (for instance, to 1) affect the reading performance?

In [2]:
#----- Import Modules -----#
import time
import torch
import torchvision
from torchvision import transforms
from d2l import torch as d2l
#----- Define FashionMNIST DataModule -----#
class FashionMNIST(d2l.DataModule):
    def __init__(self, batch_size=64, resize=(28,28)):
        super().__init__()
        self.save_hyperparameters()
        # Define the transform applied to each image
        trans = transforms.Compose([
            transforms.Resize(resize),
            transforms.ToTensor()
        ])
        # Get the FashionMNIST dataset.
        # default root is "../data"
        self.train = torchvision.datasets.FashionMNIST(
            root=self.root, train=True, transform=trans, download=True)
        self.val = torchvision.datasets.FashionMNIST(
            root=self.root, train=False, transform=trans, download=True)
    
    def text_labels(self, indices):
        # Define the text labels for each class
        text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                       'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
        return [text_labels[int(i)] for i in indices]
    
    def get_dataloader(self, train):
        data = self.train if train else self.val
        # default num_workers is 0
        return torch.utils.data.DataLoader(
            data, batch_size=self.batch_size, shuffle=train, num_workers=self.num_workers)
    
    def visualize(self, batch, nrows=1, ncols=8, labels=[]):
        X, y = batch
        if not labels:
            labels = self.text_labels(y)
        d2l.show_images(X.squeeze(1), nrows, ncols, titles=labels)  # squeeze the channel dimension
#----- Main -----#
data = FashionMNIST(batch_size=1, resize=(32, 32))
start_time = time.time()
for X, y in data.train_dataloader():
    continue
print(f"Time(BatchSize=1): {time.time() - start_time:.2f} sec")
data = FashionMNIST(batch_size=64, resize=(32, 32))
start_time = time.time()
for X, y in data.train_dataloader():
    continue
print(f"Time(BatchSize=64): {time.time() - start_time:.2f} sec")

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ../data\FashionMNIST\raw\train-images-idx3-ubyte.gz


100%|██████████| 26421880/26421880 [00:25<00:00, 1044437.49it/s]


Extracting ../data\FashionMNIST\raw\train-images-idx3-ubyte.gz to ../data\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ../data\FashionMNIST\raw\train-labels-idx1-ubyte.gz


100%|██████████| 29515/29515 [00:00<00:00, 233929.68it/s]


Extracting ../data\FashionMNIST\raw\train-labels-idx1-ubyte.gz to ../data\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ../data\FashionMNIST\raw\t10k-images-idx3-ubyte.gz


100%|██████████| 4422102/4422102 [00:06<00:00, 720119.74it/s]


Extracting ../data\FashionMNIST\raw\t10k-images-idx3-ubyte.gz to ../data\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ../data\FashionMNIST\raw\t10k-labels-idx1-ubyte.gz


100%|██████████| 5148/5148 [00:00<?, ?it/s]


Extracting ../data\FashionMNIST\raw\t10k-labels-idx1-ubyte.gz to ../data\FashionMNIST\raw

Time(BatchSize=1): 16.98 sec
Time(BatchSize=64): 5.53 sec


As we can see, a smaller batch_size increase the dataset reading time.

## 2. The data iterator performance is important. Do you think the current implementation is fast enough? Explore various options to improve it. Use a system profiler to find out where the bottlenecks are.

## 3. Check out the framework's online API documentation. Which other datasets are available?