-
Notifications
You must be signed in to change notification settings - Fork 0
/
id3.py
123 lines (115 loc) · 4.29 KB
/
id3.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import matplotlib.pyplot as plt
# -*- coding: utf-8 -*-
import math
import pandas as pd
from selectTable import select_table
import pygraphviz as pgv
file = pd.read_excel('watermelon20.xlsx')
index = file.columns.values
G = pgv.AGraph(directed=True, rankdir="TB",
compound=True, normalize=True, encoding='UTF-8')
cnt = 0
print(index)
cases_num = float(file.shape[0])
feature_num = file.shape[1]-1
print("data_num=",cases_num,"feature=",feature_num)
hd = 0.
h_feature = {}
use_feature={}
use_feature_Leaf={}
def calEntropy(confidition,feature):
h_feature = {}
data = select_table(confidition)
cases = len(data)
num_good = float(data.count(('是',)))
num_bad = float(len(data)-num_good)
hd = - (math.log(num_good/cases,2)*\
num_good/cases +\
math.log(num_bad/cases,2)*num_bad/cases)
for idx in feature:
labels = []
if idx == "编号" or idx == "好瓜":
continue
for label in select_table(confidition, label=idx, distict=True):
if len(label) == 1:
labels.append(label[0])
for label in labels:
now_confidition = confidition
if not now_confidition:
now_confidition = " where "
else:
now_confidition = now_confidition + " and "
now_confidition = now_confidition + " " + idx + " == '" +label+"'"
feature_data = select_table(condition=now_confidition)
#print("condition",now_confidition)
k_ = float(len(feature_data))/cases
label_good = float(feature_data.count(('是',)))
#print("feature_data",feature_data)
m_ = float(label_good)/len(feature_data) # 好瓜率
# 只有好瓜,坏瓜两种情况,直接加
hk = 0
if label_good != 0: # good
hk = m_ *math.log(m_,2)
if len(feature_data)-label_good != 0: # bad
hk = hk + (1-m_) *math.log((1-m_),2)
h_feature[idx] = h_feature.get(idx, 0) + k_ * hk
h_feature[idx] = h_feature.get(idx, 0) * -1
h_feature[idx] = hd - h_feature[idx]
return h_feature
def checkSameLable(confidition):
data = select_table(condition=confidition)
cases = len(data)
num_good = float(data.count(('是',)))
if num_good == cases:
return True,'好瓜'
elif num_good ==0:
return True, '坏瓜'
else:
return False,None
# bfs
def buildId3Tree(confidition,root,feature):
# 移除 root from 特征集
global cnt
feature.remove(root)
# get labels
labels = []
for label in select_table(confidition, label=root, distict=True):
if len(label) == 1:
labels.append(label[0])
for label in labels:
now_confidition = confidition
if not now_confidition:
now_confidition = " where "
else:
now_confidition = now_confidition + " and "
now_confidition = now_confidition + " " + root + " = '" +label+"'"
# 如果属于同一类返回
flag, nflag = checkSameLable(now_confidition)
if flag:
nflag = "LeafNode:"+str(cnt)+" "+nflag
G.add_node(nflag,fontname="SimHei")
G.add_edge(root, nflag, label=label,fontname="SimHei",color="black", style="dashed", penwidth=1.5)
cnt = cnt + 1
continue
# 计算增益比
h_feature = calEntropy(now_confidition, feature)
# 选增益比最大的特征作为节点
maxidx = max(h_feature, key=h_feature.get)
print("maxidx=",maxidx)
# 以该节点为根构建子树
G.add_node((maxidx),label=label)
G.add_edge(root, maxidx,label=label)
buildId3Tree(now_confidition, root, feature)
def createId3TreeRoot():
# 遍历完所有的特征时, 返回出现次数最多的标签(叶子)
# 计算增益比
h_feature = calEntropy(None, index)
# 选增益比最大的特征作为节点
maxidx = max(h_feature, key=h_feature.get)
G.add_node(maxidx,fontname="SimHei")
return maxidx
root = createId3TreeRoot()
print("root=",root)
buildId3Tree(None, root, index.tolist())
G.layout()
G.draw("id3.png", prog="dot")