Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PUBDEV-6376: StackedEnsemble prediction fails when applied to a test dataset without response column #3382

Merged
merged 4 commits into from
Mar 28, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions h2o-algos/src/main/java/hex/glm/GLMModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -1424,17 +1424,17 @@ public void generate(JCodeSB out) {



private GLMScore makeScoringTask(Frame adaptFrm, boolean generatePredictions, Job j){
private GLMScore makeScoringTask(Frame adaptFrm, boolean generatePredictions, Job j, boolean computeMetrics){
sebhrusen marked this conversation as resolved.
Show resolved Hide resolved
int responseId = adaptFrm.find(_output.responseName());
if(responseId > -1 && adaptFrm.vec(responseId).isBad()) { // remove inserted invalid response
adaptFrm = new Frame(adaptFrm.names(),adaptFrm.vecs());
adaptFrm.remove(responseId);
}
// Build up the names & domains.
final boolean computeMetrics = adaptFrm.vec(_output.responseName()) != null && !adaptFrm.vec(_output.responseName()).isBad();
String [] domain = _output.nclasses()<=1 ? null : !computeMetrics ? _output._domains[_output._domains.length-1] : adaptFrm.lastVec().domain();
final boolean detectedComputeMetrics = computeMetrics && (adaptFrm.vec(_output.responseName()) != null && !adaptFrm.vec(_output.responseName()).isBad());
String [] domain = _output.nclasses()<=1 ? null : !detectedComputeMetrics ? _output._domains[_output._domains.length-1] : adaptFrm.lastVec().domain();
// Score the dataset, building the class distribution & predictions
return new GLMScore(j, this, _output._dinfo.scoringInfo(_output._names,adaptFrm),domain,computeMetrics, generatePredictions);
return new GLMScore(j, this, _output._dinfo.scoringInfo(_output._names,adaptFrm),domain,detectedComputeMetrics, generatePredictions);
}
/** Score an already adapted frame. Returns a new Frame with new result
* vectors, all in the DKV. Caller responsible for deleting. Input is
Expand All @@ -1449,7 +1449,7 @@ private GLMScore makeScoringTask(Frame adaptFrm, boolean generatePredictions, Jo
protected Frame predictScoreImpl(Frame fr, Frame adaptFrm, String destination_key, Job j, boolean computeMetrics, CFuncRef customMetricFunc) {
String [] names = makeScoringNames();
String [][] domains = new String[names.length][];
GLMScore gs = makeScoringTask(adaptFrm,true,j);// doAll(names.length,Vec.T_NUM,adaptFrm);
GLMScore gs = makeScoringTask(adaptFrm,true,j, computeMetrics);// doAll(names.length,Vec.T_NUM,adaptFrm);
assert gs._dinfo._valid:"_valid flag should be set on data info when doing scoring";
gs.doAll(names.length,Vec.T_NUM,gs._dinfo._adaptedFrame);
if (gs._computeMetrics)
Expand All @@ -1469,7 +1469,7 @@ protected Frame predictScoreImpl(Frame fr, Frame adaptFrm, String destination_ke
*/
@Override
protected ModelMetrics.MetricBuilder scoreMetrics(Frame adaptFrm) {
GLMScore gs = makeScoringTask(adaptFrm,false,null);// doAll(names.length,Vec.T_NUM,adaptFrm);
GLMScore gs = makeScoringTask(adaptFrm,false,null, true);// doAll(names.length,Vec.T_NUM,adaptFrm);
assert gs._dinfo._valid:"_valid flag should be set on data info when doing scoring";
return gs.doAll(gs._dinfo._adaptedFrame)._mb;
}
Expand Down
Loading