In [1]:
%matplotlib inline

In [2]:
import sys
from equadratures import polytree
import random
import numpy as np
import scipy.stats as st
import time
from matplotlib import pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML
import matplotlib.lines as mlines

In [8]:
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeRegressor
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_absolute_error

def f(x1,x2):
    noise = 0.1 * st.norm.rvs(0, 1)
    return np.exp(-(x1**2 + x2**2) + noise)

def sample():
    X, y = [], []
    for i in range(150):
        x1, x2 = random.random(), random.random()        
        X.append(np.array([x1, x2]))
        y.append(np.array(f(x1, x2)))
    return np.array(X), np.array(y)

X, y = sample()
y = np.reshape(y, (y.shape[0], 1))
X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.9)

for n in range(10, 50, 10):
    tree = polytree.PolyTree(min_samples_leaf=n)
    tree.fit(X_train, y_train)
    print("min_samples_leaf: " + str(n) + " MSE:" + str(mean_absolute_error(y_test, tree.predict(X_test))))

min_samples_leaf: 10 MSE:0.6483228344799066
min_samples_leaf: 20 MSE:0.1943821063280319
min_samples_leaf: 30 MSE:0.12258101863426786
min_samples_leaf: 40 MSE:0.060855628455348726


In [None]:
def f(x1,x2):
    noise = 0.1 * st.norm.rvs(0, 1)
    return np.exp(-(x1**2 + x2**2) + noise)

def sample():
    X, y = [], []
    for i in range(150):
        x1, x2 = random.random(), random.random()        
        X.append(np.array([x1, x2]))
        y.append(np.array(f(x1, x2)))
    return np.array(X), np.array(y)

X, y = sample()
y = np.reshape(y, (y.shape[0], 1))

print(X.shape,y.shape)
#tree = polytree.PolyTree(min_samples_leaf=20, logging=True)
tree = polytree.PolyTree(search='uniform', samples=25,min_samples_leaf=20, logging=True)

start = time.time()
tree.fit(X, y)
duration = time.time() - start

In [None]:
tree.get_graphviz(['x1','x2'])

In [None]:
fig, ax = plt.subplots(figsize=(10,10))
ax.set(xlim=(0, 1), ylim=(0, 1))
ax.scatter([x[0] for x in X], [x[1] for x in X], s=100, c=y)

try_line = []
try_line.append(ax.axhline(1))

best_line = []
best_line.append(ax.axhline(1))

fixed_line = []
fixed_line.append(ax.axhline(1))

tree_pos = []

def get_boundaries(j_feature, threshold):
    j_min, j_max = [0, 0], [1, 1]
    for node in tree_pos:
        if node["direction"] == "LEFT" and node["j_feature"] == 0:
            j_max[1] = node["threshold"]
        elif node["direction"] == "LEFT" and node["j_feature"] == 1:
            j_max[0] = node["threshold"]
        elif node["direction"] == "RIGHT" and node["j_feature"] == 0:
            j_min[1] = node["threshold"]
        elif node["direction"] == "RIGHT" and node["j_feature"] == 1:
            j_min[0] = node["threshold"]
    return j_min, j_max

n = 0
def animate(log):
    global n
    
    if log["event"] == "UP":
        try:tree_pos.pop()
        except:pass
        try:best_line[-1].remove()
        except:pass
        
    if log["event"] == "DOWN":
        tree_pos.append(log["data"])

        j_min, j_max = get_boundaries(log["data"]["j_feature"], log["data"]["threshold"])
        if log["data"]["j_feature"] == 0:
            fixed_line.append(ax.axvline(log["data"]["threshold"], ymin = j_min[0], ymax = j_max[0], color='black'))
        else:
            fixed_line.append(ax.axhline(log["data"]["threshold"], xmin = j_min[1], xmax = j_max[1], color='black'))
          
        try:best_line[-1].remove()
        except:pass
        try:try_line[-1].remove()
        except:pass
        
    if log["event"] == "try_split":
        j_min, j_max = get_boundaries(log["data"]["j_feature"], log["data"]["threshold"])
        if log["data"]["j_feature"] == 0:
            try:try_line[-1].remove()
            except:pass
            try_line.append(ax.axvline(log["data"]["threshold"], ymin = j_min[0], ymax = j_max[0], color='red'))
        else:
            try:try_line[-1].remove()
            except:pass
            try_line.append(ax.axhline(log["data"]["threshold"], xmin = j_min[1], xmax = j_max[1], color='red'))
        n+=1
    if log["event"] == "best_split":
        j_min, j_max = get_boundaries(log["data"]["j_feature"], log["data"]["threshold"])
        if log["data"]["j_feature"] == 0:
            try:best_line[-1].remove()
            except:pass
            best_line.append(ax.axvline(log["data"]["threshold"], ymin = j_min[0], ymax = j_max[0], color='green'))
        else:
            try:best_line[-1].remove()
            except:pass
            best_line.append(ax.axhline(log["data"]["threshold"], xmin = j_min[1], xmax = j_max[1], color='green'))
        n+=1
    ax.set_title('Polynomials fit: ' + str(n))

anim = FuncAnimation(fig, animate, interval=duration * 10000 / len(tree.log), frames=tree.log)
HTML(anim.to_html5_video())

In [None]:
for log in tree.log:
    if log["event"] == "DOWN" or log["event"] == "UP":
        print(log)

In [None]:
tree.log
