In [2]:
class treeNode():
    def __init__(self, feat, val, right, left):
        featureToSplitOn = feat 
        valueOfSplit = val 
        rightBranch = right 
        leftBranch = left 

In [3]:
import numpy as np

def loadDataSet(fileName):
    dataMat = []
    with open(fileName) as f:
        for line in f.readlines():
            curLine = line.strip().split('\t')
            fltLine = list(map(float, curLine))
            dataMat.append(fltLine)
    return dataMat

In [4]:
def binSplitDataSet(dataSet, feature, value):
    dataArr = np.array(dataSet)
    mat0 = np.mat(dataArr[dataArr[:, feature] > value])
    mat1 = np.mat(dataArr[dataArr[:, feature] <= value])
    return mat0, mat1 

In [5]:
def regLeaf(dataSet):#returns the value used for each leaf
    return np.mean(dataSet[:,-1])

def regErr(dataSet):
    return np.var(dataSet[:,-1]) * dataSet.shape[0]


In [6]:
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
    tolS = ops[0]; tolN = ops[1]
    if len(set(dataSet[:, -1].T.tolist()[0])) == 1:
        return None, leafType(dataSet)      # 如果所有值相等则退出
    m, n = dataSet.shape
    S = errType(dataSet)
    bestS = np.inf; bestIndex = 0; bestValue = 0
    for featIndex in range(n-1):
        vals = set(dataSet[:, featIndex].flatten().tolist()[0])
        for splitVal in vals:
            mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
            if mat0.shape[0] < tolN or mat1.shape[0] < tolN:    continue
            newS = errType(mat0) + errType(mat1)
            if newS < bestS:
                bestIndex = featIndex
                bestValue = splitVal 
                bestS = newS 
    if (S - bestS) < tolS:
        return None, leafType(dataSet)  # 如果误差减少不大则退出
    mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
    if mat0.shape[0] < tolN or mat1.shape[0] < tolN: 
        return None, leafType(dataSet)  # 如果切分的数据集很小则退出
    return bestIndex, bestValue
        

In [7]:
def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
    feat, val = chooseBestSplit(dataSet, leafType, errType, ops)
    if feat == None:    return val 
    retTree = {}
    retTree['spInd'] = feat 
    retTree['spVal'] = val 
    lSet, rSet = binSplitDataSet(dataSet, feat, val)
    retTree['left'] = createTree(lSet, leafType, errType, ops)
    retTree['right'] = createTree(rSet, leafType, errType, ops)
    return retTree

In [8]:
myDat = loadDataSet('ex00.txt')
myMat = np.mat(myDat)
createTree(myMat)

{'spInd': 0,
 'spVal': 0.48813,
 'left': 1.0180967672413792,
 'right': -0.04465028571428572}

In [9]:
myDat1 = loadDataSet('ex0.txt')
myMat1 = np.mat(myDat1)
createTree(myMat1)

{'spInd': 1,
 'spVal': 0.39435,
 'left': {'spInd': 1,
  'spVal': 0.582002,
  'left': {'spInd': 1,
   'spVal': 0.797583,
   'left': 3.9871632,
   'right': 2.9836209534883724},
  'right': 1.980035071428571},
 'right': {'spInd': 1,
  'spVal': 0.197834,
  'left': 1.0289583666666666,
  'right': -0.023838155555555553}}

In [10]:
# 预剪枝
bigTree = createTree(myMat, ops=(0,1))

In [11]:
myDat2 = loadDataSet('ex2.txt')
myMat2 = np.mat(myDat2)
createTree(myMat2)

