# 决策树

本文件用于实现决策树相关的算法，包括决策树结构定义、决策树的生成（ID3、C4.5、Cart）、决策树的剪枝（预剪枝、后剪枝）

原理：待补充

导入库

In [1]:
import utils
import importlib
import numpy as np
from Tree.ID3Tree import ID3Tree
from Tree.C45Tree import C45Tree
from Tree.CartClassificationTree import CartClassificationTree
from Tree.DecisionTreeVisualizer import DecisionTreeVisualizer
from Tree.CartRegressionTree import CartRegressionTree
from Tree.Prune.Pruning import pessimistic_prune
from Tree.Prune.Pruning import cost_complexity_prune

importlib.reload(utils)

<module 'utils' from 'E:\\studyfile\\机器学习与模式识别\\code\\algorithm\\3_LinearDis\\DecisionTree\\utils.py'>

### 加载数据集
由于数据集中数据存在连续变量与离散变量的区别，需要根据数据的特征将数据正确地归类，便于决策树处理

**汽车评估数据集**：均为离散变量，用于分类任务

In [2]:
cars, cars_feature_types_dict = utils.load_car_evaluation_dataset()

In [3]:
print("---- 汽车评估数据集的前五行 ----")
print(cars.head())

print("\n---- 特征类型字典 ----")
for feature, feature_type in cars_feature_types_dict.items():
    print(f"{feature}: {feature_type}")

---- 汽车评估数据集的前五行 ----
   buying  maint  doors  persons  lug_boot  safety target
0  buying  maint  doors  persons  lug_boot  safety  label
1   vhigh  vhigh      2        2     small     low  unacc
2   vhigh  vhigh      2        2     small     med  unacc
3   vhigh  vhigh      2        2     small    high  unacc
4   vhigh  vhigh      2        2       med     low  unacc

---- 特征类型字典 ----
buying: discrete
maint: discrete
doors: discrete
persons: discrete
lug_boot: discrete
safety: discrete


**鸢尾花数据集**：均为连续变量，用于分类任务

In [4]:
iris, iris_feature_types_dict = utils.load_iris_dataset()

In [5]:
print("---- 鸢尾花数据集的前五行 ----")
print(iris.head())

print("\n---- 特征类型字典 ----")
for feature, feature_type in iris_feature_types_dict.items():
    print(f"{feature}: {feature_type}")

---- 鸢尾花数据集的前五行 ----
   sepal length (cm)  sepal width (cm)  petal length (cm)  petal width (cm)  \
0                5.1               3.5                1.4               0.2   
1                4.9               3.0                1.4               0.2   
2                4.7               3.2                1.3               0.2   
3                4.6               3.1                1.5               0.2   
4                5.0               3.6                1.4               0.2   

   target  
0       0  
1       0  
2       0  
3       0  
4       0  

---- 特征类型字典 ----
sepal length (cm): continuous
sepal width (cm): continuous
petal length (cm): continuous
petal width (cm): continuous


**泰坦尼克数据集**：存在连续型与离散型变量，用于分类任务

In [6]:
titanic, titanic_feature_types_dict = utils.load_titanic_dataset()

In [7]:
print("---- 泰坦尼克数据集的前五行 ----")
print(titanic.head())

print("\n---- 特征类型字典 ----")
for feature, feature_type in titanic_feature_types_dict.items():
    print(f"{feature}: {feature_type}")

---- 泰坦尼克数据集的前五行 ----
   Pclass     Sex   Age  SibSp  Parch     Fare Embarked  Target
0       3    male  22.0      1      0   7.2500        S       0
1       1  female  38.0      1      0  71.2833        C       1
2       3  female  26.0      0      0   7.9250        S       1
3       1  female  35.0      1      0  53.1000        S       1
4       3    male  35.0      0      0   8.0500        S       0

---- 特征类型字典 ----
Pclass: continuous
Sex: discrete
Age: continuous
SibSp: continuous
Parch: continuous
Fare: continuous
Embarked: discrete


**波士顿房价数据集**：只有连续型变量，用于回归任务

In [8]:
boston, boston_feature_types_dict = utils.load_boston_dataset()

In [9]:
print("---- 波士顿房价数据集的前五行 ----")
print(boston.head())

print("\n---- 特征类型字典 ----")
for feature, feature_type in boston_feature_types_dict.items():
    print(f"{feature}: {feature_type}")

---- 波士顿房价数据集的前五行 ----
      CRIM    ZN  INDUS  CHAS    NOX     RM   AGE     DIS  RAD    TAX  \
