# 自定义回归决策树
算法介绍：CART回归树样本空间细分为若干个子空间，子空间内样本的输出y（连续值）的均值即为该子空间内的预测值。故对于输入X为一维时，预测结果可表示为阶梯函数。  
评估方式采用**平方误差**：$y_i$属于某个数据集，c为该数据上输出向量y的均值。
$$
err = \sum(y_i - c)^2
$$

算法过程：  
输入：训练数据集$D$；  
输出：回归树$f(x)$  
在训练数据集所在的输入空间中，递归地将每个区域划分为两个子区域并决定每个子区域上的输出值，构建二叉决策树：  
（1）选择最优切分变量$j$与切分点$s$，求解
$$
\min_{j,s}[\min_{c_1}\sum_{x_i\in R_1(j,s)}(y_i - c_1)^2 + \min_{c_2}\sum_{x_i\in R_2(j, s)}(y_i-c_2)^2]
$$
遍历变量$j$，对固定的切分变量扫描切分点$s$，选择使得上式达到最小值的对$(j,s)$.  
（2）用选定的对$(j,s)$划分区域并决定相应的输出值：
$$
R_1(j,s) = \{x|x^{(j)} \le s \},R_2(j,s) = \{x|x^{(j)} \gt s \} \\
\hat{c}_m = \frac{1}{N_m} \sum_{x_i\in R_m(j, s) y_i}, x \in R_m, m=1,2
$$  
（3）继续对两个子区域调用步骤（1）和（2），直至满足停止条件  
（4）将输入空间划分为$M$个区域$R_1, R_2, \dots, R_M$，生成决策树：
$$
f(x) = \sum^M_{m=1}\hat{c}_mI(x\in R_m)
$$

In [15]:
import numpy as np

## 最小二乘损失

In [24]:
def err(dataSet):
    return np.var(dataSet[:, -1]) * shape(dataSet)[0]

## 划分数据集

In [17]:
def splitDataSet(dataSet, feature, value):
    '''
    Input:
        dataSet：当前数据集
        feature：切分变量[列名]
        value：划分点
    Output:
        dataSet1：在feature上<=value的子数据集
        dataSet2：在feature上>value的子数据集
    '''
    dataSet1 = dataSet[dataSet[:, feature] <= value] # 左边
    dataSet2 = dataSet[dataSet[:, feature] > value] # 右边
    return dataSet1, dataSet2

## 选择最好的特征用于划分数据集

In [50]:
def chooseBestFeature(dataSet, min_sample=4, epsilon=0.5):
    '''
    Input:
        dataSet：当前数据集
        min_sample：每次划分后，每部分最少的数据数量
        epsilon：误差下降阈值，值越大树的深度越大
    Output:
        bestColumn：最优划分属性
        bestValue：最优划分点
    '''
    features = dataSet.shape[1] - 1 # 特征数量（除去最后一列的标签值）
    sErr = err(dataSet) # 当前数据集的损失
    minErr = np.inf # 初始化最小误差
    bestColumn = 0 # 最优划分特征
    bestValue = 0 # 最优划分值
    nowErr = 0 # 当前误差
    
    # 如果数据都是一类，无须进行划分
    if len(np.unique(dataSet[:, -1].T.tolist())) == 1:
        return None, np.mean(dataSet[:, -1])
    # 每个特征循环，寻找最优特征
    for feature in range(0, features):
        # 遍历每一行数据，寻找最优划分点
        for row in range(0, dataSet.shape[0]):
            dataSet1, dataSet2 = splitDataSet(dataSet, feature, dataSet[row, feature]) # 划分后的数据
            # 不满足min_sample，直接跳过这种不合法的划分
            if len(dataSet1) < min_sample or len(dataSet2) < min_sample:
                continue
            # 计算当前这种划分的误差
            nowErr = err(dataSet1) + err(dataSet2)
            # 维护最优的划分（最优属性和对应的最优划分点）
            if nowErr < minErr:
                minErr = nowErr
                bestColumn = feature
                bestValue = dataSet[row, feature]
    # 当划分前后误差下降较小时，直接返回
    if (sErr - minErr) < epsilon:
        return None, np.mean(dataSet[:, -1])
    
    # 获得当前最优划分
    dataSet1, dataSet2 = splitDataSet(dataSet, bestColumn, bestValue)
    if len(dataSet1) < min_sample or len(dataSet2) < min_sample:
        return None, np.mean(dataSet[:, -1])
    
    return bestColumn, bestValue

## 创建回归树

