diff --git a/h2o-core/src/main/java/hex/ModelBuilder.java b/h2o-core/src/main/java/hex/ModelBuilder.java index ea55402c0ef0..b0b7a2567b5d 100644 --- a/h2o-core/src/main/java/hex/ModelBuilder.java +++ b/h2o-core/src/main/java/hex/ModelBuilder.java @@ -21,12 +21,6 @@ */ abstract public class ModelBuilder, P extends Model.Parameters, O extends Model.Output> extends Iced { - private ModelBuilderListener _modelBuilderListener; - - public void setModelBuilderListener(final ModelBuilderListener modelBuilderListener) { - this._modelBuilderListener = modelBuilderListener; - } - public ToEigenVec getToEigenVec() { return null; } public boolean shouldReorder(Vec v) { return _parms._categorical_encoding.needsResponse() && isSupervised(); } @@ -231,6 +225,13 @@ abstract protected class Driver extends H2O.H2OCountedCompleter { protected Driver(){ super(); } protected Driver(H2O.H2OCountedCompleter completer){ super(completer); } + + private ModelBuilderListener _callback; + + public void setCallback(ModelBuilderListener callback) { + this._callback = callback; + } + // Pull the boilerplate out of the computeImpl(), so the algo writer doesn't need to worry about the following: // 1) Scope (unless they want to keep data, then they must call Scope.untrack(Key[])) // 2) Train/Valid frame locking and unlocking @@ -253,16 +254,16 @@ public void compute2() { @Override public void onCompletion(CountedCompleter caller) { setFinalState(); - if (_modelBuilderListener != null) { - _modelBuilderListener.onModelSuccess(_result.get()); + if (_callback != null) { + _callback.onModelSuccess(_result.get()); } } @Override public boolean onExceptionalCompletion(Throwable ex, CountedCompleter caller) { setFinalState(); - if (_modelBuilderListener != null) { - _modelBuilderListener.onModelFailure(ex, _parms); + if (_callback != null) { + _callback.onModelFailure(ex, _parms); } return true; } @@ -287,7 +288,6 @@ private void setFinalState() { if (res != null && res._output != null) { res._output._job = _job; res._output.stopClock(); -// res.unlock(_job == null ? null : _job._key, false); // last resort: dirty way to force unlock to be able to reacquire lock res.write_lock(_job); res.update(_job); res.unlock(_job); @@ -357,31 +357,45 @@ public void run() { /** Method to launch training of a Model, based on its parameters. */ final public Job trainModel() { + return trainModel(null); + } + + final public Job trainModel(final ModelBuilderListener callback) { if (error_count() > 0) throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(this); startClock(); - if( !nFoldCV() ) - return _job.start(trainModelImpl(), _parms.progressUnits(), _parms._max_runtime_secs); - - // cross-validation needs to be forked off to allow continuous (non-blocking) progress bar - return _job.start(new H2O.H2OCountedCompleter() { - @Override - public void compute2() { - computeCrossValidation(); - tryComplete(); - } - @Override - public boolean onExceptionalCompletion(Throwable ex, CountedCompleter caller) { - Log.warn("Model training job "+_job._description+" completed with exception: "+ex); - try { - Keyed.remove(_job._result); //ensure there's no incomplete model left for manipulation after crash or cancellation - } catch (Exception logged) { - Log.warn("Exception thrown when removing result from job "+ _job._description, logged); + if (!nFoldCV()) { + Driver driver = trainModelImpl(); + driver.setCallback(callback); + return _job.start(driver, _parms.progressUnits(), _parms._max_runtime_secs); + } else { + // cross-validation needs to be forked off to allow continuous (non-blocking) progress bar + return _job.start(new H2O.H2OCountedCompleter() { + @Override + public void compute2() { + computeCrossValidation(); + tryComplete(); + } + + @Override + public void onCompletion(CountedCompleter caller) { + if (callback != null) callback.onModelSuccess(_result.get()); } - return true; - } - }, - (nFoldWork()+1/*main model*/) * _parms.progressUnits(), _parms._max_runtime_secs); + + @Override + public boolean onExceptionalCompletion(Throwable ex, CountedCompleter caller) { + Log.warn("Model training job " + _job._description + " completed with exception: " + ex); + if (callback != null) callback.onModelFailure(ex, _parms); + try { + Keyed.remove(_job._result); // ensure there's no incomplete model left for manipulation after crash or cancellation + } catch (Exception logged) { + Log.warn("Exception thrown when removing result from job " + _job._description, logged); + } + return true; + } + }, + (nFoldWork() + 1/*main model*/) * _parms.progressUnits(), _parms._max_runtime_secs); + } } /** diff --git a/h2o-core/src/main/java/hex/ParallelModelBuilder.java b/h2o-core/src/main/java/hex/ParallelModelBuilder.java index 9528d727a6e7..98bb175a165e 100644 --- a/h2o-core/src/main/java/hex/ParallelModelBuilder.java +++ b/h2o-core/src/main/java/hex/ParallelModelBuilder.java @@ -1,6 +1,7 @@ package hex; import jsr166y.ForkJoinTask; +import org.apache.log4j.Logger; import water.Iced; import water.util.IcedAtomicInt; @@ -15,12 +16,15 @@ * released the barrier inside. */ public class ParallelModelBuilder extends ForkJoinTask { + + private static final Logger LOG = Logger.getLogger(ParallelModelBuilder.class); public static abstract class ParallelModelBuilderCallback extends Iced { public abstract void onBuildSuccess(final Model model, final ParallelModelBuilder parallelModelBuilder); public abstract void onBuildFailure(final ModelBuildFailure modelBuildFailure, final ParallelModelBuilder parallelModelBuilder); + } private final transient ParallelModelBuilderCallback _callback; @@ -40,13 +44,11 @@ public ParallelModelBuilder(final ParallelModelBuilderCallback callback) { * @param modelBuilders An {@link Collection} of {@link ModelBuilder} to execute in parallel. */ public void run(final Collection modelBuilders) { - for (final ModelBuilder modelBuilder : modelBuilders) { - _modelInProgressCounter.incrementAndGet(); - - // Set the callbacks - modelBuilder.setModelBuilderListener(_parallelModelBuiltListener); - modelBuilder.trainModel(); - } + if (LOG.isTraceEnabled()) LOG.trace("run with " + modelBuilders.size() + " models"); + for (final ModelBuilder modelBuilder : modelBuilders) { + _modelInProgressCounter.incrementAndGet(); + modelBuilder.trainModel(_parallelModelBuiltListener); + } } @@ -54,25 +56,19 @@ private class ParallelModelBuiltListener extends ModelBuilderListener