0  0.00632  18.0   2.31     0  0.538  6.575  65.2  4.0900    1  296.0   
1  0.02731   0.0   7.07     0  0.469  6.421  78.9  4.9671    2  242.0   
2  0.02729   0.0   7.07     0  0.469  7.185  61.1  4.9671    2  242.0   
3  0.03237   0.0   2.18     0  0.458  6.998  45.8  6.0622    3  222.0   
4  0.06905   0.0   2.18     0  0.458  7.147  54.2  6.0622    3  222.0   

   PTRATIO       B  LSTAT  Target  
0     15.3  396.90   4.98    24.0  
1     17.8  396.90   9.14    21.6  
2     17.8  392.83   4.03    34.7  
3     18.7  394.63   2.94    33.4  
4     18.7  396.90   5.33    36.2  

---- 特征类型字典 ----
CRIM: continuous
ZN: continuous
INDUS: continuous
CHAS: continuous
NOX: continuous
RM: continuous
AGE: continuous
DIS: continuous
RAD: continuous
TAX: continuous
PTRATIO: continuous
B: continuous
LSTAT: continuous


**测试集训练集划分**：按照 80%/20% 的比例划分数据集

In [10]:
# 划分数据集，80% 用于训练，20% 用于测试
X_train, X_test, y_train, y_test = utils.partition_dataset(titanic, 0.8)

In [11]:
# 打印划分后的数据集大小
print(f"训练集特征数据 X_train 大小: {X_train.shape}")
print(f"测试集特征数据 X_test 大小: {X_test.shape}")
print(f"训练集标签数据 y_train 大小: {y_train.shape}")
print(f"测试集标签数据 y_test 大小: {y_test.shape}")

训练集特征数据 X_train 大小: (712, 7)
测试集特征数据 X_test 大小: (179, 7)
训练集标签数据 y_train 大小: (712,)
测试集标签数据 y_test 大小: (179,)


### 决策树生成

$\mathrm{ID3}$ 算法

原始的ID3算法只支持离散数据，且没有剪枝过程

使用汽车评估数据集进行测试

In [2]:
# 1. 加载数据集
cars, cars_feature_types_dict = utils.load_car_evaluation_dataset()

# 2. 按 80% / 20% 划分训练集与测试集
X_train, X_test, y_train, y_test = utils.partition_dataset(cars, 0.8)

# 3. 将数据转为 numpy 格式
# 若 partition_dataset 已返回 ndarray 则此步可省略
if not isinstance(X_train, np.ndarray):
    X_train, X_test = X_train.values, X_test.values
if not isinstance(y_train, np.ndarray):
    y_train, y_test = y_train.values, y_test.values

# 4. 创建并训练 ID3 决策树
tree = ID3Tree(max_depth=5, min_samples_split=2)
tree.fit(X_train, y_train, cars_feature_types_dict)

# 5. 在测试集上进行预测
y_pred = tree.predict(X_test)

# 6. 计算分类准确率
accuracy = np.mean(y_pred == y_test)
print(f"Test Accuracy: {accuracy:.4f}")

# 7. 打印决策树结构
print("\nDecision Tree Structure:\n")
viz = DecisionTreeVisualizer()
viz.show(tree, filename="car_tree")

Test Accuracy: 0.9364

Decision Tree Structure:

✅ Graphviz 渲染完成：car_tree.png


$\mathrm{C4.5}$ 算法

C4.5算法使用信息增益率以减小稀疏分类的影响

在鸢尾花数据集与泰坦尼克数据集上进行评估

鸢尾花数据集

In [2]:
iris, iris_feature_types_dict = utils.load_iris_dataset()

X_train, X_test, y_train, y_test = utils.partition_dataset(iris, 0.8)

if not isinstance(X_train, np.ndarray):
    X_train, X_test = X_train.values, X_test.values
if not isinstance(y_train, np.ndarray):
    y_train, y_test = y_train.values, y_test.values
    
tree = C45Tree(max_depth=5, min_samples_split=2)
tree.fit(X_train, y_train, iris_feature_types_dict)

y_pred = tree.predict(X_test)
accuracy = np.mean(y_pred == y_test)
print(f"Test Accuracy: {accuracy:.4f}")
print("\nDecision Tree Structure:\n")
viz = DecisionTreeVisualizer()
viz.show(tree, filename="iris_tree")

Test Accuracy: 1.0000

Decision Tree Structure:

✅ Graphviz 渲染完成：iris_tree.png


泰坦尼克数据集

In [3]:
titanic, titanic_feature_types_dict = utils.load_titanic_dataset()

X_train, X_test, y_train, y_test = utils.partition_dataset(titanic, 0.8)

if not isinstance(X_train, np.ndarray):
    X_train, X_test = X_train.values, X_test.values
if not isinstance(y_train, np.ndarray):
    y_train, y_test = y_train.values, y_test.values
    
tree = C45Tree(max_depth=5, min_samples_split=2)
tree.fit(X_train, y_train, titanic_feature_types_dict)

