diff --git a/docs/CHANGELOG.asciidoc b/docs/CHANGELOG.asciidoc index 66c0b43f50..c0812d01c4 100644 --- a/docs/CHANGELOG.asciidoc +++ b/docs/CHANGELOG.asciidoc @@ -35,6 +35,11 @@ * Upgrade PyTorch to version 1.11. (See {ml-pull}2233[#2233], {ml-pull}2235[#2235] and {ml-pull}2238[#2238].) +=== Bug Fixes + +* Correct logic for restart from failover fine tuning hyperparameters for training + classification and regression models. (See {ml-pull}2251[#2251].) + == {es} version 8.2.0 === Enhancements diff --git a/include/maths/analytics/CBoostedTreeHyperparameters.h b/include/maths/analytics/CBoostedTreeHyperparameters.h index afa7aebc7d..b70bed2d49 100644 --- a/include/maths/analytics/CBoostedTreeHyperparameters.h +++ b/include/maths/analytics/CBoostedTreeHyperparameters.h @@ -440,6 +440,7 @@ class MATHS_ANALYTICS_EXPORT CBoostedTreeHyperparameters { static const std::string SOFT_TREE_DEPTH_LIMIT_TAG; static const std::string SOFT_TREE_DEPTH_TOLERANCE_TAG; static const std::string STOP_HYPERPARAMETER_OPTIMIZATION_EARLY_TAG; + static const std::string STOPPED_HYPERPARAMETER_OPTIMIZATION_EARLY_TAG; static const std::string TRAIN_FRACTION_PER_FOLD_TAG; static const std::string TREE_SIZE_PENALTY_MULTIPLIER_TAG; @@ -580,7 +581,10 @@ class MATHS_ANALYTICS_EXPORT CBoostedTreeHyperparameters { void startSearch(); //! Check if the search for the best hyperparameter values has finished. - bool searchNotFinished() const { return m_CurrentRound < m_NumberRounds; } + bool searchNotFinished() const { + return m_StoppedHyperparameterOptimizationEarly == false && + m_CurrentRound < m_NumberRounds; + } //! Start a new round of hyperparameter search. void startNextSearchRound() { ++m_CurrentRound; } @@ -701,6 +705,7 @@ class MATHS_ANALYTICS_EXPORT CBoostedTreeHyperparameters { //@ \name Hyperparameter Optimisation //@{ bool m_StopHyperparameterOptimizationEarly{true}; + bool m_StoppedHyperparameterOptimizationEarly{false}; bool m_ScalingDisabled{false}; std::size_t m_MaximumOptimisationRoundsPerHyperparameter{2}; TOptionalSize m_BayesianOptimisationRestarts; diff --git a/lib/maths/analytics/CBoostedTreeHyperparameters.cc b/lib/maths/analytics/CBoostedTreeHyperparameters.cc index a1bccaf44d..0559a73970 100644 --- a/lib/maths/analytics/CBoostedTreeHyperparameters.cc +++ b/lib/maths/analytics/CBoostedTreeHyperparameters.cc @@ -389,6 +389,7 @@ bool CBoostedTreeHyperparameters::selectNext(const TMeanVarAccumulator& testLoss parameters = minBoundary + parameters.cwiseProduct(maxBoundary - minBoundary); } else if (m_StopHyperparameterOptimizationEarly && m_BayesianOptimization->anovaTotalCoefficientOfVariation() < 1e-3) { + m_StoppedHyperparameterOptimizationEarly = true; return false; } else { std::tie(parameters, std::ignore) = @@ -626,6 +627,8 @@ void CBoostedTreeHyperparameters::acceptPersistInserter(core::CStatePersistInser m_SoftTreeDepthTolerance, inserter); core::CPersistUtils::persist(STOP_HYPERPARAMETER_OPTIMIZATION_EARLY_TAG, m_StopHyperparameterOptimizationEarly, inserter); + core::CPersistUtils::persist(STOPPED_HYPERPARAMETER_OPTIMIZATION_EARLY_TAG, + m_StoppedHyperparameterOptimizationEarly, inserter); core::CPersistUtils::persist(TREE_SIZE_PENALTY_MULTIPLIER_TAG, m_TreeSizePenaltyMultiplier, inserter); // m_TunableHyperparameters is not persisted explicitly, it is re-generated @@ -683,6 +686,9 @@ bool CBoostedTreeHyperparameters::acceptRestoreTraverser(core::CStateRestoreTrav RESTORE(STOP_HYPERPARAMETER_OPTIMIZATION_EARLY_TAG, core::CPersistUtils::restore(STOP_HYPERPARAMETER_OPTIMIZATION_EARLY_TAG, m_StopHyperparameterOptimizationEarly, traverser)) + RESTORE(STOPPED_HYPERPARAMETER_OPTIMIZATION_EARLY_TAG, + core::CPersistUtils::restore(STOPPED_HYPERPARAMETER_OPTIMIZATION_EARLY_TAG, + m_StoppedHyperparameterOptimizationEarly, traverser)) RESTORE(TREE_SIZE_PENALTY_MULTIPLIER_TAG, core::CPersistUtils::restore(TREE_SIZE_PENALTY_MULTIPLIER_TAG, m_TreeSizePenaltyMultiplier, traverser)) @@ -858,6 +864,7 @@ const std::string CBoostedTreeHyperparameters::NUMBER_ROUNDS_TAG{"number_rounds" const std::string CBoostedTreeHyperparameters::SOFT_TREE_DEPTH_LIMIT_TAG{"soft_tree_depth_limit"}; const std::string CBoostedTreeHyperparameters::SOFT_TREE_DEPTH_TOLERANCE_TAG{"soft_tree_depth_tolerance"}; const std::string CBoostedTreeHyperparameters::STOP_HYPERPARAMETER_OPTIMIZATION_EARLY_TAG{"stop_hyperparameter_optimization_early"}; +const std::string CBoostedTreeHyperparameters::STOPPED_HYPERPARAMETER_OPTIMIZATION_EARLY_TAG{"stopped_hyperparameter_optimization_early"}; const std::string CBoostedTreeHyperparameters::TREE_SIZE_PENALTY_MULTIPLIER_TAG{"tree_size_penalty_multiplier"}; // clang-format on } diff --git a/lib/maths/analytics/CBoostedTreeImpl.cc b/lib/maths/analytics/CBoostedTreeImpl.cc index ae27dfd18f..7af36bff96 100644 --- a/lib/maths/analytics/CBoostedTreeImpl.cc +++ b/lib/maths/analytics/CBoostedTreeImpl.cc @@ -204,11 +204,10 @@ void CBoostedTreeImpl::train(core::CDataFrame& frame, this->checkTrainInvariants(frame); - if (m_Loss->isRegression()) { - m_Instrumentation->type(CDataFrameTrainBoostedTreeInstrumentationInterface::E_Regression); - } else { - m_Instrumentation->type(CDataFrameTrainBoostedTreeInstrumentationInterface::E_Classification); - } + m_Instrumentation->type( + m_Loss->isRegression() + ? CDataFrameTrainBoostedTreeInstrumentationInterface::E_Regression + : CDataFrameTrainBoostedTreeInstrumentationInterface::E_Classification); LOG_TRACE(<< "Main training loop..."); @@ -308,11 +307,10 @@ void CBoostedTreeImpl::train(core::CDataFrame& frame, m_Hyperparameters.restoreBest(); m_Hyperparameters.recordHyperparameters(*m_Instrumentation); m_Hyperparameters.captureScale(); + this->startProgressMonitoringFinalTrain(); this->scaleRegularizationMultipliers(this->allTrainingRowsMask().manhattan() / this->meanNumberTrainingRowsPerFold()); - this->startProgressMonitoringFinalTrain(); - // Reinitialize random number generator for reproducible results. m_Rng.seed(m_Seed); @@ -321,6 +319,8 @@ void CBoostedTreeImpl::train(core::CDataFrame& frame, .s_Forest; this->recordState(recordTrainStateCallback); + } else { + this->skipProgressMonitoringFinalTrain(); } m_Instrumentation->iteration(m_Hyperparameters.currentRound()); m_Instrumentation->flush(TRAIN_FINAL_FOREST);