In [97]:
import scipy.stats as st
from sklearn import linear_model
import statistics
import numpy as np
import math
import random

In [98]:
class Tree:
    def __init__(self):
        self.leaves = None
        self.data = None

In [99]:
def createTree(X, y):
    t = Tree()
    t.data = {'X': X, 'y': y}
    
    return t
# 2.2.2 Stopping rule
def shouldSplit(t, params):

    samples = [{'X':[],'y':[]} for i in range(params['V'])]
    
    data_indexes = list(range(len(t.data['X'])))
    random.shuffle(data_indexes)
    
    i = 0
    for i in range(0, len(data_indexes)):
        samples[i % params['V']]['X'].append(t.data['X'][data_indexes[i]])
        samples[i % params['V']]['y'].append(t.data['y'][data_indexes[i]])
    
    print(samples)
    
    return False
        
# 2.2.1 Split selection
def splitNode(t, params):
    
    min_p_value = float('inf')
    
    for k in range(0, len(t.data['X'][0])):
        
        X_k = [t.data['X'][i][k] for i in range(0, len(t.data['X']))]

        class1, class2 = {'X':[], 'y':[]}, {'X':[], 'y':[]}

        #Split selection (1. Model selection)

        slope, intercept,_,_,_ = st.linregress(X_k, t.data['y'])

        #Split selection (2. Residuals)

        for i in range(0, len(t.data['X'])):
            if t.data['y'][i] >= slope * X_k[i] + intercept:
                class1['X'].append(X_k[i])
                class1['y'].append(t.data['y'][i])

            else:
                class2['X'].append(X_k[i])
                class2['y'].append(t.data['y'][i])

        #Split selection (3. Tests for means and variances)

        #(a) Test for difference in means
        test_mean_diff = ( statistics.mean(class1['X']) - statistics.mean(class2['X']) ) / \
                         ( statistics.pstdev(X_k) * math.sqrt(1/len(class1['X']) + 1/len(class2['X'])))

        #(b) Test for difference in variances (Levene, 1960)
        test_var_diff = ( statistics.mean([abs(x - statistics.mean(class1['X'])) for x in class1['X']]) - \
                          statistics.mean([abs(x - statistics.mean(class2['X'])) for x in class2['X']]) ) / \
                        ( statistics.pstdev([abs(x - statistics.mean(X_k)) for x in X_k]) * math.sqrt(1/len(class1['X']) + 1/len(class2['X'])))

        prob_mean_diff = st.t.sf(test_mean_diff, len(t.data['X']) - 2)
        prob_var_diff = st.t.sf(test_var_diff, len(t.data['X']) - 2)

        if prob_mean_diff < min_p_value:
            min_p_value = prob_mean_diff
            k_0 = k
            a = ( statistics.mean(class1['X']) + statistics.mean(class2['X']) ) / 2
        elif prob_var_diff < min_p_value:
            min_p_value = prob_var_diff
            k_0 = k
            a = ( statistics.mean(class1['X']) + statistics.mean(class2['X']) ) / 2
    
    # 2.2.2 Stopping rule
    if not shouldSplit(t, params):
            return False

    #Split selection (4. Variable selection)
    
    left_data = {'X':[], 'y':[]}
    right_data = {'X':[], 'y':[]}
    
    for i in range(0, len(t.data['X'])):
        if t.data['X'][i][k_0] < a:
            left_data['X'].append(t.data['X'][i])
            left_data['y'].append(t.data['y'][i])
        else:
            right_data['X'].append(t.data['X'][i])
            right_data['y'].append(t.data['y'][i])

    splitNode(t.left)
    splitNode(t.right)
    
def fit(X, y, params):
    t = createTree(X, y)
    splitNode(t, params)
    

In [96]:
def f(x1, x2):
    return math.exp(-(x1**2 + x2**2)/2) + 0.2 * st.norm.rvs(0, 1)

X, y = [], []
for _ in range(0, 100):
    x1 = random.random() # x in [0, 1]
    x2 = random.random()
    X.append(np.array([x1,x2]))
    y.append(f(x1,x2))

params = {'f': 0.2,'eta': 0.2,'V': 10,'MINDAT': 3}

fit(X, y, params)

[{'X': [array([0.7815648 , 0.00371351]), array([0.17926235, 0.99304209]), array([0.1946865 , 0.90482227]), array([0.22327127, 0.19681879]), array([0.84303808, 0.58851401]), array([0.58253105, 0.33414342]), array([0.39355976, 0.86957523]), array([0.32154089, 0.39699326]), array([0.66593304, 0.47520731]), array([0.52065876, 0.22346995])], 'y': [0.7897716601106217, 0.186384063087026, 0.3667958672527481, 1.2932760386660127, 0.9097086563172203, 0.8780789968324257, 0.4681683973364131, 1.3496522445516697, 0.5288902627065577, 0.6472539819439996]}, {'X': [array([0.16569516, 0.88157335]), array([0.27466567, 0.80607721]), array([0.02952511, 0.58663867]), array([0.1305847 , 0.46300814]), array([0.16818252, 0.16900663]), array([0.88284107, 0.46386829]), array([0.3567027 , 0.38562896]), array([0.87121935, 0.71930824]), array([0.86969435, 0.88313527]), array([0.53646971, 0.82183353])], 'y': [0.8080101671211241, 0.9587957794907018, 0.7778415064540916, 0.7733549295097484, 1.0458377081787464, 0.74014097