-
Notifications
You must be signed in to change notification settings - Fork 2k
/
XValPredictionsCheck.java
128 lines (118 loc) · 4.51 KB
/
XValPredictionsCheck.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
package hex;
import hex.deeplearning.DeepLearning;
import hex.deeplearning.DeepLearningModel;
import hex.genmodel.utils.DistributionFamily;
import hex.glm.GLM;
import hex.glm.GLMModel;
import hex.tree.drf.DRF;
import hex.tree.drf.DRFModel;
import hex.tree.gbm.GBM;
import hex.tree.gbm.GBMModel;
import org.junit.BeforeClass;
import org.junit.Test;
import water.*;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.rapids.ast.prims.advmath.AstKFold;
import water.util.ArrayUtils;
import static org.junit.Assert.assertEquals;
/**
* This test is intended to corroborate the documented description of cross-validated
* predictions as a result of model building. These datasets have identifiers of the form
* *_cv_1, *_cv_2, ..., *_cv_n
*
* This test makes GBM, DRF, GLM, and DL models with a randomized fold column, and it
* checks that each *_cv_n contain predictions consistent with the fold column on the
* original frame.
*/
public class XValPredictionsCheck extends TestUtil {
@BeforeClass() public static void setup() { stall_till_cloudsize(1); }
@Test public void testXValPredictions() {
final int nfolds = 3;
Frame tfr = null;
try {
// Load data, hack frames
tfr = parse_test_file("smalldata/iris/iris_wheader.csv");
Frame foldId = new Frame(new String[]{"foldId"}, new Vec[]{AstKFold.kfoldColumn(tfr.vec("class").makeZero(), nfolds, 543216789)});
tfr.add(foldId);
DKV.put(tfr);
// GBM
GBMModel.GBMParameters parms = new GBMModel.GBMParameters();
parms._train = tfr._key;
parms._response_column = "class";
parms._ntrees = 1;
parms._max_depth = 1;
parms._fold_column = "foldId";
parms._distribution = DistributionFamily.multinomial;
parms._keep_cross_validation_predictions=true;
GBM job = new GBM(parms);
GBMModel gbm = job.trainModel().get();
checkModel(gbm, foldId.anyVec(),3);
// DRF
DRFModel.DRFParameters parmsDRF = new DRFModel.DRFParameters();
parmsDRF._train = tfr._key;
parmsDRF._response_column = "class";
parmsDRF._ntrees = 1;
parmsDRF._max_depth = 1;
parmsDRF._fold_column = "foldId";
parmsDRF._distribution = DistributionFamily.multinomial;
parmsDRF._keep_cross_validation_predictions=true;
DRF drfJob = new DRF(parmsDRF);
DRFModel drf = drfJob.trainModel().get();
checkModel(drf, foldId.anyVec(),3);
// GLM
GLMModel.GLMParameters parmsGLM = new GLMModel.GLMParameters();
parmsGLM._train = tfr._key;
parmsGLM._response_column = "sepal_len";
parmsGLM._fold_column = "foldId";
parmsGLM._keep_cross_validation_predictions=true;
GLM glmJob = new GLM(parmsGLM);
GLMModel glm = glmJob.trainModel().get();
checkModel(glm, foldId.anyVec(),1);
// DL
DeepLearningModel.DeepLearningParameters parmsDL = new DeepLearningModel.DeepLearningParameters();
parmsDL._train = tfr._key;
parmsDL._response_column = "class";
parmsDL._hidden = new int[]{1};
parmsDL._epochs = 1;
parmsDL._fold_column = "foldId";
parmsDL._keep_cross_validation_predictions=true;
DeepLearning dlJob = new DeepLearning(parmsDL);
DeepLearningModel dl = dlJob.trainModel().get();
checkModel(dl, foldId.anyVec(),3);
} finally {
if (tfr != null) tfr.remove();
}
}
void checkModel(Model m, Vec foldId, int nclass) {
if(!(m instanceof DRFModel)) // DRF does out of back instead of true training, nobs might be different
assertEquals(m._output._training_metrics._nobs,m._output._cross_validation_metrics._nobs);
m.delete();
m.deleteCrossValidationModels();
Key[] xvalKeys = m._output._cross_validation_predictions;
Key xvalKey = m._output._cross_validation_holdout_predictions_frame_id;
final int[] id = new int[1];
for(Key k: xvalKeys) {
Frame preds = DKV.getGet(k);
assert preds.numRows() == foldId.length();
Vec[] vecs = new Vec[nclass+1];
vecs[0] = foldId;
if( nclass==1 ) vecs[1] = preds.anyVec();
else
System.arraycopy(preds.vecs(ArrayUtils.range(1, nclass)), 0, vecs, 1, nclass);
new MRTask() {
@Override public void map(Chunk[] cs) {
Chunk foldId = cs[0];
for(int r=0;r<cs[0]._len; ++r)
if( foldId.at8(r) != id[0] )
for(int i=1; i<cs.length;++i)
assert cs[i].atd(r)==0; // no prediction for this row!
}
}.doAll(vecs);
id[0]++;
preds.delete();
}
xvalKey.remove();
}
}