# Objective

The objective of this notebook is to offer a first approximation at Image Classification problems.

For this, we will be using a very popular library `PyTorch` and a DataSet of fruits and vegetables.

In [None]:
!pip install -q eccd_datasets pygradus

In [None]:
# Download the original lables used when training a resnet
 
import json
!wget https://files.fast.ai/models/imagenet_class_index.json -O resnet_labels.json
with open("resnet_labels.json", "r") as fh: data = json.load(fh)
resnet_labels = {}
for k, v in data.items():
    resnet_labels[int(k)] = v[1]

In [None]:
STUDENT_NAME = "Nombre Apellido"
COURSE_NAME = "eccd-oct23"
EXERCISE_NAME = "image-classification"

In [None]:
import io
import torch
import pandas as pd
import torchvision.transforms as transf
import numpy as np
import matplotlib.pyplot as plt

from torchvision import models
from pathlib import Path
from PIL import Image
from eccd_datasets import load_images

from pygradus import create_exercise, check_solution

torch.manual_seed(42)

# Exploring the dataset

First, we invite you to go to the dataset folder and explore the content and structure of the project.

The dataset used in this notebook consists on a subset from the dataset located [here](https://github.com/marcusklasson/GroceryStoreDataset)

Once that is done, we can start looking at what is included in the dataset

In [None]:
df_images = load_images()
df_images.head()

### Looking at the images

We can use the `PIL` library to look at the images

In [None]:
def load_image_data(image_data):  
    return Image.open(io.BytesIO(image_data))

In [None]:
image = load_image_data(df_images.iloc[0]["image_data"])
image

### Images as matrices

We can also look look at the matrix representation of each image using numpy

In [None]:
I = np.array(image)
print("Image shape", I.shape)
print(f"Image range in each coordinate: [{I.min()}, {I.max()}]")

And we can modify the image manually by changing the values of the matrix

In [None]:
new_I = I.copy()
new_I[:, :, 0] = 0 # Killing the red channel

In [None]:
plt.imshow(new_I.astype(int))

# PyTorch Transformations

The same way we normalize tabular data with Standard and MinMax scalers, we need to normalize image data.

We will proceed to explore some of the most used transformations

## Resizing

In [None]:
transformation = transf.Resize((100, 100))

In [None]:
resized_image = transformation.forward(image)

In [None]:
plt.imshow(resized_image)
plt.title(f"New shape: {np.array(resized_image).shape}")

## Center Crop

Implement a transformation for croping and centering (hint: there is a transformation that does that)

In [None]:
type(image)

In [None]:
def center_crop_transformation(image, size: int) -> np.array:
    """
    This function uses a pytorch transformation to
    center and crop the image
    """
    # Write your code here


In [None]:
answer_center_crop = center_crop_transformation(image, 150)
plt.imshow(answer_center_crop)

In [None]:
assert answer_center_crop.shape == (150, 150, 3)

## RandomResizedCrop

In [None]:
print("Image original size: ", np.array(image).shape)
fig, ax = plt.subplots(1, 6, figsize=(20, 4))
for i, size in enumerate([50, 100, 150, 200, 300, 500]):
  
    transformation = transf.RandomResizedCrop(size)

    crp_img = transformation.forward(image)
    ax[i].imshow(np.array(crp_img))
    ax[i].set_title(np.array(crp_img).shape)

## Random Horizontal Flip

In [None]:
transformation = transf.RandomHorizontalFlip()

maybe_flipped = [transformation.forward(image) for _ in range(5)]

plt.imshow(np.hstack([np.array(img) for img in maybe_flipped]))

## Normalization

The same way we normalize columns for tabular data, here we normalize each image according to the mean and standard deviation of each colour channel.

In [None]:
two_images = [load_image_data(row["image_data"]) for _, row in df_images.iloc[:2].iterrows()]

In [None]:
two_image_dataset = np.array([np.array(img) for img in two_images])
two_image_dataset.shape

plt.imshow(np.hstack([np.array(img) for img in two_images]))

In [None]:
np.mean(two_image_dataset, axis=(0, 1, 2))

In [None]:
transformation = transf.Compose([
        transf.ToTensor(),
        transf.Normalize(
            np.mean(two_image_dataset, axis=(0, 1, 2)),
            np.std(two_image_dataset, axis=(0, 1, 2)),
            )
    ])

In [None]:
transformed_two_image_dataset = [transformation(img) for img in two_image_dataset]

In [None]:
transformed_two_image_dataset

# Using ImageNet

Since training a large neural network requires lots of data and computing power, often we download a pre-trained neural network, which we can later fine-tune.

Here, we will download an ImageNet network.

Remember that since the network is already trained with a specific dataset, when evaluating new images, we need transform them using the same transformations used for training. In particular, that includes using the same normalizations.

In [None]:
resnet = models.resnet18(pretrained=True)

In [None]:
resnet

We load a maping from resnet integer labels to the actual categories

In [None]:
def predict_using_resnet(image):
    """
    This image uses the resnet as is to
    predict an image.
    Remember to apply the correct transformations
    to the image before feeding it to the network.
    
    The following link might be useful: https://pytorch.org/hub/pytorch_vision_resnet/
    """
    
    # Write your code here


In [None]:
img1 = (
    load_image_data(
        df_images[
            df_images["coarse_cat"] == "Apple"
        ]
        .iloc[0]
        ["image_data"]
    )

)
img1

In [None]:
img2 = (
    load_image_data(
        df_images[
            df_images["coarse_cat"] == "Orange"
        ]
        .iloc[0]
        ["image_data"]
    )

)
img2

In [None]:
img3 = (
    load_image_data(
        df_images[
            df_images["coarse_cat"] == "Pear"
        ]
        .iloc[0]
        ["image_data"]
    )

)
img3

In [None]:
pred1 = predict_using_resnet(img1)

In [None]:
assert pred1 == "lemon"

In [None]:
pred2 = predict_using_resnet(img2)
print(pred2)

In [None]:
pred3 = predict_using_resnet(img3)
print(pred3)

In [None]:

proposed_solution = {
'attempt': {
    'course_name': COURSE_NAME,
    'exercise_name': EXERCISE_NAME,
    'username': STUDENT_NAME,
},
'task_attempts': 
[
         {
            "name": "center_crop",
            "answer": np.array_str(answer_center_crop),
         },

         {
            "name": "resnet_pred2",
            "answer": pred2,
         },
         {
            "name": "resnet_pred3",
            "answer": pred3,
         }
]

}
check_solution(proposed_solution)
    