In [51]:
def createTree(dataSet):
    '''
    Input:
        dataSet: 数据集D
    Output:
        决策树T
    '''
    bestColumn, bestValue = chooseBestFeature(dataSet)
    if bestColumn == None:
        return bestValue
    retTree = {} # 初始化决策树
    retTree['spCol'] = bestColumn # 最优划分属性（列）
    retTree['spVal'] = bestValue # 最优分割值
    lSet,rSet = splitDataSet(dataSet, bestColumn, bestValue) # 最优划分
    retTree['left'] = createTree(lSet)
    retTree['right'] = createTree(rSet)
    return retTree

## 剪枝

In [52]:
def prune(tree, testData):
    if shape(testData)[0] == 0:
        return getMean(tree)
    if (isTree(tree['right']) or isTree(tree['left'])):
        lSet, rSet = splitDataSet(testData, tree['spCol'], tree['spVal'])
    if isTree(tree['left']):
        tree['left'] = prune(tree['left'], lSet)
    if isTree(tree['right']):
        tree['right'] = prune(tree['right'], rSet)
        
    # 如果两个分支不再是子树，合并
    # 合并前后的误差进行比较，如果合并后的误差比较小，则合并，否则不操作
    if not isTree(tree['left']) and not isTree(tree['right']):
        lSet, rSet = splitDataSet(testData, tree['spCol'], tree['spVal'])
        errMerge = err(dataSet)
        errNoMerge = err(lSet) + err(rSet)
        if errMerge > errNoMerge:
            print('merging')
            return (tree['left'] + tree['right']) / 2.0
        else:
            return tree
        
def isTree(obj):
    return (type(obj).__name__ == 'dict')

def getMean(obj):
    if isTree(tree['right']):
        tree['right'] = getMean(tree['right'])
        tree['left'] = getMean(tree['left'])
        return (tree['left'] + tree['right']) / 2.0

## 预测

In [53]:
def predictSample(Tree, testData):
    if not isTree(Tree):
        return float(tree)
    
    # 数据比当前节点小，去左子树
    if testData[0, Tree['spCol']] < Tree['spVal']:
        if isTree(Tree['left']):
            return predictSample(Tree['left'], testData)
        else:
            return float(Tree['left'])
    else:
        if isTree(Tree['right']):
            return predictSample(Tree['right'], testData)
        else:
            return float(Tree['right'])
        
def predict(Tree, testData):
    m = shape(testData)[0]
    y_pred = mat(zeros((m, 1)))
    
    for i in range(m):
        y_pred[i, 0] = predictSample(Tree, testData[i])
        
    return y_pred

## 运行

In [54]:
# 导入数据集
def loadData(filaName):
    dataSet = []
    fr = open(filaName)
    for line in fr.readlines():
        curLine = line.strip().split('\t')
        theLine = []
        for item in curLine:
            item = float(item)
            theLine.append(item)
        dataSet.append(theLine)
    return dataSet

In [60]:
from sklearn.datasets import load_iris
iris = load_iris()
x = iris.data
y = iris.target
dataSet=np.column_stack((x,y.reshape((-1,1))))
print(dataSet)

