# 决策树分类：训练与可视化 (Decision Tree Classification)

本 notebook 系统性地介绍决策树分类器的训练、可视化与评估，涵盖以下核心内容：

1. **CART 算法原理**：二叉树结构与分裂准则
2. **决策边界可视化**：直观理解轴平行分裂特性
3. **分裂准则对比**：Gini 不纯度 vs 信息熵
4. **正则化与过拟合**：超参数对模型复杂度的影响
5. **模型评估**：混淆矩阵、分类报告与交叉验证

---

## 核心知识点

- **CART 算法**：Classification And Regression Trees，始终构建二叉树
- **分裂准则**：Gini 不纯度 $G(t) = 1 - \sum_k p_k^2$，信息熵 $H(t) = -\sum_k p_k \log_2 p_k$
- **轴平行分裂**：决策树只能产生与坐标轴平行的决策边界
- **高方差模型**：对训练数据敏感，需要正则化或集成方法

## 1. 环境配置

In [None]:
# 标准库
import warnings
warnings.filterwarnings('ignore')

# 数值计算
import numpy as np

# 可视化
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS', 'SimHei', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False
plt.style.use('seaborn-v0_8-whitegrid')

# scikit-learn 模块
from sklearn.tree import DecisionTreeClassifier, plot_tree, export_text
from sklearn.datasets import load_iris, make_moons, make_circles
from sklearn.model_selection import train_test_split, cross_val_score, GridSearchCV
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

# 设置随机种子
RANDOM_STATE = 42
np.random.seed(RANDOM_STATE)

print("环境配置完成")

## 2. 数据准备：Iris 数据集

使用经典的 Iris（鸢尾花）数据集：
- 150 个样本，3 个类别
- 4 个特征：花萼长度/宽度、花瓣长度/宽度
- 为便于可视化，我们选取花瓣长度和宽度两个特征

In [None]:
# 加载 Iris 数据集
iris = load_iris()
X = iris.data[:, 2:]  # 只取花瓣长度和宽度，便于 2D 可视化
y = iris.target

feature_names = iris.feature_names[2:]
class_names = iris.target_names

print(f"数据集形状: X={X.shape}, y={y.shape}")
print(f"特征名称: {feature_names}")
print(f"类别名称: {class_names}")
print(f"类别分布: {np.bincount(y)}")

In [None]:
# 可视化原始数据分布
fig, ax = plt.subplots(figsize=(10, 6))

colors = ['#FF6B6B', '#4ECDC4', '#45B7D1']
markers = ['o', 's', '^']

for i, (color, marker, name) in enumerate(zip(colors, markers, class_names)):
    mask = y == i
    ax.scatter(X[mask, 0], X[mask, 1], c=color, marker=marker, 
               s=50, label=name, edgecolors='black', linewidths=0.5)

ax.set_xlabel(feature_names[0])
ax.set_ylabel(feature_names[1])
ax.set_title('Iris Dataset: Petal Features')
ax.legend()
plt.tight_layout()
plt.show()

In [None]:
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=RANDOM_STATE, stratify=y
)

print(f"训练集: {len(X_train)} 样本")
print(f"测试集: {len(X_test)} 样本")

## 3. 训练决策树分类器

In [None]:
# 训练决策树分类器
tree_clf = DecisionTreeClassifier(
    max_depth=3,
    criterion='gini',
    random_state=RANDOM_STATE
)
tree_clf.fit(X_train, y_train)

# 模型基本信息
print("决策树分类器训练完成")
print(f"  树深度: {tree_clf.get_depth()}")
print(f"  叶子节点数: {tree_clf.get_n_leaves()}")
print(f"  特征数: {tree_clf.n_features_in_}")
print(f"  类别数: {tree_clf.n_classes_}")

## 4. 决策树结构可视化

决策树的可解释性是其最大优势之一。下面展示两种可视化方式：
1. 图形化树结构
2. 文本规则表示

