# 🌲 Decision Tree Classifier with Post-Pruning (Cost Complexity Pruning)

In [None]:
from sklearn.datasets import load_breast_cancer
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import matplotlib.pyplot as plt
import pandas as pd

In [None]:
# Load dataset
data = load_breast_cancer()
X = pd.DataFrame(data.data, columns=data.feature_names)
y = pd.Series(data.target)

In [None]:
# Split dataset
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

In [None]:
# Initial model to get pruning path
clf = DecisionTreeClassifier(random_state=0)
clf.fit(X_train, y_train)
path = clf.cost_complexity_pruning_path(X_train, y_train)
ccp_alphas = path.ccp_alphas[:-1]

clfs = []
for alpha in ccp_alphas:
    clf = DecisionTreeClassifier(random_state=0, ccp_alpha=alpha)
    clf.fit(X_train, y_train)
    clfs.append(clf)

# Accuracy plots
train_scores = [clf.score(X_train, y_train) for clf in clfs]
test_scores = [clf.score(X_test, y_test) for clf in clfs]

plt.figure(figsize=(10, 6))
plt.plot(ccp_alphas, train_scores, marker='o', label='Train Accuracy', drawstyle="steps-post")
plt.plot(ccp_alphas, test_scores, marker='o', label='Test Accuracy', drawstyle="steps-post")
plt.xlabel("ccp_alpha")
plt.ylabel("Accuracy")
plt.title("Accuracy vs ccp_alpha")
plt.legend()
plt.grid(True)
plt.show()

In [None]:
# Best model
best_index = test_scores.index(max(test_scores))
best_alpha = ccp_alphas[best_index]
best_clf = clfs[best_index]
y_pred_post = best_clf.predict(X_test)

print(f"Best ccp_alpha: {best_alpha}")
print("Accuracy:", accuracy_score(y_test, y_pred_post))
print("Classification Report:\n", classification_report(y_test, y_pred_post))
print("Confusion Matrix:\n", confusion_matrix(y_test, y_pred_post))

In [None]:
# Visualize post-pruned tree
plt.figure(figsize=(16, 8))
plot_tree(best_clf, filled=True, feature_names=data.feature_names, class_names=data.target_names)
plt.title(f"Decision Tree (Post-Pruned, ccp_alpha={best_alpha:.5f})")
plt.show()