In [1]:
# Dataset
from sklearn.datasets import load_diabetes

# Basics
import pandas as pd
import numpy as np

# Plotting
import matplotlib.pyplot as plt

# Model
from sklearn.ensemble import RandomForestRegressor

# Helpful:
from sklearn.model_selection import train_test_split

import sys
sys.path.append("/home/dchen/Random_Forest_Weights/")
# Now, you can import your modules
from src_rf.methods.calc_mean import *
from src_rf.methods.calc_weights import *
from src_rf.methods.calc_dist import *

### 1. Loading Data & Train_test_split

In [2]:
diabetes = load_diabetes()
df_diabetes = pd.DataFrame(
    data=np.c_[diabetes["data"], diabetes["target"]],
    columns=diabetes["feature_names"] + ["target"],
)

In [3]:
X = df_diabetes.iloc[:, :-1].to_numpy(dtype="float32")
y = df_diabetes.iloc[:, -1:].to_numpy().flatten()

In [4]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

### 2. Random Forest

In [5]:
bootstrap = True
max_sample = 0.7

In [6]:
rf = RandomForestRegressor(bootstrap=bootstrap, max_samples=max_sample, random_state=42)

In [7]:
rf_no = RandomForestRegressor(random_state=42, bootstrap= False)

In [8]:
rf.fit(X_train, y_train)
rf_no.fit(X_train, y_train)

RandomForestRegressor(bootstrap=False, random_state=42)

### 3. Calculate Weights Random Forest

In [9]:
rf_weights = calc_weights_rf(rf, X_train, X_test, bootstrap, max_sample)


100%|███████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:01<00:00, 58.56it/s]


In [10]:
rf_no_weights = calc_weights_rf(rf_no, X_train, X_test, False, None)

100%|███████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:02<00:00, 34.70it/s]


### 4. Calc Mean with weights and normal:

In [11]:
rf_mean_weights = calc_mean_rf(rf_weights, y_train)
rf_mean_normal = rf.predict(X_test)

In [12]:
rf_no_mean_weights = calc_mean_rf(rf_no_weights, y_train)
rf_no_mean_normal = rf_no.predict(X_test)

In [13]:
# Are the two the same?
sum(np.round(rf_mean_weights, 10) == np.round(rf_mean_normal, 10))

89

In [14]:
# Are the two the same?
sum(np.round(rf_no_mean_weights, 10) == np.round(rf_no_mean_normal, 10))

89

In [15]:
rf_no_mean_normal

array([195.8 , 163.82, 256.03, 211.6 , 134.  , 288.54, 131.35,  58.3 ,
       275.14, 201.56,  99.78, 178.87,  50.02,  91.85,  62.42, 143.76,
       156.57, 133.1 , 102.51, 136.26, 146.93, 102.2 ,  66.88, 184.08,
        66.88,  99.91, 164.6 ,  56.09, 277.15,  62.82,  66.88, 176.79,
       135.  , 125.92,  91.36, 101.98, 147.34, 170.2 , 140.26, 167.72,
        80.36, 236.24, 135.77,  80.78,  57.2 , 201.73,  96.08, 158.32,
        79.22, 185.93,  57.87, 294.89,  93.18,  52.86, 251.3 ,  85.93,
       252.2 , 142.22,  89.43, 141.51, 267.66, 112.99,  60.24, 151.24,
        93.32,  93.6 , 138.6 , 140.24, 270.6 ,  95.47, 251.19, 141.96,
       161.41, 169.04,  83.72, 266.26, 112.83, 148.75, 147.95, 147.2 ,
        93.19, 154.29, 273.53,  72.46, 158.16, 179.68, 216.16,  94.4 ,
       161.63])

In [16]:
rf_no_mean_weights

array([195.8 , 163.82, 256.03, 211.6 , 134.  , 288.54, 131.35,  58.3 ,
       275.14, 201.56,  99.78, 178.87,  50.02,  91.85,  62.42, 143.76,
       156.57, 133.1 , 102.51, 136.26, 146.93, 102.2 ,  66.88, 184.08,
        66.88,  99.91, 164.6 ,  56.09, 277.15,  62.82,  66.88, 176.79,
       135.  , 125.92,  91.36, 101.98, 147.34, 170.2 , 140.26, 167.72,
        80.36, 236.24, 135.77,  80.78,  57.2 , 201.73,  96.08, 158.32,
        79.22, 185.93,  57.87, 294.89,  93.18,  52.86, 251.3 ,  85.93,
       252.2 , 142.22,  89.43, 141.51, 267.66, 112.99,  60.24, 151.24,
        93.32,  93.6 , 138.6 , 140.24, 270.6 ,  95.47, 251.19, 141.96,
       161.41, 169.04,  83.72, 266.26, 112.83, 148.75, 147.95, 147.2 ,
        93.19, 154.29, 273.53,  72.46, 158.16, 179.68, 216.16,  94.4 ,
       161.63])

### 5. Calc Quantile Random Forest

In [17]:
rf_cdfs = calc_dist_rf(rf_weights, y_train)

In [18]:
len(rf_cdfs)

89

In [19]:
rf_median = calc_quantile_rf(rf_cdfs, 0.5, y_train)

In [23]:
np.array(rf_median).shape

(89,)

In [24]:
X_test.shape

(89, 10)