In [1]:
# Dataset
from sklearn.datasets import load_diabetes

# Basics
import pandas as pd
import numpy as np

# Plotting
import matplotlib.pyplot as plt

# Model
from sklearn.ensemble import RandomForestRegressor

# Helpful:
from sklearn.model_selection import train_test_split

import sys
sys.path.append("/home/dchen/Random_Forest_Weights/")
# Now, you can import your modules
from src_rf.methods.calc_mean import *
from src_rf.methods.calc_weights import *
from src_rf.methods.calc_dist import *

### 1. Loading Data & Train_test_split

In [2]:
diabetes = load_diabetes()
df_diabetes = pd.DataFrame(
    data=np.c_[diabetes["data"], diabetes["target"]],
    columns=diabetes["feature_names"] + ["target"],
)

In [3]:
X = df_diabetes.iloc[:, :-1].to_numpy(dtype="float32")
y = df_diabetes.iloc[:, -1:].to_numpy().flatten()

In [4]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

### 2. Random Forest

In [5]:
bootstrap = True
max_sample = 0.7

In [6]:
rf = RandomForestRegressor(bootstrap=bootstrap, max_samples=max_sample, random_state=42)

In [7]:
rf_no = RandomForestRegressor(random_state=42, bootstrap= False)

In [8]:
rf.fit(X_train, y_train)
rf_no.fit(X_train, y_train)

RandomForestRegressor(bootstrap=False, random_state=42)

### 3. Calculate Weights Random Forest

In [9]:
rf_weights = calc_weights_rf(rf, X_train, X_test, bootstrap, max_sample)


100%|███████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:01<00:00, 53.43it/s]


In [10]:
rf_no_weights = calc_weights_rf(rf_no, X_train, X_test, False, None)

100%|███████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:03<00:00, 33.30it/s]


### 4. Calc Mean with weights and normal:

In [11]:
rf_mean_weights = calc_mean_rf(rf_weights, y_train)
rf_mean_normal = rf.predict(X_test)

In [12]:
rf_no_mean_weights = calc_mean_rf(rf_no_weights, y_train)
rf_no_mean_normal = rf_no.predict(X_test)

In [13]:
# Are the two the same?
sum(np.round(rf_mean_weights, 10) == np.round(rf_mean_normal, 10))

89

In [14]:
# Are the two the same?
sum(np.round(rf_no_mean_weights, 10) == np.round(rf_no_mean_normal, 10))

89

In [15]:
rf_no_mean_normal

array([243.94,  85.34,  99.  , 202.16,  31.  , 239.38,  79.86, 202.16,
       219.25, 134.52,  86.79,  94.13, 135.05, 302.  ,  99.  ,  85.68,
       127.91,  52.28, 199.64, 152.65, 201.32, 192.94, 208.04, 136.24,
       119.19, 145.28,  69.52,  99.52, 289.84, 112.53,  65.  ,  52.  ,
       168.12, 135.44, 135.32,  54.96,  68.74, 124.46, 232.52, 282.8 ,
        69.24, 147.5 , 138.33,  72.29, 121.18,  51.68, 286.32, 222.24,
       197.96, 135.08,  92.67, 182.42,  94.41, 125.5 , 118.  , 184.75,
        92.15,  99.76, 213.24,  97.26,  78.5 , 152.  , 277.54, 239.45,
        93.97,  97.98,  68.14,  46.22, 235.6 ,  89.84, 151.  , 195.2 ,
       174.47,  52.  , 220.25, 241.76,  76.29, 287.  ,  91.  , 200.  ,
       102.  , 142.9 , 207.29, 324.72, 149.7 ,  84.95, 276.14, 134.84,
       201.6 ])

In [16]:
rf_no_mean_weights

array([243.94,  85.34,  99.  , 202.16,  31.  , 239.38,  79.86, 202.16,
       219.25, 134.52,  86.79,  94.13, 135.05, 302.  ,  99.  ,  85.68,
       127.91,  52.28, 199.64, 152.65, 201.32, 192.94, 208.04, 136.24,
       119.19, 145.28,  69.52,  99.52, 289.84, 112.53,  65.  ,  52.  ,
       168.12, 135.44, 135.32,  54.96,  68.74, 124.46, 232.52, 282.8 ,
        69.24, 147.5 , 138.33,  72.29, 121.18,  51.68, 286.32, 222.24,
       197.96, 135.08,  92.67, 182.42,  94.41, 125.5 , 118.  , 184.75,
        92.15,  99.76, 213.24,  97.26,  78.5 , 152.  , 277.54, 239.45,
        93.97,  97.98,  68.14,  46.22, 235.6 ,  89.84, 151.  , 195.2 ,
       174.47,  52.  , 220.25, 241.76,  76.29, 287.  ,  91.  , 200.  ,
       102.  , 142.9 , 207.29, 324.72, 149.7 ,  84.95, 276.14, 134.84,
       201.6 ])

### 5. Calc Quantile Random Forest

In [17]:
rf_cdfs = calc_dist_rf(rf_weights, y_train)

In [18]:
len(rf_cdfs)

89

In [19]:
rf_median = calc_quantile_rf(rf_cdfs, 0.5, y_train)

In [20]:
rf_median[:10]

[263.109375,
 83.3125,
 208.0000000000002,
 177.9999999999998,
 88.00000000000082,
 190.21875,
 90.99999999999976,
 264.0000000000013,
 242.45703125,
 161.0625]