In [2]:
import numpy as np
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error
from nldg.new.utils import gen_data, max_mse, min_xplvar
from nldg.new.rf import RF4DG, MaggingRF_PB
from adaXT.random_forest import RandomForest

In [3]:
dtr, dts1, dts2 = gen_data(n_train=1000, n_test=500)
Xtr, Xts1, Xts2 = (np.array(dtr.drop(columns=['E', 'Y'])),
                   np.array(dts1.drop(columns=['E', 'Y'])),
                   np.array(dts2.drop(columns=['E', 'Y'])))
Ytr, Yts1, Yts2 = np.array(dtr['Y']), np.array(dts1['Y']), np.array(dts2['Y'])
Etr = np.array(dtr['E'])
min_samples_leaf = 10
n_estimators = 50
random_state = 42
# Switch to Xts1 and Yts1 to be in the convex hull
X_test = Xts2
Y_test = Yts2

### Default

In [4]:
rf = RandomForestRegressor(n_estimators=n_estimators,
                           min_samples_leaf=min_samples_leaf,
                           random_state=random_state)
rf.fit(Xtr, Ytr)
fitted = rf.predict(Xtr)

In [5]:
rf.feature_importances_

array([0.44989299, 0.55010701])

In [6]:
max_mse(Ytr, fitted, Etr, verbose=True)

Environment 0 MSE: 6.0536531345197755
Environment 1 MSE: 17.645077340239204
Environment 2 MSE: 38.36054161024032


np.float64(38.36054161024032)

In [7]:
min_xplvar(Ytr, fitted, Etr, verbose=True)

Environment 0 explained variance: 9.827856042660768
Environment 1 explained variance: 17.959868197923008
Environment 2 explained variance: -4.412149939505497


np.float64(-4.412149939505497)

In [9]:
preds = rf.predict(X_test)
mean_squared_error(Y_test, preds)

8.021324716941832

### Maximin

In [19]:
rf_maximin = RF4DG(criterion='maximin',
                   n_estimators=n_estimators,
                   min_samples_leaf=min_samples_leaf,
                   parallel=True,
                   random_state=random_state)
rf_maximin.fit(Xtr, Ytr, Etr)
fitted_maximin = rf_maximin.predict(Xtr)

100%|██████████| 50/50 [00:00<00:00, 53.21it/s]


In [20]:
max_mse(Ytr, fitted_maximin, Etr, verbose=True)

Environment 0 MSE: 13.107976412951944
Environment 1 MSE: 31.45390718213243
Environment 2 MSE: 31.969858315979202


np.float64(31.969858315979202)

In [21]:
min_xplvar(Ytr, fitted_maximin, Etr, verbose=True)

Environment 0 explained variance: 2.7735327642286
Environment 1 explained variance: 4.151038356029783
Environment 2 explained variance: 1.9785333547556192


np.float64(1.9785333547556192)

In [22]:
preds_maximin = rf_maximin.predict(X_test)
mean_squared_error(Y_test, preds_maximin)

5.748851601657065

### Maximin - adaXT

In [23]:
rf_adaxt = RandomForest("MaximinRegression",
                        n_estimators=n_estimators,
                        min_samples_leaf=min_samples_leaf,
                        seed=random_state)
rf_adaxt.fit(Xtr, Ytr, Etr)
fitted_adaxt = rf_adaxt.predict(Xtr)

In [24]:
max_mse(Ytr, fitted_adaxt, Etr, verbose=True)

Environment 0 MSE: 14.191454635911825
Environment 1 MSE: 33.294863803020675
Environment 2 MSE: 32.52241194004936


np.float64(33.294863803020675)

In [25]:
min_xplvar(Ytr, fitted_adaxt, Etr, verbose=True)

Environment 0 explained variance: 1.6900545412687187
Environment 1 explained variance: 2.3100817351415373
Environment 2 explained variance: 1.4259797306854622


np.float64(1.4259797306854622)

