-
Notifications
You must be signed in to change notification settings - Fork 2k
/
pyunit_mean_residual_devianceGBM.py
27 lines (24 loc) · 1.29 KB
/
pyunit_mean_residual_devianceGBM.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import sys
sys.path.insert(1,"../../../")
import h2o
def gbm_mean_residual_deviance(ip,port):
cars = h2o.import_frame(path=h2o.locate("smalldata/junit/cars_20mpg.csv"))
s = cars[0].runif()
train = cars[s > 0.2]
valid = cars[s <= 0.2]
predictors = ["displacement","power","weight","acceleration","year"]
response_col = "economy"
gbm = h2o.gbm(x=train[predictors],
y=train[response_col],
validation_x=valid[predictors],
validation_y=valid[response_col],
nfolds=3)
gbm_mrd = gbm.mean_residual_deviance(train=True,valid=True,xval=True)
assert isinstance(gbm_mrd['train'],float), "Expected training mean residual deviance to be a float, but got " \
"{0}".format(type(gbm_mrd['train']))
assert isinstance(gbm_mrd['valid'],float), "Expected validation mean residual deviance to be a float, but got " \
"{0}".format(type(gbm_mrd['valid']))
assert isinstance(gbm_mrd['xval'],float), "Expected cross-validation mean residual deviance to be a float, but got " \
"{0}".format(type(gbm_mrd['xval']))
if __name__ == '__main__':
h2o.run_test(sys.argv, gbm_mean_residual_deviance)