In [None]:
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.metrics import accuracy_score, classification_report

In [None]:
data = load_iris()
X = data.data
y = data.target
class_names = data.target_names

In [None]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

In [None]:
# We limit max_depth=3 to keep the tree simple and interpretable (and prevent overfitting)
# criterion='gini' is default, but you can try 'entropy'
clf = DecisionTreeClassifier(criterion='gini', max_depth=3, random_state=42)
clf.fit(X_train, y_train)

In [None]:
y_pred = clf.predict(X_test)

In [None]:
print(f"Accuracy: {accuracy_score(y_test, y_pred)*100:.2f}%")
print("\nClassification Report:")
print(classification_report(y_test, y_pred, target_names=class_names))

In [None]:
plt.figure(figsize=(12, 8))
plot_tree(clf, 
          feature_names=data.feature_names,  
          class_names=class_names,
          filled=True, rounded=True)
plt.title("Decision Tree Visualization")
plt.show()

In [None]:
# HOW TO READ THE TREE:
# 1. Look at the top box (Root Node). It asks a question (e.g., "petal length <= 2.45").
# 2. If True, go left. If False, go right.
# 3. 'gini' shows the impurity at that node.
# 4. 'value' shows how many samples of each class are in that node.