y_pred = tree.predict(X_test)
accuracy = np.mean(y_pred == y_test)
print(f"Test Accuracy: {accuracy:.4f}")
print("\nDecision Tree Structure:\n")
viz = DecisionTreeVisualizer()
viz.show(tree, filename="titanic_tree")

Test Accuracy: 0.7989

Decision Tree Structure:

✅ Graphviz 渲染完成：titanic_tree.png


$\mathrm{Cart}$ 算法

Cart 算法分为 Cart 回归算法与 Cart 分类算法

Cart 分类算法

鸢尾花数据集

In [2]:
iris, iris_feature_types_dict = utils.load_iris_dataset()

X_train, X_test, y_train, y_test = utils.partition_dataset(iris, 0.8)

if not isinstance(X_train, np.ndarray):
    X_train, X_test = X_train.values, X_test.values
if not isinstance(y_train, np.ndarray):
    y_train, y_test = y_train.values, y_test.values
    
tree = CartClassificationTree(max_depth=5, min_samples_split=2)
tree.fit(X_train, y_train, iris_feature_types_dict)

y_pred = tree.predict(X_test)
accuracy = np.mean(y_pred == y_test)
print(f"Test Accuracy: {accuracy:.4f}")
print("\nDecision Tree Structure:\n")
viz = DecisionTreeVisualizer()
viz.show(tree, filename="cart_iris_tree")

Test Accuracy: 1.0000

Decision Tree Structure:

✅ Graphviz 渲染完成：iris_tree_cart.png


泰坦尼克数据集

In [2]:
titanic, titanic_feature_types_dict = utils.load_titanic_dataset()

X_train, X_test, y_train, y_test = utils.partition_dataset(titanic, 0.8)

if not isinstance(X_train, np.ndarray):
    X_train, X_test = X_train.values, X_test.values
if not isinstance(y_train, np.ndarray):
    y_train, y_test = y_train.values, y_test.values
    
tree = CartClassificationTree(max_depth=5, min_samples_split=2)
tree.fit(X_train, y_train, titanic_feature_types_dict)

y_pred = tree.predict(X_test)
accuracy = np.mean(y_pred == y_test)
print(f"Test Accuracy: {accuracy:.4f}")
print("\nDecision Tree Structure:\n")
viz = DecisionTreeVisualizer()
viz.show(tree, filename="cart_titanic_tree")

Test Accuracy: 0.7989

Decision Tree Structure:






✅ Graphviz 渲染完成：cart_titanic_tree.png


Cart 回归算法

In [4]:
# 1. 加载波士顿房价数据集
boston, boston_feature_types_dict = utils.load_boston_dataset()

# 2. 划分训练集和测试集（80%训练, 20%测试）
X_train, X_test, y_train, y_test = utils.partition_dataset(boston, 0.8)

# 3. 若数据是 DataFrame 转换为 ndarray
if not isinstance(X_train, np.ndarray):
    X_train, X_test = X_train.values, X_test.values
if not isinstance(y_train, np.ndarray):
    y_train, y_test = y_train.values, y_test.values

# 4. 初始化并训练 CART 回归树
tree = CartRegressionTree(max_depth=5, min_samples_split=2, min_gain=1e-6)
tree.fit(X_train, y_train, boston_feature_types_dict)

# 5. 在测试集上预测
y_pred = tree.predict(X_test)

# 6. 计算回归性能指标（MSE、RMSE、R²）
mse = np.mean((y_pred - y_test) ** 2)
rmse = np.sqrt(mse)
r2 = 1 - np.sum((y_test - y_pred)**2) / np.sum((y_test - np.mean(y_test))**2)

print(f"Test MSE: {mse:.4f}")
print(f"Test RMSE: {rmse:.4f}")
print(f"Test R² Score: {r2:.4f}")

viz = DecisionTreeVisualizer()
viz.show(tree, filename="cart_boston_tree")

Test MSE: 8.4233
Test RMSE: 2.9023
Test R² Score: 0.8851





✅ Graphviz 渲染完成：cart_boston_tree.png


### 剪枝

这部分似乎存在问题，之后改一改，从数学的角度理解一下

不想写了太恶心了

悲观剪枝（$\mathrm{Pessimistic \ Pruning}$）

适用于 $\mathrm{Cart}$ 分类算法

In [2]:
# 加载 Titanic 数据集
titanic, titanic_feature_types_dict = utils.load_titanic_dataset()

# 划分训练集与测试集（例如 8:2）
X_train, X_test, y_train, y_test = utils.partition_dataset(titanic, 0.8)

# 确保为 NumPy 数组
if not isinstance(X_train, np.ndarray):
    X_train, X_test = X_train.values, X_test.values
if not isinstance(y_train, np.ndarray):
    y_train, y_test = y_train.values, y_test.values

