Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/CHANGELOG.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@
* Adjust the syscall filter to allow mremap and avoid spurious audit logging.
(See {ml-pull}1819[#1819].)

=== Bug Fixes

* Ensure the same hyperparameters are chosen if classification or regression training
is stopped and restarted, for example, if the node fails. (See {ml-pull}1848[#1848].)

== {es} version 7.12.1

=== Enhancements
Expand Down
4 changes: 2 additions & 2 deletions include/maths/CBoostedTreeImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -404,8 +404,8 @@ class MATHS_EXPORT CBoostedTreeImpl final {
std::size_t m_NumberTopShapValues = 0;
TTreeShapFeatureImportanceUPtr m_TreeShap;
TAnalysisInstrumentationPtr m_Instrumentation;
mutable TMeanAccumulator m_ForestSizeAccumulator;
mutable TMeanAccumulator m_MeanLossAccumulator;
TMeanAccumulator m_MeanForestSizeAccumulator;
TMeanAccumulator m_MeanLossAccumulator;
THyperparametersVec m_TunableHyperparameters;
TDoubleVecVec m_HyperparameterSamples;
bool m_StopHyperparameterOptimizationEarly = true;
Expand Down
27 changes: 18 additions & 9 deletions lib/maths/CBoostedTreeImpl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,6 @@ void CBoostedTreeImpl::train(core::CDataFrame& frame,
std::tie(lossMoments, maximumNumberTrees, numberNodes) =
this->crossValidateForest(frame);

m_MeanLossAccumulator.add(CBasicStatistics::mean(lossMoments));

this->captureBestHyperparameters(lossMoments, maximumNumberTrees, numberNodes);

if (this->selectNextHyperparameters(lossMoments, *m_BayesianOptimization) == false) {
Expand Down Expand Up @@ -545,7 +543,7 @@ CBoostedTreeImpl::crossValidateForest(core::CDataFrame& frame) {
TMeanVarAccumulator lossMoments;
TDoubleVec numberTrees;
numberTrees.reserve(m_NumberFolds);
TMeanAccumulator forestSizeAccumulator;
TMeanAccumulator meanForestSizeAccumulator;

while (folds.size() > 0 && stopCrossValidationEarly(lossMoments) == false) {
std::size_t fold{folds.back()};
Expand All @@ -560,7 +558,7 @@ CBoostedTreeImpl::crossValidateForest(core::CDataFrame& frame) {
lossMoments.add(loss);
m_FoldRoundTestLosses[fold][m_CurrentRound] = loss;
numberTrees.push_back(static_cast<double>(forest.size()));
forestSizeAccumulator.add(numberForestNodes(forest));
meanForestSizeAccumulator.add(numberForestNodes(forest));
m_Instrumentation->lossValues(fold, std::move(lossValues));
}
m_TrainingProgress.increment(m_MaximumNumberTrees * folds.size());
Expand All @@ -569,12 +567,15 @@ CBoostedTreeImpl::crossValidateForest(core::CDataFrame& frame) {
std::sort(numberTrees.begin(), numberTrees.end());
std::size_t medianNumberTrees{
static_cast<std::size_t>(CBasicStatistics::median(numberTrees))};
double meanForestSize{CBasicStatistics::mean(forestSizeAccumulator)};
double meanForestSize{CBasicStatistics::mean(meanForestSizeAccumulator)};
lossMoments = this->correctTestLossMoments(std::move(folds), lossMoments);
LOG_TRACE(<< "test mean loss = " << CBasicStatistics::mean(lossMoments)
<< ", sigma = " << std::sqrt(CBasicStatistics::mean(lossMoments))
<< ", mean number nodes in forest = " << meanForestSize);

m_MeanForestSizeAccumulator += meanForestSizeAccumulator;
m_MeanLossAccumulator.add(CBasicStatistics::mean(lossMoments));

return {lossMoments, medianNumberTrees, meanForestSize};
}

Expand Down Expand Up @@ -697,9 +698,6 @@ CBoostedTreeImpl::trainForest(core::CDataFrame& frame,

forest.resize(stoppingCondition.bestSize());

// record forest size as the number of nodes
m_ForestSizeAccumulator.add(numberForestNodes(forest));

LOG_TRACE(<< "Trained one forest");

return {forest, stoppingCondition.bestLoss(), std::move(losses)};
Expand Down Expand Up @@ -1428,7 +1426,7 @@ void CBoostedTreeImpl::captureBestHyperparameters(const TMeanVarAccumulator& los

// Add 0.01 * "forest number nodes" * E[GP] / "average forest number nodes" to meanLoss.
double modelSizeDifferentiator{0.01 * numberNodes /
CBasicStatistics::mean(m_ForestSizeAccumulator) *
CBasicStatistics::mean(m_MeanForestSizeAccumulator) *
CBasicStatistics::mean(m_MeanLossAccumulator)};
double loss{lossAtNSigma(1.0, lossMoments) + modelSizeDifferentiator};
if (loss < m_BestForestTestLoss) {
Expand Down Expand Up @@ -1628,6 +1626,8 @@ const std::string MAXIMUM_NUMBER_TREES_OVERRIDE_TAG{"maximum_number_trees_overri
const std::string MAXIMUM_NUMBER_TREES_TAG{"maximum_number_trees"};
const std::string MAXIMUM_OPTIMISATION_ROUNDS_PER_HYPERPARAMETER_TAG{
"maximum_optimisation_rounds_per_hyperparameter"};
const std::string MEAN_FOREST_SIZE_ACCUMULATOR_TAG{"mean_forest_size"};
const std::string MEAN_LOSS_ACCUMULATOR_TAG{"mean_loss"};
const std::string MISSING_FEATURE_ROW_MASKS_TAG{"missing_feature_row_masks"};
const std::string NUMBER_FOLDS_TAG{"number_folds"};
const std::string NUMBER_FOLDS_OVERRIDE_TAG{"number_folds_override"};
Expand Down Expand Up @@ -1707,6 +1707,9 @@ void CBoostedTreeImpl::acceptPersistInserter(core::CStatePersistInserter& insert
core::CPersistUtils::persist(MAXIMUM_NUMBER_TREES_TAG, m_MaximumNumberTrees, inserter);
core::CPersistUtils::persist(MAXIMUM_NUMBER_TREES_OVERRIDE_TAG,
m_MaximumNumberTreesOverride, inserter);
core::CPersistUtils::persist(MEAN_FOREST_SIZE_ACCUMULATOR_TAG,
m_MeanForestSizeAccumulator, inserter);
core::CPersistUtils::persist(MEAN_LOSS_ACCUMULATOR_TAG, m_MeanLossAccumulator, inserter);
core::CPersistUtils::persist(MISSING_FEATURE_ROW_MASKS_TAG,
m_MissingFeatureRowMasks, inserter);
core::CPersistUtils::persist(NUMBER_FOLDS_TAG, m_NumberFolds, inserter);
Expand Down Expand Up @@ -1823,6 +1826,12 @@ bool CBoostedTreeImpl::acceptRestoreTraverser(core::CStateRestoreTraverser& trav
RESTORE(MAXIMUM_NUMBER_TREES_TAG,
core::CPersistUtils::restore(MAXIMUM_NUMBER_TREES_TAG,
m_MaximumNumberTrees, traverser))
RESTORE(MEAN_FOREST_SIZE_ACCUMULATOR_TAG,
core::CPersistUtils::restore(MEAN_FOREST_SIZE_ACCUMULATOR_TAG,
m_MeanForestSizeAccumulator, traverser))
RESTORE(MEAN_LOSS_ACCUMULATOR_TAG,
core::CPersistUtils::restore(MEAN_LOSS_ACCUMULATOR_TAG,
m_MeanLossAccumulator, traverser))
RESTORE(MISSING_FEATURE_ROW_MASKS_TAG,
core::CPersistUtils::restore(MISSING_FEATURE_ROW_MASKS_TAG,
m_MissingFeatureRowMasks, traverser))
Expand Down