In [15]:
# 决策树  预测乘客是否生还
import pandas as pd
import numpy as np
from sklearn import tree

data = pd.read_csv("titanic/train.csv")
print(data.info())
print(data[:5])

# 删除编号、姓名、传票编号3列
data.drop(columns=['PassengerId', 'Name', 'Ticket'], inplace=True)

# 将连续数值离散处理
cont_feat = ['Age', 'Fare']
bins = 10 # 分类点数

feat_ranges = {}
for feat in cont_feat: 
    # 数据集中存在缺省nan
    min_var = np.nanmin(data[feat])
    max_var = np.nanmax(data[feat])
    feat_ranges[feat] = np.linspace(min_var, max_var, bins).tolist()
    print(feat, ':')
    for spt in feat_ranges[feat]:
        print(f'{spt:.4f}')

# 将离散取值转化为整数
cat_feat = ['Sex', 'Pclass', 'SibSp', 'Parch', 'Cabin', 'Embarked']
for feat in cat_feat:
    data[feat] = data[feat].astype('category')
    print(f'{feat}:, {data[feat].cat.categories}')
    data[feat] = data[feat].cat.codes.to_list()
    ranges = list(set(data[feat]))
    ranges.sort()
    feat_ranges[feat] = ranges
    
data.fillna(-1, inplace=True)
for feat in feat_ranges.keys():
    feat_ranges[feat] = [-1] + feat_ranges[feat]
    
# 划分训练集与测试集
np.random.seed(0)
feat_names = data.columns[1:]
label_name = data.columns[0]

data = data.reindex(np.random.permutation(data.index))
ratio = 0.8
split = int(ratio * len(data))

train_x = data[:split].drop(columns=[label_name]).to_numpy()
train_y = data[label_name][:split].to_numpy()

test_x = data[split:].drop(columns=[label_name]).to_numpy()
test_y = data[label_name][split:].to_numpy()

print('训练集大小', len(train_x))
print('测试集大小', len(test_x))
print('特征数', train_x.shape[1])


# c4.5 分类树
c45 = tree.DecisionTreeClassifier(criterion='entropy', max_depth=6)
c45.fit(train_x, train_y)
# cart 分类树
cart = tree.DecisionTreeClassifier(criterion='gini', max_depth=6)
cart.fit(train_x, train_y)

c45_train_pred = c45.predict(train_x)
c45_test_pred = c45.predict(test_x)

cart_train_pred = cart.predict(train_x)
cart_test_pred = cart.predict(test_x)

print(f'训练集准确率:C4.5: {np.mean(c45_train_pred == train_y)}, CART: {np.mean(cart_train_pred == train_y)}')
print(f'测试集准确率:C4.5: {np.mean(c45_test_pred == test_y)}, CART: {np.mean(cart_test_pred == test_y)}')

from six import StringIO
import pydotplus

dot_data = StringIO()
tree.export_graphviz( # 导出sklearn的决策树的可视化数据
    c45,
    out_file=dot_data,
    feature_names=feat_names,
    class_names=['non-survival', 'survival'],
    filled=True,
    rounded=True,
    impurity=False
)
# 用pydotplus生成图像
graph = pydotplus.graph_from_dot_data(
    dot_data.getvalue().replace('\n', ''))
graph.write_png('tree.png')


<class 'pandas.core.frame.DataFrame'>
RangeIndex: 891 entries, 0 to 890
Data columns (total 12 columns):
 #   Column       Non-Null Count  Dtype  
---  ------       --------------  -----  
 0   PassengerId  891 non-null    int64  
 1   Survived     891 non-null    int64  
 2   Pclass       891 non-null    int64  
 3   Name         891 non-null    object 
 4   Sex          891 non-null    object 
 5   Age          714 non-null    float64
 6   SibSp        891 non-null    int64  
 7   Parch        891 non-null    int64  
 8   Ticket       891 non-null    object 
 9   Fare         891 non-null    float64
 10  Cabin        204 non-null    object 
 11  Embarked     889 non-null    object 
dtypes: float64(2), int64(5), object(5)
memory usage: 83.7+ KB
None
   PassengerId  Survived  Pclass  \
0            1         0       3   
1            2         1       1   
2            3         1       3   
3            4         1       1   
4            5         0       3   

                      

True