Skip to content

Commit

Permalink
More cleanup to the trees module
Browse files Browse the repository at this point in the history
  • Loading branch information
James Saryerwinnie committed Jul 6, 2011
1 parent 0f7f0b9 commit 26995a5
Showing 1 changed file with 25 additions and 23 deletions.
48 changes: 25 additions & 23 deletions Ch03/trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Decision Tree Source Code for Machine Learning in Action Ch. 2
@author: Peter Harrington
'''
from cPickle import dumps, loads
import operator
from copy import deepcopy
from collections import defaultdict
Expand Down Expand Up @@ -119,10 +120,10 @@ def choose_best_feature_to_split_on(dataset):
current_entropy = _calculate_entropy_for_split(dataset, feature_index=i)
# Calculate the info gain; ie reduction in entropy
info_gain = base_entropy - current_entropy
if (info_gain > best_info_gain): #compare this to the best gain so far
best_info_gain = info_gain #if better than current best, set to best
if (info_gain > best_info_gain):
best_info_gain = info_gain
best_feature = i
return best_feature #returns an integer
return best_feature


def _calculate_entropy_for_split(dataset, feature_index):
Expand Down Expand Up @@ -193,30 +194,28 @@ def create_tree(dataset, labels):
return tree


def classify(inputTree, featLabels, testVec):
root = inputTree.keys()[0]
secondDict = inputTree[root]
featIndex = featLabels.index(root)
key = testVec[featIndex]
valueOfFeat = secondDict[key]
if isinstance(valueOfFeat, dict):
classLabel = classify(valueOfFeat, featLabels, testVec)
def classify(tree, labels, input_vector):
root = tree.keys()[0]
children = tree[root]
feature_index = labels.index(root)
key = input_vector[feature_index]
new_root = children[key]
if isinstance(new_root, dict):
class_label = classify(new_root, labels, input_vector)
else:
classLabel = valueOfFeat
return classLabel
class_label = new_root
return class_label


def storeTree(inputTree,filename):
import pickle
fw = open(filename,'w')
pickle.dump(inputTree,fw)
fw.close()
def store_tree(input_tree,filename):
f = open(filename,'w')
dumps(input_tree, f)
f.close()


def grabTree(filename):
import pickle
fr = open(filename)
return pickle.load(fr)
def grab_tree(filename):
f = open(filename)
return pickle.load(f)


def demo():
Expand All @@ -241,7 +240,10 @@ def demo():
print choose_best_feature_to_split_on(data)

print "Creating decision tree:"
print create_tree(data, labels)
tree = create_tree(data, labels)
print tree
print "\nClassifying, no surfacing=true, flippers=true"
print " --> ", classify(tree, labels, [1, 1])


if __name__ == '__main__':
Expand Down

0 comments on commit 26995a5

Please sign in to comment.