In [11]:
%load_ext lab_black

The lab_black extension is already loaded. To reload it, use:
  %reload_ext lab_black


In [12]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from sklearn.tree import DecisionTreeRegressor
from sklearn.ensemble import RandomForestRegressor
from sklearn import tree
from sklearn.datasets import load_diabetes
from sklearn.model_selection import train_test_split

from DecisionTreeFunctions import *

In [13]:
diabetes = load_diabetes()
df = pd.DataFrame(
    data=np.c_[diabetes["data"], diabetes["target"]],
    columns=diabetes["feature_names"] + ["target"],
)

In [14]:
X = df.iloc[:, :-1].to_numpy(dtype="float32")
y = df.iloc[:, -1:].to_numpy().flatten()

In [15]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

### Model Training: Random Forest

In [57]:
rf = RandomForestRegressor(n_estimators=3)

In [58]:
rf.fit(X_train, y_train)

In [59]:
rf_estimators = rf.estimators_

In [188]:
leaf_nodes = []
y_values_leaves = []
X_values_leaves = []
X_ids_leaves = []
weights_leaves = (
    []
)  # 3-Dimensional with 1. Number of tree, 2. Number of leafs per tree, 3. Length of y_train
for tree in rf_estimators:
    a, b, c, d, e = calc_obs_distribution(tree, X_train, y_train)
    leaf_nodes.append(a)
    y_values_leaves.append(b)
    X_values_leaves.append(c)
    X_ids_leaves.append(d)
    weights_leaves.append(e)

In [195]:
def calc_weights_rf(rf, X_test, y_train, leaf_nodes_trees, weights_leaves_trees):
    '''
    Method to calculate the mean prediction and weights of a random forest
    
    Input: 
        param rf: Fully fitted random Forest
        param X_test: OOS test data
        param y_train: Data used to train the RF
        param leaf_nodes_trees: 3-Dimensional: 1. number_trees, 2. number_leaf_nodes, 
                                3. array with leaf node indexes
        param weights_leaves_trees: 3-Dimensional: 1. number_trees, 2. number_leaf_nodes
                                    3. array with weights of individual leaf_nodes
    
    Output: 
        weights_all: list of length X_test with weights used to calculate mean prediction
        mean_preds: List of mean predictions
    '''
    # Calculate index of the leaf that each sample is predicted as in all trees
    X_test_id_leaves = []  # dim: num_trees x len_X_test
    for tree in rf.estimators_:  # iterate number of tree times
        X_test_id_leaves.append(tree.apply(X_test))

    weights_all = []
    mean_preds = []
    for i in range(len(X_test)):  # iterate number of X_test times
        weight_k = np.zeros(y_train.shape)
        for j in range(len(X_test_id_leaves)):  # iterate number of trees times
            X_id = X_test_id_leaves[j][i]
            index = np.where(leaf_nodes_trees[j] == X_id)[0][
                0
            ]  # Calculate index of test
            weight_k = weight_k + weights_leaves_trees[j][index]
        weight = weight_k / len(X_test_id_leaves)
        weights_all.append(weight)
        mean_preds.append(np.dot(weight, y_train))

    return weights_all, mean_preds

In [196]:
weights, mean_preds = calc_weights_rf(rf, X_test, y_train, leaf_nodes, weights_leaves)

In [198]:
mean_preds[0:10]

[248.3333333333333,
 100.66666666666666,
 89.0,
 121.22222222222221,
 232.40476190476187,
 194.57142857142856,
 240.99999999999997,
 126.88888888888889,
 87.66666666666666,
 262.66666666666663]

In [107]:
rf.predict(X_test)

array([223.        , 105.66666667,  80.66666667, 110.33333333,
       239.        , 156.66666667, 216.66666667, 130.        ,
        87.66666667, 257.33333333, 175.66666667, 229.        ,
       110.        , 163.33333333,  64.33333333, 192.        ,
        81.66666667,  93.33333333, 122.        ,  77.33333333,
        81.66666667, 103.33333333, 164.        , 118.33333333,
        90.        , 227.33333333, 184.66666667, 107.66666667,
       215.33333333, 144.33333333, 236.        , 113.        ,
       113.        , 117.33333333, 146.        , 151.66666667,
       246.33333333, 117.33333333, 224.66666667, 197.66666667,
       134.66666667, 217.66666667, 193.33333333, 264.        ,
       211.        , 112.        ,  94.66666667,  88.        ,
       127.33333333,  67.66666667, 113.33333333, 174.33333333,
       193.66666667,  88.33333333, 222.66666667, 252.66666667,
       131.66666667,  94.        , 253.33333333, 188.33333333,
        69.66666667, 155.66666667, 116.66666667, 234.33