In [None]:
# 图形化决策树结构
fig, ax = plt.subplots(figsize=(16, 10))
plot_tree(
    tree_clf,
    feature_names=feature_names,
    class_names=class_names,
    filled=True,
    rounded=True,
    fontsize=11,
    ax=ax
)
ax.set_title('Decision Tree Structure (max_depth=3)', fontsize=14)
plt.tight_layout()
plt.show()

print("\n节点说明:")
print("  - gini: Gini 不纯度，越小越纯")
print("  - samples: 落入该节点的样本数")
print("  - value: 各类别的样本数 [setosa, versicolor, virginica]")
print("  - class: 该节点的预测类别（多数类）")

In [None]:
# 文本规则表示
tree_rules = export_text(tree_clf, feature_names=list(feature_names))
print("决策规则 (文本格式):")
print(tree_rules)

## 5. 决策边界可视化

决策树的一个重要特性是**轴平行分裂**：所有决策边界都与坐标轴平行。
这使得决策树难以学习斜线或曲线形式的真实决策边界。

In [None]:
def plot_decision_boundary(clf, X, y, ax, title, feature_names, class_names):
    """
    绘制决策边界
    
    Parameters
    ----------
    clf : 已训练的分类器
    X : 特征矩阵 (n_samples, 2)
    y : 标签向量
    ax : matplotlib axes 对象
    title : 图标题
    feature_names : 特征名称列表
    class_names : 类别名称列表
    """
    # 创建网格
    h = 0.02  # 网格步长
    x_min, x_max = X[:, 0].min() - 0.5, X[:, 0].max() + 0.5
    y_min, y_max = X[:, 1].min() - 0.5, X[:, 1].max() + 0.5
    xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
    
    # 预测网格点
    Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
    Z = Z.reshape(xx.shape)
    
    # 绘制决策区域
    cmap_light = ListedColormap(['#FFAAAA', '#AAFFAA', '#AAAAFF'])
    cmap_bold = ListedColormap(['#FF0000', '#00FF00', '#0000FF'])
    
    ax.contourf(xx, yy, Z, alpha=0.4, cmap=cmap_light)
    ax.contour(xx, yy, Z, colors='black', linewidths=0.5, alpha=0.5)
    
    # 绘制样本点
    colors = ['#FF6B6B', '#4ECDC4', '#45B7D1']
    for i, (color, name) in enumerate(zip(colors, class_names)):
        mask = y == i
        ax.scatter(X[mask, 0], X[mask, 1], c=color, s=30, 
                   label=name, edgecolors='black', linewidths=0.5)
    
    ax.set_xlabel(feature_names[0])
    ax.set_ylabel(feature_names[1])
    ax.set_title(title)
    ax.legend(loc='upper left', fontsize=8)

In [None]:
# 可视化不同深度的决策边界
fig, axes = plt.subplots(2, 2, figsize=(14, 12))
depths = [1, 2, 3, None]  # None 表示无限制

for ax, depth in zip(axes.ravel(), depths):
    clf = DecisionTreeClassifier(max_depth=depth, random_state=RANDOM_STATE)
    clf.fit(X_train, y_train)
    
    train_acc = accuracy_score(y_train, clf.predict(X_train))
    test_acc = accuracy_score(y_test, clf.predict(X_test))
    
    depth_str = str(depth) if depth else 'None'
    title = f'max_depth={depth_str}\nTrain Acc={train_acc:.3f}, Test Acc={test_acc:.3f}'
    
    plot_decision_boundary(clf, X, y, ax, title, feature_names, class_names)

plt.suptitle('Decision Boundary Evolution with Increasing Depth', fontsize=14)
plt.tight_layout()
plt.show()

## 6. 分裂准则对比：Gini vs Entropy

两种常用的分裂准则：

**Gini 不纯度**：
$$G(t) = 1 - \sum_{k=1}^{K} p_k^2$$

