# Project 1: Preparing an image dataset for model training

**Goal:**

- Using PyTorch, implement a [map-style Dataset class](https://pytorch.org/docs/stable/data.html#map-style-datasets).
- Implement a [data loader] (https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) that directly feeds data into model training.
- Quality check your data, served through dataset class.

**Acceptance criteria:**

- The Dataset class should implement `__getitem__`  and `__len__` methods, where `__getitem__` should return a training example, which is a tuple of (image, label).
- The data loader should be iterable and returns batches of examples.
- In a notebook, present example images, labels and ensure correct shapes.

## Step 1: Create a Dataset class

For a practical project, the dataset is usually custom rather than a standard public dataset, so you would need to write your own dataset class. As a start, let us use CIFAR10 raw data as an example to create one. First, we need to download the data:


In [5]:
!wget https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
!tar -xzf cifar-10-python.tar.gz

Inspect the downloaded raw data and folder structure:

In [12]:
!ls -lht cifar-10-batches-py/

total 363752
-rw-r--r--  1 zzsi  staff    88B Jun  4  2009 readme.html
-rw-r--r--  1 zzsi  staff   158B Mar 30  2009 batches.meta
-rw-r--r--  1 zzsi  staff    30M Mar 30  2009 data_batch_4
-rw-r--r--  1 zzsi  staff    30M Mar 30  2009 data_batch_1
-rw-r--r--  1 zzsi  staff    30M Mar 30  2009 data_batch_5
-rw-r--r--  1 zzsi  staff    30M Mar 30  2009 data_batch_2
-rw-r--r--  1 zzsi  staff    30M Mar 30  2009 data_batch_3
-rw-r--r--  1 zzsi  staff    30M Mar 30  2009 test_batch


The archive contains the files data_batch_1, data_batch_2, ..., data_batch_5, as well as test_batch. Each of these files is a Python "pickled" object.

Loaded using the `unpickle` method provided below, each of the batch files contains a dictionary with the following elements:

*data* -- a 10000x3072 numpy array of uint8s. Each row of the array stores a 32x32 colour image. The first 1024 entries contain the red channel values, the next 1024 the green, and the final 1024 the blue. The image is stored in row-major order, so that the first 32 entries of the array are the red channel values of the first row of the image.

*labels* -- a list of 10000 numbers in the range 0-9. 

**Your task**: Fill in the missing code in the next cell, so that it can pass the test that follows.

In [20]:
# TODO: Please add your code here.

import pickle


def unpickle(file):
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict


class CustomDataset:
    def __getitem__(self, index):
        ...
    
    def __len__(self):
        ...

In [21]:
# This is the test code for the class CustomDataset.

def test_custom_dataset():
    ds = CustomDataset()
    print(f"Dataset has {len(ds)} examples")
    first_example = ds[0]
    print(f"The first example is a {type(first_example)} with {len(first_example)} elements")
    image, label = first_example
    print(f"The first example has an image of shape {image.shape} and label {label}")

test_custom_dataset()

If it works, congratulations! You have completed the first step of the project.

## Step 2: Create a data loader

A data loader is a class that can be used to load batches of data. It is a simple wrapper around a dataset that provides a way to iterate over the dataset, returning a batch of data at each iteration.

For PyTorch, we can use the [torch.utils.data.DataLoader](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) class to load the data.

**Your task**: Fill in the missing code in the next cell, so that it can pass the test that follows.

In [None]:
# TODO: fill in the code to create a data loader with a batch size of 8.

batch_size = 8
data_loader = ...

In [None]:
def test_data_loader():
    for batch in data_loader:
        images, labels = batch
        assert images.shape[0] == batch_size, f"Expecting a batch of {batch_size} images"
        assert len(labels) == batch_size, f"Expecting a batch of {batch_size} labels"
        break


test_data_loader()

## Step 3: Visualize the data

In the next cells, display several images from the dataset together with the label. Visually check the image and the label. A helper function is provided.

In [None]:
from matplotlib import pyplot as plt


def display_image_with_label(im, label):
    plt.imshow(im)
    plt.title(f"Label: {label}")
    plt.show()

# TODO: add code to display at least 3 images from CustomDataset.
# You can use one cell for each image.

## Step 4 (optional): Writing unit tests for your dataset class

To make sure your dataset is feeding the right data to the model, visualization is helpful, but unit tests can make certain checks more efficient.

You may have noticed, in previous steps unit tests such as `test_custom_dataset` are already used to check the correctness of your implementation.

The following types of unit tests can be written against a dataset class:

- Tests that check the functionality of the class (e.g. calling `len(dataset)` does not raise an error).
- Tests that check the type and shape of the data returned by the class.
- Tests that check the content of the data returned by the class.

Below is an example unit test that checks if the shape of the image is "channels-first" (3, height, width) rather than "channels-last" (height, width, 3). PyTorch expects the image to be in "channels-first" format, while Tensorflow expects it to be "channels-last".

**Your task**: Fill in the missing code in unit test. Make necesssary changes to your dataset class to pass the test. Can you think of other tests that can answer questions about the dataset? For example, can you write a test about the maximum pixel intensity of this dataset?



In [None]:
def test_custom_dataset_returns_images_with_channels_first():
    """Check the shape of the image, especially the channels axis.
    """
    # TODO: fill in.
    pass


test_custom_dataset_returns_images_with_channels_first()

Congratulations! You have completed the first project.