Hyperparameter Testing of Decision Tree

Import libraries

In [1]:
import pandas as pd
import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score

Load dataset



In [2]:
iris = load_iris()
X, y = iris.data, iris.target

Train-test split

In [3]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

Define different hyperparameter settings

In [4]:
params = [
    {"max_depth": 2, "min_samples_split": 2, "min_samples_leaf": 1},
    {"max_depth": 4, "min_samples_split": 2, "min_samples_leaf": 1},
    {"max_depth": 6, "min_samples_split": 5, "min_samples_leaf": 2},
    {"max_depth": None, "min_samples_split": 2, "min_samples_leaf": 1}, # fully grown
]

Train and evaluate models

In [11]:
results = []
for p in params:
    clf = DecisionTreeClassifier(
        criterion="gini",
        max_depth=p["max_depth"],
        min_samples_split=p["min_samples_split"],
        min_samples_leaf=p["min_samples_leaf"],
        random_state=42
    )
    clf.fit(X_train, y_train)

    train_acc = accuracy_score(y_train, clf.predict(X_train))
    test_acc = accuracy_score(y_test, clf.predict(X_test))

    results.append({
        "max_depth": p["max_depth"],
        "min_samples_split": p["min_samples_split"],
        "min_samples_leaf": p["min_samples_leaf"],
        "Train Accuracy": train_acc,
        "Test Accuracy": test_acc
    })

Show results as table

In [12]:
df_results = pd.DataFrame(results)
print(df_results)

   max_depth  min_samples_split  min_samples_leaf  Train Accuracy  \
0        2.0                  2                 1        0.942857   
1        4.0                  2                 1        0.971429   
2        6.0                  5                 2        0.971429   
3        NaN                  2                 1        1.000000   

   Test Accuracy  
0       0.977778  
1       1.000000  
2       1.000000  
3       1.000000  
