# Image Classifcation

<div class="alert alert-info">
The <b>goal</b> of this notebook is to show you how you can <b>train your own model</b> to <b>classify images of your chosen categories</b>.
</div>

## Preparation of the programming environment

First things first, let's initiate a few libraries that we need during our analyis.

In [None]:
#@title Setup Google Colab by running this cell {display-mode: "form"}
import sys
if 'google.colab' in sys.modules:
    # Clone GitHub repository
    !git clone https://github.com/miykael/amld20_classification.git
        
    # Copy files required to run the code
    !cp -r 'amld20_classification/download' 'amld20_classification/utils.py' .
    
    # Install packages via pip
    !pip install -r "amld20_classification/colab-requirements.txt"
    
    # Restart Runtime
    import os
    os.kill(os.getpid(), 9)

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_context('talk')

import numpy as np
import pandas as pd
from utils import *

# 1. Data Preparation

## 1.1. Define Class Labels

In [None]:
# List of class labels used for the classifcation
class_labels = [
               'brown bear',
               'polar bear',
               'giant panda',
               'red panda',
               'lion',
               'tiger',
               'racoon',
               'red fox',
               ]

<div class="alert alert-success">
<b>Note</b>: Feel free to change the list of class labels above as you want. It is recommended to have something between 4-8 classes.
    
<b>However</b>, for a smooth workshop experience, I recommend to keep the labels that are already listed and to only add a maximum of 4 more labels.
</div>

## 1.2. Collect Dataset

We will create the dataset for this classification ourselves by using google's image search page. The quickest way to do so is to use the python package [Google Images Download](https://github.com/hardikvasa/google-images-download).

In [None]:
# Collect the dataset and load the images
imgs_raw = collect_images(class_labels, suffix='photo,close up,portrait')

<div class="alert alert-success">
<b>Note</b>: The parameter <code>suffix='photo,close up,portrait'</code> in the previous function is used to expand our image search requests with additional terms. As it is specified now, the code will look for images for "brown bear photo", "brown bear close up" and "brown bear portrait".
</div>

In [None]:
# Let's take a look at the data we've collected
plot_images(imgs_raw, n_col=8, n_row=3)

## 1.3. Clean Dataset

In [None]:
# Remove duplicates from the dataset
imgs_unique = remove_duplicates(imgs_raw)

In [None]:
# Inspect images with their RGB color distributions
plot_images(imgs_unique, n_col=8, n_row=3, show_histogram=True)

<div class="alert alert-success">
<b>Note</b>: For this visualization, we set the parameter <code>show_histogram=True</code> to true, so that the color distributions for Red, Green and Blue are visualized as well. Keep in mind that setting this parameter to true slightly increases the visualization time.
</div>

In [None]:
# Remove outliers
imgs_clean, imgs_outlier = remove_outliers(imgs_unique)

In [None]:
# Plot clean images
plot_images(imgs_clean, n_col=8, n_row=3)

In [None]:
# Plot outlier images
plot_images(imgs_outlier, n_col=8, n_row=3)

## 1.4 Finalize Dataset

In [None]:
# Create dataset
X_pixel, y_pixel, metainfo = create_dataset(imgs_clean, class_labels, img_dim=64)

# 2. Data Exploration

## 2.1. Data Description

In [None]:
# How many images per class do we have?
plot_class_distribution(y_pixel, metainfo)

In [None]:
# What does the average image of each class look like?
plot_class_average(X_pixel, y_pixel, metainfo)

In [None]:
# What does the average RGB color profile look like per class?
plot_class_RGB(X_pixel, y_pixel, metainfo)

## 2.2. Feature Engineering

In [None]:
# Extract RGB color profiles for each image individually
X_rgb, y_rgb = extract_RGB_features(X_pixel, y_pixel)

In [None]:
# Extract features according to MobileNet (Neural Network)
X_nn, y_nn = extract_neural_network_features()

## 2.3 Recap before Modeling

Each image from our original cleaned dataset is now represented in three different ways:
- `X_pixel`: In its original pixel format.
- `X_rgb`: Represented by its RGB color profile only.
- `X_nn`: Represented by its MobileNet (Neural Network) features

To better understand what this means, let's plot an image in these three representations.

In [None]:
plot_recap(X_pixel, X_rgb, X_nn)

# 3. Modeling and Analysis

## 3.1. Fit Model to Training Data

In [None]:
model_pixel = model_fit(X_pixel, y_pixel)

In [None]:
model_rgb = model_fit(X_rgb, y_rgb)

In [None]:
model_nn = model_fit(X_nn, y_nn)

## 3.2. Check Model Performance on Test data

In [None]:
check_model_performance(model_pixel, metainfo)

In [None]:
check_model_performance(model_rgb, metainfo)

In [None]:
check_model_performance(model_nn, metainfo)

# 4. Communication & Reporting

## 4.1. Investigation of Model Performance

In [None]:
print_class_labels(metainfo)

In [None]:
# Choose which model to investigate: 'model_pixel', 'model_rgb' or 'model_nn'
model = model_nn

In [None]:
# Plot correct predictions
investigate_predictions(model, metainfo, show_correct=True, imgs=model_pixel)

In [None]:
# Plot wrong predictions
investigate_predictions(model, metainfo, show_correct=False, imgs=model_pixel)

## 4.2. Try out Model Predictions on New Images

In [None]:
# Specify the URL path to an image that you would like to classify
img_url = 'https://d33wubrfki0l68.cloudfront.net/067d6b185769f031404e927b1b70de6d5ece3e0a/d82f4/michael.1164620e.jpg'

In [None]:
predict_new_image(img_url, model_nn, metainfo)