## CART算法的代码实现

In [1]:
# CART算法的代码实现
from numpy import *
def loadDataSet(fileName):
    dataMat = []
    fr = open(fileName)
    for line in fr.readlines():
        curLine = line.strip().split('\t')
        # 转换成float类型
#         fltLine = map(float,curLine)
        fltLine = list(map(float,curLine))
        dataMat.append(fltLine)
    return dataMat
def binSplitDataSet(dataSet,feature,value):
    mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:]
    mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:]
    return mat0,mat1

In [2]:
testMat = mat(eye(4))
testMat

matrix([[1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.]])

In [3]:
# testMat
mat0,mat1 = binSplitDataSet(testMat,1,0.5)

In [4]:
mat0

matrix([[0., 1., 0., 0.]])

In [5]:
mat1

matrix([[1., 0., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.]])

In [6]:
def regLeaf(dataSet):
    return mean(dataSet[:,-1])
def regErr(dataSet):
    return var(dataSet[:,-1]) * shape(dataSet)[0]

In [7]:
def creatTree(dataSet,leafType = regLeaf,errType = regErr,ops = (1,4)):
    feat,val = chooseBestSplit(dataSet,leafType,errType,ops)
    if feat == None:
        return val
    retTree = {}
    retTree['spInd'] = feat
    retTree['spVal'] = val
    lSet,rSet = binSplitDataSet(dataSet,feat,val)
    retTree['left'] = creatTree(lSet,leafType,errType,ops)
    retTree['right'] = creatTree(rSet,leafType,errType,ops)
    return retTree

## 将CART算法用于回归

## 回归树的切分函数

In [8]:
# 回归树的切分函数

def chooseBestSplit(dataSet,leafType = regLeaf,errType = regErr,ops = (1,4)):
    tolS = ops[0]
    tolN = ops[1]
    if len(set(dataSet[:,-1].T.tolist()[0])) == 1:
        return None,leafType(dataSet)
    m,n = shape(dataSet)
    S = errType(dataSet)
    bestS = inf
    bestIndex = 0
    bestValue = 0
    for featIndex in range(n-1):
        for splitVal in set((dataSet[:,featIndex].T.A.tolist())[0]):
            mat0,mat1 = binSplitDataSet(dataSet,featIndex,splitVal)
            if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):
                continue
            newS = errType(mat0) + errType(mat1)
            if newS < bestS:
                bestIndex =  featIndex
                bestValue = splitVal
                bestS = newS
    if (S - bestS) < tolS:
        return None,leafType(dataSet)
    mat0,mat1 = binSplitDataSet(dataSet,bestIndex,bestValue)
    if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):
        return None,leafType(dataSet)
    return bestIndex,bestValue

In [9]:
from numpy import *

In [10]:
myDat = loadDataSet('ex00.txt')

In [11]:
myMat = mat(myDat)

In [12]:
creatTree(myMat)

{'spInd': 0,
 'spVal': 0.48813,
 'left': 1.0180967672413792,
 'right': -0.04465028571428572}

In [13]:
myDat1 = loadDataSet('ex0.txt')
myMat1 = mat(myDat1)
creatTree(myMat1)

{'spInd': 1,
 'spVal': 0.39435,
 'left': {'spInd': 1,
  'spVal': 0.582002,
  'left': {'spInd': 1,
   'spVal': 0.797583,
   'left': 3.9871632,
   'right': 2.9836209534883724},
  'right': 1.980035071428571},
 'right': {'spInd': 1,
  'spVal': 0.197834,
  'left': 1.0289583666666666,
  'right': -0.023838155555555553}}

## 预剪枝

In [14]:
myDat2 = loadDataSet('ex2.txt')
myMat2 = mat(myDat2)

In [15]:
# creatTree(myMat2)
creatTree(myMat2,ops=(10000,4))

{'spInd': 0,
 'spVal': 0.499171,
 'left': 101.35815937735848,
 'right': -2.637719329787234}

## 回归树剪枝函数

In [16]:
# 回归树剪枝函数
def isTree(obj):
    return (type(obj).__name__ == 'dict')
def getMean(tree):
    if isTree(tree['right']):
        tree['right'] = getMean(tree['right'])
    if isTree(tree['left']):
        tree['left'] = getMean(tree['left'])
    return (tree['left'] + tree['right']) / 2.0
