# Decision tree
- Iris dataset
- 2 features
- Multiclass

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score
from PlotFunction import plot_decision_surface_train_test
from sklearn.tree import DecisionTreeClassifier
from sklearn import tree

In [None]:
import os
imagePath = os.path.join(os.getcwd(),'Images')
if (not(os.path.isdir(imagePath))):
    os.mkdir(imagePath)

In [None]:
# Read data
iris = datasets.load_iris()

In [None]:
# Extract the last 2 columns
X = iris.data[:, 2:4]
y = iris.target

In [None]:
# Split data into training and testing data
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.3, random_state=1, stratify=y
)

In [None]:
# Standardization
sc = StandardScaler()
sc.fit(X_train)
X_train_std = sc.transform(X_train)
X_test_std = sc.transform(X_test)

In [None]:
param = "ex2"
paramSetAll = {
    "ex1": {"criterion": "gini", "max_depth": 4, "min_samples_split": 2},
    "ex2": {"criterion": "entropy", "max_depth": 4, "min_samples_split": 2},
    "ex3": {"criterion": "gini", "max_depth": 8, "min_samples_split": 2},
    "ex4": {"criterion": "entropy", "max_depth": 8, "min_samples_split": 2},
}
paramValue = paramSetAll[param]

In [None]:
# Creating objects
tree_model = DecisionTreeClassifier(**paramValue)

In [None]:
# Training
tree_model.fit(X_train_std, y_train)

In [None]:
# Prediction
y_pred = tree_model.predict(X_test_std)

In [None]:
# Misclassification from the test samples
sumMiss = (y_test != y_pred).sum()

In [None]:
# Accuracy score from the test samples
accuracyScore = accuracy_score(y_test, y_pred)

In [None]:
print(f"Misclassified examples: {sumMiss}")
print(f"Accuracy score: {accuracyScore}")

In [None]:
filenamePNG = "Images/T41_DT_" + param + ".png"
plot_decision_surface_train_test(
    X_train_std, X_test_std, y_train, y_test, tree_model, filename=filenamePNG
)

In [None]:
#Create label names
fn = [ st.replace("(cm)", "(scaled)") for st in iris.feature_names]
print(fn)

# Visualization: Plot tree
fig, ax = plt.subplots(1, figsize=(5, 5))
tree.plot_tree(
    tree_model,
    feature_names=fn[2:4],
    class_names=iris.target_names,
    filled=True,
)
filenamePDF = "Images/T41_tree_visualize_1_" + param + ".pdf"
fig.savefig(filenamePDF) 