In [41]:
import numpy as np
import pandas as pd

In [42]:
x = np.arange(1, 11).reshape(-1, 1)
y = np.array([5.56,	5.7, 5.91, 6.4,	6.8, 7.05, 8.9,	8.7, 9,	9.05]).reshape(-1, 1)
data = np.concatenate((x, y), axis=1)
data

array([[ 1.  ,  5.56],
       [ 2.  ,  5.7 ],
       [ 3.  ,  5.91],
       [ 4.  ,  6.4 ],
       [ 5.  ,  6.8 ],
       [ 6.  ,  7.05],
       [ 7.  ,  8.9 ],
       [ 8.  ,  8.7 ],
       [ 9.  ,  9.  ],
       [10.  ,  9.05]])

In [43]:
def split_record(mat, axis=0):
    """某个特征上不同切分点对应的c1,c2和loss"""
    point_list = list() # 不同的切分点
    for i in range(len(mat) - 1):
        point_list.append((mat[i][axis] + mat[i+1][axis])/2)
        
    record = list()
    for point in point_list:
        c1 = np.mean(mat[mat[:, 0] <= point], axis=0)[1]
        c2 = np.mean(mat[mat[:, 0] > point], axis=0)[1]
        the_loss = sum(list(map(lambda  s: s**2, mat[mat[:, 0] <= point][:, 1] - c1))) +\
                   sum(list(map(lambda  z: z**2, mat[mat[:, 0] > point][:, 1] - c2)))
        record.append([point, c1, c2, the_loss])
    
    return np.array(record)
        

In [44]:
first_record = split_record(data, axis=0)
columns = ['切分点', 'c1', 'c2', 'loss']
pd.DataFrame(first_record.T, index=columns)

Unnamed: 0,0,1,2,3,4,5,6,7,8
切分点,1.5,2.5,3.5,4.5,5.5,6.5,7.5,8.5,9.5
c1,5.56,5.63,5.723333,5.8925,6.074,6.236667,6.617143,6.8775,7.113333
c2,7.501111,7.72625,7.985714,8.25,8.54,8.9125,8.916667,9.025,9.05
loss,15.723089,12.083388,8.365638,5.775475,3.91132,1.930008,8.00981,11.7354,15.7386


In [45]:
def find_best_split(mat, pre_record, axis=0):
    """最优的划分:r1,r2"""
    best_point = pre_record[pre_record[:, 3] == np.min(pre_record[:, 3])][0][0]
    print('最佳切分点为:', best_point)
    r1 = mat[mat[:, axis] <= best_point]
    r2 = mat[mat[:, axis] > best_point]
    
    return r1, r2

In [46]:
left1, right1 = find_best_split(data, first_record) # 第一次划分(二叉树)

print(left1, end='\n\n')
print(right1)

最佳切分点为: 6.5
[[1.   5.56]
 [2.   5.7 ]
 [3.   5.91]
 [4.   6.4 ]
 [5.   6.8 ]
 [6.   7.05]]

[[ 7.    8.9 ]
 [ 8.    8.7 ]
 [ 9.    9.  ]
 [10.    9.05]]


In [47]:
second_record = split_record(left1, axis=0)
pd.DataFrame(second_record.T, index=columns)

Unnamed: 0,0,1,2,3,4
切分点,1.5,2.5,3.5,4.5,5.5
c1,5.56,5.63,5.723333,5.8925,6.074
c2,6.372,6.54,6.75,6.925,7.05
loss,1.30868,0.754,0.277067,0.436725,1.06432


In [48]:
left2, right2 = find_best_split(left1, second_record) # 第二次划分
print(left2, end='\n\n')
print(right2)

最佳切分点为: 3.5
[[1.   5.56]
 [2.   5.7 ]
 [3.   5.91]]

[[4.   6.4 ]
 [5.   6.8 ]
 [6.   7.05]]
