In [410]:
import numpy as np
import importlib
import RegressionTree as rt
import pandas as pd
from ISLP import load_data
import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeClassifier

In [374]:
Hitters = load_data('Hitters')
Hitters = Hitters.drop(['League', 'Division', 'NewLeague'], axis=1)
Hitters = Hitters.dropna()
Hitters.columns

Index(['AtBat', 'Hits', 'HmRun', 'Runs', 'RBI', 'Walks', 'Years', 'CAtBat',
       'CHits', 'CHmRun', 'CRuns', 'CRBI', 'CWalks', 'PutOuts', 'Assists',
       'Errors', 'Salary'],
      dtype='object')

In [147]:
Hitters = Hitters[['Salary', 'CHits', 'Hits', 'CAtBat']]

In [377]:
X = Hitters.drop(labels=['Salary'], axis=1)
Y = Hitters['Salary']

In [378]:
train_count = int(X.shape[0] * 0.8)
X_train, X_test = X.head(train_count), X.tail(X.shape[0] - train_count)
Y_train, Y_test = Y.head(train_count), Y.tail(X.shape[0] - train_count)

# implemented

In [401]:
importlib.reload(rt)
regression_tree = rt.RegressionTree(40)
regression_tree.fit(X_train, Y_train)

In [402]:
def print_tree(node: rt.Node, lev: int = 0) -> str:

    res = ''

    if node is None:
        return res

    sep = '___ '

    if node.compare_function is not None:
        res = res + f'{sep * lev}{node.compare_function.label}'
        res = res + '\n' if res[-1] != '\n' else res

        res = res + print_tree(node.right, lev+1)
        res = res + '\n' if res[-1] != '\n' else res

        res = res + f'{sep*lev}else'
        res = res + '\n' if res[-1] != '\n' else res

        res = res + print_tree(node.left, lev+1)
        res = res + '\n' if res[-1] != '\n' else res

    else:
        res = res + f'{sep * lev}{round(node.mean, 2)}'
        res = res + '\n' if res[-1] != '\n' else res
    
    return res
    

In [403]:
print(print_tree(regression_tree.root))

CHits >= 452
___ Walks >= 52
___ ___ RBI >= 81
___ ___ ___ 1311.91
___ ___ else
___ ___ ___ 827.13
___ else
___ ___ AtBat >= 424
___ ___ ___ PutOuts >= 88
___ ___ ___ ___ 722.9
___ ___ ___ else
___ ___ ___ ___ 1053.33
___ ___ else
___ ___ ___ CHmRun >= 194
___ ___ ___ ___ 810.0
___ ___ ___ else
___ ___ ___ ___ 476.99
else
___ RBI >= 8
___ ___ CRBI >= 117
___ ___ ___ CWalks >= 122
___ ___ ___ ___ 420.96
___ ___ ___ else
___ ___ ___ ___ 271.34
___ ___ else
___ ___ ___ CRuns >= 102
___ ___ ___ ___ Assists >= 317
___ ___ ___ ___ ___ 160.0
___ ___ ___ ___ else
___ ___ ___ ___ ___ 226.5
___ ___ ___ else
___ ___ ___ ___ CRBI >= 72
___ ___ ___ ___ ___ 154.4
___ ___ ___ ___ else
___ ___ ___ ___ ___ 97.11
___ else
___ ___ 2127.33



In [405]:
Y_test_pred = regression_tree.pred(X_test.to_numpy())

In [407]:
np.mean(np.abs(Y_test - Y_test_pred))

164.30809765897973

# library

In [414]:
library_tree = DecisionTreeClassifier()
library_tree.fit(X_train.to_numpy(), Y_train.to_numpy().astype(np.uint32))

In [415]:
Y_library = library_tree.predict(X_test.to_numpy())

In [416]:
np.mean(np.abs(Y_test - Y_library))

177.97483018867928

# Simple test

In [322]:
X_simple = pd.DataFrame(np.array([1, 2, 3, 4, 5, 6, 7, 8, 9]).reshape((9, 1)))
Y_simple = pd.Series([10, 10, 10, 10, 20, 20, 8, 8, 8])

In [323]:
Y_simple

0    10
1    10
2    10
3    10
4    20
5    20
6     8
7     8
8     8
dtype: int64

In [344]:
importlib.reload(rt)
simple_tree = rt.RegressionTree(1)
simple_tree.fit(X_simple, Y_simple)

In [345]:
print(print_tree(simple_tree.root))

0 >= 7
___ 8.0
else
___ 0 >= 5
___ ___ 20.0
___ else
___ ___ 10.0



In [346]:
Y_simple_pred = simple_tree.pred(X_simple.to_numpy())
print(Y_simple_pred)

[10. 10. 10. 10. 20. 20.  8.  8.  8.]
