Skip to content

Commit

Permalink
Rename functions to pep8 naming conventions
Browse files Browse the repository at this point in the history
  • Loading branch information
James Saryerwinnie committed Jun 30, 2011
1 parent 2b5432b commit dab2687
Showing 1 changed file with 41 additions and 33 deletions.
74 changes: 41 additions & 33 deletions Ch03/treePlotter.py
Expand Up @@ -3,82 +3,90 @@
@author: Peter Harrington
"""
import matplotlib.pyplot as plt
import matplotlib.pyplot as plot


decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
decision_node = dict(boxstyle="sawtooth", fc="0.8")
leaf_node = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")


def getNumLeafs(myTree):
def get_num_leafs(myTree):
numLeafs = 0
firstStr = myTree.keys()[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
numLeafs += getNumLeafs(secondDict[key])
numLeafs += get_num_leafs(secondDict[key])
else: numLeafs +=1
return numLeafs


def getTreeDepth(myTree):
def get_tree_depth(myTree):
maxDepth = 0
firstStr = myTree.keys()[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
thisDepth = 1 + getTreeDepth(secondDict[key])
thisDepth = 1 + get_tree_depth(secondDict[key])
else: thisDepth = 1
if thisDepth > maxDepth: maxDepth = thisDepth
return maxDepth


def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
def plot_node(nodeTxt, centerPt, parentPt, nodeType):
create_plot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)


def plotMidText(cntrPt, parentPt, txtString):
def plot_mid_text(cntrPt, parentPt, txtString):
xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString)
create_plot.ax1.text(xMid, yMid, txtString)


def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on
numLeafs = getNumLeafs(myTree) #this determines the x width of this tree
depth = getTreeDepth(myTree)
def plot_tree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on
numLeafs = get_num_leafs(myTree) #this determines the x width of this tree
firstStr = myTree.keys()[0] #the text label for this node should be this
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
plotMidText(cntrPt, parentPt, nodeTxt)
plotNode(firstStr, cntrPt, parentPt, decisionNode)
cntrPt = (plot_tree.xOff + (1.0 + float(numLeafs))/2.0/plot_tree.totalW, plot_tree.yOff)
plot_mid_text(cntrPt, parentPt, nodeTxt)
plot_node(firstStr, cntrPt, parentPt, decision_node)
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
plot_tree.yOff = plot_tree.yOff - 1.0/plot_tree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
plotTree(secondDict[key],cntrPt,str(key)) #recursion
plot_tree(secondDict[key],cntrPt,str(key)) #recursion
else: #it's a leaf node print the leaf node
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
plot_tree.xOff = plot_tree.xOff + 1.0/plot_tree.totalW
plot_node(secondDict[key], (plot_tree.xOff, plot_tree.yOff), cntrPt, leaf_node)
plot_mid_text((plot_tree.xOff, plot_tree.yOff), cntrPt, str(key))
plot_tree.yOff = plot_tree.yOff + 1.0/plot_tree.totalD


def createPlot(inTree):
fig = plt.figure(1, facecolor='white')
def create_plot(inTree):
fig = plot.figure(1, facecolor='white')
fig.clf()
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) #no ticks
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
plotTree(inTree, (0.5,1.0), '')
plt.show()
create_plot.ax1 = plot.subplot(111, frameon=False, **axprops) #no ticks
plot_tree.totalW = float(get_num_leafs(inTree))
plot_tree.totalD = float(get_tree_depth(inTree))
plot_tree.xOff = -0.5/plot_tree.totalW; plot_tree.yOff = 1.0;
plot_tree(inTree, (0.5,1.0), '')
plot.show()


def retrieveTree(i):
def retrieve_tree(i):
listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
]
return listOfTrees[i]


def main():
tree = retrieve_tree(1)
create_plot(tree)


if __name__ == '__main__':
main()

0 comments on commit dab2687

Please sign in to comment.