Skip to content

Commit f9fd2cd

Browse files
committed
test
1 parent 0b2b686 commit f9fd2cd

File tree

1 file changed

+43
-37
lines changed

1 file changed

+43
-37
lines changed

DesicionTree/DesicionTreeTest.py

Lines changed: 43 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,9 @@ class DesicionTree():
66
def __init__(self):
77
pass
88

9-
def _calcShannonEnt(self, dataSet): ## 计算数据集的熵
10-
numEntries = len(dataSet)
9+
def _calcShannonEnt(self, classList): ## 计算数据集的熵
1110
classCounts = {}
12-
for data in dataSet:
13-
currentLabel = data[-1]
11+
for currentLabel in classList:
1412
if currentLabel not in classCounts:
1513
classCounts[currentLabel] = 1
1614
else:
@@ -24,37 +22,39 @@ def _calcShannonEnt(self, dataSet): ## 计算数据集的熵
2422
'''
2523
shannonEnt = 0.0
2624
for key in classCounts:
27-
prob = classCounts[key]/float(numEntries)
25+
prob = classCounts[key]/float(len(classList))
2826
shannonEnt -= prob*math.log(prob, 2) # log base 2
2927
return shannonEnt
3028

31-
def _splitDataSet(self, dataSet, axis, value):
32-
retDataSet = []
33-
for data in dataSet:
29+
def _splitDataSet(self, dataArr, classList, axis, value):
30+
retFeatData = []
31+
retLabelData = []
32+
for data, label in zip(dataArr, classList):
3433
# print data[axis]
3534
if data[axis] == value:
36-
reduceddata = data[:axis]
37-
reduceddata.extend(data[axis+1:])
38-
retDataSet.append(reduceddata)
39-
return retDataSet
35+
reducedFeat = data[:axis]
36+
reducedFeat.extend(data[axis+1:])
37+
retFeatData.append(reducedFeat)
38+
retLabelData.append(label)
39+
return retFeatData, retLabelData
4040

41-
def _chooseBestFeatureToSplit(self, dataSet):
42-
numFeatures = len(dataSet[0])-1 # 最后一列是类标签
43-
baseEntropy = self._calcShannonEnt(dataSet)
41+
def _chooseBestFeatureToSplit(self, dataArr, classList):
42+
baseEntropy = self._calcShannonEnt(classList)
4443
bestInfoGain = 0
4544
bestFeature = -1
45+
numFeatures = len(dataArr[0])
4646
for i in range(numFeatures): # 依次迭代所有的特征
47-
featList = [data[i] for data in dataSet]
47+
featList = [data[i] for data in dataArr]
4848
values = set(featList)
4949
'''
5050
条件熵:sigma(pj*子数据集的熵)
5151
'''
5252
## 计算每个特征对数据集的条件熵
5353
newEntropy = 0.0
5454
for value in values:
55-
subDataSet = self._splitDataSet(dataSet, i, value)
56-
prob = len(subDataSet)/float(len(dataSet))
57-
newEntropy += prob*self._calcShannonEnt(subDataSet)
55+
subDataArr, subClassList = self._splitDataSet(dataArr, classList, i, value)
56+
prob = len(subClassList)/float(len(classList))
57+
newEntropy += prob*self._calcShannonEnt(subClassList)
5858
'''
5959
信息增益 = 熵-条件熵
6060
'''
@@ -66,33 +66,34 @@ def _chooseBestFeatureToSplit(self, dataSet):
6666

6767
def _majorityCnt(self, classList):
6868
classCount = {}
69-
for vote in classList:
70-
if vote not in classCount:
71-
classCount[vote] = 1
69+
for currentLabel in classList:
70+
if currentLabel not in classCount:
71+
classCount[currentLabel] = 1
7272
else:
73-
classCount[vote] += 1
74-
# if vote not in classCount:
75-
# classCount[vote] = 0
76-
# classCount[vote] += 1
73+
classCount[currentLabel] += 1
74+
# if currentLabel not in classCount:
75+
# classCount[currentLabel] = 0
76+
# classCount[currentLabel] += 1
7777
sortedClassCount = sorted(classCount.items(), key=lambda xx:xx[1], reverse=True)
7878
return sortedClassCount[0][0]
7979

80-
def fit(self, dataSet, featLabels):
81-
classList = [data[-1] for data in dataSet]
80+
def fit(self, dataArr, classList, featLabels):
8281
if classList.count(classList[0]) == len(classList):
8382
return classList[0] # 所有的类标签都相同,则返回类标签
84-
if len(dataSet[0]) == 1: # 所有的类标签不完全相同,但用完所有特征,则返回次数最多的类标签
83+
if len(dataArr[0]) == 0: # 所有的类标签不完全相同,但用完所有特征,则返回次数最多的类标签
8584
return self._majorityCnt(classList)
86-
bestFeat = self._chooseBestFeatureToSplit(dataSet)
85+
bestFeat = self._chooseBestFeatureToSplit(dataArr, classList)
8786
bestFeatLabel = featLabels[bestFeat]
8887
tree = {bestFeatLabel:{}}
8988
featLabels_copy = featLabels[:] # 这样不会改变输入的featLabels
9089
featLabels_copy.remove(bestFeatLabel)
91-
featList = [data[bestFeat] for data in dataSet]
90+
featList = [data[bestFeat] for data in dataArr]
9291
values = set(featList)
9392
for value in values:
94-
subfeatLabels_copy = featLabels_copy[:] # 列表复制,非列表引用
95-
tree[bestFeatLabel][value] = self.fit(self._splitDataSet(dataSet, bestFeat, value), subfeatLabels_copy)
93+
subFeatLabels_copy = featLabels_copy[:] # 列表复制,非列表引用
94+
subDataArr = self._splitDataSet(dataArr, classList, bestFeat, value)[0]
95+
subClassList = self._splitDataSet(dataArr, classList, bestFeat, value)[1]
96+
tree[bestFeatLabel][value] = self.fit(subDataArr, subClassList, subFeatLabels_copy)
9697
return tree
9798

9899
def predict(self, tree, featLabels, testVec):
@@ -113,14 +114,19 @@ def loadDataSet():
113114
[1, 0, 'no'],
114115
[0, 1, 'no'],
115116
[0, 1, 'no']]
117+
featData = []
118+
labelData = []
119+
for data in dataSet:
120+
featData.append(data[:-1])
121+
labelData.append(data[-1])
116122
featLabels = ['no surfacing', 'flippers'] # 特征标签
117-
return dataSet, featLabels
123+
return featData, labelData, featLabels
118124

119125
if __name__ == '__main__':
120-
myDataSet, myFeatLabels = loadDataSet()
121-
print myDataSet, myFeatLabels
126+
myFeatData, myLabelData, myFeatLabels = loadDataSet()
127+
print myFeatData, myLabelData, myFeatLabels
122128
dt = DesicionTree()
123-
myTree = dt.fit(myDataSet, myFeatLabels)
129+
myTree = dt.fit(myFeatData, myLabelData, myFeatLabels)
124130
print myTree
125131
results = dt.predict(myTree, myFeatLabels, [1, 1])
126132
print results

0 commit comments

Comments
 (0)