**信息熵**：
$$H(t) = -\sum_{k=1}^{K} p_k \log_2 p_k$$

实践中两者效果通常非常接近，Gini 计算更快是默认选择。

In [None]:
# 对比 Gini 和 Entropy
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

for ax, criterion in zip(axes, ['gini', 'entropy']):
    clf = DecisionTreeClassifier(max_depth=3, criterion=criterion, random_state=RANDOM_STATE)
    clf.fit(X_train, y_train)
    
    train_acc = accuracy_score(y_train, clf.predict(X_train))
    test_acc = accuracy_score(y_test, clf.predict(X_test))
    
    title = f'Criterion: {criterion.upper()}\nTrain Acc={train_acc:.3f}, Test Acc={test_acc:.3f}'
    plot_decision_boundary(clf, X, y, ax, title, feature_names, class_names)

plt.suptitle('Gini vs Entropy: Decision Boundary Comparison', fontsize=14)
plt.tight_layout()
plt.show()

In [None]:
# 可视化 Gini 和 Entropy 函数曲线（二分类情况）
p = np.linspace(0.001, 0.999, 100)

# 二分类情况下的公式
gini = 2 * p * (1 - p)
entropy = -p * np.log2(p) - (1 - p) * np.log2(1 - p)

fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(p, gini, 'b-', linewidth=2, label='Gini: 2p(1-p)')
ax.plot(p, entropy, 'r-', linewidth=2, label='Entropy: -p·log₂(p) - (1-p)·log₂(1-p)')
ax.plot(p, entropy / 2, 'r--', linewidth=1.5, label='Entropy / 2 (scaled)')

ax.axvline(x=0.5, color='gray', linestyle=':', alpha=0.7)
ax.set_xlabel('Class proportion p')
ax.set_ylabel('Impurity')
ax.set_title('Comparison of Gini Impurity and Entropy (Binary Classification)')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print("观察：")
print("  - 两种指标在 p=0 和 p=1 时都为 0（纯净）")
print("  - 两种指标在 p=0.5 时都达到最大值（最不纯净）")
print("  - Entropy 缩放后与 Gini 形状非常接近")

## 7. 概率预测

决策树不仅可以预测类别，还可以输出概率。
概率 = 叶子节点中各类别样本的比例。

In [None]:
# 概率预测示例
test_samples = np.array([
    [5.0, 1.5],  # 可能是 versicolor
    [1.5, 0.5],  # 可能是 setosa
    [6.0, 2.2],  # 可能是 virginica
])

print("概率预测示例:")
print(f"{'Sample':<20} {'setosa':>10} {'versicolor':>12} {'virginica':>12} {'Prediction':>12}")
print("-" * 70)

for sample in test_samples:
    proba = tree_clf.predict_proba([sample])[0]
    pred = tree_clf.predict([sample])[0]
    print(f"{str(sample):<20} {proba[0]:>10.3f} {proba[1]:>12.3f} {proba[2]:>12.3f} {class_names[pred]:>12}")

## 8. 模型评估

In [None]:
# 在测试集上评估
y_pred = tree_clf.predict(X_test)

print("分类报告:")
print(classification_report(y_test, y_pred, target_names=class_names))

# 混淆矩阵
cm = confusion_matrix(y_test, y_pred)

fig, ax = plt.subplots(figsize=(8, 6))
im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
ax.figure.colorbar(im, ax=ax)

ax.set(xticks=np.arange(cm.shape[1]),
       yticks=np.arange(cm.shape[0]),
       xticklabels=class_names,
       yticklabels=class_names,
       ylabel='True label',
       xlabel='Predicted label',
       title='Confusion Matrix')

# 在格子中显示数值
thresh = cm.max() / 2.
for i in range(cm.shape[0]):
    for j in range(cm.shape[1]):
        ax.text(j, i, format(cm[i, j], 'd'),
                ha="center", va="center",
                color="white" if cm[i, j] > thresh else "black")