[[5.1 3.5 1.4 0.2 0. ]
 [4.9 3.  1.4 0.2 0. ]
 [4.7 3.2 1.3 0.2 0. ]
 [4.6 3.1 1.5 0.2 0. ]
 [5.  3.6 1.4 0.2 0. ]
 [5.4 3.9 1.7 0.4 0. ]
 [4.6 3.4 1.4 0.3 0. ]
 [5.  3.4 1.5 0.2 0. ]
 [4.4 2.9 1.4 0.2 0. ]
 [4.9 3.1 1.5 0.1 0. ]
 [5.4 3.7 1.5 0.2 0. ]
 [4.8 3.4 1.6 0.2 0. ]
 [4.8 3.  1.4 0.1 0. ]
 [4.3 3.  1.1 0.1 0. ]
 [5.8 4.  1.2 0.2 0. ]
 [5.7 4.4 1.5 0.4 0. ]
 [5.4 3.9 1.3 0.4 0. ]
 [5.1 3.5 1.4 0.3 0. ]
 [5.7 3.8 1.7 0.3 0. ]
 [5.1 3.8 1.5 0.3 0. ]
 [5.4 3.4 1.7 0.2 0. ]
 [5.1 3.7 1.5 0.4 0. ]
 [4.6 3.6 1.  0.2 0. ]
 [5.1 3.3 1.7 0.5 0. ]
 [4.8 3.4 1.9 0.2 0. ]
 [5.  3.  1.6 0.2 0. ]
 [5.  3.4 1.6 0.4 0. ]
 [5.2 3.5 1.5 0.2 0. ]
 [5.2 3.4 1.4 0.2 0. ]
 [4.7 3.2 1.6 0.2 0. ]
 [4.8 3.1 1.6 0.2 0. ]
 [5.4 3.4 1.5 0.4 0. ]
 [5.2 4.1 1.5 0.1 0. ]
 [5.5 4.2 1.4 0.2 0. ]
 [4.9 3.1 1.5 0.2 0. ]
 [5.  3.2 1.2 0.2 0. ]
 [5.5 3.5 1.3 0.2 0. ]
 [4.9 3.6 1.4 0.1 0. ]
 [4.4 3.  1.3 0.2 0. ]
 [5.1 3.4 1.5 0.2 0. ]
 [5.  3.5 1.3 0.3 0. ]
 [4.5 2.3 1.3 0.3 0. ]
 [4.4 3.2 1.3 0.2 0. ]
 [5.  3.5 1

In [62]:
mytree=createTree(dataSet) 
print('mytree\n',mytree)

mytree
 {'spCol': 2, 'spVal': 1.9, 'left': 0.0, 'right': {'spCol': 3, 'spVal': 1.7, 'left': {'spCol': 2, 'spVal': 4.9, 'left': 1.0208333333333333, 'right': 1.6666666666666667}, 'right': 1.9782608695652173}}


## 封装

In [63]:
import numpy as np
from sklearn.datasets import load_iris

In [71]:
class SimpleRegressionTree():
    def __init__(self):
        return
    def err(self, dataSet):
        return np.var(dataSet[:, -1]) * shape(dataSet)[0]
    
    def splitDataSet(self, dataSet, feature, value):
        '''
        Input:
            dataSet：当前数据集
            feature：切分变量[列名]
            value：划分点
        Output:
            dataSet1：在feature上<=value的子数据集
            dataSet2：在feature上>value的子数据集
        '''
        dataSet1 = dataSet[dataSet[:, feature] <= value] # 左边
        dataSet2 = dataSet[dataSet[:, feature] > value] # 右边
        return dataSet1, dataSet2
    
    def chooseBestFeature(self, dataSet, min_sample=4, epsilon=0.5):
        '''
        Input:
            dataSet：当前数据集
            min_sample：每次划分后，每部分最少的数据数量
            epsilon：误差下降阈值，值越大树的深度越大
        Output:
            bestColumn：最优划分属性
            bestValue：最优划分点
        '''
        features = dataSet.shape[1] - 1 # 特征数量（除去最后一列的标签值）
        sErr = err(dataSet) # 当前数据集的损失
        minErr = np.inf # 初始化最小误差
        bestColumn = 0 # 最优划分特征
        bestValue = 0 # 最优划分值
        nowErr = 0 # 当前误差

        # 如果数据都是一类，无须进行划分
        if len(np.unique(dataSet[:, -1].T.tolist())) == 1:
            return None, np.mean(dataSet[:, -1])
        # 每个特征循环，寻找最优特征
        for feature in range(0, features):
            # 遍历每一行数据，寻找最优划分点
            for row in range(0, dataSet.shape[0]):
                dataSet1, dataSet2 = splitDataSet(dataSet, feature, dataSet[row, feature]) # 划分后的数据
                # 不满足min_sample，直接跳过这种不合法的划分
                if len(dataSet1) < min_sample or len(dataSet2) < min_sample:
                    continue
                # 计算当前这种划分的误差
                nowErr = err(dataSet1) + err(dataSet2)
                # 维护最优的划分（最优属性和对应的最优划分点）
                if nowErr < minErr:
                    minErr = nowErr
                    bestColumn = feature
                    bestValue = dataSet[row, feature]
        # 当划分前后误差下降较小时，直接返回
        if (sErr - minErr) < epsilon:
            return None, np.mean(dataSet[:, -1])

        # 获得当前最优划分
        dataSet1, dataSet2 = splitDataSet(dataSet, bestColumn, bestValue)
        if len(dataSet1) < min_sample or len(dataSet2) < min_sample:
            return None, np.mean(dataSet[:, -1])

        return bestColumn, bestValue

    def createTree(self, dataSet):
        '''
        Input:
            dataSet: 数据集D
        Output:
            决策树T
        '''
        bestColumn, bestValue = chooseBestFeature(dataSet)
        if bestColumn == None:
            return bestValue
        retTree = {} # 初始化决策树
        retTree['spCol'] = bestColumn # 最优划分属性（列）
        retTree['spVal'] = bestValue # 最优分割值
        lSet,rSet = splitDataSet(dataSet, bestColumn, bestValue) # 最优划分
        retTree['left'] = createTree(lSet)
        retTree['right'] = createTree(rSet)
        return retTree
    
    def prune(self, tree, testData):
        if shape(testData)[0] == 0:
            return getMean(tree)
        if (isTree(tree['right']) or isTree(tree['left'])):
            lSet, rSet = splitDataSet(testData, tree['spCol'], tree['spVal'])
        if isTree(tree['left']):
            tree['left'] = prune(tree['left'], lSet)
        if isTree(tree['right']):
            tree['right'] = prune(tree['right'], rSet)

        # 如果两个分支不再是子树，合并
        # 合并前后的误差进行比较，如果合并后的误差比较小，则合并，否则不操作
        if not isTree(tree['left']) and not isTree(tree['right']):
            lSet, rSet = splitDataSet(testData, tree['spCol'], tree['spVal'])
            errMerge = err(dataSet)
            errNoMerge = err(lSet) + err(rSet)
            if errMerge > errNoMerge:
                print('merging')
                return (tree['left'] + tree['right']) / 2.0
            else:
                return tree

    def isTree(self, obj):
        return (type(obj).__name__ == 'dict')

    def getMean(self, obj):
        if isTree(tree['right']):
            tree['right'] = getMean(tree['right'])
            tree['left'] = getMean(tree['left'])
            return (tree['left'] + tree['right']) / 2.0
        
    # 导入数据集
    def loadData(self, filaName):
        dataSet = []
        fr = open(filaName)
        for line in fr.readlines():
            curLine = line.strip().split('\t')
            theLine = []
            for item in curLine:
                item = float(item)
                theLine.append(item)
            dataSet.append(theLine)
        return dataSet
    def getTree(self):
        return 

In [73]:
iris = load_iris()
x = iris.data
y = iris.target
dataSet=np.column_stack((x,y.reshape((-1,1))))
print(dataSet)

mytree = SimpleRegressionTree()
tree = mytree.createTree(dataSet)
print('mytree\n',tree)

[[5.1 3.5 1.4 0.2 0. ]
 [4.9 3.  1.4 0.2 0. ]
 [4.7 3.2 1.3 0.2 0. ]
 [4.6 3.1 1.5 0.2 0. ]
 [5.  3.6 1.4 0.2 0. ]
 [5.4 3.9 1.7 0.4 0. ]
 [4.6 3.4 1.4 0.3 0. ]
 [5.  3.4 1.5 0.2 0. ]
 [4.4 2.9 1.4 0.2 0. ]
 [4.9 3.1 1.5 0.1 0. ]
 [5.4 3.7 1.5 0.2 0. ]
 [4.8 3.4 1.6 0.2 0. ]
 [4.8 3.  1.4 0.1 0. ]
 [4.3 3.  1.1 0.1 0. ]
 [5.8 4.  1.2 0.2 0. ]
 [5.7 4.4 1.5 0.4 0. ]
 [5.4 3.9 1.3 0.4 0. ]
 [5.1 3.5 1.4 0.3 0. ]
 [5.7 3.8 1.7 0.3 0. ]
 [5.1 3.8 1.5 0.3 0. ]
 [5.4 3.4 1.7 0.2 0. ]
 [5.1 3.7 1.5 0.4 0. ]
 [4.6 3.6 1.  0.2 0. ]
 [5.1 3.3 1.7 0.5 0. ]
 [4.8 3.4 1.9 0.2 0. ]
 [5.  3.  1.6 0.2 0. ]
 [5.  3.4 1.6 0.4 0. ]
 [5.2 3.5 1.5 0.2 0. ]
 [5.2 3.4 1.4 0.2 0. ]
 [4.7 3.2 1.6 0.2 0. ]
 [4.8 3.1 1.6 0.2 0. ]
 [5.4 3.4 1.5 0.4 0. ]
 [5.2 4.1 1.5 0.1 0. ]
 [5.5 4.2 1.4 0.2 0. ]
 [4.9 3.1 1.5 0.2 0. ]
 [5.  3.2 1.2 0.2 0. ]
 [5.5 3.5 1.3 0.2 0. ]
 [4.9 3.6 1.4 0.1 0. ]
 [4.4 3.  1.3 0.2 0. ]
 [5.1 3.4 1.5 0.2 0. ]
 [5.  3.5 1.3 0.3 0. ]
 [4.5 2.3 1.3 0.3 0. ]
 [4.4 3.2 1.3 0.2 0. ]
 [5.  3.5 1