# Chapter 9 Tree Regression

## 9.2 连续和离散型特征的树的构建

In [None]:
# -*- coding: utf-8 -*-
import numpy as np

In [None]:
#this cell make code below runnable, and it will be written repeatedly below, so ingore this cell temporarily 
def regLeaf(dataSet):
    return np.mean(dataSet[:, -1])

def regErr(dataSet):
    return np.var(dataSet[:,-1]) * np.shape(dataSet)[0]

In [None]:
def loadDataSet(fileName):
    '''
    read the data file using TAB as separator,and store the data in float list
    '''
    dataMat = []
    fr = open(fileName)
    for line in fr.readlines():
        curLine = line.strip().split('\t')
        fltLine = list(map(float, curLine))
        dataMat.append(fltLine)
    return dataMat

def binSplitDataSet(dataSet, feature, value):
    mat0 = dataSet[np.nonzero(dataSet[:,feature]  > value)[0],:]
    mat1 = dataSet[np.nonzero(dataSet[:,feature] <= value)[0],:]
    return mat0, mat1

def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
    feat, val =chooseBestSplit(dataSet, leafType, errType, ops)
    if feat == None: return val
    retTree = {}
    retTree['spInd'] = feat
    retTree['spVal'] = val
    lSet, rSet = binSplitDataSet(dataSet, feat, val)
    retTree['left'] = createTree(lSet, leafType, errType, ops)
    retTree['right'] = createTree(rSet, leafType, errType, ops)
    return retTree

In [None]:
#test the function above, need to add print()
testMat = np.mat(np.eye(4))
mat0, mat1 = binSplitDataSet(testMat, 1, 0.5)

## 9.3 将CART算法用于回归

In [None]:
def regLeaf(dataSet):
    return np.mean(dataSet[:, -1])

def regErr(dataSet):
    return np.var(dataSet[:,-1]) * np.shape(dataSet)[0]

#choose the best feature and splitting value
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
    tolS = ops[0] #tolerant value of S decilne
    tolN = ops[1] #min number of samples to be splitted
    if len(set(dataSet[:,-1].T.tolist()[0])) == 1:
        return None, leafType(dataSet)
    m,n = np.shape(dataSet)
    S = errType(dataSet)
    bestS = np.inf;
    bestIndex= 0;
    bestValue = 0
    for featIndex in range(n-1):
        for splitVal in set(dataSet[:,featIndex].T.tolist()[0]):
            mat0, mat1 = binSplitDataSet(dataSet,featIndex, splitVal)
            if(np.shape(mat0)[0] < tolN) or (np.shape(mat1)[0] < tolN): continue
            newS = errType(mat0) + errType(mat1)
            if newS < bestS:
                bestIndex = featIndex
                bestValue = splitVal
                bestS = newS
    #verdict whether the deciline of S reach the tolS or not
    if (S - bestS) < tolS: 
        return None, leafType(dataSet)
    mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
    if(np.shape(mat0)[0] < tolN) or (np.shape(mat1)[0] < tolN):
        return None, leafType(dataSet)
    return bestIndex, bestValue

In [None]:
myDat = loadDataSet('./data/ex00.txt')
createTree(np.mat(myDat))

In [None]:
myDat1 = loadDataSet('./data/ex0.txt')
myMat1 = np.mat(myDat1)
createTree(myMat1)

## 9.4 Tree Pruning

In [None]:
createTree(np.mat(myDat), ops=(0,1))

In [None]:
myDat2 = loadDataSet('./data/ex2.txt')
myMat2 = np.mat(myDat2)
createTree(myMat2)

In [None]:
def isTree(obj):
    return (type(obj).__name__=='dict')

def getMean(tree):
    if isTree(tree['right']): tree['right'] = getMean(tree['right'])
    if isTree(tree['left']) : tree['left']  = getMean(tree['left']) 
    return (tree['left'] + tree['right'])/2.0

def prune(tree, testData):
    if np.shape(testData)[0] == 0: return getMean(tree)
    if(isTree(tree['right']) or isTree(tree['left'])): 
        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
    if isTree(tree['left']): tree['left'] = prune(tree['left'], lSet)
    if isTree(tree['right']):tree['right']= prune(tree['right'],rSet)
    if not isTree(tree['left']) and not isTree(tree['right']):
        lSet, rSet = binSplitDataSet(testData, tree['spInd'],tree['spVal'])
        errorNoMerge = np.sum(np.power(lSet[:,-1] - tree['left'], 2)) + np.sum(np.power(rSet[:,-1] - tree['right'], 2))
        treeMean = (tree['left']+tree['right'])/2.0
        errorMerge = np.sum(np.power(testData[:,-1] - treeMean, 2))
        if errorMerge < errorNoMerge:
            print("merging")
            return treeMean
        else: return tree
    else: return tree

In [None]:
myTree = createTree(myMat2, ops=(0,1))
#load the test data
myDatTest = loadDataSet('./data/ex2test.txt')
myMat2Test = np.mat(myDatTest)
prune(myTree, myMat2Test)

## 9.5 Model Tree

In [None]:
def linearSolve(dataSet):
    m,n = np.shape(dataSet)
    X = np.mat(np.ones((m,n)))
    Y = np.mat(np.ones((m,1)))
    X[:,1:n] = dataSet[:,0:n-1]
    Y = dataSet[:,-1]
    xTx = X.T*X
    if np.linalg.det(xTx) == 0.0:
        raise NameError("This matrix is singular, cannot do inverse,\ntry increasing the second value of ops")
    ws = xTx.I * (X.T*Y)
    return ws, X , Y
