In [1]:
import sys
from equadratures import polytree
import random
import numpy as np
import scipy.stats as st
import time
from matplotlib import pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML
import matplotlib.lines as mlines
%matplotlib inline
from copy import deepcopy

In [2]:
def f(x, noise):
    
    if noise:
        noise = 0.1 * st.norm.rvs(0, 1)
    else:
        noise = 0
    if x < 0.5:
        return 25*(x-0.25)**2 - 1.0625 + noise
    elif x > 0.5:
        return 25*(x-0.75)**2 - 1.0625 + noise
    
def sample():
    X, y = [], []
    for i in range(100):
        x = random.random()    
        X.append(np.array([x]))
        y.append(np.array([f(x, True)]))
    return np.array(X), np.array(y)

X, y = sample()
y = np.reshape(y, (y.shape[0], 1))

tree = polytree.PolyTree(min_samples_leaf=20, logging=True)
#tree = polytree.PolyTree(search="uniform", min_samples_leaf=20, logging=True)
start = time.time()
tree.fit(X, y)
duration = time.time() - start

In [3]:
tree.log

[{'event': 'try_split',
  'data': {'j_feature': 0,
   'threshold': 0.30536724988884856,
   'loss': 29.173221338828025,
   'poly_left': <equadratures.poly.Poly at 0x7fc92681ae80>,
   'poly_right': <equadratures.poly.Poly at 0x7fc926823640>}},
 {'event': 'try_split',
  'data': {'j_feature': 0,
   'threshold': 0.3138853392879769,
   'loss': 28.549496897169906,
   'poly_left': <equadratures.poly.Poly at 0x7fc926823760>,
   'poly_right': <equadratures.poly.Poly at 0x7fc9268370d0>}},
 {'event': 'try_split',
  'data': {'j_feature': 0,
   'threshold': 0.31513066305143156,
   'loss': 28.11544952889523,
   'poly_left': <equadratures.poly.Poly at 0x7fc926837490>,
   'poly_right': <equadratures.poly.Poly at 0x7fc9268376d0>}},
 {'event': 'try_split',
  'data': {'j_feature': 0,
   'threshold': 0.3157364684495968,
   'loss': 27.390030610190415,
   'poly_left': <equadratures.poly.Poly at 0x7fc92683e100>,
   'poly_right': <equadratures.poly.Poly at 0x7fc92683e3d0>}},
 {'event': 'try_split',
  'data': {

In [None]:
fig, ax = plt.subplots(figsize=(10,10))
ax.set(xlim=(0, 1), ylim=(-1.5, 1))
ax.scatter(X, y, s=100)

try_line = []
try_line.append(ax.axhline(1))

best_line = []
best_line.append(ax.axhline(1))

fixed_line = []
fixed_line.append(ax.axhline(1))

left_try_poly, = ax.plot([], [], lw=2, color="red")
right_try_poly, = ax.plot([], [], lw=2, color="red")

left_best_poly, = ax.plot([], [], lw=2, color="green")
right_best_poly, = ax.plot([], [], lw=2, color="green")

left_fixed_poly, = ax.plot([], [], lw=2, color="black")
right_fixed_poly, = ax.plot([], [], lw=2,color="black")

fixed_polys = []
splits = [0, 1]
dirs = []

def boundary(threshold, direction):
    if direction == "left":
        left = max([split for split in splits if split < threshold])
        return left
    elif direction == "right":
        right = min([split for split in splits if split > threshold])
        return right
n = 0
def animate(log):
    print(splits)
    global n
    if log["event"] == "UP":
        try:fixed_polys.pop()
        except:pass
        try:splits.pop()
        except:pass
        try:best_line[-1].remove()
        except:pass
        try:dirs.pop()
        except:pass
        
    if log["event"] == "DOWN":
        splits.append(log["data"]["threshold"])

        fixed_line.append(ax.axvline(log["data"]["threshold"], color='black'))
        
        if log["data"]["direction"] == "LEFT":
            fixed_polys.append(ax.plot(right_fixed_poly.get_xdata(),right_fixed_poly.get_ydata(),color="black"))
            right_fixed_poly.set_data(deepcopy(right_best_poly.get_xdata()), deepcopy(right_best_poly.get_ydata()))
            dirs.append("left")
            
        if log["data"]["direction"] == "RIGHT":
            fixed_polys.append(ax.plot(left_fixed_poly.get_xdata(),left_fixed_poly.get_ydata(),color="black"))
            left_fixed_poly.set_data(deepcopy(left_best_poly.get_xdata()), deepcopy(left_best_poly.get_ydata()))
            dirs.append("right")
            
    if log["event"] == "try_split":
        try_line[-1].remove()
        try_line.append(ax.axvline(log["data"]["threshold"], color='red'))
 
        x_test = np.reshape(np.linspace(boundary(log["data"]["threshold"], "left"), log["data"]["threshold"], 100), (100, 1))
        left_try_poly.set_data(x_test, log["data"]["poly_left"].get_polyfit(x_test))
                        
        x_test = np.reshape(np.linspace(log["data"]["threshold"], boundary(log["data"]["threshold"], "right"), 100), (100, 1))
        right_try_poly.set_data(x_test, log["data"]["poly_right"].get_polyfit(x_test))
        
        n+=1
    if log["event"] == "best_split":
        try:best_line[-1].remove()
        except:pass
        
        best_line.append(ax.axvline(log["data"]["threshold"], color='green'))
         
        x_test = np.reshape(np.linspace(boundary(log["data"]["threshold"], "left"), log["data"]["threshold"], 100), (100, 1))
        left_best_poly.set_data(x_test, log["data"]["poly_left"].get_polyfit(x_test))
                        
        x_test = np.reshape(np.linspace(log["data"]["threshold"], boundary(log["data"]["threshold"], "right"), 100), (100, 1))
        right_best_poly.set_data(x_test, log["data"]["poly_right"].get_polyfit(x_test))
        
        left_try_poly.set_data([],[])
        right_try_poly.set_data([],[])
        
        n+=1
    ax.set_title('Polynomials fit: ' + str(n))
    print(n)

anim = FuncAnimation(fig, animate, interval=duration * 1000 / len(tree.log) * 20, frames=tree.log)
HTML(anim.to_html5_video())

[0, 1]
1
[0, 1]
2
[0, 1]
3
[0, 1]
4
[0, 1]
5
[0, 1]
6
[0, 1]
7
[0, 1]
8
[0, 1]
9
[0, 1]
10
[0, 1]
11
[0, 1]
12
[0, 1]
13
[0, 1]
14
[0, 1]
15
[0, 1]
16
[0, 1]
17
[0, 1]
18
[0, 1]
19
[0, 1]
20
[0, 1]
21
[0, 1]
22
[0, 1]
23
[0, 1]
24
[0, 1]
25
[0, 1]
26
[0, 1]
27
[0, 1]
28
[0, 1]
29
[0, 1]
30
[0, 1]
31
[0, 1]
32
[0, 1]
33
[0, 1]
34
[0, 1]
35
[0, 1]
36
[0, 1]
37
[0, 1]
38
[0, 1]
39
[0, 1]
40
[0, 1]
41
[0, 1]
42
[0, 1]
43
[0, 1]
44
[0, 1]
45
[0, 1]
46
[0, 1]
47
[0, 1]
48
[0, 1]
49
[0, 1]
50
[0, 1]
51
[0, 1]
52
[0, 1]
53
[0, 1]
54
[0, 1]
55
[0, 1]
56
[0, 1]
57
[0, 1]
58
[0, 1]
59
[0, 1]
60
[0, 1]
61
[0, 1]
62
[0, 1]
62
[0, 1, 0.5976878471383624]
63
[0, 1, 0.5976878471383624]
64
[0, 1, 0.5976878471383624]
65
[0, 1, 0.5976878471383624]
66
[0, 1, 0.5976878471383624]
67
[0, 1, 0.5976878471383624]
68
[0, 1, 0.5976878471383624]
69
[0, 1, 0.5976878471383624]
70
[0, 1, 0.5976878471383624]
71
[0, 1, 0.5976878471383624]
71
[0, 1, 0.5976878471383624, 0.35511062813889005]
71
[0, 1, 0.5976878471383624]
7

In [None]:
tree.get_graphviz(['x', 'y'])