{'spInd': 0,
 'spVal': 0.499171,
 'left': {'spInd': 0,
  'spVal': 0.729397,
  'left': {'spInd': 0,
   'spVal': 0.952833,
   'left': {'spInd': 0,
    'spVal': 0.958512,
    'left': 105.24862350000001,
    'right': 112.42895575000001},
   'right': {'spInd': 0,
    'spVal': 0.759504,
    'left': {'spInd': 0,
     'spVal': 0.790312,
     'left': {'spInd': 0,
      'spVal': 0.833026,
      'left': {'spInd': 0,
       'spVal': 0.944221,
       'left': 87.3103875,
       'right': {'spInd': 0,
        'spVal': 0.85497,
        'left': {'spInd': 0,
         'spVal': 0.910975,
         'left': 96.452867,
         'right': {'spInd': 0,
          'spVal': 0.892999,
          'left': 104.825409,
          'right': {'spInd': 0,
           'spVal': 0.872883,
           'left': 95.181793,
           'right': 102.25234449999999}}},
        'right': 95.27584316666666}},
      'right': {'spInd': 0,
       'spVal': 0.811602,
       'left': 81.110152,
       'right': 88.78449880000001}},
     'right': 102.

In [12]:
createTree(myMat2, ops=(1000, 4))

{'spInd': 0,
 'spVal': 0.499171,
 'left': {'spInd': 0,
  'spVal': 0.729397,
  'left': {'spInd': 0,
   'spVal': 0.952833,
   'left': 108.838789625,
   'right': {'spInd': 0,
    'spVal': 0.759504,
    'left': 95.7366680212766,
    'right': 78.08564325}},
  'right': 107.68699163829788},
 'right': -2.637719329787234}

In [13]:
# 后剪枝

def isTree(obj):
    return type(obj) == dict

def getMean(tree):
    if isTree(tree['right']):
        tree['right'] = getMean(tree['right'])
    if isTree(tree['left']):
        tree['left'] = getMean(tree['left'])
    return (tree['left'] + tree['right']) / 2.0 

In [14]:
def prune(tree, testData):
    if testData.shape[0] == 0:  return getMean(tree)    # 如果没有测试数据则对数进行塌陷处理
    if isTree(tree['right']) or isTree(tree['left']):
        lSet, rSet = binSplitDataSet(testData, tree['spInd'], 
                        tree['spVal'])
    if isTree(tree['left']): tree['left'] = prune(tree['left'], lSet)
    if isTree(tree['right']): tree['right'] = prune(tree['right'], rSet)
    if not isTree(tree['left']) and not isTree(tree['right']):
        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
        errorNoMerge = np.sum(np.power(lSet[:, -1] - tree['left'], 2)) +\
                    np.sum(np.power(rSet[:, -1] - tree['right'], 2))
        treeMean = (tree['left'] + tree['right']) / 2.0 
        errorMerge = np.sum(np.power(testData[:,-1] - treeMean, 2))
        if errorMerge < errorNoMerge:
            print("Merging")
            return treeMean
        else:
            return tree 
    else:
        return tree

In [15]:
myTree = createTree(myMat2, ops=(0,1))

In [16]:
myDatTest = loadDataSet('ex2test.txt')
myMat2Test = np.mat(myDatTest)

In [17]:
prune(myTree, myMat2Test) # 然鹅剪枝之后还是很多

Merging
Merging
Merging
Merging
Merging
Merging
Merging
Merging
Merging
Merging
Merging
Merging
Merging
Merging
Merging
Merging
Merging
Merging
Merging
Merging
Merging
Merging
Merging
Merging
Merging
Merging
Merging
Merging
Merging
Merging
Merging
Merging
Merging
Merging
Merging
Merging
Merging
Merging
Merging
Merging
Merging
Merging
Merging
Merging


{'spInd': 0,
 'spVal': 0.499171,
 'left': {'spInd': 0,
  'spVal': 0.729397,
  'left': {'spInd': 0,
   'spVal': 0.952833,
   'left': {'spInd': 0,
    'spVal': 0.965969,
    'left': 92.5239915,
    'right': {'spInd': 0,
     'spVal': 0.956951,
     'left': {'spInd': 0,
      'spVal': 0.958512,
      'left': {'spInd': 0,
       'spVal': 0.960398,
       'left': 112.386764,
       'right': 123.559747},
      'right': 135.837013},
     'right': 111.2013225}},
   'right': {'spInd': 0,
    'spVal': 0.759504,
    'left': {'spInd': 0,
     'spVal': 0.763328,
     'left': {'spInd': 0,
      'spVal': 0.769043,
      'left': {'spInd': 0,
       'spVal': 0.790312,
       'left': {'spInd': 0,
        'spVal': 0.806158,
        'left': {'spInd': 0,
         'spVal': 0.815215,
         'left': {'spInd': 0,
          'spVal': 0.833026,
          'left': {'spInd': 0,
           'spVal': 0.841547,
           'left': {'spInd': 0,
            'spVal': 0.841625,
            'left': {'spInd': 0,
            

In [18]:
# 模型树的节点生成函数

def linearSolve(dataSet):
    m,n = dataSet.shape
    X = np.mat(np.ones((m, n))); Y = np.mat(np.ones((m,1)))
    X[:,1:n] = dataSet[:,0:n-1]; Y = dataSet[:,-1]      # 将X，Y中的数据格式化
    xTx = X.T * X 
    if np.linalg.det(xTx) == 0.0:
        raise NameError('This matrix is singular, cannot do inverse,\n\
                try increasing the second value of ops')
    ws = xTx.I * (X.T * Y)
    return ws, X, Y 

def modelLeaf(dataSet):
    ws, X, Y = linearSolve(dataSet)
    return ws 

def modelErr(dataSet):
    ws, X, Y = linearSolve(dataSet)
    yHat = X * ws 
    return np.sum(np.power(Y - yHat, 2))

In [19]:
myMat2 = np.mat(loadDataSet('exp2.txt'))

In [20]:
createTree(myMat2, modelLeaf, modelErr, (1, 10))

{'spInd': 0,
 'spVal': 0.285477,
 'left': matrix([[1.69855694e-03],
         [1.19647739e+01]]),
 'right': matrix([[3.46877936],
         [1.18521743]])}

In [21]:
# 使用树回归进行预测的代码

def regTreeEval(model, inDat):
    return float(model)

def modelTreeEval(model, inDat):
    n = inDat.shape[1]
    X = np.mat(np.ones((1, n+1)))
    X[:, 1:n+1] = inDat 
    return float(X * model)

def treeForeCast(tree, inData, modelEval=regTreeEval):
    if not isTree(tree):     return modelEval(tree, inData)
    if inData[tree['spInd']] > tree['spVal']:
        if isTree(tree['left']):
            return treeForeCast(tree['left'], inData, modelEval)
        else:
            return modelEval(tree['left'], inData)
    else:
        if isTree(tree['right']):
            return treeForeCast(tree['right'], inData, modelEval)
        else:
            return modelEval(tree['right'], inData)

def createForeCast(tree, testData, modelEval=regTreeEval):
    m = len(testData)
    yHat = np.mat(np.zeros((m, 1)))
    for i in range(m):
        yHat[i, 0] = treeForeCast(tree, np.mat(testData[i]), modelEval)
    return yHat

In [22]:
trainMat = np.mat(loadDataSet('bikeSpeedVsIq_train.txt'))
testMat = np.mat(loadDataSet('bikeSpeedVsIq_test.txt'))
# 回归树
myTree = createTree(trainMat, ops=(1,20))
yHat = createForeCast(myTree, testMat[:,0])
np.corrcoef(yHat, testMat[:,1], rowvar=0)[0, 1]

0.9640852318222141

In [23]:
# 模型树
myTree = createTree(trainMat, modelLeaf, modelErr, ops=(1,20))
yHat = createForeCast(myTree, testMat[:,0], modelTreeEval)
np.corrcoef(yHat, testMat[:,1], rowvar=0)[0, 1]

0.9760412191380593

In [24]:
ws, X, Y = linearSolve(trainMat)
ws

matrix([[37.58916794],
        [ 6.18978355]])

In [25]:
for i in range(testMat.shape[0]):
    yHat[i] = testMat[i, 0] * ws[1, 0] + ws[0, 0]

In [26]:
# 线性结果
np.corrcoef(yHat, testMat[:,1], rowvar=0)[0, 1]

0.9434684235674763

# 使用Tkinter创建GUI

In [35]:
import tkinter
import matplotlib
matplotlib.use('TkAgg')
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from matplotlib.figure import Figure

def reDraw(tolS,tolN):
    reDraw.f.clf()        # clear the figure
    reDraw.a = reDraw.f.add_subplot(111)
    if chkBtnVar.get():
        if tolN < 2: tolN = 2
        myTree=createTree(reDraw.rawDat, modelLeaf,\
                                   modelErr, (tolS,tolN))
        yHat = createForeCast(myTree, reDraw.testDat, \
                                       modelTreeEval)
    else:
        myTree=createTree(reDraw.rawDat, ops=(tolS,tolN))
        yHat = createForeCast(myTree, reDraw.testDat)
    reDraw.a.scatter(reDraw.rawDat[:,0].tolist(), reDraw.rawDat[:,1].tolist(), s=5) #use scatter for data set
    reDraw.a.plot(reDraw.testDat, yHat, linewidth=2.0) #use plot for yHat
    reDraw.canvas.draw()
    
def getInputs():
    try: tolN = int(tolNentry.get())
    except: 
        tolN = 10 
        print ("enter Integer for tolN")
        tolNentry.delete(0, tkinter.END)
        tolNentry.insert(0,'10')
    try: tolS = float(tolSentry.get())
    except: 
        tolS = 1.0 
        print ("enter Float for tolS")
        tolSentry.delete(0, tkinter.END)
        tolSentry.insert(0,'1.0')
    return tolN,tolS

def drawNewTree():
    tolN,tolS = getInputs()#get values from Entry boxes
    reDraw(tolS,tolN)

In [36]:
root=tkinter.Tk()

reDraw.f = Figure(figsize=(5,4), dpi=100) #create canvas
reDraw.canvas = FigureCanvasTkAgg(reDraw.f, master=root)
reDraw.canvas.draw()
reDraw.canvas.get_tk_widget().grid(row=0, columnspan=3)

tkinter.Label(root, text="tolN").grid(row=1, column=0)
tolNentry = tkinter.Entry(root)
tolNentry.grid(row=1, column=1)
tolNentry.insert(0,'10')
tkinter.Label(root, text="tolS").grid(row=2, column=0)
tolSentry = tkinter.Entry(root)
tolSentry.grid(row=2, column=1)
tolSentry.insert(0,'1.0')
tkinter.Button(root, text="ReDraw", command=drawNewTree).grid(row=1, column=2, rowspan=3)
chkBtnVar = tkinter.IntVar()
chkBtn = tkinter.Checkbutton(root, text="Model Tree", variable = chkBtnVar)
chkBtn.grid(row=3, column=0, columnspan=2)

reDraw.rawDat = np.mat(loadDataSet('sine.txt'))
reDraw.testDat = np.arange(min(reDraw.rawDat[:,0]),max(reDraw.rawDat[:,0]),0.01)
reDraw(1.0, 10)
               
root.mainloop()