Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Early stopping in the line searches to compute initial regulariser values #903

Merged
merged 12 commits into from
Dec 17, 2019
4 changes: 2 additions & 2 deletions docs/CHANGELOG.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ is no longer decreasing. (See {ml-pull}875[#875].)
* Emit `prediction_field_name` in ml results using the type provided as
`prediction_field_type` parameter. (See {ml-pull}877[#877].)
* Improve performance updating quantile estimates. (See {ml-pull}881[#881].)
* Migrate to BO for initial hyperparameter value line searches and stop early if the
expected improvement is too small. (See {ml-pull}903[#903].)
* Migrate to use Bayesian Optimisation for initial hyperparameter value line searches and
stop early if the expected improvement is too small. (See {ml-pull}903[#903].)

=== Bug Fixes
* Fixes potential memory corruption when determining seasonality. (See {ml-pull}852[#852].)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ struct SFixture {
test::CDataFrameAnalysisSpecificationFactory::predictionSpec(
"regression", "c5", s_Rows, 5, 8000000, 0, 0, {"c1"}, s_Alpha,
s_Lambda, s_Gamma, s_SoftTreeDepthLimit, s_SoftTreeDepthTolerance,
s_Eta, s_MaximumNumberTrees, s_FeatureBagFraction, s_ShapValues),
s_Eta, s_MaximumNumberTrees, s_FeatureBagFraction, shapValues),
outputWriterFactory};
TStrVec fieldNames{"c1", "c2", "c3", "c4", "c5", ".", "."};
TStrVec fieldValues{"", "", "", "", "", "0", ""};
Expand Down
22 changes: 14 additions & 8 deletions lib/maths/CBoostedTreeFactory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -707,13 +707,19 @@ CBoostedTreeFactory::testLossLineSearch(core::CDataFrame& frame,
double returnedIntervalLeftEndOffset,
double returnedIntervalRightEndOffset) const {

// This uses a quadratic approximation to the test loss function w.r.t.
// the scaled regularization hyperparameter from which it estimates the
// minimum error point in the interval we search here. Separately, it
// examines size of the residual errors w.r.t. to the variation in the
// best fit curve over the interval. We truncate the interval the main
// hyperparameter optimisation loop searches if we determine there is a
// low chance of missing the best solution by doing so.
// This has the following steps:
// 1. Coarse search the interval [intervalLeftEnd, intervalRightEnd] using
// fixed steps,
// 2. Fine tune, via Bayesian Optimisation targeting expected improvement,
// and stop if the expected improvement small compared to the current
// minimum test loss,
// 3. Calculate the parameter interval which gives the lowest test losses,
// 4. Fit an OLS quadratic approximation to the test losses in the interval
// from step 3 and use it to estimate the best parameter value,
// 5. Compare the size of the residual errors w.r.t. to the OLS curve from
// step 4 with its variation over the interval from step 3 and truncate
// the returned interval if we can determine there is a low chance of
// missing the best solution by doing so.
Comment on lines +710 to +722
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice! 👍


using TMeanVarAccumulator = CBasicStatistics::SSampleMeanVar<double>::TAccumulator;
using TMinAccumulator = CBasicStatistics::SMin<double>::TAccumulator;
Expand Down Expand Up @@ -776,7 +782,7 @@ CBoostedTreeFactory::testLossLineSearch(core::CDataFrame& frame,
return TOptionalVector{};
}

// Find the smallest test losses and the corresponding regularizer span.
// Find the smallest test losses and the corresponding regularizer interval.
auto minimumTestLosses = CBasicStatistics::orderStatisticsAccumulator<TDoubleDoublePr>(
minNumberTestLosses - 1, COrderings::SSecondLess{});
minimumTestLosses.add(testLosses);
Expand Down