Skip to content

Commit

Permalink
[PUBDEV-7914] grid with failing CV models would hang indefinitely (#5183
Browse files Browse the repository at this point in the history
)
  • Loading branch information
Honza Sterba committed Dec 12, 2020
1 parent 19bfcb9 commit 12f6e81
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 55 deletions.
78 changes: 46 additions & 32 deletions h2o-core/src/main/java/hex/ModelBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,6 @@
*/
abstract public class ModelBuilder<M extends Model<M,P,O>, P extends Model.Parameters, O extends Model.Output> extends Iced {

private ModelBuilderListener _modelBuilderListener;

public void setModelBuilderListener(final ModelBuilderListener modelBuilderListener) {
this._modelBuilderListener = modelBuilderListener;
}

public ToEigenVec getToEigenVec() { return null; }
public boolean shouldReorder(Vec v) { return _parms._categorical_encoding.needsResponse() && isSupervised(); }

Expand Down Expand Up @@ -231,6 +225,13 @@ abstract protected class Driver extends H2O.H2OCountedCompleter<Driver> {

protected Driver(){ super(); }
protected Driver(H2O.H2OCountedCompleter completer){ super(completer); }

private ModelBuilderListener _callback;

public void setCallback(ModelBuilderListener callback) {
this._callback = callback;
}

// Pull the boilerplate out of the computeImpl(), so the algo writer doesn't need to worry about the following:
// 1) Scope (unless they want to keep data, then they must call Scope.untrack(Key<Vec>[]))
// 2) Train/Valid frame locking and unlocking
Expand All @@ -253,16 +254,16 @@ public void compute2() {
@Override
public void onCompletion(CountedCompleter caller) {
setFinalState();
if (_modelBuilderListener != null) {
_modelBuilderListener.onModelSuccess(_result.get());
if (_callback != null) {
_callback.onModelSuccess(_result.get());
}
}

@Override
public boolean onExceptionalCompletion(Throwable ex, CountedCompleter caller) {
setFinalState();
if (_modelBuilderListener != null) {
_modelBuilderListener.onModelFailure(ex, _parms);
if (_callback != null) {
_callback.onModelFailure(ex, _parms);
}
return true;
}
Expand All @@ -287,7 +288,6 @@ private void setFinalState() {
if (res != null && res._output != null) {
res._output._job = _job;
res._output.stopClock();
// res.unlock(_job == null ? null : _job._key, false); // last resort: dirty way to force unlock to be able to reacquire lock
res.write_lock(_job);
res.update(_job);
res.unlock(_job);
Expand Down Expand Up @@ -357,31 +357,45 @@ public void run() {

/** Method to launch training of a Model, based on its parameters. */
final public Job<M> trainModel() {
return trainModel(null);
}

final public Job<M> trainModel(final ModelBuilderListener callback) {
if (error_count() > 0)
throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(this);
startClock();
if( !nFoldCV() )
return _job.start(trainModelImpl(), _parms.progressUnits(), _parms._max_runtime_secs);

// cross-validation needs to be forked off to allow continuous (non-blocking) progress bar
return _job.start(new H2O.H2OCountedCompleter() {
@Override
public void compute2() {
computeCrossValidation();
tryComplete();
}
@Override
public boolean onExceptionalCompletion(Throwable ex, CountedCompleter caller) {
Log.warn("Model training job "+_job._description+" completed with exception: "+ex);
try {
Keyed.remove(_job._result); //ensure there's no incomplete model left for manipulation after crash or cancellation
} catch (Exception logged) {
Log.warn("Exception thrown when removing result from job "+ _job._description, logged);
if (!nFoldCV()) {
Driver driver = trainModelImpl();
driver.setCallback(callback);
return _job.start(driver, _parms.progressUnits(), _parms._max_runtime_secs);
} else {
// cross-validation needs to be forked off to allow continuous (non-blocking) progress bar
return _job.start(new H2O.H2OCountedCompleter() {
@Override
public void compute2() {
computeCrossValidation();
tryComplete();
}

@Override
public void onCompletion(CountedCompleter caller) {
if (callback != null) callback.onModelSuccess(_result.get());
}
return true;
}
},
(nFoldWork()+1/*main model*/) * _parms.progressUnits(), _parms._max_runtime_secs);

@Override
public boolean onExceptionalCompletion(Throwable ex, CountedCompleter caller) {
Log.warn("Model training job " + _job._description + " completed with exception: " + ex);
if (callback != null) callback.onModelFailure(ex, _parms);
try {
Keyed.remove(_job._result); // ensure there's no incomplete model left for manipulation after crash or cancellation
} catch (Exception logged) {
Log.warn("Exception thrown when removing result from job " + _job._description, logged);
}
return true;
}
},
(nFoldWork() + 1/*main model*/) * _parms.progressUnits(), _parms._max_runtime_secs);
}
}

/**
Expand Down
45 changes: 22 additions & 23 deletions h2o-core/src/main/java/hex/ParallelModelBuilder.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package hex;

import jsr166y.ForkJoinTask;
import org.apache.log4j.Logger;
import water.Iced;
import water.util.IcedAtomicInt;

Expand All @@ -15,12 +16,15 @@
* released the barrier inside.
*/
public class ParallelModelBuilder extends ForkJoinTask<ParallelModelBuilder> {

private static final Logger LOG = Logger.getLogger(ParallelModelBuilder.class);

public static abstract class ParallelModelBuilderCallback<D extends ParallelModelBuilderCallback> extends Iced<D> {

public abstract void onBuildSuccess(final Model model, final ParallelModelBuilder parallelModelBuilder);

public abstract void onBuildFailure(final ModelBuildFailure modelBuildFailure, final ParallelModelBuilder parallelModelBuilder);

}

private final transient ParallelModelBuilderCallback _callback;
Expand All @@ -40,39 +44,31 @@ public ParallelModelBuilder(final ParallelModelBuilderCallback callback) {
* @param modelBuilders An {@link Collection} of {@link ModelBuilder} to execute in parallel.
*/
public void run(final Collection<ModelBuilder> modelBuilders) {
for (final ModelBuilder modelBuilder : modelBuilders) {
_modelInProgressCounter.incrementAndGet();

// Set the callbacks
modelBuilder.setModelBuilderListener(_parallelModelBuiltListener);
modelBuilder.trainModel();
}
if (LOG.isTraceEnabled()) LOG.trace("run with " + modelBuilders.size() + " models");
for (final ModelBuilder modelBuilder : modelBuilders) {
_modelInProgressCounter.incrementAndGet();
modelBuilder.trainModel(_parallelModelBuiltListener);
}
}


private class ParallelModelBuiltListener extends ModelBuilderListener<ParallelModelBuiltListener> {

@Override
public void onModelSuccess(Model model) {
if (! model._parms._is_cv_model) {
try {
_callback.onBuildSuccess(model, ParallelModelBuilder.this);
} finally {
_modelInProgressCounter.decrementAndGet();
}
try {
_callback.onBuildSuccess(model, ParallelModelBuilder.this);
} finally {
attemptComplete();
}
}

@Override
public void onModelFailure(Throwable cause, Model.Parameters parameters) {
if (! parameters._is_cv_model) {
try {
final ModelBuildFailure modelBuildFailure = new ModelBuildFailure(cause, parameters);
_callback.onBuildFailure(modelBuildFailure, ParallelModelBuilder.this);
} finally {
_modelInProgressCounter.decrementAndGet();
}
try {
final ModelBuildFailure modelBuildFailure = new ModelBuildFailure(cause, parameters);
_callback.onBuildFailure(modelBuildFailure, ParallelModelBuilder.this);
} finally {
attemptComplete();
}
}
Expand All @@ -99,9 +95,12 @@ public Model.Parameters getParameters() {
}
}

private void attemptComplete(){
if(_modelInProgressCounter.get() != 0) return;
complete(this);
private void attemptComplete() {
int modelsInProgress = _modelInProgressCounter.decrementAndGet();
if (LOG.isTraceEnabled()) LOG.trace("Completed a model, left in progress: " + modelsInProgress);
if (modelsInProgress == 0) {
complete(this);
}
}


Expand Down
38 changes: 38 additions & 0 deletions h2o-py/tests/testdir_algos/grid/pyunit_grid_parallel_cv_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import sys
import os
import random

sys.path.insert(1, os.path.join("..", "..", ".."))
import h2o
from tests import pyunit_utils
from h2o.grid.grid_search import H2OGridSearch
from h2o.estimators.gbm import H2OGradientBoostingEstimator


def grid_parallel():
train = h2o.import_file(path=pyunit_utils.locate("smalldata/iris/iris_wheader.csv"))
fold_assignments = h2o.H2OFrame([[random.randint(0, 4)] for f in range(train.nrow)])
fold_assignments.set_names(["fold_assignment"])
train = train.cbind(fold_assignments)

hyper_parameters = {
"ntrees": [1, 3, 5],
"min_rows": [1, 10, 100]
}
print("GBM grid with the following hyper_parameters:", hyper_parameters)

gs = H2OGridSearch(
H2OGradientBoostingEstimator,
hyper_params=hyper_parameters,
parallelism=4
)
gs.train(x=list(range(4)), y=4, training_frame=train, fold_column="fold_assignment")
assert gs is not None
# only six models are trained, since CV is not possible with min_rows=100
assert len(gs.model_ids) == 6


if __name__ == "__main__":
pyunit_utils.standalone_test(grid_parallel)
else:
grid_parallel()

0 comments on commit 12f6e81

Please sign in to comment.