Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Callback Functionality #1697

Merged
merged 15 commits into from
May 18, 2021
2 changes: 2 additions & 0 deletions catboost/R-package/src/catboostr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,7 @@ EXPORT_FUNCTION CatBoostFit_R(SEXP learnPoolParam, SEXP testPoolParam, SEXP fitP
nullptr,
Nothing(),
Nothing(),
Nothing(),
pools,
/*initModel*/ Nothing(),
/*initLearnProgress*/ nullptr,
Expand All @@ -524,6 +525,7 @@ EXPORT_FUNCTION CatBoostFit_R(SEXP learnPoolParam, SEXP testPoolParam, SEXP fitP
nullptr,
Nothing(),
Nothing(),
Nothing(),
pools,
/*initModel*/ Nothing(),
/*initLearnProgress*/ nullptr,
Expand Down
1 change: 1 addition & 0 deletions catboost/cuda/train_lib/train.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ namespace NCatboostCuda {
TMaybe<NCB::TPrecomputedOnlineCtrData> precomputedSingleOnlineCtrDataForSingleFold,
const TLabelConverter& labelConverter,
ITrainingCallbacks* trainingCallbacks,
ICustomCallbacks* /*customCallbacks*/,
TMaybe<TFullModel*> initModel,
THolder<TLearnProgress> initLearnProgress,
NCB::TDataProviders initModelApplyCompatiblePools,
Expand Down
1 change: 1 addition & 0 deletions catboost/cuda/train_lib/ut/train_ut.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ Y_UNIT_TEST_SUITE(TrainModelTests) {
nullptr,
{},
{},
Nothing(),
std::move(dataProviders),
/*initModel*/ Nothing(),
/*initLearnProgress*/ nullptr,
Expand Down
1 change: 1 addition & 0 deletions catboost/libs/calc_metrics/ut/calc_metrics_ut.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ void BuildApproxAsPoolFileAndGetMetric(
nullptr,
Nothing(),
Nothing(),
Nothing(),
TDataProviders{pool, {pool}},
/*initModel*/ Nothing(),
/*initLearnProgress*/ nullptr,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,7 @@ namespace NCB {
TVector<TEvalResult> evalResults(trainingData.Test.ysize());
THolder<IModelTrainer> modelTrainerHolder(TTrainerFactory::Construct(catBoostOptions.GetTaskType()));
TRestorableFastRng64 rnd(catBoostOptions.RandomSeed);
const auto defaultCustomCallbacks = MakeHolder<TCustomCallbacks>(Nothing());
modelTrainerHolder->TrainModel(
TTrainModelInternalOptions(),
catBoostOptions,
Expand All @@ -321,6 +322,7 @@ namespace NCB {
/*precomputedSingleOnlineCtrDataForSingleFold*/ Nothing(),
labelConverter,
callbacks,
defaultCustomCallbacks.Get(),
/*initModel*/ Nothing(),
/*initLearnProgress*/ nullptr,
TDataProviders(),
Expand Down
1 change: 1 addition & 0 deletions catboost/libs/model/ut/leaf_weights_ut.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ static TFullModel TrainModelOnPool(TDataProviderPtr pool, ETargetDimMode multicl
nullptr,
Nothing(),
Nothing(),
Nothing(),
TDataProviders{pool, {pool}},
/*initModel*/ Nothing(),
/*initLearnProgress*/ nullptr,
Expand Down
2 changes: 2 additions & 0 deletions catboost/libs/model/ut/lib/model_test_helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ TFullModel TrainFloatCatboostModel(int iterations, int seed) {
nullptr,
Nothing(),
Nothing(),
Nothing(),
std::move(dataProviders),
/*initModel*/ Nothing(),
/*initLearnProgress*/ nullptr,
Expand Down Expand Up @@ -417,6 +418,7 @@ TFullModel DefaultTrainCatOnlyModel(const NJson::TJsonValue& params) {
nullptr,
{},
{},
Nothing(),
std::move(dataProviders),
/*initModel*/ Nothing(),
/*initLearnProgress*/ nullptr,
Expand Down
1 change: 1 addition & 0 deletions catboost/libs/model/ut/model_serialization_ut.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ Y_UNIT_TEST_SUITE(TModelSerialization) {
nullptr,
Nothing(),
Nothing(),
Nothing(),
dataProviders,
Nothing(),
&learnProgress,
Expand Down
1 change: 1 addition & 0 deletions catboost/libs/model/ut/model_summ_ut.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ static void AssertModelSumEqualSliced(TDataProviderPtr dataProvider, bool change
nullptr,
Nothing(),
Nothing(),
Nothing(),
dataProviders,
Nothing(),
&learnProgress,
Expand Down
2 changes: 2 additions & 0 deletions catboost/libs/model/ut/shrink_model_ut.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ Y_UNIT_TEST_SUITE(TShrinkModel) {
nullptr,
Nothing(),
Nothing(),
Nothing(),
TDataProviders{pool, {pool}},
/*initModel*/ Nothing(),
/*initLearnProgress*/ nullptr,
Expand All @@ -36,6 +37,7 @@ Y_UNIT_TEST_SUITE(TShrinkModel) {
nullptr,
Nothing(),
Nothing(),
Nothing(),
TDataProviders{pool, {pool}},
/*initModel*/ Nothing(),
/*initLearnProgress*/ nullptr,
Expand Down
1 change: 1 addition & 0 deletions catboost/libs/train_interface/catboost_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ CATBOOST_API bool TrainCatBoost(const TDataSet* trainPtr,
quantizedFeaturesInfo,
objectiveDescriptor,
evalMetricDescriptor,
Nothing(),
std::move(dataProviders),
/*initModel*/ Nothing(),
/*initLearnProgress*/ nullptr,
Expand Down
4 changes: 4 additions & 0 deletions catboost/libs/train_lib/cross_validation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,7 @@ void TrainBatch(
upToIteration,
foldContext);

const auto defaultCustomCallbacks = MakeHolder<TCustomCallbacks>(Nothing());
auto foldOutputOptions = foldContext->OutputOptions;
auto trainDir = foldOutputOptions.GetTrainDir();
if (foldContext->FullModel.Defined()) {
Expand All @@ -337,6 +338,7 @@ void TrainBatch(
/*precomputedSingleOnlineCtrDataForSingleFold*/ Nothing(),
labelConverter,
cvCallbacks.Get(),
defaultCustomCallbacks.Get(),
/*initModel*/ Nothing(),
std::move(foldContext->LearnProgress),
/*initModelApplyCompatiblePools*/ TDataProviders(),
Expand Down Expand Up @@ -406,6 +408,7 @@ void Train(
foldOutputOptions.ResultModelPath = NCatboostOptions::TOption<TString>("result_model_file", "model");
}
TMetricsAndTimeLeftHistory metricsAndTimeHistory;
const auto defaultCustomCallbacks = MakeHolder<TCustomCallbacks>(Nothing());
modelTrainer->TrainModel(
internalOptions,
catboostOption,
Expand All @@ -416,6 +419,7 @@ void Train(
/*precomputedSingleOnlineCtrDataForSingleFold*/ Nothing(),
labelConverter,
trainingCallbacks,
defaultCustomCallbacks.Get(),
/*initModel*/ Nothing(),
THolder<TLearnProgress>(),
/*initModelApplyCompatiblePools*/ TDataProviders(),
Expand Down
22 changes: 20 additions & 2 deletions catboost/libs/train_lib/train_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,7 @@ static void Train(
const TTrainModelInternalOptions& internalOptions,
const TTrainingDataProviders& data,
ITrainingCallbacks* trainingCallbacks,
ICustomCallbacks* customCallbacks,
TLearnContext* ctx,
TVector<TVector<TVector<double>>>* testMultiApprox // [test][dim][docIdx]
) {
Expand Down Expand Up @@ -455,7 +456,8 @@ static void Train(
break;
}

continueTraining = trainingCallbacks->IsContinueTraining(ctx->LearnProgress->MetricsAndTimeHistory);
continueTraining = trainingCallbacks->IsContinueTraining(ctx->LearnProgress->MetricsAndTimeHistory)
&& customCallbacks->AfterIteration(ctx->LearnProgress->MetricsAndTimeHistory);
}

ctx->SaveProgress(onSaveSnapshotCallback);
Expand Down Expand Up @@ -710,6 +712,15 @@ static void SaveModel(
}
}

TCustomCallbacks::TCustomCallbacks(TMaybe<TCustomCallbackDescriptor> callbackDescriptor)
: CallbackDescriptor(std::move(callbackDescriptor)) {
}
bool TCustomCallbacks::AfterIteration(const TMetricsAndTimeLeftHistory &history) {
if (!CallbackDescriptor.Empty()) {
return CallbackDescriptor->AfterIterationFunc(history, CallbackDescriptor->CustomData);
}
return true;
}

namespace {
class TCPUModelTrainer : public IModelTrainer {
Expand All @@ -724,6 +735,7 @@ namespace {
TMaybe<TPrecomputedOnlineCtrData> precomputedSingleOnlineCtrDataForSingleFold,
const TLabelConverter& labelConverter,
ITrainingCallbacks* trainingCallbacks,
ICustomCallbacks* customCallbacks,
TMaybe<TFullModel*> initModel,
THolder<TLearnProgress> initLearnProgress,
TDataProviders initModelApplyCompatiblePools,
Expand Down Expand Up @@ -837,7 +849,7 @@ namespace {
TVector<TVector<double>> oneRawValues(ctx.LearnProgress->ApproxDimension);
TVector<TVector<TVector<double>>> rawValues(trainingData.Test.size(), oneRawValues);

Train(internalOptions, trainingData, trainingCallbacks, &ctx, &rawValues);
Train(internalOptions, trainingData, trainingCallbacks, customCallbacks, &ctx, &rawValues);

if (!dstLearnProgress) {
// Save memory as it is no longer needed
Expand Down Expand Up @@ -886,6 +898,7 @@ static void TrainModel(
TQuantizedFeaturesInfoPtr quantizedFeaturesInfo,
const TMaybe<TCustomObjectiveDescriptor>& objectiveDescriptor,
const TMaybe<TCustomMetricDescriptor>& evalMetricDescriptor,
const TMaybe<TCustomCallbackDescriptor>& callbackDescriptor,
TDataProviders pools,

// can be non-empty only if there is single fold
Expand Down Expand Up @@ -1051,6 +1064,7 @@ static void TrainModel(
}

const auto defaultTrainingCallbacks = MakeHolder<ITrainingCallbacks>();
const auto customCallbacks = MakeHolder<TCustomCallbacks>(callbackDescriptor);
TTrainModelInternalOptions trainModelInternalOptions;
trainModelInternalOptions.HaveLearnFeatureInMemory = haveLearnFeaturesInMemory;
modelTrainerHolder->TrainModel(
Expand All @@ -1063,6 +1077,7 @@ static void TrainModel(
std::move(precomputedSingleOnlineCtrDataForSingleFold),
labelConverter,
defaultTrainingCallbacks.Get(),
customCallbacks.Get(),
std::move(initModel),
std::move(initLearnProgress),
needInitModelApplyCompatiblePools ? std::move(pools) : TDataProviders(),
Expand Down Expand Up @@ -1221,6 +1236,7 @@ void TrainModel(
quantizedFeaturesInfo,
/*objectiveDescriptor*/ Nothing(),
/*evalMetricDescriptor*/ Nothing(),
/*callbackDescriptor*/ Nothing(),
needPoolAfterTrain ? pools : std::move(pools),
std::move(precomputedSingleOnlineCtrDataForSingleFold),
/*initModel*/ Nothing(),
Expand Down Expand Up @@ -1493,6 +1509,7 @@ void TrainModel(
NCB::TQuantizedFeaturesInfoPtr quantizedFeaturesInfo, // can be nullptr
const TMaybe<TCustomObjectiveDescriptor>& objectiveDescriptor,
const TMaybe<TCustomMetricDescriptor>& evalMetricDescriptor,
const TMaybe<TCustomCallbackDescriptor>& callbackDescriptor,
NCB::TDataProviders pools, // not rvalue reference because Cython does not support them
TMaybe<TFullModel*> initModel,
THolder<TLearnProgress>* initLearnProgress,
Expand All @@ -1519,6 +1536,7 @@ void TrainModel(
quantizedFeaturesInfo,
objectiveDescriptor,
evalMetricDescriptor,
callbackDescriptor,
std::move(pools),
/*precomputedSingleOnlineCtrDataForSingleFold*/ Nothing(),
std::move(initModel),
Expand Down
22 changes: 22 additions & 0 deletions catboost/libs/train_lib/train_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,13 @@ struct TTrainModelInternalOptions {
bool HaveLearnFeatureInMemory = true;
};

struct TCustomCallbackDescriptor {
using TAfterIteration = bool (*)(const TMetricsAndTimeLeftHistory& history, void* customData);

void* CustomData = nullptr;
TAfterIteration AfterIterationFunc = nullptr;
};

class ITrainingCallbacks {
public:
virtual bool IsContinueTraining(const TMetricsAndTimeLeftHistory& /*history*/) {
Expand All @@ -73,6 +80,19 @@ class ITrainingCallbacks {
virtual ~ITrainingCallbacks() = default;
};

class ICustomCallbacks {
public:
virtual bool AfterIteration(const TMetricsAndTimeLeftHistory& /*history*/) = 0;
virtual ~ICustomCallbacks() = default;
};

class TCustomCallbacks : public ICustomCallbacks {
public:
explicit TCustomCallbacks(TMaybe<TCustomCallbackDescriptor> callbackDescriptor);
bool AfterIteration(const TMetricsAndTimeLeftHistory& history) override;
private:
const TMaybe<TCustomCallbackDescriptor> CallbackDescriptor;
};

class IModelTrainer {
public:
Expand All @@ -88,6 +108,7 @@ class IModelTrainer {
TMaybe<NCB::TPrecomputedOnlineCtrData> precomputedSingleOnlineCtrDataForSingleFold,
const TLabelConverter& labelConverter,
ITrainingCallbacks* trainingCallbacks,
ICustomCallbacks* customCallbacks,
TMaybe<TFullModel*> initModel,
THolder<TLearnProgress> initLearnProgress, // can be nullptr, can be modified if non-nullptr

Expand Down Expand Up @@ -132,6 +153,7 @@ void TrainModel(
NCB::TQuantizedFeaturesInfoPtr quantizedFeaturesInfo, // can be nullptr
const TMaybe<TCustomObjectiveDescriptor>& objectiveDescriptor,
const TMaybe<TCustomMetricDescriptor>& evalMetricDescriptor,
const TMaybe<TCustomCallbackDescriptor>& callbackDescriptor,
NCB::TDataProviders pools, // not rvalue reference because Cython does not support them
TMaybe<TFullModel*> initModel,

Expand Down
3 changes: 3 additions & 0 deletions catboost/libs/train_lib/ut/train_model_ut.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ Y_UNIT_TEST_SUITE(TrainModelTests) {
nullptr,
{},
{},
Nothing(),
std::move(dataProviders),
/*initModel*/ Nothing(),
/*initLearnProgress*/ nullptr,
Expand Down Expand Up @@ -160,6 +161,7 @@ Y_UNIT_TEST_SUITE(TrainModelTests) {
nullptr,
{},
{},
Nothing(),
std::move(dataProviders),
/*initModel*/ Nothing(),
/*initLearnProgress*/ nullptr,
Expand Down Expand Up @@ -238,6 +240,7 @@ Y_UNIT_TEST_SUITE(TrainModelTests) {
nullptr,
{},
{},
Nothing(),
std::move(dataProviders),
/*initModel*/ Nothing(),
/*initLearnProgress*/ nullptr,
Expand Down
2 changes: 2 additions & 0 deletions catboost/private/libs/algo/ut/train_ut.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ Y_UNIT_TEST_SUITE(TTrainTest) {
nullptr,
Nothing(),
Nothing(),
Nothing(),
dataProviders,
/*initModel*/ Nothing(),
/*initLearnProgress*/ nullptr,
Expand All @@ -112,6 +113,7 @@ Y_UNIT_TEST_SUITE(TTrainTest) {
nullptr,
Nothing(),
Nothing(),
Nothing(),
dataProviders,
/*initModel*/ Nothing(),
/*initLearnProgress*/ nullptr,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1032,6 +1032,7 @@ namespace {
internalOptions.OffsetMetricPeriodByInitModelSize = true;
outputFileOptions.SetAllowWriteFiles(false);
const auto defaultTrainingCallbacks = MakeHolder<ITrainingCallbacks>();
const auto defaultCustomCallbacks = MakeHolder<TCustomCallbacks>(Nothing());
// Training model
modelTrainerHolder->TrainModel(
internalOptions,
Expand All @@ -1043,6 +1044,7 @@ namespace {
/*precomputedSingleOnlineCtrDataForSingleFold*/ Nothing(),
labelConverter,
defaultTrainingCallbacks.Get(), // TODO(ilikepugs): MLTOOLS-3540
defaultCustomCallbacks.Get(),
/*initModel*/ Nothing(),
/*initLearnProgress*/ nullptr,
/*initModelApplyCompatiblePools*/ NCB::TDataProviders(),
Expand Down
1 change: 1 addition & 0 deletions catboost/private/libs/options/check_train_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ namespace NJson {

struct TCustomMetricDescriptor;
struct TCustomObjectiveDescriptor;
struct TCustomCallbackDescriptor;

void CheckFitParams(const NJson::TJsonValue& plainOptions,
const TCustomObjectiveDescriptor* objectiveDescriptor = nullptr,
Expand Down
4 changes: 4 additions & 0 deletions catboost/private/libs/options/plain_options_helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,10 @@ void NCatboostOptions::PlainJsonToOptions(
seenKeys.insert("eval_metric");
}

if (plainOptions.Has("callbacks")) {
seenKeys.insert("callbacks");
}

if (plainOptions.Has("custom_metric") || plainOptions.Has("custom_loss")) {
const NJson::TJsonValue& metrics = plainOptions.Has("custom_metric") ? plainOptions["custom_metric"] : plainOptions["custom_loss"];
if (metrics.IsArray()) {
Expand Down
Loading