#树的后剪枝
def prune(tree, testData):#待剪枝的树和剪枝所需的测试数据
    if shape(testData)[0] == 0: return getMean(tree)  # 确认数据集非空
    #假设发生过拟合，采用测试数据对树进行剪枝
    if (isTree(tree['right']) or isTree(tree['left'])): #左右子树非空
        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
    if isTree(tree['left']): tree['left'] = prune(tree['left'], lSet)
    if isTree(tree['right']): tree['right'] = prune(tree['right'], rSet)
    #剪枝后判断是否还是有子树
    if not isTree(tree['left']) and not isTree(tree['right']):
        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
        #判断是否merge
        errorNoMerge = sum(power(lSet[:, -1] - tree['left'], 2)) + \
                       sum(power(rSet[:, -1] - tree['right'], 2))
        treeMean = (tree['left'] + tree['right']) / 2.0
        errorMerge = sum(power(testData[:, -1] - treeMean, 2))
        #如果合并后误差变小
        if errorMerge < errorNoMerge:
            print("merging")
            return treeMean
        else:
            return tree
    else:
        return tree

In [17]:
myTree = creatTree(myMat2,ops = (0,1))

In [18]:
myDatTest = loadDataSet('ex2test.txt')

In [19]:
myMat2Test = mat(myDatTest)

In [20]:
prune(myTree,myMat2Test)

merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging


