# Decision Tree Classifier Demo
This notebook walks through training a CART-style Decision Tree on the Iris dataset.

## Intuition
A decision tree partitions the feature space with axis-aligned splits. At each node it chooses the feature and threshold that best reduce impurity (Gini or entropy). Depth limits and minimum samples per node prevent overfitting.


In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.metrics import classification_report, ConfusionMatrixDisplay
from sklearn.decomposition import PCA

sns.set_theme(style="whitegrid")


In [None]:
iris = load_iris(as_frame=True)
X = iris.data
y = iris.target
feature_names = iris.feature_names
target_names = iris.target_names
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)

tree = DecisionTreeClassifier(criterion="gini", max_depth=None, random_state=42)
tree.fit(X_train, y_train)
y_pred = tree.predict(X_test)
print(classification_report(y_test, y_pred, target_names=target_names))


### Gini vs Entropy
- **Gini impurity** measures the probability of incorrect classification at a node.
- **Entropy** measures information gain.
Both typically yield similar performance; Gini is slightly faster.


In [None]:
fig, ax = plt.subplots(figsize=(10, 6))
plot_tree(tree, feature_names=feature_names, class_names=target_names, filled=True, ax=ax)
plt.title("Decision Tree Structure")
plt.show()


### Overfitting Demonstration
Setting a very large depth can memorize the training data. Try adjusting `max_depth` (e.g., 2 vs None) and compare performance. Shallow trees generalize better.


In [None]:
shallow_tree = DecisionTreeClassifier(max_depth=2, random_state=42)
shallow_tree.fit(X_train, y_train)
print("Depth 2 accuracy:", shallow_tree.score(X_test, y_test))
print("Full depth accuracy:", tree.score(X_test, y_test))


### Feature Importance
Trees compute feature importances from impurity reduction.


In [None]:
importances = pd.Series(tree.feature_importances_, index=feature_names)
importances.sort_values().plot(kind="barh", figsize=(6, 4), title="Feature Importance")
plt.show()


### PCA Visualization
Plotting data in 2D using PCA colored by class.


In [None]:
pca = PCA(n_components=2, random_state=42)
components = pca.fit_transform(X)
plt.figure(figsize=(7,5))
scatter = plt.scatter(components[:,0], components[:,1], c=y, cmap="viridis", edgecolor='k')
plt.xlabel('PC1')
plt.ylabel('PC2')
plt.title('PCA Scatter of Iris')
plt.legend(*scatter.legend_elements(), title="Classes")
plt.show()
