-
Notifications
You must be signed in to change notification settings - Fork 2k
/
pyunit_imbalancedRF.py
51 lines (35 loc) · 1.45 KB
/
pyunit_imbalancedRF.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from __future__ import print_function
from builtins import range
import sys
sys.path.insert(1,"../../../")
import h2o
from tests import pyunit_utils
from h2o.estimators.random_forest import H2ORandomForestEstimator
def imbalanced():
covtype = h2o.import_file(path=pyunit_utils.locate("smalldata/covtype/covtype.20k.data"))
covtype[54] = covtype[54].asfactor()
imbalanced = H2ORandomForestEstimator(ntrees=10, balance_classes=False, nfolds=3)
imbalanced.train(x=list(range(54)), y=54, training_frame=covtype)
imbalanced_perf = imbalanced.model_performance(covtype)
imbalanced_perf.show()
balanced = H2ORandomForestEstimator(ntrees=10, balance_classes=True, seed=123, nfolds=3)
balanced.train(x=list(range(54)), y=54, training_frame=covtype)
balanced_perf = balanced.model_performance(covtype)
balanced_perf.show()
##compare error for class 6 (difficult minority)
class_6_err_imbalanced = imbalanced_perf.confusion_matrix().cell_values[5][7]
class_6_err_balanced = balanced_perf.confusion_matrix().cell_values[5][7]
print("--------------------")
print("")
print("class_6_err_imbalanced")
print(class_6_err_imbalanced)
print("")
print("class_6_err_balanced")
print(class_6_err_balanced)
print("")
print("--------------------")
assert class_6_err_imbalanced >= 0.9*class_6_err_balanced, "balance_classes makes it at least 10% worse!"
if __name__ == "__main__":
pyunit_utils.standalone_test(imbalanced)
else:
imbalanced()