Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Stop cross-validation early if the parameters have high predicted test loss #915

Merged
merged 13 commits into from
Jan 10, 2020
2 changes: 2 additions & 0 deletions docs/CHANGELOG.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ is no longer decreasing. (See {ml-pull}875[#875].)
* Improve performance updating quantile estimates. (See {ml-pull}881[#881].)
* Migrate to use Bayesian Optimisation for initial hyperparameter value line searches and
stop early if the expected improvement is too small. (See {ml-pull}903[#903].)
* Stop cross-validation early if the predicted test loss has a small chance of being
smaller than for the best parameter values found so far. (See {ml-pull}915[#915].)

=== Bug Fixes
* Fixes potential memory corruption when determining seasonality. (See {ml-pull}852[#852].)
Expand Down
1 change: 1 addition & 0 deletions include/api/CDataFrameTrainBoostedTreeRunner.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class API_EXPORT CDataFrameTrainBoostedTreeRunner : public CDataFrameAnalysisRun
static const std::string MAXIMUM_NUMBER_TREES;
static const std::string FEATURE_BAG_FRACTION;
static const std::string NUMBER_FOLDS;
static const std::string STOP_CROSS_VALIDATION_EARLY;
static const std::string NUMBER_ROUNDS_PER_HYPERPARAMETER;
static const std::string BAYESIAN_OPTIMISATION_RESTARTS;
static const std::string TOP_FEATURE_IMPORTANCE_VALUES;
Expand Down
9 changes: 7 additions & 2 deletions include/maths/CBoostedTreeFactory.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,10 @@ class MATHS_EXPORT CBoostedTreeFactory final {
CBoostedTreeFactory& minimumFrequencyToOneHotEncode(double frequency);
//! Set the number of folds to use for estimating the generalisation error.
CBoostedTreeFactory& numberFolds(std::size_t numberFolds);
//! Stratify the cross validation we do for regression.
//! Stratify the cross-validation we do for regression.
CBoostedTreeFactory& stratifyRegressionCrossValidation(bool stratify);
//! Stop cross-validation early if the test loss is not promising.
CBoostedTreeFactory& stopCrossValidationEarly(bool stopEarly);
//! The number of rows per feature to sample in the initial downsample.
CBoostedTreeFactory& initialDownsampleRowsPerFeature(double rowsPerFeature);
//! Set the sum of leaf depth penalties multiplier.
Expand Down Expand Up @@ -133,6 +135,9 @@ class MATHS_EXPORT CBoostedTreeFactory final {
//! Compute the row masks for the missing values for each feature.
void initializeMissingFeatureMasks(const core::CDataFrame& frame) const;

//! Set up the number of folds we'll use for cross-validation.
void initializeNumberFolds(core::CDataFrame& frame) const;

//! Set up cross validation.
void initializeCrossValidation(core::CDataFrame& frame) const;

Expand Down Expand Up @@ -187,7 +192,7 @@ class MATHS_EXPORT CBoostedTreeFactory final {
void resumeRestoredTrainingProgressMonitoring();

//! The maximum number of trees to use in the hyperparameter optimisation loop.
std::size_t mainLoopMaximumNumberTrees() const;
std::size_t mainLoopMaximumNumberTrees(double eta) const;

static void noopRecordProgress(double);
static void noopRecordMemoryUsage(std::int64_t);
Expand Down
27 changes: 23 additions & 4 deletions include/maths/CBoostedTreeImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,11 @@ class MATHS_EXPORT CBoostedTreeImpl final {
using TMeanAccumulator = CBasicStatistics::SSampleMean<double>::TAccumulator;
using TMeanVarAccumulator = CBasicStatistics::SSampleMeanVar<double>::TAccumulator;
using TMeanVarAccumulatorSizePr = std::pair<TMeanVarAccumulator, std::size_t>;
using TMeanVarAccumulatorVec = std::vector<TMeanVarAccumulator>;
using TBayesinOptimizationUPtr = std::unique_ptr<maths::CBayesianOptimisation>;
using TNodeVec = CBoostedTree::TNodeVec;
using TNodeVecVec = CBoostedTree::TNodeVecVec;
using TLossFunctionUPtr = CBoostedTree::TLossFunctionUPtr;
using TProgressCallback = CBoostedTree::TProgressCallback;
using TMemoryUsageCallback = CBoostedTree::TMemoryUsageCallback;
using TTrainingStateCallback = CBoostedTree::TTrainingStateCallback;
Expand All @@ -68,7 +70,7 @@ class MATHS_EXPORT CBoostedTreeImpl final {
static const double MINIMUM_RELATIVE_GAIN_PER_SPLIT;

public:
CBoostedTreeImpl(std::size_t numberThreads, CBoostedTree::TLossFunctionUPtr loss);
CBoostedTreeImpl(std::size_t numberThreads, TLossFunctionUPtr loss);

~CBoostedTreeImpl();

Expand Down Expand Up @@ -152,6 +154,8 @@ class MATHS_EXPORT CBoostedTreeImpl final {
private:
using TSizeDoublePr = std::pair<std::size_t, double>;
using TDoubleDoublePr = std::pair<double, double>;
using TOptionalDoubleVec = std::vector<TOptionalDouble>;
using TOptionalDoubleVecVec = std::vector<TOptionalDoubleVec>;
using TOptionalSize = boost::optional<std::size_t>;
using TImmutableRadixSetVec = std::vector<core::CImmutableRadixSet<double>>;
using TVector = CDenseVector<double>;
Expand Down Expand Up @@ -416,6 +420,9 @@ class MATHS_EXPORT CBoostedTreeImpl final {
TDoubleDoublePr gainAndCurvatureAtPercentile(double percentile,
const TNodeVecVec& forest) const;

//! Presize the collection to hold the per fold test errors.
void initializePerFoldTestLosses();

//! Train the forest and compute loss moments on each fold.
TMeanVarAccumulatorSizePr crossValidateForest(core::CDataFrame& frame,
const TMemoryUsageCallback& recordMemoryUsage);
Expand Down Expand Up @@ -447,6 +454,16 @@ class MATHS_EXPORT CBoostedTreeImpl final {
const std::size_t maximumTreeSize,
const TMemoryUsageCallback& recordMemoryUsage) const;

//! Compute the minimum mean test loss per fold for any round.
double minimumTestLoss() const;

//! Estimate the loss we'll get including the missing folds.
TMeanVarAccumulator correctTestLossMoments(const TSizeVec& missing,
TMeanVarAccumulator lossMoments) const;

//! Estimate test losses for the \p missing folds.
TMeanVarAccumulatorVec estimateMissingTestLosses(const TSizeVec& missing) const;

//! Get the number of features including category encoding.
std::size_t numberFeatures() const;

Expand Down Expand Up @@ -503,8 +520,7 @@ class MATHS_EXPORT CBoostedTreeImpl final {
std::size_t maximumTreeSize(std::size_t numberRows) const;

//! Restore \p loss function pointer from the \p traverser.
static bool restoreLoss(CBoostedTree::TLossFunctionUPtr& loss,
core::CStateRestoreTraverser& traverser);
static bool restoreLoss(TLossFunctionUPtr& loss, core::CStateRestoreTraverser& traverser);

//! Record the training state using the \p recordTrainState callback function
void recordState(const TTrainingStateCallback& recordTrainState) const;
Expand All @@ -513,10 +529,12 @@ class MATHS_EXPORT CBoostedTreeImpl final {
mutable CPRNG::CXorOShiro128Plus m_Rng;
std::size_t m_NumberThreads;
std::size_t m_DependentVariable = std::numeric_limits<std::size_t>::max();
CBoostedTree::TLossFunctionUPtr m_Loss;
TLossFunctionUPtr m_Loss;
bool m_StopCrossValidationEarly = true;
TRegularizationOverride m_RegularizationOverride;
TOptionalDouble m_DownsampleFactorOverride;
TOptionalDouble m_EtaOverride;
TOptionalSize m_NumberFoldsOverride;
TOptionalSize m_MaximumNumberTreesOverride;
TOptionalDouble m_FeatureBagFractionOverride;
TRegularization m_Regularization;
Expand All @@ -537,6 +555,7 @@ class MATHS_EXPORT CBoostedTreeImpl final {
TPackedBitVectorVec m_TrainingRowMasks;
TPackedBitVectorVec m_TestingRowMasks;
double m_BestForestTestLoss = INF;
TOptionalDoubleVecVec m_FoldRoundTestLosses;
CBoostedTreeHyperparameters m_BestHyperparameters;
TNodeVecVec m_BestForest;
TBayesinOptimizationUPtr m_BayesianOptimization;
Expand Down
14 changes: 6 additions & 8 deletions include/maths/CTreeShapFeatureImportance.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,7 @@ class MATHS_EXPORT CTreeShapFeatureImportance {
struct SPath {
explicit SPath(std::size_t length)
: s_FractionOnes(length), s_FractionZeros(length),
s_FeatureIndex(length, -1), s_Scale(length), s_NextIndex(0),
s_MaxLength(length) {}
s_FeatureIndex(length, -1), s_Scale(length), s_MaxLength(length) {}

void extend(int featureIndex, double fractionZero, double fractionOne) {
if (s_NextIndex < s_MaxLength) {
Expand All @@ -81,7 +80,7 @@ class MATHS_EXPORT CTreeShapFeatureImportance {
}

void reduce(std::size_t pathIndex) {
for (std::size_t i = pathIndex; i < this->depth(); ++i) {
for (int i = static_cast<int>(pathIndex); i < this->depth(); ++i) {
s_FeatureIndex[i] = s_FeatureIndex[i + 1];
s_FractionZeros[i] = s_FractionZeros[i + 1];
s_FractionOnes[i] = s_FractionOnes[i + 1];
Expand All @@ -107,10 +106,10 @@ class MATHS_EXPORT CTreeShapFeatureImportance {
double scale(std::size_t pathIndex) const { return s_Scale[pathIndex]; }

//! Current depth in the tree
int depth() const { return static_cast<int>(s_NextIndex) - 1; };
int depth() const { return static_cast<int>(s_NextIndex) - 1; }

//! Get next index.
size_t nextIndex() const { return s_NextIndex; }
std::size_t nextIndex() const { return s_NextIndex; }

//! Set next index.
void nextIndex(std::size_t nextIndex) { s_NextIndex = nextIndex; }
Expand All @@ -119,9 +118,8 @@ class MATHS_EXPORT CTreeShapFeatureImportance {
TDoubleVec s_FractionZeros;
TIntVec s_FeatureIndex;
TDoubleVec s_Scale;
std::size_t s_NextIndex;

std::size_t s_MaxLength;
std::size_t s_NextIndex = 0;
std::size_t s_MaxLength = 0;
};

private:
Expand Down
6 changes: 5 additions & 1 deletion lib/api/CDataFrameTrainBoostedTreeRunner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ const CDataFrameAnalysisConfigReader& CDataFrameTrainBoostedTreeRunner::paramete
theReader.addParameter(FEATURE_BAG_FRACTION,
CDataFrameAnalysisConfigReader::E_OptionalParameter);
theReader.addParameter(NUMBER_FOLDS, CDataFrameAnalysisConfigReader::E_OptionalParameter);
theReader.addParameter(STOP_CROSS_VALIDATION_EARLY,
CDataFrameAnalysisConfigReader::E_OptionalParameter);
theReader.addParameter(NUMBER_ROUNDS_PER_HYPERPARAMETER,
CDataFrameAnalysisConfigReader::E_OptionalParameter);
theReader.addParameter(BAYESIAN_OPTIMISATION_RESTARTS,
Expand Down Expand Up @@ -82,6 +84,7 @@ CDataFrameTrainBoostedTreeRunner::CDataFrameTrainBoostedTreeRunner(
parameters[NUMBER_ROUNDS_PER_HYPERPARAMETER].fallback(std::size_t{0})};
std::size_t bayesianOptimisationRestarts{
parameters[BAYESIAN_OPTIMISATION_RESTARTS].fallback(std::size_t{0})};
bool stopCrossValidationEarly{parameters[STOP_CROSS_VALIDATION_EARLY].fallback(true)};
std::size_t topFeatureImportanceValues{
parameters[TOP_FEATURE_IMPORTANCE_VALUES].fallback(std::size_t{0})};

Expand Down Expand Up @@ -120,6 +123,7 @@ CDataFrameTrainBoostedTreeRunner::CDataFrameTrainBoostedTreeRunner(
maths::CBoostedTreeFactory::constructFromParameters(this->spec().numberThreads()));

(*m_BoostedTreeFactory)
.stopCrossValidationEarly(stopCrossValidationEarly)
.progressCallback(this->progressRecorder())
.trainingStateCallback(this->statePersister())
.memoryUsageCallback(this->memoryMonitor(counter_t::E_DFTPMPeakMemoryUsage));
Expand Down Expand Up @@ -309,10 +313,10 @@ const std::string CDataFrameTrainBoostedTreeRunner::SOFT_TREE_DEPTH_TOLERANCE{"s
const std::string CDataFrameTrainBoostedTreeRunner::MAXIMUM_NUMBER_TREES{"maximum_number_trees"};
const std::string CDataFrameTrainBoostedTreeRunner::FEATURE_BAG_FRACTION{"feature_bag_fraction"};
const std::string CDataFrameTrainBoostedTreeRunner::NUMBER_FOLDS{"number_folds"};
const std::string CDataFrameTrainBoostedTreeRunner::STOP_CROSS_VALIDATION_EARLY{"stop_cross_validation_early"};
const std::string CDataFrameTrainBoostedTreeRunner::NUMBER_ROUNDS_PER_HYPERPARAMETER{"number_rounds_per_hyperparameter"};
const std::string CDataFrameTrainBoostedTreeRunner::BAYESIAN_OPTIMISATION_RESTARTS{"bayesian_optimisation_restarts"};
const std::string CDataFrameTrainBoostedTreeRunner::TOP_FEATURE_IMPORTANCE_VALUES{"top_feature_importance_values"};

// clang-format on
}
}
37 changes: 22 additions & 15 deletions lib/api/unittest/CDataFrameAnalyzerFeatureImportanceTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,13 @@ BOOST_AUTO_TEST_SUITE(CDataFrameAnalyzerFeatureImportanceTest)
using namespace ml;

namespace {
using TBoolVec = std::vector<bool>;
using TSizeVec = std::vector<std::size_t>;
using TRowItr = core::CDataFrame::TRowItr;
using TRowRef = core::CDataFrame::TRowRef;
using TDataFrameUPtr = std::unique_ptr<core::CDataFrame>;
using TDoubleVec = std::vector<double>;
using TStrVec = std::vector<std::string>;
using TMeanVarAccumulator = ml::maths::CBasicStatistics::SSampleMeanVar<double>::TAccumulator;
using TRowItr = core::CDataFrame::TRowItr;
using TRowRef = core::CDataFrame::TRowRef;
using TMeanAccumulator = maths::CBasicStatistics::SSampleMean<double>::TAccumulator;
using TMeanAccumulatorVec = std::vector<TMeanAccumulator>;
using TMeanVarAccumulator = maths::CBasicStatistics::SSampleMeanVar<double>::TAccumulator;

void setupLinearRegressionData(const TStrVec& fieldNames,
TStrVec& fieldValues,
Expand Down Expand Up @@ -228,18 +227,19 @@ BOOST_FIXTURE_TEST_CASE(testRunBoostedTreeRegressionFeatureImportanceAllShap, SF
// randomly on [-10, 10].
BOOST_TEST_REQUIRE(c1Sum > c3Sum);
BOOST_TEST_REQUIRE(c1Sum > c4Sum);
BOOST_REQUIRE_CLOSE(weights[1] / weights[2], c2Sum / c3Sum, 5.0); // ratio within 5% of ratio of coefficients
BOOST_REQUIRE_CLOSE(weights[1] / weights[2], c2Sum / c3Sum, 10.0); // ratio within 10% of ratio of coefficients
BOOST_REQUIRE_CLOSE(c3Sum, c4Sum, 5.0); // c3 and c4 within 5% of each other
// make sure the local approximation differs from the prediction always by the same bias (up to a numeric error)
BOOST_REQUIRE_SMALL(ml::maths::CBasicStatistics::variance(bias), 1e-6);
BOOST_REQUIRE_SMALL(maths::CBasicStatistics::variance(bias), 1e-6);
}

BOOST_FIXTURE_TEST_CASE(testRunBoostedTreeRegressionFeatureImportanceNoImportance, SFixture) {
// Test that feature importance calculates low SHAP values if regressors have no weight.
// We also add high noise variance.
std::size_t topShapValues{4};
auto results{runRegression(topShapValues, {10.0, 0.0, 0.0, 0.0}, 10.0)};
auto results = runRegression(topShapValues, {10.0, 0.0, 0.0, 0.0}, 10.0);

TMeanAccumulator c2Mean, c3Mean, c4Mean;
for (const auto& result : results.GetArray()) {
if (result.HasMember("row_results")) {
double c1{result["row_results"]["results"]["ml"][maths::CDataFrameRegressionModel::SHAP_PREFIX + "c1"]
Expand All @@ -252,13 +252,20 @@ BOOST_FIXTURE_TEST_CASE(testRunBoostedTreeRegressionFeatureImportanceNoImportanc
.GetDouble()};
double prediction{
result["row_results"]["results"]["ml"]["target_prediction"].GetDouble()};
// c1 explain 97% of the prediction value, i.e. the difference from the prediction is less than 1%.
BOOST_REQUIRE_CLOSE(c1, prediction, 3.0);
BOOST_REQUIRE_SMALL(c2, 0.25);
BOOST_REQUIRE_SMALL(c3, 0.25);
BOOST_REQUIRE_SMALL(c4, 0.25);
// c1 explains 95% of the prediction value.
BOOST_REQUIRE_CLOSE(c1, prediction, 5.0);
BOOST_REQUIRE_SMALL(c2, 2.0);
BOOST_REQUIRE_SMALL(c3, 2.0);
BOOST_REQUIRE_SMALL(c4, 2.0);
c2Mean.add(std::fabs(c2));
c3Mean.add(std::fabs(c3));
c4Mean.add(std::fabs(c4));
}
}

BOOST_REQUIRE_SMALL(maths::CBasicStatistics::mean(c2Mean), 0.1);
BOOST_REQUIRE_SMALL(maths::CBasicStatistics::mean(c3Mean), 0.1);
BOOST_REQUIRE_SMALL(maths::CBasicStatistics::mean(c4Mean), 0.1);
}

BOOST_FIXTURE_TEST_CASE(testRunBoostedTreeClassificationFeatureImportanceAllShap, SFixture) {
Expand Down Expand Up @@ -314,7 +321,7 @@ BOOST_FIXTURE_TEST_CASE(testRunBoostedTreeClassificationFeatureImportanceAllShap
BOOST_TEST_REQUIRE(c1Sum > c4Sum);
BOOST_REQUIRE_CLOSE(c3Sum, c4Sum, 40.0); // c3 and c4 within 40% of each other
// make sure the local approximation differs from the prediction always by the same bias (up to a numeric error)
BOOST_REQUIRE_SMALL(ml::maths::CBasicStatistics::variance(bias), 1e-6);
BOOST_REQUIRE_SMALL(maths::CBasicStatistics::variance(bias), 1e-6);
}

BOOST_FIXTURE_TEST_CASE(testRunBoostedTreeRegressionFeatureImportanceNoShap, SFixture) {
Expand Down