## Imports and configurations

In [22]:
import datetime
from math import sqrt

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, OneHotEncoder
from sklearn.preprocessing import StandardScaler

from sklearn.preprocessing import LabelEncoder
from sklearn.tree import DecisionTreeRegressor, export_text
from sklearn.tree import plot_tree

from sklearn.preprocessing import OneHotEncoder
from sklearn.compose import ColumnTransformer

from plotly import express as px
from plotly.subplots import make_subplots
import plotly.graph_objects as go


In [2]:
# plot figure size configuration
# plt.rcParams['figure.figsize'] = [25, 5]
sns.set_palette('muted')

In [3]:
data_path = "https://storage.googleapis.com/biosense-ml-data/insurance_clean.csv"
data_path_no_outliers = "https://storage.googleapis.com/biosense-ml-data/insurance_clean_no_outliers.csv"

## Read the data

In [4]:
column_definitions = {
    'age': np.int8,
    'sex': 'category',
    'bmi': np.float32,
    'children': np.int8,
    'smoker': bool,
    'region': 'category',
    'charges': np.float32
}

In [5]:
orig_df = pd.read_csv(data_path, dtype=column_definitions)


### Prepare features and labels - X, y

In [6]:
df = orig_df.copy()

## Convert categories from string to number

In [7]:
df = pd.get_dummies(df, columns=['sex', 'region'], drop_first=True)

In [8]:
df

Unnamed: 0,age,bmi,children,smoker,charges,sex_male,region_northwest,region_southeast,region_southwest
0,19,27.900000,0,True,16884.923828,False,False,False,True
1,18,33.770000,1,False,1725.552246,True,False,True,False
2,28,33.000000,3,False,4449.461914,True,False,True,False
3,33,22.705000,0,False,21984.470703,True,True,False,False
4,32,28.879999,0,False,3866.855225,True,True,False,False
...,...,...,...,...,...,...,...,...,...
1334,50,30.969999,3,False,10600.547852,True,True,False,False
1335,18,31.920000,0,False,2205.980713,False,False,False,False
1336,18,36.849998,0,False,1629.833496,False,False,True,False
1337,21,25.799999,0,False,2007.944946,False,False,False,True


In [9]:
X = df.drop('charges', axis=1)
y = df['charges']

### Train / test split

In [10]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)


## Decision Tree - Create and train

In [38]:
tree_model = DecisionTreeRegressor()


In [39]:
tree_model.fit(X_train, y_train)

## Interpret the tree

In [None]:
plt.rcParams["figure.figsize"] = (120,70)
plot_tree(tree_model, feature_names=X_train.columns, max_depth=2, filled=True, rounded=True)

In [14]:
tree_model.get_depth()

20

In [15]:
tree_model.feature_importances_

array([0.13217762, 0.22827641, 0.01729372, 0.60029404, 0.00485714,
       0.00916737, 0.00237621, 0.00555748])

In [16]:
tree_model.tree_.n_leaves

np.int64(1070)

In [17]:
tree_model.decision_path(X_test[:1]).toarray()

array([[1, 1, 0, ..., 0, 0, 0]])

### Lets evaluate the tree

In [18]:
train_mae = mean_absolute_error(y_train, tree_model.predict(X_train))
test_mae = mean_absolute_error(y_test, tree_model.predict(X_test))

print(train_mae, test_mae)

11.672471446260795 3000.716619235366


## Tree pruning

### Pre-Pruning techniques

Tree pruning is a technique used in decision tree models to reduce the size of the tree and avoid overfitting. Decision trees, by nature, are prone to overfitting because they can grow deep and complex, capturing noise and small patterns in the training data that do not generalize well to unseen data. Pruning helps mitigate this by simplifying the tree.

**Stopping Criteria:**
* Maximum depth of the tree (max_depth): Limits the depth of the tree.
* Minimum samples per leaf (min_samples_leaf): Ensures a minimum number of samples in each leaf.
* Minimum samples to split (min_samples_split): Requires a minimum number of samples to perform a split.
* Maximum number of nodes (max_leaf_nodes): Limits the number of leaf nodes in the tree.

In [43]:
tree_model = DecisionTreeRegressor(max_depth=5)
tree_model.fit(X_train, y_train)

train_mae = mean_absolute_error(y_train, tree_model.predict(X_train))
test_mae = mean_absolute_error(y_test, tree_model.predict(X_test))

print(train_mae, test_mae)

2387.3963606294815 2522.8764560200166


### Post-Pruning

In [44]:
path = tree_model.cost_complexity_pruning_path(X_train, y_train)
ccp_alphas, impurities = path.ccp_alphas, path.impurities

In [45]:
# Create Plotly figure
fig = go.Figure()

fig.add_trace(go.Scatter(
    x=ccp_alphas[:-1],
    y=impurities[:-1],
    mode='lines+markers',
    line_shape='hv',  # steps-post equivalent
    marker=dict(size=8),
    name='Impurity'
))

fig.update_layout(
    title="Total Impurity vs Effective Alpha for Training Set",
    xaxis_title="Effective Alpha",
    yaxis_title="Total Impurity of Leaves",
    template="plotly_white"
)

fig.show()

### Can we improve the overfitting even more?