# SLU13 - Tree-based models: Examples

In [1]:
import warnings
warnings.filterwarnings("ignore")

import pandas as pd

from sklearn.datasets import load_boston
from sklearn.ensemble import (
    RandomForestClassifier,
    RandomForestRegressor,
    GradientBoostingClassifier,
    GradientBoostingRegressor,
)
from sklearn.metrics import mean_squared_error
from sklearn.tree import (
    DecisionTreeClassifier,
    DecisionTreeRegressor,
)

from utils.utils import *

# Decision trees

## Classification

In [9]:
data = make_data()
X, y = separate_target_variable(data)

X = process_categorical_features(X) 

dtc = DecisionTreeClassifier()
dtc.fit(X, y)
dtc.predict(X)
y

0     0
1     0
2     1
3     1
4     1
5     0
6     1
7     0
8     1
9     1
10    1
11    1
12    1
13    0
Name: Class, dtype: int64

## Regression

In [10]:
def prepare_boston():
    boston = load_boston()
    X = pd.DataFrame(data=boston.data, columns=boston.feature_names)
    y = pd.Series(data=boston.target, name='price')
    
    return X, y


X_, y_ = prepare_boston()
y_

0      24.0
1      21.6
2      34.7
3      33.4
4      36.2
       ... 
501    22.4
502    20.6
503    23.9
504    22.0
505    11.9
Name: price, Length: 506, dtype: float64

In [4]:
dtc = DecisionTreeRegressor()
dtc.fit(X_, y_)
dtc.predict(X_)

array([24. , 21.6, 34.7, 33.4, 36.2, 28.7, 22.9, 27.1, 16.5, 18.9, 15. ,
       18.9, 21.7, 20.4, 18.2, 19.9, 23.1, 17.5, 20.2, 18.2, 13.6, 19.6,
       15.2, 14.5, 15.6, 13.9, 16.6, 14.8, 18.4, 21. , 12.7, 14.5, 13.2,
       13.1, 13.5, 18.9, 20. , 21. , 24.7, 30.8, 34.9, 26.6, 25.3, 24.7,
       21.2, 19.3, 20. , 16.6, 14.4, 19.4, 19.7, 20.5, 25. , 23.4, 18.9,
       35.4, 24.7, 31.6, 23.3, 19.6, 18.7, 16. , 22.2, 25. , 33. , 23.5,
       19.4, 22. , 17.4, 20.9, 24.2, 21.7, 22.8, 23.4, 24.1, 21.4, 20. ,
       20.8, 21.2, 20.3, 28. , 23.9, 24.8, 22.9, 23.9, 26.6, 22.5, 22.2,
       23.6, 28.7, 22.6, 22. , 22.9, 25. , 20.6, 28.4, 21.4, 38.7, 43.8,
       33.2, 27.5, 26.5, 18.6, 19.3, 20.1, 19.5, 19.5, 20.4, 19.8, 19.4,
       21.7, 22.8, 18.8, 18.7, 18.5, 18.3, 21.2, 19.2, 20.4, 19.3, 22. ,
       20.3, 20.5, 17.3, 18.8, 21.4, 15.7, 16.2, 18. , 14.3, 19.2, 19.6,
       23. , 18.4, 15.6, 18.1, 17.4, 17.1, 13.3, 17.8, 14. , 14.4, 13.4,
       15.6, 11.8, 13.8, 15.6, 14.6, 17.8, 15.4, 21

# Random Forests

## Classification

In [5]:
dtc = RandomForestClassifier()
dtc.fit(X, y)
dtc.predict(X)

array([0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0])

## Regression

In [6]:
dtc = RandomForestRegressor()
dtc.fit(X_, y_)
dtc.predict(X_)

array([25.872, 21.893, 34.602, 33.903, 35.432, 27.242, 21.566, 23.388,
       17.063, 19.819, 17.82 , 19.214, 21.475, 20.003, 18.844, 19.912,
       22.114, 17.659, 19.708, 18.977, 14.129, 18.883, 15.4  , 14.815,
       15.99 , 14.454, 16.961, 14.995, 19.153, 21.546, 13.828, 16.321,
       14.516, 13.618, 13.746, 19.624, 20.078, 20.922, 23.171, 29.954,
       34.517, 28.274, 24.965, 24.629, 21.353, 19.361, 19.917, 17.363,
       15.848, 19.418, 19.914, 21.197, 25.022, 22.413, 18.808, 34.864,
       24.21 , 31.301, 23.092, 19.923, 18.829, 17.165, 22.767, 25.515,
       32.82 , 23.766, 19.66 , 21.565, 17.952, 21.077, 24.029, 21.687,
       22.754, 23.399, 24.184, 22.205, 20.165, 20.933, 21.118, 20.329,
       27.835, 24.778, 24.34 , 23.057, 23.261, 26.772, 21.997, 22.185,
       26.463, 29.795, 22.486, 22.051, 22.964, 24.878, 20.988, 27.955,
       22.163, 40.328, 43.812, 32.611, 26.521, 26.449, 18.494, 19.544,
       19.976, 19.43 , 18.981, 20.068, 19.868, 19.145, 21.471, 23.857,
      

# Gradient Boosting

## Classification

In [7]:
dtc = GradientBoostingClassifier()
dtc.fit(X, y)
dtc.predict(X)

array([0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0])

## Regression

In [8]:
dtc = GradientBoostingRegressor()
dtc.fit(X_, y_)
dtc.predict(X_)

array([25.90772604, 21.96320179, 33.92712155, 34.14528061, 35.41267912,
       26.7925396 , 21.48031031, 20.87839556, 16.95411564, 18.45898255,
       18.05928146, 20.04582877, 19.88575493, 20.39575276, 18.96852027,
       20.21179657, 21.76179638, 16.96912497, 19.2506871 , 19.01636451,
       14.00404763, 18.24207805, 16.1851528 , 14.57787808, 15.77917671,
       14.86551421, 16.81521596, 14.71617291, 18.54626475, 20.6745727 ,
       13.39158783, 18.14465195, 13.44652826, 15.1383587 , 14.41160215,
       21.05815993, 21.42388009, 22.12656447, 23.40601618, 29.19035841,
       34.2785307 , 28.48271366, 24.45460595, 24.43134015, 22.55323935,
       20.688097  , 20.8094598 , 18.07873592, 16.27250646, 18.89252819,
       20.53847579, 21.69114043, 25.03984638, 21.75334966, 17.27619882,
       34.80045814, 24.13821604, 31.24080817, 22.9320532 , 20.75925478,
       18.78591467, 17.35922753, 22.55596532, 24.23089689, 32.40378007,
       24.75067937, 19.7671051 , 20.89998955, 19.02139353, 21.23