In this code, we load the well-known Iris dataset, which contains measurements for three iris flower species: Setosa, Versicolor, and Virginica. Each sample is represented by four features (sepal length, sepal width, petal length, and petal width). We then initialize and train a Decision Tree Classifier, setting a maximum depth of 3 to prevent overfitting and ensure interpretability. After training, this classifier can be used to predict the species of new iris samples based on the learned relationships in the dataset.

In [5]:
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier


# Load the Iris dataset
data = load_iris()
X, y = data.data, data.target
feature_names = data.feature_names
target_names = data.target_names

# Train a Decision Tree Classifier
clf = DecisionTreeClassifier(max_depth=3, random_state=42)
clf.fit(X, y)

DecisionTreeClassifier(max_depth=3, random_state=42)

In this code, we use the TreeSplanerClassifier class to interpret and convert the decision rules of a trained decision tree model into natural language. First, we initialize TreeSplanerClassifier with the classifier and provide feature and target names to enhance interpretability. The decision_tree_to_text() method generates a readable description of the decision rules, outlining conditions under which each class is predicted. We then use build_text_prediction() to explain the decision path and prediction details for specific sample inputs, making the predictions easier to understand. Lastly, branch_impurity() provides impurity metrics for each decision path, indicating the certainty of classifications at each node. This approach makes the inner workings of the decision tree more transparent and accessible for interpretation.

In [6]:
from tree_splaner.tree_classifier_explainer import TreeSplanerClassifier

# Initialize TreeSplanerClassifier with the trained classifier
# Here we create an instance of TreeSplanerClassifier, which is designed to interpret
# and convert a trained decision tree model's structure and predictions into natural language.
# We pass the trained DecisionTreeClassifier instance, along with feature and target names for interpretability.
tree_splaner = TreeSplanerClassifier(clf, feature_names=feature_names, target_names=target_names)

# Convert the decision tree to a natural language description
# This function generates a detailed, human-readable description of the decision rules
# in the tree model, explaining the conditions under which each class would be predicted.
decision_text = tree_splaner.decision_tree_to_text()
print("Decision Tree Rules in Natural Language:\n", decision_text)

# Generate text-based predictions for specific samples
# This method builds a natural language explanation of predictions for given samples,
# showing the specific path each sample would take through the decision tree.
sample_predictions = tree_splaner.build_text_prediction(samples=[[5.1, 3.5, 1.4, 0.2], [6.3, 3.3, 6.0, 2.5]])
print("\nSample Predictions:\n", sample_predictions)

# Display branch impurity information
# This function outputs impurity metrics for each decision path in the tree, giving insights
# into how "pure" each branch is, which reflects the certainty of classifications at each node.
branch_impurity_info = tree_splaner.branch_impurity()
print("\nBranch Impurity Information:\n", branch_impurity_info)


Decision Tree Rules in Natural Language:
 If  (petal length (cm) <= 2.45 and petal length (cm) <= -2.0 then class is setosa  with probability of 50.0 ) or (petal length (cm) > 2.45 and petal width (cm) > 1.75 and petal length (cm) <= 4.95 and petal length (cm) <= -2.0 then class is versicolor  with probability of 47.0 ) or (petal length (cm) > 2.45 and petal width (cm) > 1.75 and petal length (cm) <= 4.95 and petal length (cm) > -2.0 then class is virginica  with probability of 4.0 ) or (petal length (cm) > 2.45 and petal width (cm) > 1.75 and petal length (cm) > 4.85 and petal length (cm) <= -2.0 then class is virginica  with probability of 2.0 ) or (petal length (cm) > 2.45 and petal width (cm) > 1.75 and petal length (cm) > 4.85 and petal length (cm) > -2.0 then class is virginica  with probability of 43.0 )

Sample Predictions:
 ['petal length (cm) <= 2.45 and petal length (cm) <= -2.0 therefore the predicted class is setosa with probability of 1.0', 'petal length (cm) <= 2.45 and 