# Decission Tree Classifier for IRIS dataset
Use a decission tree to classify iris flowers from the IRIS dataset. This example is based on https://www.youtube.com/playlist?list=PLOU2XLYxmsIIuiBfYad6rFYQU_jL2ryal

## Load IRIS dataset and extract a small test set
Load the IRIS dataset, SKLearn has this dataset build-in so we can directly load it.
The dataset contains three type of flowers and of each flower it has 50 items. By extracting row 0, 50 and 100 we extract one flower of each type.

In [None]:
from sklearn.datasets import load_iris

import numpy as np

iris = load_iris()
test_idx = [ 0, 50, 100 ]

train_features = np.delete(iris.data, test_idx, axis = 0)
train_target = np.delete(iris.target, test_idx)

test_features = iris.data[test_idx]
test_target = iris.target[test_idx]

## Visualize the dataset

In [15]:
import pandas as pd

df = pd.DataFrame(train_features, columns = ['sepal_length', 'sepal_width' , 'petal_length', 'petal_width'])
df['target'] = train_target

df.head()

Unnamed: 0,sepal_length,sepal_width,petal_length,petal_width,target
0,4.9,3.0,1.4,0.2,0
1,4.7,3.2,1.3,0.2,0
2,4.6,3.1,1.5,0.2,0
3,5.0,3.6,1.4,0.2,0
4,5.4,3.9,1.7,0.4,0


## Create and train DecisionTreeClassifier from SKLearn
Use the DecisionTreeClassifier from SKLearn and train it using the training set.

In [12]:
from sklearn import tree

X = train_features
y = train_target

clf = tree.DecisionTreeClassifier()
clf.fit(X,y)

DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,
            max_features=None, max_leaf_nodes=None,
            min_impurity_split=1e-07, min_samples_leaf=1,
            min_samples_split=2, min_weight_fraction_leaf=0.0,
            presort=False, random_state=None, splitter='best')

## Validate it against our test set
Validate the values against the rows we set aside for testing

In [13]:
print(clf.predict(test_features))
print(test_target)

[0 1 2]
[0 1 2]


## Visualize the decision tree generated

In [14]:
from ipywidgets import Image

import pydotplus
from sklearn.externals.six import StringIO
from sklearn.tree import export_graphviz

dot_data = StringIO()

export_graphviz(
    clf,
    out_file=dot_data,
    feature_names=iris.feature_names,
    class_names=iris.target_names,
    filled=True,
    rounded=True,  
    impurity=False
)

graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
Image(value=graph.create(format='png'))

A Jupyter Widget