# The Iris Classification Challenge

##### Based on a notebook by [Randal S. Olson](http://www.randalolson.com/)

#### Goal: Work through a simple machine learning example from start to finish along a typical data analysis pipeline. 

We will apply machine learning to a classification task and use a slightly modified version of the famous [Iris data set](https://archive.ics.uci.edu/ml/datasets/Iris). These data were collected by botanist Edward Anderson and popularized by the statistician Ronald Fisher in his classic 1936 paper, "The use of multiple measurements in taxonomic problems". Anderson measured the properties of three different species of iris flowers. The data set contains four features from each sample: the length and width of the petals and sepals, in centimeters. 

<img src="images/iris_petal_sepal.png" style="width: 300px;"/>

**Please note**: This exercise is based on a Jupyter notebook, an interactive environment for writing and running code, and is running in Python. To get familiar with working in Jupyter notebooks, see our JupyterLab Tutorial. For a short introduction to the basics of programming in Python, see the "Introduction to Python" notebook.

<mark>Yellow highlights indicate a small exercise or task for you to try out.</mark>

<div class="alert alert-block alert-info">A cell like this indicates a question you need to answer for this Challenge on the U4I platform.</div>

## Table of contents

0. [Introduction](#Introduction)

1. [Step 1: Frame the problem](#Step-1:-Frame-the-problem) 

2. [Step 2: Check the data](#Step-2:-Check-the-data)

3. [Step 3: Tidy the data](#Step-3:-Tidy-the-data)

4. [Step 4: Exploratory analysis](#Step-4:-Exploratory-analysis)

5. [Step 5: Classification](#Step-5:-Classification)

6. [Step 6: Evaluation with cross-validation](#Step-6:-Evaluation-with-cross-validation)

7. [Sources](#Sources)

## Introduction

[[ go back to the top ]](#Table-of-contents)

Let's pretend we would like to create a smartphone app that automatically identifies species of flowers from pictures taken on the smartphone. 

Our task is to create a demo machine learning model that **takes four measurements from flowers** (petal length, petal width, sepal length, and sepal width) and **identifies the species based on these measurements** alone. 

To develop the demo, we've been given a [data set](iris-data.csv) which includes measurements for three types of *Iris* flowers: Iris versicolor, Iris setosa, and Iris virginica.

<img src="images/iris-types.png" style="width: 800px;"/>

## Step 1: Frame the problem 

[[ go back to the top ]](#Table-of-contents)

The first step to any data analysis project is to define the question or problem we're trying to solve, and to define a measure (or set of measures) for our success at solving that task. Here, we're trying to classify the species (i.e., class) of Iris flowers based on four measurements: sepal length, sepal width, petal length, and petal width. To quantify how well our model is performing, we look at the fraction of correctly classified flowers and aim to achieve at least 90% accuracy. 

When dealing with data, we need to always consider the experimental design, and what we can expect to answer with the available data. The data we're using come from hand-measurements of 50 randomly-sampled flowers of each species using a standardized methodology. Our measurements only include three types of *Iris* flowers. The model trained on this data set will thus also only work for Iris flowers, and we will need more data to create a general flower classifier.

## Step 2: Check the data

[[ go back to the top ]](#Table-of-contents)

The next step is to look closer at the data we're working with. Spotting potential errors in the data set early can save a lot of time during our later analysis. 

<div class="alert alert-block alert-info">Question 1: What might be one reason for spotting potential errors in the data before using it to create a model?</div>

<mark>Double-click on the "iris-data.csv" file to view the data in a separate tab.</mark>

Let's read the data into a DataFrame object and display the first five rows. A DataFrame object is a 2-dimensional, labeled data structure with columns that can be of different types; you can think of it like of a spreadsheet.

<mark>Remember to press *Shift+Enter* to run each code cell.</mark>

In [None]:
# The first line below loads the pandas package, which provides the DataFrame structure; 
# an overview of the packages used in this notebook is provided in the "Sources" section
import pandas as pd

import warnings
warnings.filterwarnings('ignore')

# Load data from file
iris_data = pd.read_csv('iris-data.csv')

# Display the first five rows
iris_data.head()

The first row in the data file contains the column headers, which indicate which measurement each column represents and the unit of measurement. Each row below the header row represents an entry for a flower: four measurements and one class, which tells us the species of the flower.

One of the first things we should look for is **missing data.** By looking at the "iris-data.csv" file, we observe that missing measurements are denoted with 'NA' in the spreadsheet. We can tell Python to automatically interpret values with 'NA' as missing values:

In [None]:
# Note: this line produces no output
iris_data = pd.read_csv('iris-data.csv', na_values=['NA'])

Next, it's always a good idea to look at the **distribution** of the data — especially to identify potential **outliers**. Let's start by printing out some summary statistics about the numerical attributes of the data set:

In [None]:
# Plot summary statistics of the data set
iris_data.describe()

We can get some useful pieces of information from this table. For example, we observe that five `petal_width_cm` entries are missing. However, identifying outliers and errors in a large table of numbers is difficult; some **visualization** of the data would be helpful. 

For this, let's set up the notebook so we can plot inside of it.

In [None]:
# This line tells the notebook to show plots inside of the notebook
%matplotlib inline

# Load the plotting libraries
import matplotlib.pyplot as plt
import seaborn as sb

A quick way to obtain an overview of the data is to plot **histograms** for the numerical attributes. A histogram shows the number of data points (on the vertical axis) for which a selected attribute falls in a given value range (on the horizontal axis). Histograms reveal the distribution of data and can indicate the presence of outliers or of multiple classes in the data. Here is the histogram of the petal length attribute:

In [None]:
iris_data['petal_length_cm'].hist()
plt.xlabel('petal_length_cm')
plt.ylabel('number of samples')
;

Another useful way to visualize data are **scatterplots**. A scatterplot displays the relationship between two numerical attributes by showing the data as a collection of points; the x- and y-coordinate of each point are the values of the two selected attributes for that point. In the following plot, we use petal length and petal width as attributes, and additionaly color each point to indicate the class it belongs to. You might already observe an issue with our data set in the chart below:

In [None]:
sb.scatterplot(x="petal_length_cm", y="petal_width_cm",data=iris_data,hue='class')
;

We can combine histograms and scatterplots to create a **scatterplot matrix**. Scatterplot matrices plot the histograms of each column along the diagonal, and a matrix of scatterplots for the combination of any two attributes. They make for an efficient tool to look for errors in our data across multiple features.

In [None]:
# Note: Plotting the scatterplot matrix will take a moment. 
# For plotting, we temporarily ignore the rows with 'NA' values 

sb.pairplot(iris_data.dropna())
;

The scatterplot matrix becomes even more informative if we again use colors to indicate the classes, allowing us see some further issues with the data set:

In [None]:
# Note: Plotting the scatterplot matrix will take a moment. 

sb.pairplot(iris_data.dropna(),hue='class')
;

Note that in this chart, the diagonal shows so-called "density plots" to visualize the distributions of each feature (e.g., sepal width) for all classes simultaneously. A density plot is a variation of a histogram that uses a statistical smoothing technique to estimate a smooth shape of the distribution.

<mark>Please pause here and take a moment to analyze the charts.</mark>

<div class="alert alert-block alert-info">Question 2: Which potential issues do you notice in the data? Write down your observations before moving on to the next step.</div>

## Step 3: Tidy the data

[[ go back to the top ]](#Table-of-contents)

So far, we could observe several issues with our data set. Issues include the wrong number of classes, potentially erroneous outliers in the measurements, and missing values. In all these cases, we need to figure out what to do. Let's walk through the issues one by one and fix them before proceeding with the analysis. 

##### 1. There are five classes when there should only be three.

Some classes are obviusly mislabeled: Some `Iris-versicolor` entries lack the `Iris-` prefix, while the other extraneous class, `Iris-setossa`, is simply a typo. We fix these errors in the DataFrame:

In [None]:
# Renaming the mislabeled classes
iris_data.loc[iris_data['class'] == 'versicolor', 'class'] = 'Iris-versicolor'
iris_data.loc[iris_data['class'] == 'Iris-setossa', 'class'] = 'Iris-setosa'

# Checking that only the correct three class types remain
iris_data['class'].unique()

##### 2. There are some potentially erroneous outliers. 

Fixing outliers can be tricky, because it's rarely clear whether the outlier was caused by measurement error, recording the data in improper units, or if the outlier is a real anomaly. If we decide to exclude any data, we need to make sure to document what data we excluded and provide solid reasoning for excluding that data. 

Here we observe some clear outliers in the measurements that may be erroneous: one `sepal_width_cm` entry for `Iris-setosa` falls well outside its normal range; and several `sepal_length_cm` entries for `Iris-versicolor` are near-zero: 

In [None]:
plt.figure(figsize=(10,6))

plt.subplot(2,2,1)
# Plot the histogram of `sepal_width_cm` measurements for `Iris-setosa`
iris_data.loc[iris_data['class'] == 'Iris-setosa', 'sepal_width_cm'].hist()
plt.xlabel('sepal_width_cm for Iris-setosa')
plt.ylabel('number of samples')
;

plt.subplot(2,2,2)
# Plot the histogram of `sepal_length_cm` measurements for `Iris-versicolor`
iris_data.loc[iris_data['class'] == 'Iris-versicolor', 'sepal_length_cm'].hist()
plt.xlabel('sepal_length_cm for Iris-versicolor')
plt.ylabel('number of samples')
;

Let's fix these outliers. In the case of the one anomalous `sepal_width_cm` entry for `Iris-setosa`, we decide to remove that specific entry:

In [None]:
# We keep only 'Iris-setosa' rows for which the sepal width is equal or greater than 2.5 cm
iris_data = iris_data.loc[(iris_data['class'] != 'Iris-setosa') | (iris_data['sepal_width_cm'] >= 2.5)]

In the case of the `sepal_length_cm` entries for `Iris-versicolor` that are near-zero, we take a look at the suspect entries and notice that they seem to be off by two orders of magnitude, as if they had been recorded in meters instead of centimeters:

In [None]:
# Display the Iris-versicolor rows with a sepal length below 1 cm
iris_data.loc[(iris_data['class'] == 'Iris-versicolor') &
              (iris_data['sepal_length_cm'] < 1.0)]

We decide to adjust the entries based on this assumption, but keep track of our choice to be able to revise it later if necessary:

In [None]:
# We assume near-zeor sepal_length_cm entries for Iris versicolor were recorded in meters instead of centimeters
# and convert them to centimeters by multiplying with 100
iris_data.loc[(iris_data['class'] == 'Iris-versicolor') &
              (iris_data['sepal_length_cm'] < 1.0),
              'sepal_length_cm'] *= 100.0

##### 3. There are some rows with missing values.

Let's take a look at these rows:

In [None]:
# Display all rows which contain missing values
iris_data.loc[(iris_data['sepal_length_cm'].isnull()) |
              (iris_data['sepal_width_cm'].isnull()) |
              (iris_data['petal_length_cm'].isnull()) |
              (iris_data['petal_width_cm'].isnull())]

It would not be ideal to exclude these rows considering they all belong to the `Iris-setosa` class. Since it seems like the missing data is systematic — all of the missing values are in the same column for the same *Iris* type — this error could potentially bias our analysis.

One way to deal with missing data is **mean imputation**: If we know that the values for a measurement fall in a certain range, we can fill in empty values with the average of that measurement. Besides the mean, using the median or zero can be reasonable choices. Let's use the mean here.

In [None]:
# Calculate mean petal width for Iris setosa
average_petal_width = iris_data.loc[iris_data['class'] == 'Iris-setosa', 'petal_width_cm'].mean()

# Fill missing Iris setosa petal width values with the mean value
iris_data.loc[(iris_data['class'] == 'Iris-setosa') &
              (iris_data['petal_width_cm'].isnull()),
              'petal_width_cm'] = average_petal_width

Now that we've cleaned the data, we don't want to repeat this process every time we work with the data set. Let's save the tidied data file *as a separate file* and work directly with that data file from now on.

In [None]:
# Write cleaned data set to new csv (comma-separated values) file
iris_data.to_csv('iris-data-clean.csv', index=False)

# Load cleaned data into new DataFrame object
iris_data_clean = pd.read_csv('iris-data-clean.csv')

Before we go on, let's summarize some general takeaways related to tidying the data:

* Make sure the data is labeled properly
* Check that the data falls within the expected range, using domain knowledge whenever possible to define that expected range
* Deal with missing data: replace it if you can; otherwise drop it
* Never "fix" the data manually but instead use code as a record of how you modified the data

## Step 4: Exploratory analysis

[[ go back to the top ]](#Table-of-contents)

Exploratory analysis is the step where we start delving deeper into the data set beyond the outliers and errors to answer questions such as:

* How is the data distributed?
* Are there any correlations in the data?
* Are there any factors that explain these correlations?

Let's return to the scatterplot matrix that we used earlier and take a look at the cleaned data set:

In [None]:
sb.pairplot(iris_data_clean)
;

Our data is normally distributed (bell-shaped) for the most part, which is great news if we plan on using any modeling methods that assume normally distributed data. But there seem to be some clusters in the petal measurements, which we suspect are related to the different Iris species. Let's color code the data by the class again:

In [None]:
sb.pairplot(iris_data_clean, hue='class')
;

Indeed, the grouping of the petal measurements is related to the different species. This is actually great news for our classification task since it means that the petal measurements will facilitate distinguishing between the Iris types. 

<div class="alert alert-block alert-info">Question 3: Do you notice anything that might make it difficult to distinguish between the three Iris species? Hint: Notice where the groups of datapoints are in the scatter plots. </div>

There are also correlations between petal length and petal width, as well as between sepal length and sepal width. These correlations suggest that longer flower petals (or sepals) also tend to be wider, which seems realistic. 

Now we have a better understanding of the data we are dealing with. Let's finally get to some modeling!

## Step 5: Classification

[[ go back to the top ]](#Table-of-contents)

You might be surprised that it took us so long to get to the actual modeling step. In general, data preparation tends to be by far the [most time-consuming](https://www.forbes.com/sites/gilpress/2016/03/23/data-preparation-most-time-consuming-least-enjoyable-data-science-task-survey-says/) step in data science, and it is a vital one: If we had jumped straight to the modeling, we would have created a faulty classification model. Remember to always check your data first: **Bad data leads to bad models.** 

Before we select and train a classification model, we need to **split the data into training and testing sets:**
- A **training set** is a random subset of the data that we use to train our models.
- A **testing set** is a random subset of the data (mutually exclusive from the training set) that we use to validate our models on data they have not seen before.

Let's set up our data first.

In [None]:
# We're using all four measurements as inputs to the model
# Note that scikit-learn expects each entry to be a list of values, e.g.,
# [ [val1, val2, val3],
#   [val1, val2, val3],
#   ... ]
# such that our input data set is represented as a list of lists

# We can extract the data in this format from the DataFrame like this:
all_inputs = iris_data_clean[['sepal_length_cm', 'sepal_width_cm',
                             'petal_length_cm', 'petal_width_cm']].values
# Make sure to not mix up the order of the entries

# Similarly, we can extract the class labels
all_labels = iris_data_clean['class'].values

In [None]:
# Let's check that the representation is correct for a random input data point n
n = 42

print(all_inputs[n],all_labels[n])
iris_data_clean.loc[[n]]


Looks fine! Now our data is ready to be split. We split the data randomly using the method "train_test_split" from the **scikit-learn** library, the essential machine learning package in Python. We can specify what fraction of the data should be used for the test set using the parameter `test_size`:

In [None]:
from sklearn.model_selection import train_test_split

(training_inputs,
 testing_inputs,
 training_classes,
 testing_classes) = train_test_split(all_inputs, all_labels, test_size=0.25, random_state=1)

With our data split, we can start fitting models to our data. We will use a **decision tree classifier**: In their simplest form, decision tree classifiers ask a series of Yes/No questions about the data — each time getting closer to finding out the class of each entry — until they either classify the data set perfectly or simply can't differentiate a set of entries. 


Note: There are several [parameters](http://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html) that we can tune for decision tree classifiers, but for now we use a basic decision tree classifier:

In [None]:
from sklearn.tree import DecisionTreeClassifier

# Create the classifier
decision_tree_classifier = DecisionTreeClassifier()

# Train the classifier on the training set
decision_tree_classifier.fit(training_inputs, training_classes)

# Validate the classifier on the testing set using classification accuracy
decision_tree_classifier.score(testing_inputs, testing_classes)

That was easy: Our model achieves **97% classification accuracy** without much effort. However, there's a catch: The split of the data set was done randomly; depending on how our training and testing set is sampled, our model can achieve anywhere from 80% to 100% accuracy:

In [None]:
model_accuracies = []

# Split data set randomly 1000 times, and train and test the model
for repetition in range(1000):
    (training_inputs,
     testing_inputs,
     training_classes,
     testing_classes) = train_test_split(all_inputs, all_labels, test_size=0.25)
    
    decision_tree_classifier = DecisionTreeClassifier()
    decision_tree_classifier.fit(training_inputs, training_classes)
    classifier_accuracy = decision_tree_classifier.score(testing_inputs, testing_classes)
    model_accuracies.append(classifier_accuracy)

# Display the accuracy distribution obtained    
plt.hist(model_accuracies)
plt.xlabel('model accuracy')
plt.ylabel('number of repetitions')
left, right = plt.xlim() 
;

It's obviously a problem that our model performs quite differently depending on the subset of the data it's trained on. A better way to evaluate the model is to use **cross-validation**.

## Step 6: Evaluation with cross-validation

[[ go back to the top ]](#Table-of-contents)

A better approach to evaluate our model is to perform a ***k*-fold cross-validation**, which works as follows: Split the original data set into *k* subsets called *folds*; use one of the folds (the "test data") for evaluation, and train on the rest of the folds (the "trainig data"). Repeat *k* times such that each fold is used as the testing set exactly once:

<img src="images/Diagram_K-fold_cross_validation.png" style="width: 600px;"/>

10-fold cross-validation is the most common choice. We can perform 10-fold cross-validation on our model with the code below. The result is an array containing the 10 evaluation scores:

In [None]:
from sklearn.model_selection import cross_val_score
import numpy as np

decision_tree_classifier = DecisionTreeClassifier()
k = 10

# cross_val_score returns a list of the scores, which we can visualize
# to get a reasonable estimate of our classifier's performance
cv_scores = cross_val_score(decision_tree_classifier, all_inputs, all_labels, cv=k)

np.set_printoptions(precision=4)
print(cv_scores)

Note that the `cross_val_score` method actually performs a **stratified *k*-fold cross-validation**. Stratified *k*-fold keeps the class proportions the same across all of the folds, which is vital for ensuring that the testing is performed on a representative subset of the data set.

Now we have a more consistent rating of our classifier's general classification accuracy:

In [None]:
plt.hist(cv_scores)
plt.title('Average score: {:.3f}'.format(np.mean(cv_scores)))
plt.xlabel('model accuracy')
plt.ylabel('number of folds')
plt.xlim(left, right) 
;

Alright! We finally have our demo classifier in a complete and reproducible machine learning pipeline. We've met the success criteria that we set from the beginning (>90% accuracy), and our pipeline is flexible enough to handle new inputs or flowers when that data set is ready.

<div class="alert alert-block alert-info">Question 4: Would you expect the classification accuracy to deteriorate considerably if we did not correct for the outliers in the data? Why / why not?</div>

## Sources

[[ go back to the top ]](#Table-of-contents)

Based on a notebook by [Randal S. Olson](http://www.randalolson.com/). 

Sources for pictures:
* iris-types.png: https://gadictos.com/iris-data-classification-using-neural-net/
* iris_petal_sepal.png: https://holgerbrandl.github.io/kotlin4ds_kotlin_night_frankfurt/krangl_example_report.html
* Diagram_K-fold_cross_validation.png: https://en.wikipedia.org/wiki/Cross-validation_(statistics)

### Python packages used

This notebook uses several standard Python packages. These are:

* **pandas**: Provides the "DataFrame" structure to store data in memory and work with it easily and efficiently. DataFrame is a 2-dimensional labeled data structure with columns of potentially different types; you can think of it like a spreadsheet.
* **matplotlib**: Basic plotting library in Python; most other Python plotting libraries are built on top of it.
* **Seaborn**: Advanced statistical plotting library.
* **scikit-learn**: The essential Machine Learning package in Python.
* **NumPy**: Provides a fast numerical array structure and helper functions.