{'spInd': 0,
 'spVal': 0.499171,
 'left': {'spInd': 0,
  'spVal': 0.729397,
  'left': {'spInd': 0,
   'spVal': 0.952833,
   'left': {'spInd': 0,
    'spVal': 0.965969,
    'left': 92.5239915,
    'right': {'spInd': 0,
     'spVal': 0.956951,
     'left': {'spInd': 0,
      'spVal': 0.958512,
      'left': {'spInd': 0,
       'spVal': 0.960398,
       'left': 112.386764,
       'right': 123.559747},
      'right': 135.837013},
     'right': 111.2013225}},
   'right': {'spInd': 0,
    'spVal': 0.759504,
    'left': {'spInd': 0,
     'spVal': 0.763328,
     'left': {'spInd': 0,
      'spVal': 0.769043,
      'left': {'spInd': 0,
       'spVal': 0.790312,
       'left': {'spInd': 0,
        'spVal': 0.806158,
        'left': {'spInd': 0,
         'spVal': 0.815215,
         'left': {'spInd': 0,
          'spVal': 0.833026,
          'left': {'spInd': 0,
           'spVal': 0.841547,
           'left': {'spInd': 0,
            'spVal': 0.841625,
            'left': {'spInd': 0,
            

## 模型树的叶节点生成函数

In [21]:

# 模型树的叶节点生成函数
def linearSolve(dataSet):
    m,n = shape(dataSet)
    X = mat(ones((m,n)))
    Y = mat(ones((m,1)))
    X[:,1:n] = dataSet[:,0:n-1]
    Y = dataSet[:,-1]
    xTx = X.T * X
    if linalg.det(xTx) == 0.0:
        raise NameError('This materix is singular, cannot do inverse,try increasing the second value of ops')
    ws = xTx.I *(X.T * Y)
    return ws,X,Y
def modelLeaf(dataSet):
    ws,X,Y = linearSolve(dataSet)
    return ws
def modelErr(dataSet):
    ws,X,Y = linearSolve(dataSet)
    yHat = X * ws
    # 方差
    return sum(power(Y - yHat,2))


In [22]:
myMat2 = mat(loadDataSet('exp2.txt'))

In [23]:
creatTree(myMat2,modelLeaf,modelErr,(1,10))

{'spInd': 0, 'spVal': 0.285477, 'left': matrix([[1.69855694e-03],
         [1.19647739e+01]]), 'right': matrix([[3.46877936],
         [1.18521743]])}

## 用树回归进行预测的代码

In [24]:
# 用树回归进行预测的代码
def regTreeEval(model,inDat):
    return float(model)
def modelTreeEval(model,inDat):
    n = shape(inDat)[1]
    X = mat(ones((1,n+1)))
    X[:,1:n+1] = inDat
    return float(X * model)
def treeForeCast(tree,inData,modelEval = regTreeEval):
    if not isTree(tree):
        return modelEval(tree,inData)
    if inData[tree['spInd']] > tree['spVal']:
        if isTree(tree['left']):
            return treeForeCast(tree['left'],inData,modelEval)
        else:
            return modelEval(tree['left'],inData)
    else:
        if isTree(tree['right']):
            return treeForeCast(tree['right'],inData,modelEval)
        else:
            return modelEval(tree['right'],inData)
def creatForeCast(tree,testData,modelEval = regTreeEval):
    m = len(testData)
    yHat = mat(zeros((m,1)))
    for i in range(m):
        yHat[i,0] = treeForeCast(tree,mat(testData[i]),modelEval)
    return yHat

## 创建一棵回归树

In [25]:
trainMat = mat(loadDataSet('bikeSpeedVsIq_train.txt'))

In [26]:
testMat = mat(loadDataSet('bikeSpeedVsIq_test.txt'))

In [27]:
myTree = creatTree(trainMat,ops = (1,20))

In [28]:
yHat = creatForeCast(myTree,testMat[:,0])

In [29]:
corrcoef(yHat,testMat[:,1],rowvar=0)[0,1]

0.9640852318222141

## 创建一棵模型树

In [30]:
myTree = creatTree(trainMat,modelLeaf,modelErr,(1,20))

In [31]:
yHat = creatForeCast(myTree,testMat[:,0],modelTreeEval)

In [32]:
corrcoef(yHat,testMat[:,1],rowvar=0)[0,1]

0.9760412191380593

## 得到r^2

In [33]:
ws,X,Y = linearSolve(trainMat)

In [34]:
for i in range(shape(testMat)[0]):
    yHat[i] = testMat[i,0] * ws[1,0] + ws[0,0]

In [35]:
corrcoef(yHat,testMat[:,1],rowvar=0)[0,1]

0.9434684235674763

## 使用Python的Tkinter库创建GUI

In [37]:
from numpy import *
from tkinter import *

import matplotlib
matplotlib.use('TkAgg')
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from matplotlib.figure import Figure

#
def reDraw(tolS, tolN) :
    reDraw.f.clf()
    reDraw.a = reDraw.f.add_subplot(111)
    if chkBtnVar.get() :
        if tolN < 2 : tolN = 2
        myTree = creatTree(reDraw.rawDat, modelLeaf, modelErr, (tolS, tolN))
        yHat = creatForeCast(myTree, reDraw.testDat, modelTreeEval)
    else :
        myTree = creatTree(reDraw.rawDat, ops=(tolS, tolN))
        yHat = creatForeCast(myTree, reDraw.testDat)
    # reDraw.rawDat[:,0].A，需要将矩阵转换成数组
    reDraw.a.scatter(reDraw.rawDat[:,0].A, reDraw.rawDat[:,1].A, s=5)
    reDraw.a.plot(reDraw.testDat, yHat, linewidth=2.0)
    reDraw.canvas.show()

#
def getInputs() :
    try : tolN = int(tolNentry.get())
    except :
        tolN = 10
        print ("enter Integer for tolN")
        tolNentry.delete(0, END)
        tolNentry.insert(0, '10')
    try : tolS = float(tolSentry.get())
    except : 
        tolS = 1.0
        print ("enter Float for tolS")
        tolSentry.delete(0, END)
        tolSentry.insert(0, '1.0')
    return tolN, tolS

# 
def drawNewTree() :
    # 取得输入框的值
    tolN, tolS = getInputs()
    # 利用tolN,tolS，调用reDraw生成漂亮的图
    reDraw(tolS, tolN)

root = Tk()

Label(root, text='Plot Place Holder').grid(row=0, columnspan=3)

Label(root, text='tolN').grid(row=1, column=0)
tolNentry = Entry(root)
tolNentry.grid(row=1, column=1)
tolNentry.insert(0, '10')
Label(root, text='tolS').grid(row=2, column=0)
tolSentry = Entry(root)
tolSentry.grid(row=2, column=1)
tolSentry.insert(0, '1.0')
# 点击“ReDraw”按钮后，调用drawNewTree()函数
Button(root, text='ReDraw', command=drawNewTree).grid(row=1, column=2, rowspan=3)

chkBtnVar = IntVar()
chkBtn = Checkbutton(root, text='Model Tree', variable=chkBtnVar)
chkBtn.grid(row=3, column=0, columnspan=2)

reDraw.f = Figure(figsize=(5,4), dpi=100)
reDraw.canvas = FigureCanvasTkAgg(reDraw.f, master=root)
reDraw.canvas.show()
reDraw.canvas.get_tk_widget().grid(row=0, columnspan=3)

reDraw.rawDat = mat(loadDataSet('sine.txt'))
reDraw.testDat = arange(min(reDraw.rawDat[:, 0]), max(reDraw.rawDat[:, 0]), 0.01)

reDraw(1.0, 10)

root.mainloop()