def modelLeaf(dataSet):
    ws, X, Y = linearSolve(dataSet)
    return ws

def modelErr(dataSet):
    ws, X, Y = linearSolve(dataSet)
    yHat = X * ws
    return np.sum(np.power(Y - yHat, 2))

In [None]:
myMat2 = np.mat(loadDataSet('./data/exp2.txt'))

In [None]:
createTree(myMat2, modelLeaf, modelErr, (1,10))

In [None]:
def regTreeEval(model, inDat):
    return float(model)

def modelTreeEval(model, inDat):
    n = np.shape(inDat)[1]
    X = np.mat(np.ones((1,n+1)))
    X[:,1:n+1] = inDat
    return float(X*model)

def treeForecast(tree, inData, modelEval=regTreeEval):
    if not isTree(tree): return modelEval(tree, inData)
    if inData[tree['spInd']] > tree['spVal']:
        if isTree(tree['left']):
            return treeForecast(tree['left'], inData, modelEval)
        else:
            return modelEval(tree['left'], inData)
    else:
        if isTree(tree['right']):
            return treeForecast(tree['right'], inData, modelEval)
        else:
            return modelEval(tree['right'], inData)

def createForecast(tree, testData, modelEval=regTreeEval):
    m = len(testData)
    yHat = np.mat(np.zeros((m,1)))
    for i in range(m):
        yHat[i,0] = treeForecast(tree, np.mat(testData[i]), modelEval)
    return yHat

In [None]:
trainMat = np.mat(loadDataSet('./data/bikeSpeedVsIq_train.txt'))
testMat = np.mat(loadDataSet('./data/bikeSpeedVsIq_test.txt'))
myTree = createTree(trainMat, ops=(1,20))
yHat = createForecast(myTree, testMat[:,0])
np.corrcoef(yHat, testMat[:,1], rowvar=0)[0,1]

In [None]:
myModelTree = createTree(trainMat, modelLeaf, modelErr, ops=(1,20))
yModelTreeHat = createForecast(myModelTree, testMat[:,0], modelTreeEval)
np.corrcoef(yModelTreeHat, testMat[:,1], rowvar=0)[0,1]

In [None]:
ws,X,Y = linearSolve(trainMat)
ws

In [None]:
yHat = np.mat(np.zeros((len(testMat),1)))
for i in range(np.shape(testMat)[0]):
    yHat[i] = testMat[i,0] * ws[1,0] + ws[0,0]
np.corrcoef(yHat, testMat[:,1], rowvar=0)[0,1]

## 9.7 使用Python的Tkinter库创建GUI

In [None]:
import tkinter as tk

In [None]:
import matplotlib

matplotlib.use('TkAgg')
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from matplotlib.figure import Figure

def reDraw(tolS, tolN):
    reDraw.f.clf()
    reDraw.a = reDraw.f.add_subplot(111)
    
    if chkBtnVar.get():
        if tolN < 2: tolN = 2
        myTree = createTree(reDraw.rawDat, modelLeaf, modelErr, (tolS,tolN))
        yHat = createForecast(myTree, reDraw.testDat, modelTreeEval)
    else:
        myTree = createTree(reDraw.rawDat, ops=(tolS, tolN))
        yHat = createForecast(myTree, reDraw.testDat)
        
    reDraw.a.scatter(reDraw.rawDat[:,0].A, reDraw.rawDat[:,1].A, s=5)
    reDraw.a.plot(reDraw.testDat, yHat, linewidth=2.0)
    
    reDraw.canvas.show()
    
    
def getInput():
    try:
        tolN = int(tolNentry.get())
    except:
        tolN = 10
        print("enter Integet for tolN")
        tolNentry.delete(0, tk.END)
        tolNentry.insert(0, "10")
    try:
        tolS = float(tolSentry.get())
    except:
        tolS = 1.0
        print("enter Integet for tolS")
        tolNentry.delete(0, tk.END)
        tolNentry.insert(0, "1.0")
    return tolN, tolS

def drawNewTree():
    tolN, tolS = getInput()
    reDraw(tolS, tolN)

If your python broke down when you run the code below, please try to reinstall the matplotlib. :)

In [None]:
root = tk.Tk()

#tk.Label(root, text="Plot Place Holder").grid(row=0, columnspan=3)

reDraw.f = Figure(figsize=(5,4), dpi=100)
reDraw.canvas = FigureCanvasTkAgg(reDraw.f, master=root)
reDraw.canvas.show()
reDraw.canvas.get_tk_widget().grid(row=0, columnspan=3)


tk.Label(root, text="tolN").grid(row=1, column=0)
tolNentry = tk.Entry(root)
tolNentry.grid(row=1, column=1)
tolNentry.insert(0, '10')
tk.Label(root, text="tolS").grid(row=2, column=0)
tolSentry = tk.Entry(root)
tolSentry.grid(row=2, column=1)
tolSentry.insert(0, '1.0')
tk.Button(root, text="ReDraw", command=drawNewTree).grid(row=1,column=2, rowspan=3)

chkBtnVar = tk.IntVar()
chkBtn = tk.Checkbutton(root, text="Model Tree", variable= chkBtnVar)
chkBtn.grid(row=3, column=0, columnspan=2)

reDraw.rawDat = np.mat(loadDataSet('./data/sine.txt'))
reDraw.testDat = np.arange(np.min(reDraw.rawDat[:,0]), np.max(reDraw.rawDat[:,0]), 0.01)

reDraw(1.0, 10)

root.mainloop()