plt.tight_layout()
plt.show()

## 9. 超参数调优

通过网格搜索找到最佳超参数组合。

In [None]:
# 网格搜索
param_grid = {
    'max_depth': [2, 3, 4, 5, None],
    'min_samples_split': [2, 5, 10],
    'min_samples_leaf': [1, 2, 4],
    'criterion': ['gini', 'entropy']
}

grid_search = GridSearchCV(
    DecisionTreeClassifier(random_state=RANDOM_STATE),
    param_grid,
    cv=5,
    scoring='accuracy',
    n_jobs=-1
)

grid_search.fit(X_train, y_train)

print("网格搜索结果:")
print(f"  最佳参数: {grid_search.best_params_}")
print(f"  最佳交叉验证准确率: {grid_search.best_score_:.4f}")

# 使用最佳模型在测试集上评估
best_tree = grid_search.best_estimator_
test_acc = accuracy_score(y_test, best_tree.predict(X_test))
print(f"  测试集准确率: {test_acc:.4f}")

## 10. 决策树的局限性：旋转问题

由于决策树只能产生轴平行的决策边界，对于需要斜线分割的数据表现不佳。

In [None]:
# 创建需要斜线分割的数据
np.random.seed(RANDOM_STATE)
n = 100

# 原始数据（可以用直线分割）
X_simple = np.random.randn(n, 2)
y_simple = (X_simple[:, 0] > 0).astype(int)

# 旋转 45 度
angle = np.pi / 4
rotation_matrix = np.array([
    [np.cos(angle), -np.sin(angle)],
    [np.sin(angle), np.cos(angle)]
])
X_rotated = X_simple @ rotation_matrix

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

for ax, X_data, title in zip(axes, [X_simple, X_rotated], ['Original Data', 'Rotated Data (45°)']):
    # 训练决策树
    clf = DecisionTreeClassifier(max_depth=5, random_state=RANDOM_STATE)
    clf.fit(X_data, y_simple)
    
    # 绘制决策边界
    h = 0.02
    x_min, x_max = X_data[:, 0].min() - 0.5, X_data[:, 0].max() + 0.5
    y_min, y_max = X_data[:, 1].min() - 0.5, X_data[:, 1].max() + 0.5
    xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
    Z = clf.predict(np.c_[xx.ravel(), yy.ravel()]).reshape(xx.shape)
    
    ax.contourf(xx, yy, Z, alpha=0.3, cmap='RdBu')
    ax.scatter(X_data[y_simple==0, 0], X_data[y_simple==0, 1], c='red', s=30, label='Class 0')
    ax.scatter(X_data[y_simple==1, 0], X_data[y_simple==1, 1], c='blue', s=30, label='Class 1')
    ax.set_title(f'{title}\nDepth={clf.get_depth()}, Leaves={clf.get_n_leaves()}')
    ax.legend()

plt.suptitle('Decision Tree Rotation Problem: Axis-Parallel Splits', fontsize=14)
plt.tight_layout()
plt.show()

print("观察：")
print("  - 原始数据只需要一条垂直线即可分割")
print("  - 旋转后的数据需要多条阶梯状边界近似斜线")
print("  - 这导致模型复杂度增加，泛化能力下降")

## 11. 特征重要性

In [None]:
# 使用完整特征训练
X_full = iris.data
X_train_full, X_test_full, y_train_full, y_test_full = train_test_split(
    X_full, y, test_size=0.2, random_state=RANDOM_STATE, stratify=y
)

tree_full = DecisionTreeClassifier(max_depth=4, random_state=RANDOM_STATE)
tree_full.fit(X_train_full, y_train_full)

# 特征重要性
importances = tree_full.feature_importances_
indices = np.argsort(importances)[::-1]

