Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PUBDEV-7914] grid with failing CV models would hang indefinitely #5183

Merged
merged 3 commits into from
Dec 12, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 46 additions & 32 deletions h2o-core/src/main/java/hex/ModelBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,6 @@
*/
abstract public class ModelBuilder<M extends Model<M,P,O>, P extends Model.Parameters, O extends Model.Output> extends Iced {

private ModelBuilderListener _modelBuilderListener;

public void setModelBuilderListener(final ModelBuilderListener modelBuilderListener) {
this._modelBuilderListener = modelBuilderListener;
}

public ToEigenVec getToEigenVec() { return null; }
public boolean shouldReorder(Vec v) { return _parms._categorical_encoding.needsResponse() && isSupervised(); }

Expand Down Expand Up @@ -231,6 +225,13 @@ abstract protected class Driver extends H2O.H2OCountedCompleter<Driver> {

protected Driver(){ super(); }
protected Driver(H2O.H2OCountedCompleter completer){ super(completer); }

private ModelBuilderListener _callback;

public void setCallback(ModelBuilderListener callback) {
this._callback = callback;
}

// Pull the boilerplate out of the computeImpl(), so the algo writer doesn't need to worry about the following:
// 1) Scope (unless they want to keep data, then they must call Scope.untrack(Key<Vec>[]))
// 2) Train/Valid frame locking and unlocking
Expand All @@ -253,16 +254,16 @@ public void compute2() {
@Override
public void onCompletion(CountedCompleter caller) {
setFinalState();
if (_modelBuilderListener != null) {
_modelBuilderListener.onModelSuccess(_result.get());
if (_callback != null) {
_callback.onModelSuccess(_result.get());
}
}

@Override
public boolean onExceptionalCompletion(Throwable ex, CountedCompleter caller) {
setFinalState();
if (_modelBuilderListener != null) {
_modelBuilderListener.onModelFailure(ex, _parms);
if (_callback != null) {
_callback.onModelFailure(ex, _parms);
}
return true;
}
Expand All @@ -287,7 +288,6 @@ private void setFinalState() {
if (res != null && res._output != null) {
res._output._job = _job;
res._output.stopClock();
// res.unlock(_job == null ? null : _job._key, false); // last resort: dirty way to force unlock to be able to reacquire lock
res.write_lock(_job);
res.update(_job);
res.unlock(_job);
Expand Down Expand Up @@ -357,31 +357,45 @@ public void run() {

/** Method to launch training of a Model, based on its parameters. */
final public Job<M> trainModel() {
return trainModel(null);
}

final public Job<M> trainModel(final ModelBuilderListener callback) {
if (error_count() > 0)
throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(this);
startClock();
if( !nFoldCV() )
return _job.start(trainModelImpl(), _parms.progressUnits(), _parms._max_runtime_secs);

// cross-validation needs to be forked off to allow continuous (non-blocking) progress bar
return _job.start(new H2O.H2OCountedCompleter() {
@Override
public void compute2() {
computeCrossValidation();
tryComplete();
}
@Override
public boolean onExceptionalCompletion(Throwable ex, CountedCompleter caller) {
Log.warn("Model training job "+_job._description+" completed with exception: "+ex);
try {
Keyed.remove(_job._result); //ensure there's no incomplete model left for manipulation after crash or cancellation
} catch (Exception logged) {
Log.warn("Exception thrown when removing result from job "+ _job._description, logged);
if (!nFoldCV()) {
Driver driver = trainModelImpl();
driver.setCallback(callback);
return _job.start(driver, _parms.progressUnits(), _parms._max_runtime_secs);
} else {
// cross-validation needs to be forked off to allow continuous (non-blocking) progress bar
return _job.start(new H2O.H2OCountedCompleter() {
@Override
public void compute2() {
computeCrossValidation();
tryComplete();
}

@Override
public void onCompletion(CountedCompleter caller) {
if (callback != null) callback.onModelSuccess(_result.get());
}
return true;
}
},
(nFoldWork()+1/*main model*/) * _parms.progressUnits(), _parms._max_runtime_secs);

@Override
public boolean onExceptionalCompletion(Throwable ex, CountedCompleter caller) {
Log.warn("Model training job " + _job._description + " completed with exception: " + ex);
if (callback != null) callback.onModelFailure(ex, _parms);
try {
Keyed.remove(_job._result); // ensure there's no incomplete model left for manipulation after crash or cancellation
} catch (Exception logged) {
Log.warn("Exception thrown when removing result from job " + _job._description, logged);
}
return true;
}
},
(nFoldWork() + 1/*main model*/) * _parms.progressUnits(), _parms._max_runtime_secs);
}
}

/**
Expand Down
45 changes: 22 additions & 23 deletions h2o-core/src/main/java/hex/ParallelModelBuilder.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package hex;

import jsr166y.ForkJoinTask;
import org.apache.log4j.Logger;
import water.Iced;
import water.util.IcedAtomicInt;

Expand All @@ -15,12 +16,15 @@
* released the barrier inside.
*/
public class ParallelModelBuilder extends ForkJoinTask<ParallelModelBuilder> {

private static final Logger LOG = Logger.getLogger(ParallelModelBuilder.class);

public static abstract class ParallelModelBuilderCallback<D extends ParallelModelBuilderCallback> extends Iced<D> {

public abstract void onBuildSuccess(final Model model, final ParallelModelBuilder parallelModelBuilder);

public abstract void onBuildFailure(final ModelBuildFailure modelBuildFailure, final ParallelModelBuilder parallelModelBuilder);

}

private final transient ParallelModelBuilderCallback _callback;
Expand All @@ -40,39 +44,31 @@ public ParallelModelBuilder(final ParallelModelBuilderCallback callback) {
* @param modelBuilders An {@link Collection} of {@link ModelBuilder} to execute in parallel.
*/
public void run(final Collection<ModelBuilder> modelBuilders) {
for (final ModelBuilder modelBuilder : modelBuilders) {
_modelInProgressCounter.incrementAndGet();

// Set the callbacks
modelBuilder.setModelBuilderListener(_parallelModelBuiltListener);
modelBuilder.trainModel();
}
if (LOG.isTraceEnabled()) LOG.trace("run with " + modelBuilders.size() + " models");
for (final ModelBuilder modelBuilder : modelBuilders) {
_modelInProgressCounter.incrementAndGet();
modelBuilder.trainModel(_parallelModelBuiltListener);
}
}


private class ParallelModelBuiltListener extends ModelBuilderListener<ParallelModelBuiltListener> {

@Override
public void onModelSuccess(Model model) {
if (! model._parms._is_cv_model) {
try {
_callback.onBuildSuccess(model, ParallelModelBuilder.this);
} finally {
_modelInProgressCounter.decrementAndGet();
}
try {
_callback.onBuildSuccess(model, ParallelModelBuilder.this);
} finally {
attemptComplete();
}
}

@Override
public void onModelFailure(Throwable cause, Model.Parameters parameters) {
if (! parameters._is_cv_model) {
try {
final ModelBuildFailure modelBuildFailure = new ModelBuildFailure(cause, parameters);
_callback.onBuildFailure(modelBuildFailure, ParallelModelBuilder.this);
} finally {
_modelInProgressCounter.decrementAndGet();
}
try {
final ModelBuildFailure modelBuildFailure = new ModelBuildFailure(cause, parameters);
_callback.onBuildFailure(modelBuildFailure, ParallelModelBuilder.this);
} finally {
attemptComplete();
}
}
Expand All @@ -99,9 +95,12 @@ public Model.Parameters getParameters() {
}
}

private void attemptComplete(){
if(_modelInProgressCounter.get() != 0) return;
complete(this);
private void attemptComplete() {
int modelsInProgress = _modelInProgressCounter.decrementAndGet();
if (LOG.isTraceEnabled()) LOG.trace("Completed a model, left in progress: " + modelsInProgress);
if (modelsInProgress == 0) {
complete(this);
}
}


Expand Down
38 changes: 38 additions & 0 deletions h2o-py/tests/testdir_algos/grid/pyunit_grid_parallel_cv_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import sys
import os
import random

sys.path.insert(1, os.path.join("..", "..", ".."))
import h2o
from tests import pyunit_utils
from h2o.grid.grid_search import H2OGridSearch
from h2o.estimators.gbm import H2OGradientBoostingEstimator


def grid_parallel():
train = h2o.import_file(path=pyunit_utils.locate("smalldata/iris/iris_wheader.csv"))
fold_assignments = h2o.H2OFrame([[random.randint(0, 4)] for f in range(train.nrow)])
fold_assignments.set_names(["fold_assignment"])
train = train.cbind(fold_assignments)

hyper_parameters = {
"ntrees": [1, 3, 5],
"min_rows": [1, 10, 100]
}
print("GBM grid with the following hyper_parameters:", hyper_parameters)

gs = H2OGridSearch(
H2OGradientBoostingEstimator,
hyper_params=hyper_parameters,
parallelism=4
)
gs.train(x=list(range(4)), y=4, training_frame=train, fold_column="fold_assignment")
assert gs is not None
# only six models are trained, since CV is not possible with min_rows=100
assert len(gs.model_ids) == 6


if __name__ == "__main__":
pyunit_utils.standalone_test(grid_parallel)
else:
grid_parallel()