Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/CHANGELOG.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@
* Upgrade PyTorch to version 1.11. (See {ml-pull}2233[#2233], {ml-pull}2235[#2235]
and {ml-pull}2238[#2238].)

=== Bug Fixes

* Correct logic for restart from failover fine tuning hyperparameters for training
classification and regression models. (See {ml-pull}2251[#2251].)

== {es} version 8.2.0

=== Enhancements
Expand Down
7 changes: 6 additions & 1 deletion include/maths/analytics/CBoostedTreeHyperparameters.h
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,7 @@ class MATHS_ANALYTICS_EXPORT CBoostedTreeHyperparameters {
static const std::string SOFT_TREE_DEPTH_LIMIT_TAG;
static const std::string SOFT_TREE_DEPTH_TOLERANCE_TAG;
static const std::string STOP_HYPERPARAMETER_OPTIMIZATION_EARLY_TAG;
static const std::string STOPPED_HYPERPARAMETER_OPTIMIZATION_EARLY_TAG;
static const std::string TRAIN_FRACTION_PER_FOLD_TAG;
static const std::string TREE_SIZE_PENALTY_MULTIPLIER_TAG;

Expand Down Expand Up @@ -580,7 +581,10 @@ class MATHS_ANALYTICS_EXPORT CBoostedTreeHyperparameters {
void startSearch();

//! Check if the search for the best hyperparameter values has finished.
bool searchNotFinished() const { return m_CurrentRound < m_NumberRounds; }
bool searchNotFinished() const {
return m_StoppedHyperparameterOptimizationEarly == false &&
m_CurrentRound < m_NumberRounds;
}

//! Start a new round of hyperparameter search.
void startNextSearchRound() { ++m_CurrentRound; }
Expand Down Expand Up @@ -701,6 +705,7 @@ class MATHS_ANALYTICS_EXPORT CBoostedTreeHyperparameters {
//@ \name Hyperparameter Optimisation
//@{
bool m_StopHyperparameterOptimizationEarly{true};
bool m_StoppedHyperparameterOptimizationEarly{false};
bool m_ScalingDisabled{false};
std::size_t m_MaximumOptimisationRoundsPerHyperparameter{2};
TOptionalSize m_BayesianOptimisationRestarts;
Expand Down
7 changes: 7 additions & 0 deletions lib/maths/analytics/CBoostedTreeHyperparameters.cc
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,7 @@ bool CBoostedTreeHyperparameters::selectNext(const TMeanVarAccumulator& testLoss
parameters = minBoundary + parameters.cwiseProduct(maxBoundary - minBoundary);
} else if (m_StopHyperparameterOptimizationEarly &&
m_BayesianOptimization->anovaTotalCoefficientOfVariation() < 1e-3) {
m_StoppedHyperparameterOptimizationEarly = true;
return false;
} else {
std::tie(parameters, std::ignore) =
Expand Down Expand Up @@ -626,6 +627,8 @@ void CBoostedTreeHyperparameters::acceptPersistInserter(core::CStatePersistInser
m_SoftTreeDepthTolerance, inserter);
core::CPersistUtils::persist(STOP_HYPERPARAMETER_OPTIMIZATION_EARLY_TAG,
m_StopHyperparameterOptimizationEarly, inserter);
core::CPersistUtils::persist(STOPPED_HYPERPARAMETER_OPTIMIZATION_EARLY_TAG,
m_StoppedHyperparameterOptimizationEarly, inserter);
core::CPersistUtils::persist(TREE_SIZE_PENALTY_MULTIPLIER_TAG,
m_TreeSizePenaltyMultiplier, inserter);
// m_TunableHyperparameters is not persisted explicitly, it is re-generated
Expand Down Expand Up @@ -683,6 +686,9 @@ bool CBoostedTreeHyperparameters::acceptRestoreTraverser(core::CStateRestoreTrav
RESTORE(STOP_HYPERPARAMETER_OPTIMIZATION_EARLY_TAG,
core::CPersistUtils::restore(STOP_HYPERPARAMETER_OPTIMIZATION_EARLY_TAG,
m_StopHyperparameterOptimizationEarly, traverser))
RESTORE(STOPPED_HYPERPARAMETER_OPTIMIZATION_EARLY_TAG,
core::CPersistUtils::restore(STOPPED_HYPERPARAMETER_OPTIMIZATION_EARLY_TAG,
m_StoppedHyperparameterOptimizationEarly, traverser))
RESTORE(TREE_SIZE_PENALTY_MULTIPLIER_TAG,
core::CPersistUtils::restore(TREE_SIZE_PENALTY_MULTIPLIER_TAG,
m_TreeSizePenaltyMultiplier, traverser))
Expand Down Expand Up @@ -858,6 +864,7 @@ const std::string CBoostedTreeHyperparameters::NUMBER_ROUNDS_TAG{"number_rounds"
const std::string CBoostedTreeHyperparameters::SOFT_TREE_DEPTH_LIMIT_TAG{"soft_tree_depth_limit"};
const std::string CBoostedTreeHyperparameters::SOFT_TREE_DEPTH_TOLERANCE_TAG{"soft_tree_depth_tolerance"};
const std::string CBoostedTreeHyperparameters::STOP_HYPERPARAMETER_OPTIMIZATION_EARLY_TAG{"stop_hyperparameter_optimization_early"};
const std::string CBoostedTreeHyperparameters::STOPPED_HYPERPARAMETER_OPTIMIZATION_EARLY_TAG{"stopped_hyperparameter_optimization_early"};
const std::string CBoostedTreeHyperparameters::TREE_SIZE_PENALTY_MULTIPLIER_TAG{"tree_size_penalty_multiplier"};
// clang-format on
}
Expand Down
14 changes: 7 additions & 7 deletions lib/maths/analytics/CBoostedTreeImpl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -204,11 +204,10 @@ void CBoostedTreeImpl::train(core::CDataFrame& frame,

this->checkTrainInvariants(frame);

if (m_Loss->isRegression()) {
m_Instrumentation->type(CDataFrameTrainBoostedTreeInstrumentationInterface::E_Regression);
} else {
m_Instrumentation->type(CDataFrameTrainBoostedTreeInstrumentationInterface::E_Classification);
}
m_Instrumentation->type(
m_Loss->isRegression()
? CDataFrameTrainBoostedTreeInstrumentationInterface::E_Regression
: CDataFrameTrainBoostedTreeInstrumentationInterface::E_Classification);

LOG_TRACE(<< "Main training loop...");

Expand Down Expand Up @@ -308,11 +307,10 @@ void CBoostedTreeImpl::train(core::CDataFrame& frame,
m_Hyperparameters.restoreBest();
m_Hyperparameters.recordHyperparameters(*m_Instrumentation);
m_Hyperparameters.captureScale();
this->startProgressMonitoringFinalTrain();
this->scaleRegularizationMultipliers(this->allTrainingRowsMask().manhattan() /
this->meanNumberTrainingRowsPerFold());

this->startProgressMonitoringFinalTrain();

// Reinitialize random number generator for reproducible results.
m_Rng.seed(m_Seed);

Expand All @@ -321,6 +319,8 @@ void CBoostedTreeImpl::train(core::CDataFrame& frame,
.s_Forest;

this->recordState(recordTrainStateCallback);
} else {
this->skipProgressMonitoringFinalTrain();
}
m_Instrumentation->iteration(m_Hyperparameters.currentRound());
m_Instrumentation->flush(TRAIN_FINAL_FOREST);
Expand Down