In [18]:
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split, StratifiedKFold, GridSearchCV, cross_val_score
from sklearn.metrics import accuracy_score
from HHCART import HouseHolderCART  # oblique tree classifier
from segmentor import Gini, TotalSegmentor  # module to determine splits
import numpy as np
import itertools, time

In [16]:
# Load training data - we use the Iris dataset as an example
X, y = load_iris(return_X_y=True)
x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state = 1)

# Initialize an HHCART classifier object
sgmtr = TotalSegmentor()
HHTree = HouseHolderCART(impurity = Gini(), segmentor = sgmtr, max_depth = 5, 
                                    min_samples = 4)
# max_depth: maximum depth of the decision tree. 
# min_samples: minimum allowed number of samples in a terminal node

# Train the classifier
HHTree.fit(x_train, y_train)

# Evaluate the classifier performance
train_score = accuracy_score(y_train, HHTree.predict(x_train))
test_score = accuracy_score(y_test, HHTree.predict(x_test))
print(f"train accuracy: {train_score:.00%}")
print(f"test accuracy: {test_score:.00%}")

train accuracy: 99%
test accuracy: 93%


In [25]:
# We can also perform cross-validation on an HHCART classifier to tune hyperparameters
# In this example, we tune the minimum allowed number of samples in a terminal node.
min_sample_leaves = [2,4,6,8,10]
scores_by_leaves = np.zeros(len(min_sample_leaves)) # cv score by each min_sample_leaves value

train_ratio = 0.8
test_ratio = 0.2

for index, l in enumerate(min_sample_leaves):
    sgmtr = TotalSegmentor()
    cv = StratifiedKFold(shuffle = True, random_state = 0) # 5-fold cv that preserves the class distribution
    HHTree = HouseHolderCART(impurity = Gini(), segmentor = sgmtr, max_depth = 5, min_samples = l)
    scores = cross_val_score(HHTree, x_train, y_train, cv = cv)
    scores_by_leaves[index] = sum(scores) / len(scores)

best_min_sample_leaves = min_sample_leaves[np.argmax(scores_by_leaves)]
print(f"min_sample_leaves value with the highest cv score: {best_min_sample_leaves}")

min_sample_leaves value with the highest cv score: 6
