In [1]:
import numpy as np
from sklearn.metrics import mean_squared_error
from sklearn import datasets
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import plotly.graph_objects as go

In [2]:
from tree import DecisionTree

# Classification

In [3]:
digits_data = load_digits().data
digits_target = load_digits().target[:, None]
X_train, X_test, y_train, y_test = train_test_split(digits_data, digits_target, test_size=0.3, random_state=42)

In [4]:
tree_entr = DecisionTree(criterion_name='entropy')
tree_entr.fit(X_train,y_train)

In [5]:
tree_gini = DecisionTree(criterion_name='gini')
tree_gini.fit(X_train,y_train)

In [6]:
tree_entr.predict_proba(X_test)[:2]

[array([0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]),
 array([0., 0., 0., 0., 0., 0., 0., 0., 0., 1.])]

In [7]:
accuracy_score(y_test, tree_entr.predict(X_test))

0.8481481481481481

In [8]:
accuracy_score(y_test, tree_gini.predict(X_test))

0.8611111111111112

# Regression

In [9]:
X,y = datasets.make_regression(500,2)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

In [10]:
tree_variance = DecisionTree(criterion_name='variance')
tree_variance.fit(X_train,y_train)

In [11]:
tree_median = DecisionTree(criterion_name='mad_median')
tree_median.fit(X_train,y_train)

In [12]:
mean_squared_error(y_test,tree_variance.predict(X_test))

23.10837901081343

In [13]:
mean_squared_error(y_test,tree_median.predict(X_test))

24.479695645382755

Compare with constant prediction

In [14]:
contstant = [np.mean(y_test) for i in range(len(y_test))]

In [15]:
mean_squared_error(y_test,contstant)

1153.2962384618274

Visualization of the work

In [16]:
t = np.linspace(0, 5*np.pi, 500)
X = np.multiply(np.sin(t),100)
y = np.cos(t)

In [17]:
tree_variance = DecisionTree(criterion_name='variance',max_depth=3)
tree_variance.fit(X,X)

In [18]:
fig = go.Figure()
fig.add_trace(go.Scatter(x=t, y=X,
                    mode='lines'))
fig.add_trace(go.Scatter(x=t, y=tree_variance.predict(X),
                    mode='lines'))
fig.show()