# 1. Inspecting your data

If you want to solve a problem with machine learning, you'll need three things:

1. A good dataset;
2. A good model;
3. A good optimization algorithm.

If any of these three is not good enough, your trained model won't be good either. Many novice machine learning practitioners mainly focus on the second requirement, i.e., the model. While a good model certainly is a vital aspect of a succesful machine learning pipeline, if your data sucks, it doesn't really matter how advanced your model is. **The model can only be as good as the data.**

Therefore, your first step in a machine learning project should always be to **inspect your data**. And that's exactly what we'll do now!

## 1.1 Inspecting the filetree

Before you can start inspecting the data itself, you need to know how and where your data is stored.

For this example, download the [Gen 1 Pokemon Dataset from Kaggle](https://www.kaggle.com/datasets/echometerhhwl/pokemon-gen-1-38914), extract it and move it to the parent directory of this notebook. Rename the folder from `archive` to `PokemonGen1`.

We'll now inspect what's inside the `PokemonGen1` directory. For this, you can use the [`pathlib`](https://docs.python.org/3/library/pathlib.html) module from the Python standard library. A [`pathlib.Path`](https://docs.python.org/3/library/pathlib.html#pathlib.Path) object is an abstract representation of a path (e.g., to a file or directory) in your operating system.

Let's create a `Path` object for the `PokemonGen1` directory.

In [None]:
# ... WRITE YOUR CODE HERE ... #

To inspect the files in the directory, we can employ the [`glob()` method of the `Path` class](https://docs.python.org/3/library/pathlib.html#pathlib.Path.glob). This method expects a *pattern* and returns all paths that match the given pattern. In this pattern, an asterisk (`*`) is interpreted as a wildcard. So, to list all files in the `PokemonGen1` directory, we can pass in the pattern `*`:

In [None]:
# ... WRITE YOUR CODE HERE ... #

This will give you something like `<generator object Path.glob at 0x............>`. Indeed, the `glob()` method returns a Python [generator](https://wiki.python.org/moin/Generators), not a list. A generator is an object you can iterate over (just like a list) but that (unlike a list) cannot be indexed. A generator can only tell you the *next* iteration item.

> The advantage of generators is that they don't need to store all iteration items up-front. If the iteration items are large, or if there is a large number of items, a generator can save large amounts of memory.

To get the next item in a generator, you can use [the built-in Python function `next()`](https://docs.python.org/3/library/functions.html#next):

In [None]:
# ... WRITE YOUR CODE HERE ... #

Once you have iterated over all items, the generator can be considered *empty* and it will raise a [`StopIteration`](https://docs.python.org/3/library/exceptions.html#StopIteration) exception if you attempt to call `next()` with it.

In [None]:
# ... WRITE YOUR CODE HERE ... #

The `StopIteration` is already raised with our second `next()` call, so we know that there is only one subdirectory in the `PokemonGen1` directory.

Of course, calling `next()` over and over on a generator that contains a lot of items is cumbersome. Luckily, we can also *iterate* over a generator with a simple `for` loop.

> ⚠️ **Only use new generators in `for` loops**
>
> You can think of iterating over a generator with a `for` loop as simply calling `next()` over and over, passing in the generator as an argument and stopping once a `StopIteration` exception is raised. If you have already called `next()` on the generator object, the `for` loop will start wherever the generator left off. **A for loop does not *rewind* a generator before it starts iterating!** (Btw, there is no such thing as *rewinding* a generator. Once an iteration item is returned, the generator forgets about it.) Therefore, to avoid subtle bugs, it is important to only use freshly created generators in `for` loops.

In [None]:
# ... WRITE YOUR CODE HERE ... #

We can also collect all iteration items of a generator in a list by passing the generator to the [built-in function](https://docs.python.org/3/library/functions.html) [`list()`](https://docs.python.org/3/library/functions.html#func-list)...

In [None]:
# ... WRITE YOUR CODE HERE ... #

...or by employing a [list comprehension](https://www.w3schools.com/python/python_lists_comprehension.asp):

In [None]:
# ... WRITE YOUR CODE HERE ... #

As we saw from the previous cells, the `PokemonGen1` directory only contains a single directory, i.e., `PokemonGen1/data`. Let's inspect this directory as well:

In [None]:
# ... WRITE YOUR CODE HERE ... #

We can see that `PokemonGen1/data` contains a large number of subdirectories that correspond to different Pokémon names. How many Pokémon does our dataset contain?

In [None]:
# ... WRITE YOUR CODE HERE ... #

We can inspect the total number of files that is in one of the subdirectories

In [None]:
# ... WRITE YOUR CODE HERE ... #

As these file paths are also represented with `pathlib.Path` objects, we can easily extract useful properties of the image paths, like the name, extension, stem, parent path,...

In [None]:
# ... WRITE YOUR CODE HERE ... #

In [None]:
# ... WRITE YOUR CODE HERE ... #

In [None]:
# ... WRITE YOUR CODE HERE ... #

In [None]:
# ... WRITE YOUR CODE HERE ... #

The `stem` is the path's name without suffix:

In [None]:
# ... WRITE YOUR CODE HERE ... #

The `parent` attribute gives the path of the parent directory:

In [None]:
# ... WRITE YOUR CODE HERE ... #

Note that we can get the Pokémon's name from the name of the parent directory.

In [None]:
# ... WRITE YOUR CODE HERE ... #

We can also inspect the file extensions of all these files:

In [None]:
# ... WRITE YOUR CODE HERE ... #

Or, with a set comprehension:

In [None]:
# ... WRITE YOUR CODE HERE ... #

From the above cells, we know that the subdirectories contain files with `.jpg` and `.png` extensions. In other words, when globbing the `PokemonGen1/data` directory, we'll get **35 626 images**.

## 1.2 Representing the dataset as a `DataFrame`

We're starting to get a feeling of how the files in the dataset are structured. The folder `PokemonGen1` contains a subdirectory `data`, and this directory contains multiple subdirectories, each of which corresponds to a Pokémon. The Pokémon directories contain files that have either a `.jpg` or a `.png` extension (i.e., images).

With `pathlib`'s `glob()`, we can explore our dataset in a rudimentary way. To understand our data more deeply, we can represent our dataset as a [`DataFrame` object from the `pandas` library](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.html). A DataFrame is a tabular data structure consisting of rows and columns, much like an Excel sheet.

We want to create a DataFrame that contains two columns: `image` and `label`. The `image` column contains the path to an image of a Pokémon, and the `label` column contains the corresponding name of that Pokémon.

You can construct a `DataFrame` by passing in a list of dictionaries. Each dictionary in the list corresponds to a row in the DataFrame. The dictionary keys correspond to column names, and the values to the value that should be in the table cell of that column and row.

For example:

In [None]:
# ... WRITE YOUR CODE HERE ... #

We can iterate over all available data and store the image paths along with the Pokémon names.

In [None]:
# ... WRITE YOUR CODE HERE ... #

In [None]:
# ... WRITE YOUR CODE HERE ... #

Of course, you can also create the same DataFrame with a list comprehension:

In [None]:
# ... WRITE YOUR CODE HERE ... #

## 1.3 Visualizing data imbalance

A common problem in machine learning is *data imbalance*. This means that some classes have much more examples than others. Such an imbalance could cause a difference in model performance on the majority classes vs. the minority classes.

To visualize the number of images per class, we can use the plotting library [Matplotlib](https://matplotlib.org/).

In Matplotlib, plots are drawn on [`Figure`s](https://matplotlib.org/stable/api/figure_api.html#matplotlib.figure.Figure). A `Figure` contains one or more [`Axes`](https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.html#matplotlib.axes.Axes), which is an area that will contain the actual plot.

To create a `Figure` with a single `Axes`, you can call [`matplotlib.pyplot.subplots()`](https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.subplots.html#matplotlib.pyplot.subplots) without any arguments.

In [None]:
# ... WRITE YOUR CODE HERE ... #

Now, we'll use the [`DataFrame.groupby()` method](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.groupby.html) to group the rows by `label` and apply [the built-in `len()` function](https://docs.python.org/3/library/functions.html#len) on each group to get the number of images per label.

In [None]:
# ... WRITE YOUR CODE HERE ... #

Let's put this in a variable called `count_per_label`:

In [None]:
# ... WRITE YOUR CODE HERE ... #

Now we want to draw a bar chart in our `Axes` with the labels as x-values and the counts as bar height.

In [None]:
# ... WRITE YOUR CODE HERE ... #

Woops, that looks pretty cluttered. Let's try with a larger figure.

In [None]:
# ... WRITE YOUR CODE HERE ... #

The bars are much clearer now, but the tick labels are still overlapping... We can *rotate* the tick labels with [`Axes.tick_params()`](https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.tick_params.html):

In [None]:
# ... WRITE YOUR CODE HERE ... #

That's much more readable! To get a real sense of the data imbalance, it's best to **sort the labels in descending order of size**. You can do this by calling [`sort_values()`](https://pandas.pydata.org/docs/reference/api/pandas.Series.sort_values.html) on `count_per_label`.

In [None]:
# ... WRITE YOUR CODE HERE ... #

Let's see what we get now...

In [None]:
# ... WRITE YOUR CODE HERE ... #

Let's put everything together.

In [None]:
# ... WRITE YOUR CODE HERE ... #

Now, you should be careful drawing too many conclusions from this graph alone. It is not necessarily true that our model won't work on minority classes. However, the data imbalance we observe is something to take into account when evaluating a trained model, and it might push us toward data sampling strategies to counter the imbalance, if necessary.

To get an even better intuition of our data, let's take a look at some images!

## 1.4 Inspecting images and labels
### 1.4.1 Visualizing a single image

Let's visualize the first image in our dataset.

To obtain the elements of a column in a DataFrame, you can pass in the column name between square brackets.

In [None]:
# ... WRITE YOUR CODE HERE ... #

The returned object is a pandas [`Series`](https://pandas.pydata.org/docs/reference/api/pandas.Series.html) object. This is somewhat comparable to a list. To get the value of the `image` column at row `0`, you can index with `[0]`.

In [None]:
# ... WRITE YOUR CODE HERE ... #

Now, we can use [`PIL.Image.open()`](https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.open) to open the image.

In [None]:
# ... WRITE YOUR CODE HERE ... #

As you can see, our image is represented as a PyTorch tensor of shape $\text{Channels}\times\text{Height}\times\text{Width}$, or simply $\text{C}\times\text{H}\times\text{W}$. To visualize this tensor, we can also use Matplotlib.

Let's create another `Figure` with a single `Axes`.

In [None]:
# ... WRITE YOUR CODE HERE ... #

Now, we can draw our image on the `Axes` with [`Axes.imshow()`](https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.imshow.html).

In [None]:
# ... WRITE YOUR CODE HERE ... #

To remove the ticks and tick labels, we can use [`Axes.tick_params()`](https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.tick_params.html).

In [None]:
# ... WRITE YOUR CODE HERE ... #

And to remove the contouring black lines, we make the [*spines* of the `Axes`](https://matplotlib.org/stable/api/spines_api.html) invisible.

In [None]:
# ... WRITE YOUR CODE HERE ... #

Putting it all together, we have:

In [None]:
# ... WRITE YOUR CODE HERE ... #

### 1.4.2 Visualizing multiple images

Instead of viewing images one by one, we can save some time by visualizing a **grid** of images. To get the first 10 image paths in the DataFrame, we can run the following:

In [None]:
# ... WRITE YOUR CODE HERE ... #

To visualize the image paths in this `Series` object, we'll need **multiple `Axes`** in our `Figure`. You can pass a number of rows (`nrows`) and a number of columns (`ncols`) to [`matplotlib.pyplot.subplots()`](https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.subplots.html#matplotlib.pyplot.subplots).

For example, to create a `Figure` with 10 `Axes` in a 2 x 5 grid, you can do:

In [None]:
# ... WRITE YOUR CODE HERE ... #

The returned `axes` is a $2\times 5$ [numpy array](https://numpy.org/doc/stable/reference/generated/numpy.array.html) of `Axes` objects:

In [None]:
# ... WRITE YOUR CODE HERE ... #

We now want to call `imshow()` on each of these `Axes` objects, each time with another image we read in with `Image.open()`. To avoid nasty index computations, we can simply flatten our `axes` array, to get a flat numpy array of $10$ `Axes` objects:

In [None]:
# ... WRITE YOUR CODE HERE ... #

Now, we can iterate jointly over these 10 `Axes` objects and the first 10 image paths in our DataFrame. Such a joint iteration is easily done with [Python's built-in function `zip()`](https://docs.python.org/3/library/functions.html#zip).

In the loop itself, we can reuse our code to visualize a single image.

In [None]:
# ... WRITE YOUR CODE HERE ... #

You can call [`Figure.tight_layout()`](https://matplotlib.org/stable/api/figure_api.html#matplotlib.figure.Figure.tight_layout) to reduce the padding between the `Axes`.

In [None]:
# ... WRITE YOUR CODE HERE ... #

We can also draw random samples from our DataFrame:

In [None]:
# ... WRITE YOUR CODE HERE ... #

Let's take the `image` column from such a sampled DataFrame and visualize those images:

In [None]:
# ... WRITE YOUR CODE HERE ... #

### 1.4.3 Adding labels

To make it more insightful, we can put the corresponding label on top of each image. To access the label of an image in the iteration, we'll need to iterate over entire rows of the DataFrame, as each row contains information on both the image and the corresponding label. To iterate over the rows of a DataFrame, you can use [`DataFrame.iterrows()`](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.iterrows.html):

In [None]:
# ... WRITE YOUR CODE HERE ... #

We can now access both the image path and the label of each row in the sampled DataFrame:

In [None]:
# ... WRITE YOUR CODE HERE ... #

Now, to put the Pokémon's name on top of each image, we can use [`Axes.set_title()`](https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.set_title.html).

In [None]:
# ... WRITE YOUR CODE HERE ... #

### 1.4.4 Visualizing images with a particular label

Another interesting inspection is to visualize some random images of a **particular label**. To get all DataFrame rows that belong to a certain label, you can use the following code:

In [None]:
# ... WRITE YOUR CODE HERE ... #

Now, we can just sample from these rows and run the same code as before.

In [None]:
# ... WRITE YOUR CODE HERE ... #

## 1.5 Inspecting data transforms

### 1.5.1 Transform images to the same size

When visualizing the images in the dataset, you might have noticed that the images have all kinds of sizes. When training a neural network, we'll want to create *batches* of images. Such a batch can only be created from images of the **same size**. Before we can start training, thus, we'll need a way to give each image the same size.

The [`torchvision` library](https://pytorch.org/vision/stable/index.html) contains some handy tools to help us with that. More specifically, we can make use of the transforms in [`torchvision.transforms.v2`](https://pytorch.org/vision/main/auto_examples/transforms/plot_transforms_getting_started.html). For example, with [`v2.Resize`](https://pytorch.org/vision/main/generated/torchvision.transforms.v2.Resize.html#torchvision.transforms.v2.Resize), we can resize an image to a fixed size.

In [None]:
# ... WRITE YOUR CODE HERE ... #

To apply this resize transform to an image, we simply call the transform with the image as an argument:

In [None]:
# ... WRITE YOUR CODE HERE ... #

As you can see, `Resize()` has resized our image such that the smallest size became equal to `224`. You can also pass in a width and height to `Resize()`, but this might change the aspect ratio of your image:

In [None]:
# ... WRITE YOUR CODE HERE ... #

What if we want all images to be of size $224\times 224$, but don't want our aspect ratio to drastically change? Well, we can resize the shortest side of the image to a fixed size and then **crop** out the center square.

In [None]:
# ... WRITE YOUR CODE HERE ... #

An easier way to apply a chain of transforms to an image is with [`v2.Compose`](https://pytorch.org/vision/main/generated/torchvision.transforms.v2.Compose.html#torchvision.transforms.v2.Compose):

In [None]:
# ... WRITE YOUR CODE HERE ... #

We can also apply this transform to multiple images:

In [None]:
# ... WRITE YOUR CODE HERE ... #

### 1.5.2 Data augmentation

Apart from ensuring that all images can be put in a batch, another important use of data transforms is to **augment the dataset**. The idea of *data augmentation* is to apply image transforms that our model should be invariant against. If a certain image is somewhat zoomed in, the aspect ratio has changed slightly, or the image is rotated by 10°, for example, we still want our model to recognize the Pokémon.

Let's first put our image visualization logic inside a utility function.

In [None]:
# ... WRITE YOUR CODE HERE ... #

Next, we sample some images from the DataFrame and show them without any transforms.

In [None]:
# ... WRITE YOUR CODE HERE ... #

Let's now play around with some combinations of transforms and visualize their effect on the images. See [this page](https://pytorch.org/vision/stable/transforms.html#v2-api-reference-recommended) for an overview of the available image transforms.

In [None]:
# ... WRITE YOUR CODE HERE ... #

In [None]:
# ... WRITE YOUR CODE HERE ... #

In [None]:
# ... WRITE YOUR CODE HERE ... #

In [None]:
# ... WRITE YOUR CODE HERE ... #

### 1.5.3 Data type conversion and normalization

Our images are [`PIL.Image` instances](https://pillow.readthedocs.io/en/stable/reference/Image.html). Our neural network, however, will expect `torch.Tensor` objects with `torch.float32` numbers that are (roughly) normally distributed with zero mean and unit variance.

To convert a `PIL.Image` to a PyTorch tensor, you can use [`v2.ToImage()`](https://pytorch.org/vision/main/generated/torchvision.transforms.v2.ToImage.html). The name of this transform might be somewhat confusing at first sight, but it is named so because it converts the input into a [`torchvision.tv_tensors.Image`](https://pytorch.org/vision/main/generated/torchvision.tv_tensors.Image.html#torchvision.tv_tensors.Image) instance, which is a subclass of [`torch.Tensor`](https://pytorch.org/docs/stable/tensors.html#torch.Tensor) for images.

In [None]:
# ... WRITE YOUR CODE HERE ... #

As you can see, the data stored in our `PIL.Image` consists of 8 bit unsigned integers (0 - 255). To convert this to `torch.float32` numbers, we can use the [`v2.ToDtype()`](https://pytorch.org/vision/stable/generated/torchvision.transforms.v2.ToDtype.html#torchvision.transforms.v2.ToDtype) transform.

In [None]:
# ... WRITE YOUR CODE HERE ... #

But as you can see, these numbers are still between 0 and 255. If we pass `scale=True`, the numbers will be scaled between `0.0` and `1.0`:

In [None]:
# ... WRITE YOUR CODE HERE ... #

Now, with [`v2.Normalize()`](https://pytorch.org/vision/stable/generated/torchvision.transforms.v2.Normalize.html#torchvision.transforms.v2.Normalize), we can normalize our pixel values so that they'll approximately have zero mean and unit variance. `Normalize` expects two arguments: `mean` (the mean to subtract) and `std` (the standard deviation to divide by). We will use `mean = [0.485, 0.456, 0.406]` and `std = [0.229, 0.224, 0.225]`. These values are the mean and std of [the ImageNet dataset](https://www.image-net.org/index.php), which is often used to pretrain neural networks.

In [None]:
# ... WRITE YOUR CODE HERE ... #

We can again chain these transforms with `v2.Compose()`:

In [None]:
# ... WRITE YOUR CODE HERE ... #

And we can prepend them with other transforms:

In [None]:
# ... WRITE YOUR CODE HERE ... #

For some transforms, you might gain a small amount of processing time by moving `v2.ToImage()` to the beginning of the transforms list, as the subsequent transforms will then be applied directly to PyTorch tensors instead of PIL images. Therefore, it is customary to use `v2.ToImage()` as the first transform.

In [None]:
# ... WRITE YOUR CODE HERE ... #

Note that these normalized tensors cannot really be visualized directly, as the data range extends beyond $[0.0, 1.0]$ and the pixels will need to be clipped. These values won't be clipped for the neural network, however, so a visualization will not really be representative of what the neural network will actually receive.

In [None]:
# ... WRITE YOUR CODE HERE ... #