# 训练 CART 分类树
tree = CartClassificationTree(max_depth=5, min_samples_split=2)
tree.fit(X_train, y_train, titanic_feature_types_dict)

# 剪枝前评估
y_pred_before = tree.predict(X_test)
acc_before = np.mean(y_pred_before == y_test)
print(f"剪枝前测试集准确率: {acc_before:.4f}")

# 执行悲观剪枝（仅用训练集）
print("\n 执行悲观剪枝中...")
pruned_tree = pessimistic_prune(tree)

# 剪枝后评估
y_pred_after = pruned_tree.predict(X_test)
acc_after = np.mean(y_pred_after == y_test)
print(f"\n 剪枝后测试集准确率: {acc_after:.4f}")

# 打印剪枝对比结果
print("\n剪枝效果对比：")
print(f"  - 剪枝前准确率: {acc_before:.4f}")
print(f"  - 剪枝后准确率: {acc_after:.4f}")
if acc_after >= acc_before:
    print("  泛化能力提升或持平")
else:
    print("  剪枝后略有下降，但模型更简洁")

# 9️⃣ 可视化剪枝后决策树
print("\n 剪枝后决策树结构：")
viz = DecisionTreeVisualizer()
viz.show(pruned_tree, filename="cart_titanic_pruned_tree")

剪枝前测试集准确率: 0.7989

 执行悲观剪枝中...

 剪枝后测试集准确率: 0.0000

剪枝效果对比：
  - 剪枝前准确率: 0.7989
  - 剪枝后准确率: 0.0000
  剪枝后略有下降，但模型更简洁

 剪枝后决策树结构：





✅ Graphviz 渲染完成：cart_titanic_pruned_tree.png


代价复杂度剪枝（$\mathrm{Cost-Complexity \ Pruning}$）

适用于 $\mathrm{Cart}$ 剪枝算法

In [3]:
# 加载波士顿房价数据集
boston, boston_feature_types_dict = utils.load_boston_dataset()

# 划分训练集与测试集（80% 训练, 20% 测试）
X_train, X_test, y_train, y_test = utils.partition_dataset(boston, 0.8)

# 初始化并训练 CART 回归树
tree = CartRegressionTree(max_depth=6, min_samples_split=4, min_gain=1e-6)
tree.fit(X_train, y_train, boston_feature_types_dict)

# 剪枝前测试集性能
y_pred_before = tree.predict(X_test)
mse_before = np.mean((y_pred_before - y_test) ** 2)
rmse_before = np.sqrt(mse_before)
r2_before = 1 - np.sum((y_test - y_pred_before)**2) / np.sum((y_test - np.mean(y_test))**2)

print(f" 剪枝前性能:")
print(f"  MSE  = {mse_before:.4f}")
print(f"  RMSE = {rmse_before:.4f}")
print(f"  R²   = {r2_before:.4f}")

# 执行代价复杂度剪枝（仅需训练集 + 测试集）
alpha = 0.01  # 控制剪枝强度，可调
print("\n 执行代价复杂度剪枝中...")
pruned_tree = cost_complexity_prune(tree, alpha, X_test, y_test)

# 剪枝后测试集性能
y_pred_after = pruned_tree.predict(X_test)
mse_after = np.mean((y_pred_after - y_test) ** 2)
rmse_after = np.sqrt(mse_after)
r2_after = 1 - np.sum((y_test - y_pred_after)**2) / np.sum((y_test - np.mean(y_test))**2)

print(f"\n 剪枝后性能:")
print(f"  MSE  = {mse_after:.4f}")
print(f"  RMSE = {rmse_after:.4f}")
print(f"  R²   = {r2_after:.4f}")

# 打印对比结果
print("\n 剪枝效果对比:")
print(f"  - 剪枝前 R² = {r2_before:.4f}")
print(f"  - 剪枝后 R² = {r2_after:.4f}")
if r2_after >= r2_before:
    print("  → 泛化性能提升或持平 ")
else:
    print("  → 剪枝后略有下降，但模型更简洁")

# 可视化剪枝后决策树
viz = DecisionTreeVisualizer()
viz.show(pruned_tree, filename="cart_boston_pruned_tree")

 剪枝前性能:
  MSE  = 9.3186
  RMSE = 3.0526
  R²   = 0.8729

 执行代价复杂度剪枝中...
✅ 代价复杂度剪枝完成：生成 1 棵子树
→ 最优 α = 0.0000, 验证 MSE = 9.3186

 剪枝后性能:
  MSE  = 9.3186
  RMSE = 3.0526
  R²   = 0.8729

 剪枝效果对比:
  - 剪枝前 R² = 0.8729
  - 剪枝后 R² = 0.8729
  → 泛化性能提升或持平 





✅ Graphviz 渲染完成：cart_boston_pruned_tree.png
