# A Decision Tree Classification Example

There is a famous small data set called the *Iris* flower data set or Fisher's *Iris* data set from 1936 that is used as a  classification example. It contains only 150 rows of data, where there are 50 samples from each of three species of *Iris*. There are four features/variable/columns. We'll use this classic example to see how you could use a decision tree to make a multi-class classification decision.

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.model_selection import train_test_split

from sklearn.tree import DecisionTreeClassifier, plot_tree

from sklearn.metrics import classification_report
from sklearn.metrics import ConfusionMatrixDisplay

In [None]:
# Get the data set
iris = sns.load_dataset('iris')

In [None]:
# Look at it
iris.info()

In [None]:
# Take a peek
iris.head()

In [None]:
# Sample it
iris.sample(5)

In [None]:
sns.countplot(x='species', data=iris)

## Create $X$ and $y$ and Split Data

We now create our $X$ and $y$ variables. We are trying to predict `species`, so that is our $y$ variable. We will then split the data into trainoing and test sets.

In [None]:
# X is everything but species
X = iris.drop(columns=['species'])

# y is species
y = iris.species

In [None]:
# Split the data
X_train, X_test, y_train, y_test = train_test_split(X, y,
                                                    test_size=0.3,
                                                   stratify=y,
                                                   random_state=42)

X_train.info()

In [None]:
# Plot y_test to make sure stratification occurred
sns.countplot(x=y_test)

In [None]:
# Create and fit DecisionTreeClassifier on training data
dt = DecisionTreeClassifier().fit(X_train, y_train)

In [None]:
# Plot the decision tree
fig = plt.figure(figsize=(25,20))
plot_tree(dt, feature_names=X.columns, filled=True,
         rounded=True, class_names=iris.species.value_counts().index);

In [None]:
# Print the classification report for test set
print(classification_report(y_test, dt.predict(X_test)))

In [None]:
# Plot the confusion matrix
ConfusionMatrixDisplay.from_estimator(dt, X_test, y_test, cmap='binary')

## Important Features

One of the nice things about decision trees is that you can infer the importance of each feature/variable by seeing how close to the root node it is. The closer to the root node (top), the more important that feature is for model. Additionally, you can pull out the attribute `feature_importances_` to explicitly see them or plot them.

Let's try it.

In [None]:
# Just see what it looks like
dt.feature_importances_

In [None]:
# Now, create a nice bar plot with most important feature on top
importance = pd.DataFrame({'Importance':dt.feature_importances_*100},
                          index=X.columns)
importance.sort_values('Importance', axis=0, ascending=True).plot(kind='barh')
plt.xlabel('Variable Importance')
plt.gca().legend_ = None

**&copy; 2023 - Present: Matthew D. Dean, Ph.D.   
Clinical Associate Professor of Business Analytics at William \& Mary.**