In [9]:
import csv
from collections import defaultdict
import pydotplus

In [1]:
class Tree:
    def __init__(self, value=None, trueBranch=None, falseBranch=None,
                results=None, col=-1, summary=None, data=None):
        self.value = value
        self.trueBranch = trueBranch
        self.falseBranch = falseBranch
        self.results = results
        self.col = col
        self.summary = summary
        self.data = data
            

In [14]:
def calculateDiffCount(datas):
    results = {}
    for data in datas:
        if data[-1] not in results:
            results[data[-1]] = 1
        else:
            results[data[-1]] += 1
    return results

In [3]:
def gini(rows):
    length = len(rows)
    results = calculateDiffCount(rows)
    imp = 0.0
    for i in results:
        imp += (results[i]/length)**2
    return 1 - imp

In [4]:
def splitDatas(rows, value, column):
    list1 = []
    list2 = []
    if (isinstance(value, int) or isinstance(value, float)):
        for row in rows:
            if row[column] >= value:
                list1.append(row)
            else:
                list2.append(row)
    else:
        for row in rows:
            if row[column] == value:
                list1.append(row)
            else:
                list2.append(row)
    return (list1, list2)

In [21]:
def buildDecisionTree(rows, evaluationFunction=gini):
#     递归调用决策树,当gain=0时停止递归
    currentGain = evaluationFunction(rows)
    column_length = len(rows[0])
    row_length = len(rows)
    best_gain = 0.0
    best_value = None
    best_set = None
    
    for col in range(column_length-1):
        col_value_set = set([x[col] for x in rows])
        for value in col_value_set:
            list1, list2 = splitDatas(rows, value, col)
            p = len(list1) / row_length
            gain = currentGain - p * evaluationFunction(list1) - (1-p) * evaluationFunction(list2)
            if gain > best_gain:
                best_gain = gain
                best_value = (col, value)
                best_set = (list1, list2)
    dcY = {'impurity': '%.3f' % currentGain, 'samples': '%d' %row_length}
    
#     是否停止递归
    if best_gain > 0:
        trueBranch = buildDecisionTree(best_set[0], evaluationFunction)
        falseBranch = buildDecisionTree(best_set[1], evaluationFunction)
        return Tree(col=best_value[0],value=best_value[1],trueBranch=trueBranch,falseBranch=falseBranch,summary=dcY)
    else:
        return Tree(results=calculateDiffCount(rows),summary=dcY,data=rows)
  

In [26]:
def prune(tree, miniGain, evaluationFunction=gini):
#     减枝是从根节点传进来的,递归调用
    if tree.trueBranch.results == None:
        prune(tree.trueBranch, miniGain, evaluationFunction)
    if tree.falseBranch.results == None:
        prune(tree.falseBranch, miniGain, evaluationFunction)
    if tree.trueBranch.results != None and tree.falseBranch.results != None:
        len1 = len(tree.trueBranch.data)
        len2 = len(tree.falseBranch.data)
        len3 = len(tree.trueBranch.data + tree.falseBranch.data)
        p = float(len1) / float(len1+len2)
        gain = evaluationFunction(tree.trueBranch.data + tree.falseBranch.data) -\
            p * evaluationFunction(tree.trueBranch.data) -\
            (1-p) * evaluationFunction(tree.falseBranch.data)
        if gain < miniGain:
            tree.data = tree.trueBranch.data + tree.falseBranch.data
            tree.results = calculateDiffCount(tree.data)
            tree.trueBranch = None
            tree.falseBranch = None
        

In [27]:
def loadCSV(file):
    """Loads a CSV file and converts all floats and ints into basic datatypes."""
    def convertTypes(s):
        s = s.strip()
        try:
            return float(s) if '.' in s else int(s)
        except ValueError:
            return s

    reader = csv.reader(open(file, 'rt'))
    dcHeader = {}
    if bHeader:
        lsHeader = next(reader)
        for i, szY in enumerate(lsHeader):
                szCol = 'Column %d' % i
                dcHeader[szCol] = str(szY)
    return dcHeader, [[convertTypes(item) for item in row] for row in reader]

In [28]:
bHeader = True
# the bigger example
dcHeadings, trainingData = loadCSV('fishiris.csv') # demo data from matlab


In [29]:
dt = buildDecisionTree(trainingData, evaluationFunction=gini)
prune(dt, 0.4)

In [30]:
dt

<__main__.Tree at 0x102e556a0>

In [32]:
def plot(decisionTree):
    """Plots the obtained decision tree. """

    def toString(decisionTree, indent=''):
        if decisionTree.results != None:  # leaf node
            return str(decisionTree.results)
        else:
            szCol = 'Column %s' % decisionTree.col
            if szCol in dcHeadings:
                szCol = dcHeadings[szCol]
            if isinstance(decisionTree.value, int) or isinstance(decisionTree.value, float):
                decision = '%s >= %s?' % (szCol, decisionTree.value)
            else:
                decision = '%s == %s?' % (szCol, decisionTree.value)
            trueBranch = indent + 'yes -> ' + toString(decisionTree.trueBranch, indent + '\t\t')
            falseBranch = indent + 'no  -> ' + toString(decisionTree.falseBranch, indent + '\t\t')
            return (decision + '\n' + trueBranch + '\n' + falseBranch)

    print(toString(decisionTree))

In [33]:
result = plot(dt)

PetalLength >= 3?
yes -> PetalWidth >= 1.8?
		yes -> PetalLength >= 4.9?
				yes -> {'virginica': 43}
				no  -> SepalLength >= 6?
						yes -> {'virginica': 2}
						no  -> {'versicolor': 1}
		no  -> PetalLength >= 5?
				yes -> PetalWidth >= 1.6?
						yes -> SepalLength >= 7.2?
								yes -> {'virginica': 1}
								no  -> {'versicolor': 2}
						no  -> {'virginica': 3}
				no  -> {'virginica': 1, 'versicolor': 47}
no  -> {'setosa': 50}
