-
Notifications
You must be signed in to change notification settings - Fork 2k
/
runit_GBMGrid_airlines.R
72 lines (62 loc) · 3.24 KB
/
runit_GBMGrid_airlines.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
setwd(normalizePath(dirname(R.utils::commandArgs(asValues=TRUE)$"f")))
source("../../../scripts/h2o-r-test-setup.R")
gbm.grid.test <- function() {
air.hex <- h2o.uploadFile(locate("smalldata/airlines/allyears2k_headers.zip"), destination_frame="air.hex")
print(summary(air.hex))
myX <- c("DayofMonth", "DayOfWeek")
# Specify grid hyper parameters
ntrees_opts <- c(5, 10, 15)
max_depth_opts <- c(2, 3, 4)
learn_rate_opts <- c(0.1, 0.2)
size_of_hyper_space <- length(ntrees_opts) * length(max_depth_opts) * length(learn_rate_opts)
hyper_params = list( ntrees = ntrees_opts, max_depth = max_depth_opts, learn_rate = learn_rate_opts)
air.grid <- h2o.grid("gbm", y = "IsDepDelayed", x = myX,
distribution="bernoulli",
training_frame = air.hex,
hyper_params = hyper_params)
print(air.grid)
expect_equal(length(air.grid@model_ids), size_of_hyper_space)
# Get models
grid_models <- lapply(air.grid@model_ids, function(mid) {
model = h2o.getModel(mid)
})
# Check expected number of models
expect_equal(length(grid_models), size_of_hyper_space)
expect_model_param(grid_models, "ntrees", ntrees_opts)
expect_model_param(grid_models, "max_depth", max_depth_opts)
expect_model_param(grid_models, "learn_rate", learn_rate_opts)
#
# test random/max_models search criterion: max_models
max_models <- 5
search_criteria = list(strategy = "RandomDiscrete", max_models = max_models, seed=1234)
air.grid <- h2o.grid("gbm", y = "IsDepDelayed", x = myX,
distribution="bernoulli",
training_frame = air.hex,
hyper_params = hyper_params,
search_criteria = search_criteria)
print(air.grid)
expect_equal(length(air.grid@model_ids), max_models)
# test random/max_models search criterion: asymptotic
search_criteria = list(strategy = "RandomDiscrete", stopping_metric = "AUTO", stopping_tolerance = 0.01, stopping_rounds = 3, seed=1234)
air.grid <- h2o.grid("gbm", y = "IsDepDelayed", x = myX,
distribution="bernoulli",
training_frame = air.hex,
hyper_params = hyper_params,
search_criteria = search_criteria,
nfolds = 5, fold_assignment = 'Modulo',
keep_cross_validation_predictions = TRUE,
seed = 5678)
print(air.grid)
expect_that(length(air.grid@model_ids) < size_of_hyper_space, is_true())
# stacker.grid <- h2o.grid("stackedensemble", y = "IsDepDelayed", x = myX,
# training_frame = air.hex,
# model_id = "my_ensemble",
# base_models = air.grid@model_ids)
stacker <- h2o.stackedEnsemble(x = myX, y = "IsDepDelayed", training_frame = air.hex,
model_id = "my_ensemble",
base_models = air.grid@model_ids)
predictions = h2o.predict(stacker, air.hex) # training data
print("preditions for ensemble are in: ")
print(h2o.getId(predictions))
}
doTest("GBM Grid Test: Airlines Smalldata", gbm.grid.test)