# A simple example how to train provably robust boosted trees.

In [1]:
%load_ext autoreload
%autoreload 2
import os
os.chdir("../")
import numpy as np
import data
from tree_ensemble import TreeEnsemble


In [8]:
n_trees = 20  # total number of trees in the ensemble
model = 'robust_bound'  # robust tree ensemble
X_train, y_train, X_test, y_test, eps = data.all_datasets_dict['diabetes']()

# initialize a tree ensemble with some hyperparameters
ensemble = TreeEnsemble(weak_learner='tree', n_trials_coord=X_train.shape[1], 
                        lr=1.0, min_samples_split=5, min_samples_leaf=10, max_depth=3)
# initialize gammas, per-example weights which are recalculated each iteration
gamma = np.ones(X_train.shape[0])
for i in range(1, n_trees + 1):
    # fit a new tree in order to minimize the robust loss of the whole ensemble
    weak_learner = ensemble.fit_tree(X_train, y_train, gamma, model, eps, depth=1)
    ensemble.add_weak_learner(weak_learner)
    ensemble.prune_last_tree(X_train, y_train, eps, model)
    # calculate per-example weights for the next iteration
    gamma = np.exp(-ensemble.certify_treewise_bound(X_train, y_train, eps))
    
    # track generalization and robustness
    yf_test = y_test * ensemble.predict(X_test)
    min_yf_test = ensemble.certify_treewise_bound(X_test, y_test, eps)
    print('Iteration: {}, test error: {:.2%}, upper bound on robust test error: {:.2%}'.format(
        i, np.mean(yf_test < 0.0), np.mean(min_yf_test < 0.0)))
    

Iteration: 1, test error: 24.68%, upper bound on robust test error: 32.47%
Iteration: 2, test error: 23.38%, upper bound on robust test error: 32.47%
Iteration: 3, test error: 23.38%, upper bound on robust test error: 32.47%
Iteration: 4, test error: 23.38%, upper bound on robust test error: 32.47%
Iteration: 5, test error: 24.03%, upper bound on robust test error: 33.12%
Iteration: 6, test error: 24.03%, upper bound on robust test error: 33.12%
Iteration: 7, test error: 24.03%, upper bound on robust test error: 33.12%
Iteration: 8, test error: 24.03%, upper bound on robust test error: 33.12%
Iteration: 9, test error: 24.03%, upper bound on robust test error: 33.12%
Iteration: 10, test error: 24.03%, upper bound on robust test error: 33.12%
Iteration: 11, test error: 24.03%, upper bound on robust test error: 33.12%
Iteration: 12, test error: 24.03%, upper bound on robust test error: 33.12%
Iteration: 13, test error: 24.03%, upper bound on robust test error: 33.12%
Iteration: 14, test e