In [1]:
import pandas as pd
import numpy as np
file_path = 'mushrooms_new.csv'
data = pd.read_csv(file_path)

In [2]:
data.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 8124 entries, 0 to 8123
Data columns (total 23 columns):
class                       8124 non-null object
cap-shape                   8124 non-null object
cap-surface                 8124 non-null object
cap-color                   8124 non-null object
bruises                     8124 non-null object
odor                        8124 non-null object
gill-attachment             8124 non-null object
gill-spacing                8124 non-null object
gill-size                   8124 non-null object
gill-color                  8124 non-null object
stalk-shape                 8124 non-null object
stalk-root                  8124 non-null object
stalk-surface-above-ring    8124 non-null object
stalk-surface-below-ring    8124 non-null object
stalk-color-above-ring      8124 non-null object
stalk-color-below-ring      8124 non-null object
veil-type                   8124 non-null object
veil-color                  8124 non-null object
ring-number

In [3]:
data.nunique().sort_values(ascending=True)

veil-type                    1
class                        2
bruises                      2
gill-attachment              2
gill-spacing                 2
gill-size                    2
stalk-shape                  2
ring-number                  3
cap-surface                  4
veil-color                   4
stalk-surface-below-ring     4
stalk-surface-above-ring     4
ring-type                    5
stalk-root                   5
cap-shape                    6
population                   6
habitat                      7
stalk-color-above-ring       9
stalk-color-below-ring       9
odor                         9
spore-print-color            9
cap-color                   10
gill-color                  12
dtype: int64

In [4]:
data.head()

Unnamed: 0,class,cap-shape,cap-surface,cap-color,bruises,odor,gill-attachment,gill-spacing,gill-size,gill-color,...,stalk-surface-below-ring,stalk-color-above-ring,stalk-color-below-ring,veil-type,veil-color,ring-number,ring-type,spore-print-color,population,habitat
0,p,x,s,n,t,p,f,c,n,k,...,s,w,w,p,w,o,p,k,s,u
1,e,x,s,y,t,a,f,c,b,k,...,s,w,w,p,w,o,p,n,n,g
2,e,b,s,w,t,l,f,c,b,n,...,s,w,w,p,w,o,p,n,n,m
3,p,x,y,w,t,p,f,c,n,n,...,s,w,w,p,w,o,p,k,s,u
4,e,x,s,g,f,n,f,w,b,k,...,s,w,w,p,w,o,e,n,a,g


In [5]:
# 用于保存特征最优分割点和节点基尼指数的二叉树
class Tree:
    def __init__(self, node, value, gini):
        self.node = node
        self.value = value
        self.yes = None
        self.no = None
        self.gini = gini

计算基尼指数的过程，看着《统计学习方法》里的步骤写的

In [6]:
# 计算样本集合的基尼指数
def calc_gini(x, total):
    res = 0
    temp = x[label].value_counts()
    for k in temp.index:
        res += (temp[k] / x.shape[0])**2
    return (1 - res) * x.shape[0] / total

# 计算每一个特征的每一个取值的基尼指数
def create_gini_dict(data, label):
    node_gini = dict()
    feature = [c for c in data.columns if c not in [label]]
    total = data.shape[0]
    for c in feature:
        for i in data[c].unique():
            temp = data[[label]].groupby(data[c] == i).apply(calc_gini, total=total)
            node_gini[(c, i)] = temp.sum()
    return node_gini

# 选择基尼指数最小的特征及其切分点，从现节点分成两个子节点
# 将样本集合分配到两个子节点中去，递归调用该函数直到无特征可划分
def create_tree(data, label):
    if data.shape[1] < 2:
        return data[label].value_counts().idxmax()        
    node_gini = create_gini_dict(data, label)
    c, a = min(node_gini, key=node_gini.get)
    node = Tree(c, a, node_gini[(c, a)])
    yes = data[data[c] == a]
    yes = yes.drop(c, axis=1)
    no = data[data[c] != a]
    if not yes.empty:
        if yes[label].nunique() == 1:
            node.yes = yes[label].unique()[0]
        else:
            node.yes = create_tree(yes, label)
    else:
        node.yes = data[label].value_counts().idxmax()        
    if not no.empty:
        if no[label].nunique() == 1:
            node.no = no[label].unique()[0]
        else:
            node.no = create_tree(no, label)
    else:
        node.no = data[label].value_counts().idxmax()        
    return node

In [7]:
# 预测
def predict(tree, data):
    res = []
    for i, row in data.iterrows():
        node = tree
        flag = False
        while not flag:
            if row[node.node] == node.value:
                if isinstance(node.yes, Tree):
                    node = node.yes
                else:
                    res.append(node.yes)
                    flag = True
            else:
                if isinstance(node.no, Tree):
                    node = node.no
                else:
                    res.append(node.no)
                    flag = True
    return res

In [8]:
# 划分训练集和测试集，这里用了20%作为测试集
random = list(range(data.shape[0]))
np.random.shuffle(random)
train_size = int(data.shape[0] * 0.8)
train = data.iloc[random[:train_size], :]
test = data.iloc[random[train_size:], :]

In [9]:
# 评价模型
label = 'class'
feature = [c for c in data.columns if c not in [label]]
tree = create_tree(train, label)
result = predict(tree, test[feature])
print('CART accuracy: {}'.format(sum(test[label] == result)/len(result)))

CART accuracy: 1.0