fig, ax = plt.subplots(figsize=(10, 6))
ax.barh(range(len(importances)), importances[indices], color='steelblue')
ax.set_yticks(range(len(importances)))
ax.set_yticklabels([iris.feature_names[i] for i in indices])
ax.set_xlabel('Feature Importance')
ax.set_title('Decision Tree: Feature Importance (Iris Dataset)')
ax.invert_yaxis()
plt.tight_layout()
plt.show()

print("特征重要性排名:")
for i, idx in enumerate(indices):
    print(f"  {i+1}. {iris.feature_names[idx]}: {importances[idx]:.4f}")

## 12. 总结

### 决策树分类器优点

1. **可解释性强**：可以可视化树结构，理解决策过程
2. **无需特征缩放**：对特征的量纲不敏感
3. **能处理非线性关系**：通过多层分裂捕获复杂模式
4. **支持多分类**：天然支持多类别问题

### 决策树分类器缺点

1. **高方差**：对训练数据敏感，容易过拟合
2. **轴平行限制**：只能产生与坐标轴平行的决策边界
3. **不稳定**：数据小变化可能导致树结构大变化

### 实践建议

- 优先使用正则化参数（`max_depth`, `min_samples_leaf`）
- 使用交叉验证选择超参数
- 考虑使用集成方法（随机森林、梯度提升）提升性能

In [None]:
# ============================================================
# 单元测试：验证代码正确性
# ============================================================

def run_tests():
    """运行基础功能测试"""
    print("="*50)
    print("运行单元测试...")
    print("="*50)
    
    # 测试 1: 模型训练
    try:
        clf = DecisionTreeClassifier(max_depth=3, random_state=42)
        clf.fit(X_train, y_train)
        assert clf.get_depth() <= 3
        print("[PASS] 测试 1: 模型训练成功")
    except Exception as e:
        print(f"[FAIL] 测试 1: {e}")
    
    # 测试 2: 预测输出
    try:
        pred = clf.predict(X_test)
        assert pred.shape == (len(X_test),)
        assert all(p in [0, 1, 2] for p in pred)
        print("[PASS] 测试 2: 预测输出正确")
    except Exception as e:
        print(f"[FAIL] 测试 2: {e}")
    
    # 测试 3: 概率预测
    try:
        proba = clf.predict_proba(X_test)
        assert proba.shape == (len(X_test), 3)
        assert np.allclose(proba.sum(axis=1), 1.0)
        print("[PASS] 测试 3: 概率预测正确")
    except Exception as e:
        print(f"[FAIL] 测试 3: {e}")
    
    # 测试 4: 特征重要性
    try:
        imp = clf.feature_importances_
        assert len(imp) == X_train.shape[1]
        assert abs(sum(imp) - 1.0) < 1e-6
        print("[PASS] 测试 4: 特征重要性正确")
    except Exception as e:
        print(f"[FAIL] 测试 4: {e}")
    
    # 测试 5: 交叉验证
    try:
        scores = cross_val_score(clf, X_train, y_train, cv=3)
        assert len(scores) == 3
        assert all(0 <= s <= 1 for s in scores)
        print("[PASS] 测试 5: 交叉验证成功")
    except Exception as e:
        print(f"[FAIL] 测试 5: {e}")
    
    # 测试 6: Gini vs Entropy
    try:
        clf_gini = DecisionTreeClassifier(criterion='gini', max_depth=3, random_state=42)
        clf_entropy = DecisionTreeClassifier(criterion='entropy', max_depth=3, random_state=42)
        clf_gini.fit(X_train, y_train)
        clf_entropy.fit(X_train, y_train)
        assert clf_gini.get_depth() <= 3
        assert clf_entropy.get_depth() <= 3
        print("[PASS] 测试 6: Gini/Entropy 准则正常")
    except Exception as e:
        print(f"[FAIL] 测试 6: {e}")
    
    print("="*50)
    print("测试完成!")
    print("="*50)

run_tests()