In [13]:
# 基于iris数据集使用CART算法来创建分类树

from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_iris

# 准备数据集
iris = load_iris()

# 获取特征集和分类表示
features = iris.data
labels = iris.target

# 随机抽取33%的数据作为测试集，其余为训练集
train_features, test_features, train_labels, test_labels = train_test_split(features, labels, test_size=0.33, random_state=0)

# 创建CART分类数
clf = DecisionTreeClassifier(criterion='gini')

# 拟合构造CART分类数
clf = clf.fit(train_features, train_labels)

# 用CART分类数做预测
test_predict = clf.predict(test_features)
print(test_predict)

# 预测结果与测试集结果做比对
score = accuracy_score(test_labels, test_predict)
print(f'test_labels:\t{test_labels}\n test_predict:\t{test_predict}')
print(f'CART分类树准确率{score}')

[2 1 0 2 0 2 0 1 1 1 2 1 1 1 1 0 1 1 0 0 2 1 0 0 2 0 0 1 1 0 2 1 0 2 2 1 0
 2 1 1 2 0 2 0 0 1 2 2 1 2]
test_labels:	[2 1 0 2 0 2 0 1 1 1 2 1 1 1 1 0 1 1 0 0 2 1 0 0 2 0 0 1 1 0 2 1 0 2 2 1 0
 1 1 1 2 0 2 0 0 1 2 2 2 2]
 test_predict:	[2 1 0 2 0 2 0 1 1 1 2 1 1 1 1 0 1 1 0 0 2 1 0 0 2 0 0 1 1 0 2 1 0 2 2 1 0
 2 1 1 2 0 2 0 0 1 2 2 1 2]
CART分类树准确率0.96


In [22]:
# CART回归树预测
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_boston
from sklearn.metrics import r2_score,mean_absolute_error, mean_squared_error
from sklearn.tree import DecisionTreeRegressor,export_graphviz
import graphviz

# 准备数据集
boston = load_boston()

# 探索数据
print(boston.feature_names)

# 获取特征集和房价
features = boston.data
prices = boston.target


# 随机抽取33%的数据作为测试集，其余作为训练集
train_features, test_features, train_price, test_price = train_test_split(features,prices,test_size=0.33)

# 创建CART回归树
dtr = DecisionTreeRegressor()

# 拟合构造CART回归树
dtr.fit(train_features, train_price)

# 预测测试集中的房价
predict_price = dtr.predict(test_features)

# 测试集的结果评价
print(f'回归树二乘偏差均值:',mean_squared_error(test_price, predict_price))
print(f'回归树绝对值偏差均值:',mean_absolute_error(test_price, predict_price))

dot_data = export_graphviz(dtr, out_file=None)
graph = graphviz.Source(dot_data)

# 生成回归树可视化
graph.render('Boston')

['CRIM' 'ZN' 'INDUS' 'CHAS' 'NOX' 'RM' 'AGE' 'DIS' 'RAD' 'TAX' 'PTRATIO'
 'B' 'LSTAT']
回归树二乘偏差均值: 16.10736526946108
回归树绝对值偏差均值: 3.00059880239521


In [21]:
# 作业

from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_digits
from sklearn.metrics import r2_score,mean_absolute_error, mean_squared_error
from sklearn.tree import DecisionTreeRegressor,export_graphviz
import graphviz

# 准备数据集
digits = load_digits()

# 探索数据
print(digits.keys())

#获取特征集
features = digits.data
labels = digits.target

# 随机抽取33%的数据作为测试集，其余作为训练集
train_features, test_features, train_labels, test_labels = train_test_split(features,labels,test_size=0.33,random_state=0)

# 创建CART分类树
clf = DecisionTreeClassifier(criterion='gini')

# 拟合构造CART分类树
clf = clf.fit(train_features, train_labels)

# 用CART分类树做预测
test_predict = clf.predict(test_features)

# 预测结果于测试集结果做对比
score = accuracy_score(test_labels, test_predict)
print(f'test_labels:\t{test_labels}\n test_predict:\t{test_predict}')
print(f'CART分类树准确率{score}')

dict_keys(['data', 'target', 'frame', 'feature_names', 'target_names', 'images', 'DESCR'])
test_labels:	[2 8 2 6 6 7 1 9 8 5 2 8 6 6 6 6 1 0 5 8 8 7 8 4 7 5 4 9 2 9 4 7 6 8 9 4 3
 1 0 1 8 6 7 7 1 0 7 6 2 1 9 6 7 9 0 0 5 1 6 3 0 2 3 4 1 9 2 6 9 1 8 3 5 1
 2 8 2 2 9 7 2 3 6 0 5 3 7 5 1 2 9 9 3 1 7 7 4 8 5 8 5 5 2 5 9 0 7 1 4 7 3
 4 8 9 7 9 8 2 6 5 2 5 8 4 8 7 0 6 1 5 9 9 9 5 9 9 5 7 5 6 2 8 6 9 6 1 5 1
 5 9 9 1 5 3 6 1 8 9 8 7 6 7 6 5 6 0 8 8 9 8 6 1 0 4 1 6 3 8 6 7 4 5 6 3 0
 3 3 3 0 7 7 5 7 8 0 7 8 9 6 4 5 0 1 4 6 4 3 3 0 9 5 9 2 1 4 2 1 6 8 9 2 4
 9 3 7 6 2 3 3 1 6 9 3 6 3 2 2 0 7 6 1 1 9 7 2 7 8 5 5 7 5 2 3 7 2 7 5 5 7
 0 9 1 6 5 9 7 4 3 8 0 3 6 4 6 3 2 6 8 8 8 4 6 7 5 2 4 5 3 2 4 6 9 4 5 4 3
 4 6 2 9 0 1 7 2 0 9 6 0 4 2 0 7 9 8 5 4 8 2 8 4 3 7 2 6 9 1 5 1 0 8 2 1 9
 5 6 8 2 7 2 1 5 1 6 4 5 0 9 4 1 1 7 0 8 9 0 5 4 3 8 8 6 5 3 4 4 4 8 8 7 0
 9 6 3 5 2 3 0 8 3 3 1 3 3 0 0 4 6 0 7 7 6 2 0 4 4 2 3 7 8 9 8 6 8 5 6 2 2
 3 1 7 7 8 0 3 3 2 1 5 5 9 1 3 7 0 0 7 0 4 5 9 3 3 4 3 1 8 9 8 3 6 2 1 