In [1]:
#----------------------------------------------------------------------
# Purpose: Train a GBM model. Fetch and traverse the backed tree
#----------------------------------------------------------------------

In [2]:
import h2o
from h2o.estimators.gbm import H2OGradientBoostingEstimator

versionFromGradle='3.20.0',projectVersion='3.20.0.99999',branch='rel-wright',lastCommitHash='17cd2095ef4547f12c3efc122ceba7132a8a8f56',gitDescribe='jenkins-3.20.0.7-6-g17cd2095ef-dirty',compiledOn='2018-09-11 15:07:39',compiledBy='pavel'


In [3]:
h2o.init()

Checking whether there is an H2O instance running at http://localhost:54321. connected.
versionFromGradle='3.20.0',projectVersion='3.20.0.99999',branch='rel-wright',lastCommitHash='17cd2095ef4547f12c3efc122ceba7132a8a8f56',gitDescribe='jenkins-3.20.0.7-6-g17cd2095ef-dirty',compiledOn='2018-09-11 15:07:39',compiledBy='pavel'


0,1
H2O cluster uptime:,29 secs
H2O cluster timezone:,Europe/Prague
H2O data parsing timezone:,UTC
H2O cluster version:,3.20.0.99999
H2O cluster version age:,21 hours and 17 minutes
H2O cluster name:,pavel
H2O cluster total nodes:,1
H2O cluster free memory:,3.452 Gb
H2O cluster total cores:,12
H2O cluster allowed cores:,12


In [4]:
from h2o.utils.shared_utils import _locate # private function. used to find files within h2o git project directory.

air = h2o.import_file(_locate("smalldata/airlines/allyears2k_headers.zip"))

Parse progress: |█████████████████████████████████████████████████████████| 100%


In [5]:
# Create new GBM model with limited number of trees and limited depth.
gbm_air_model = H2OGradientBoostingEstimator(ntrees = 3, max_depth = 3)

In [6]:
gbm_air_model.train(x = ["Origin", "Distance", "UniqueCarrier"], y = "IsDepDelayed", training_frame = air)

gbm Model Build progress: |███████████████████████████████████████████████| 100%


In [7]:
# Fetch the very first tree (out of 3 total trees)
from h2o.tree import H2OTree, H2ONode
tree = H2OTree(gbm_air_model, 0, "NO")

In [8]:
# Print number of nodes in a tree
len(tree)

15

In [9]:
# Show description of root node
tree.root_node.show()

Node ID 0 
Left child node ID = 1
Right child node ID = 2

Splits on column Origin
  - Categorical levels going to the left node: ['ABE', 'ABQ', 'ACY', 'AUS', 'AVP', 'BHM', 'BIL', 'BNA', 'BOI', 'BOS', 'BUF', 'BUR', 'CAE', 'CHS', 'CLE', 'CLT', 'COS', 'CRP', 'CRW', 'DCA', 'DEN', 'DSM', 'DTW', 'EGE', 'EWR', 'EYW', 'GEG', 'GNV', 'GSO', 'HNL', 'HRL', 'IAD', 'IAH', 'JAX', 'JFK', 'KOA', 'LAN', 'LBB', 'LIH', 'MAF', 'MDT', 'MEM', 'MHT', 'MKE', 'MLB', 'MRY', 'MSP', 'MSY', 'MYR', 'OAK', 'OGG', 'OKC', 'PHF', 'PHX', 'PWM', 'RDU', 'SAN', 'SAV', 'SBN', 'SDF', 'SJC', 'SLC', 'SMF', 'STL', 'STT', 'TLH', 'TRI', 'TUL', 'TUS', 'TYS']
  - Categorical levels going to the right node: ['ALB', 'AMA', 'ANC', 'ATL', 'BDL', 'BGM', 'BTV', 'BWI', 'CHO', 'CMH', 'CVG', 'DAL', 'DAY', 'DFW', 'ELP', 'ERI', 'FLL', 'GRR', 'HOU', 'HPN', 'ICT', 'IND', 'ISP', 'JAN', 'LAS', 'LAX', 'LEX', 'LGA', 'LIT', 'LYH', 'MCI', 'MCO', 'MDW', 'MFR', 'MIA', 'OMA', 'ONT', 'ORD', 'ORF', 'PBI', 'PDX', 'PHL', 'PIT', 'PSP', 'PVD', 'RIC', 'RNO', '

In [10]:
# Show description of a terminal node
tree.root_node.left_child.right_child.right_child.show()

Leaf node ID 18. Predicted value at leaf node is 0.044361357 



In [11]:
# Show raw attributes of a tree node
print(tree.root_node.split_feature)
print(tree.root_node.left_levels)
print(tree.root_node.right_levels)
print(tree.root_node.threshold)
print(tree.root_node.na_direction)

Origin
['ABE', 'ABQ', 'ACY', 'AUS', 'AVP', 'BHM', 'BIL', 'BNA', 'BOI', 'BOS', 'BUF', 'BUR', 'CAE', 'CHS', 'CLE', 'CLT', 'COS', 'CRP', 'CRW', 'DCA', 'DEN', 'DSM', 'DTW', 'EGE', 'EWR', 'EYW', 'GEG', 'GNV', 'GSO', 'HNL', 'HRL', 'IAD', 'IAH', 'JAX', 'JFK', 'KOA', 'LAN', 'LBB', 'LIH', 'MAF', 'MDT', 'MEM', 'MHT', 'MKE', 'MLB', 'MRY', 'MSP', 'MSY', 'MYR', 'OAK', 'OGG', 'OKC', 'PHF', 'PHX', 'PWM', 'RDU', 'SAN', 'SAV', 'SBN', 'SDF', 'SJC', 'SLC', 'SMF', 'STL', 'STT', 'TLH', 'TRI', 'TUL', 'TUS', 'TYS']
['ALB', 'AMA', 'ANC', 'ATL', 'BDL', 'BGM', 'BTV', 'BWI', 'CHO', 'CMH', 'CVG', 'DAL', 'DAY', 'DFW', 'ELP', 'ERI', 'FLL', 'GRR', 'HOU', 'HPN', 'ICT', 'IND', 'ISP', 'JAN', 'LAS', 'LAX', 'LEX', 'LGA', 'LIT', 'LYH', 'MCI', 'MCO', 'MDW', 'MFR', 'MIA', 'OMA', 'ONT', 'ORD', 'ORF', 'PBI', 'PDX', 'PHL', 'PIT', 'PSP', 'PVD', 'RIC', 'RNO', 'ROA', 'ROC', 'RSW', 'SAT', 'SCK', 'SEA', 'SFO', 'SJU', 'SNA', 'SRQ', 'STX', 'SWF', 'SYR', 'TPA', 'UCA']
nan
RIGHT


In [12]:
# Print some of the raw attributes of a tree available
print(tree.left_children)
print(tree.right_children)
print(tree.thresholds)
print(tree.features)

[1, 3, 5, 7, 9, 11, 13, -1, -1, -1, -1, -1, -1, -1, -1]
[2, 4, 6, 8, 10, 12, 14, -1, -1, -1, -1, -1, -1, -1, -1]
[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan]
['Origin', 'Origin', 'UniqueCarrier', 'UniqueCarrier', 'UniqueCarrier', 'UniqueCarrier', 'Origin', None, None, None, None, None, None, None, None]
