In [1]:
import numpy as np
import pandas as pd

from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, mean_absolute_error
from sklearn.datasets import make_regression
from regression_tree import MyTreeReg

In [2]:
from sklearn.datasets import load_diabetes

data = load_diabetes(as_frame=True)
X, y = data['data'], data['target']
X.head()

Unnamed: 0,age,sex,bmi,bp,s1,s2,s3,s4,s5,s6
0,0.038076,0.05068,0.061696,0.021872,-0.044223,-0.034821,-0.043401,-0.002592,0.019907,-0.017646
1,-0.001882,-0.044642,-0.051474,-0.026328,-0.008449,-0.019163,0.074412,-0.039493,-0.068332,-0.092204
2,0.085299,0.05068,0.044451,-0.00567,-0.045599,-0.034194,-0.032356,-0.002592,0.002861,-0.02593
3,-0.089063,-0.044642,-0.011595,-0.036656,0.012191,0.024991,-0.036038,0.034309,0.022688,-0.009362
4,0.005383,-0.044642,-0.036385,0.021872,0.003935,0.015596,0.008142,-0.002592,-0.031988,-0.046641


In [3]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=42, random_state=42)
X_train.head()

Unnamed: 0,age,sex,bmi,bp,s1,s2,s3,s4,s5,s6
284,0.041708,0.05068,-0.022373,0.028758,-0.066239,-0.045155,-0.061809,-0.002592,0.002861,-0.054925
402,0.110727,0.05068,-0.033151,-0.022885,-0.004321,0.020293,-0.061809,0.07121,0.015568,0.044485
199,0.041708,-0.044642,-0.045007,0.034508,0.043837,-0.015719,0.037595,-0.014401,0.089897,0.007207
82,-0.016412,-0.044642,-0.035307,-0.026328,0.03283,0.017162,0.100183,-0.039493,-0.070209,-0.079778
77,-0.096328,-0.044642,-0.036385,-0.074527,-0.03872,-0.027618,0.015505,-0.039493,-0.074093,-0.001078


In [4]:
X_train.shape, X_test.shape

((400, 10), (42, 10))

In [5]:
tree = MyTreeReg()
tree.fit(X_train, y_train)

In [6]:
print(tree)
tree.print_tree()

MyTreeReg class: max_depth=5, min_samples_split=2, max_leaves=20
s5 -0.0037611760063045703
 bmi 0.0061888847138220964
  s3 0.02102781591949656
   s1 0.06310082451524143
    sex 0.003019241116414738
     leaf_left 119.0
     leaf_right 93.92682926829268
    age 0.05623859868852012
     leaf_left 253.0
     leaf_right 230.0
   sex 0.003019241116414738
    bp 0.025315236489885963
     leaf_left 87.38888888888889
     leaf_right 142.5
    age 0.005383060374248236
     leaf_left 55.8
     leaf_right 78.33333333333333
  age -0.07998159322470814
   age -0.0890629393522567
    leaf_left 302.0
    leaf_right 246.0
   bp 0.009822407098564287
    bmi 0.059001676333981276
     leaf_left 151.88235294117646
     leaf_right 96.16666666666667
    bmi 0.028284032228378497
     leaf_left 132.2
     leaf_right 215.4
 bmi 0.06870198499890848
  bmi -0.021834229207078688
   s2 -0.02855779360190825
    bp 0.03908664039328301
     leaf_left 159.625
     leaf_right 252.0
    s2 -0.002253322811587326
     leaf_

In [7]:
preds = tree.predict(X_test)

In [8]:
mean_absolute_error(y_test, preds)

42.930222811798615

In [9]:
mean_squared_error(y_test, preds)**0.5

54.77098101370871

In [10]:
tree2 = MyTreeReg(3, 30, 10)
tree2.fit(X_train, y_train)

In [11]:
tree2.print_tree()

s5 -0.0037611760063045703
 bmi 0.0061888847138220964
  s3 0.02102781591949656
   leaf_left 109.33333333333333
   leaf_right 83.56
  age -0.07998159322470814
   leaf_left 274.0
   leaf_right 157.21052631578948
 bmi 0.06870198499890848
  bmi -0.021834229207078688
   leaf_left 138.8918918918919
   leaf_right 190.55
  leaf_right 270.4074074074074


In [12]:
preds_2 = tree2.predict(X_test)

In [13]:
mean_absolute_error(y_test, preds_2)

41.810197277477975

In [14]:
tree3 = MyTreeReg(10, 2, 40)
tree3.fit(X_train, y_train)
preds_3 = tree3.predict(X_test)

In [15]:
mean_squared_error(y_test, preds_3)**0.5

65.44923100551165

In [16]:
X, y = make_regression(n_samples=150, n_features=14, n_informative=10, noise=15, random_state=42)
X = pd.DataFrame(X).round(2)
y = pd.Series(y)
X.columns = [f'col_{col}' for col in X.columns]
X_test = X.sample(20)

In [17]:
y_test = y[X_test.index]

In [18]:
tree = MyTreeReg(5, 15, 20, bins = 7)
tree.fit(X, y)
tree.print_tree()

col_11 -0.9114285714285715
 col_10 0.0942857142857143
  leaf_left -213.21706206832062
  leaf_right -114.23666850402158
 col_10 -1.262857142857143
  leaf_left -160.2893959978475
  col_11 0.4942857142857142
   col_12 0.6028571428571428
    col_4 0.5271428571428571
     leaf_left -42.11027810423582
     leaf_right 66.31174937849201
    col_8 0.5885714285714285
     leaf_left 79.71214800066767
     leaf_right 227.0213965449283
   col_4 0.5271428571428571
    col_1 0.0942857142857143
     leaf_left 44.946034007632214
     leaf_right 148.69139038116185
    leaf_right 207.32104127877886


In [19]:
preds = tree.predict(X_test)

In [20]:
mean_absolute_error(y_test, preds)

84.61864755960113

In [21]:
tree.fi

{'col_0': 0,
 'col_1': 442.0045976631684,
 'col_2': 0,
 'col_3': 0,
 'col_4': 1652.2415884979,
 'col_5': 0,
 'col_6': 0,
 'col_7': 0,
 'col_8': 566.0873401741238,
 'col_9': 0,
 'col_10': 3955.779477492326,
 'col_11': 7184.79557754227,
 'col_12': 1568.536772487773,
 'col_13': 0}