In [26]:
preds_adaxt = rf_adaxt.predict(X_test)
mean_squared_error(Y_test, preds_adaxt)

5.865319125629842

### Maximin - Global - adaXT

In [27]:
rf_adaxt_global = RandomForest("MaximinRegression_Global",
                               n_estimators=n_estimators,
                               min_samples_leaf=min_samples_leaf,
                               seed=random_state)
rf_adaxt_global.fit(Xtr, Ytr, Etr)
fitted_adaxt_global = rf_adaxt_global.predict(Xtr)

In [28]:
max_mse(Ytr, fitted_adaxt_global, Etr, verbose=True)

Environment 0 MSE: 14.365651340868284
Environment 1 MSE: 33.07413714886074
Environment 2 MSE: 32.809451026864686


np.float64(33.07413714886074)

In [29]:
min_xplvar(Ytr, fitted_adaxt_global, Etr, verbose=True)

Environment 0 explained variance: 1.515857836312259
Environment 1 explained variance: 2.5308083893014697
Environment 2 explained variance: 1.1389406438701357


np.float64(1.1389406438701357)

In [30]:
preds_adaxt_global = rf_adaxt_global.predict(X_test)
mean_squared_error(Y_test, preds_adaxt_global)

5.921439033128586

### Magging

In [31]:
rf_magging = MaggingRF_PB(n_estimators=n_estimators,
                          min_samples_leaf=min_samples_leaf,
                          random_state=random_state,
                          backend='sklearn')
fitted_magging, preds_magging = rf_magging.fit_predict_magging(Xtr, Ytr, Etr, X_test)
wmag = rf_magging.get_weights()

In [32]:
max_mse(Ytr, fitted_magging, Etr, verbose=True)

Environment 0 MSE: 14.735047423507314
Environment 1 MSE: 34.59357589558412
Environment 2 MSE: 32.63989748823563


np.float64(34.59357589558412)

In [33]:
min_xplvar(Ytr, fitted_magging, Etr, verbose=True)

Environment 0 explained variance: 1.1464617536732291
Environment 1 explained variance: 1.0113696425780958
Environment 2 explained variance: 1.3084941824991887


np.float64(1.0113696425780958)

In [36]:
mean_squared_error(Y_test, preds_magging)

6.470894602826004

In [35]:
wmag

array([2.81683987e-18, 4.93589477e-01, 5.06410523e-01])

### RF on each environment + equal weights

In [41]:
n_envs = len(np.unique(Etr))
winit = np.array([1 / n_envs] * n_envs)
preds_envs = []
fitted_envs = []
for env in np.unique(Etr):
    Xtr_e = Xtr[Etr == env]
    Ytr_e = Ytr[Etr == env]
    Etr_e = Etr[Etr == env]
    rfm = RandomForestRegressor(
        n_estimators=n_estimators,
        min_samples_leaf=min_samples_leaf,
        random_state=random_state,
    )
    rfm.fit(Xtr_e, Ytr_e)
    preds_envs.append(rfm.predict(X_test))
    fitted_envs.append(rfm.predict(Xtr))
preds_envs = np.column_stack(preds_envs)
fitted_envs = np.column_stack(fitted_envs)

preds_magging_wdef = np.dot(winit, preds_envs.T)
fitted_magging_wdef = np.dot(winit, fitted_envs.T)

In [42]:
max_mse(Ytr, fitted_magging_wdef, Etr, verbose=True)

Environment 0 MSE: 6.974183806082349
Environment 1 MSE: 21.07533639413647
Environment 2 MSE: 49.25600111353233


np.float64(49.25600111353233)

In [43]:
min_xplvar(Ytr, fitted_magging_wdef, Etr, verbose=True)

Environment 0 explained variance: 8.907325371098194
Environment 1 explained variance: 14.529609144025741
Environment 2 explained variance: -15.307609442797506


np.float64(-15.307609442797506)

In [44]:
mean_squared_error(Y_test, preds_magging_wdef)

7.55324796247885