Skip to content

Commit

Permalink
Save training metrics into the model metadata (best_score_, evals_res…
Browse files Browse the repository at this point in the history
…ult_, best_iteration_ model attributes now work after model save/load).. Fix #1166
  • Loading branch information
andrey-khropov committed Feb 9, 2024
1 parent a325582 commit 1df1f72
Show file tree
Hide file tree
Showing 30 changed files with 166 additions and 18 deletions.
3 changes: 2 additions & 1 deletion catboost/cuda/train_lib/train.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,8 @@ namespace NCatboostCuda {
.WithCoreModelFrom(
modelPtr)
.WithObjectsDataFrom(trainingData.Learn->ObjectsData)
.WithFeatureEstimators(trainingData.FeatureEstimators);
.WithFeatureEstimators(trainingData.FeatureEstimators)
.WithMetrics(*metricsAndTimeHistory);

if (dstModel) {
coreModelToFullModelConverter.Do(true, dstModel, localExecutor, &targetClassifiers);
Expand Down
1 change: 1 addition & 0 deletions catboost/libs/loggers/CMakeLists.darwin-arm64.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ add_library(catboost-libs-loggers)
target_link_libraries(catboost-libs-loggers PUBLIC
contrib-libs-cxxsupp
yutil
catboost-libs-helpers
catboost-libs-logging
private-libs-options
catboost-libs-metrics
Expand Down
1 change: 1 addition & 0 deletions catboost/libs/loggers/CMakeLists.darwin-x86_64.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ add_library(catboost-libs-loggers)
target_link_libraries(catboost-libs-loggers PUBLIC
contrib-libs-cxxsupp
yutil
catboost-libs-helpers
catboost-libs-logging
private-libs-options
catboost-libs-metrics
Expand Down
1 change: 1 addition & 0 deletions catboost/libs/loggers/CMakeLists.linux-aarch64-cuda.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ target_link_libraries(catboost-libs-loggers PUBLIC
contrib-libs-linux-headers
contrib-libs-cxxsupp
yutil
catboost-libs-helpers
catboost-libs-logging
private-libs-options
catboost-libs-metrics
Expand Down
1 change: 1 addition & 0 deletions catboost/libs/loggers/CMakeLists.linux-aarch64.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ target_link_libraries(catboost-libs-loggers PUBLIC
contrib-libs-linux-headers
contrib-libs-cxxsupp
yutil
catboost-libs-helpers
catboost-libs-logging
private-libs-options
catboost-libs-metrics
Expand Down
1 change: 1 addition & 0 deletions catboost/libs/loggers/CMakeLists.linux-ppc64le-cuda.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ target_link_libraries(catboost-libs-loggers PUBLIC
contrib-libs-linux-headers
contrib-libs-cxxsupp
yutil
catboost-libs-helpers
catboost-libs-logging
private-libs-options
catboost-libs-metrics
Expand Down
1 change: 1 addition & 0 deletions catboost/libs/loggers/CMakeLists.linux-ppc64le.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ target_link_libraries(catboost-libs-loggers PUBLIC
contrib-libs-linux-headers
contrib-libs-cxxsupp
yutil
catboost-libs-helpers
catboost-libs-logging
private-libs-options
catboost-libs-metrics
Expand Down
1 change: 1 addition & 0 deletions catboost/libs/loggers/CMakeLists.linux-x86_64-cuda.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ target_link_libraries(catboost-libs-loggers PUBLIC
contrib-libs-linux-headers
contrib-libs-cxxsupp
yutil
catboost-libs-helpers
catboost-libs-logging
private-libs-options
catboost-libs-metrics
Expand Down
1 change: 1 addition & 0 deletions catboost/libs/loggers/CMakeLists.linux-x86_64.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ target_link_libraries(catboost-libs-loggers PUBLIC
contrib-libs-linux-headers
contrib-libs-cxxsupp
yutil
catboost-libs-helpers
catboost-libs-logging
private-libs-options
catboost-libs-metrics
Expand Down
1 change: 1 addition & 0 deletions catboost/libs/loggers/CMakeLists.windows-x86_64-cuda.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ add_library(catboost-libs-loggers)
target_link_libraries(catboost-libs-loggers PUBLIC
contrib-libs-cxxsupp
yutil
catboost-libs-helpers
catboost-libs-logging
private-libs-options
catboost-libs-metrics
Expand Down
1 change: 1 addition & 0 deletions catboost/libs/loggers/CMakeLists.windows-x86_64.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ add_library(catboost-libs-loggers)
target_link_libraries(catboost-libs-loggers PUBLIC
contrib-libs-cxxsupp
yutil
catboost-libs-helpers
catboost-libs-logging
private-libs-options
catboost-libs-metrics
Expand Down
47 changes: 47 additions & 0 deletions catboost/libs/loggers/catboost_logger_helpers.cpp
Original file line number Diff line number Diff line change
@@ -1,13 +1,60 @@
#include "catboost_logger_helpers.h"
#include "logger.h"

#include <catboost/libs/helpers/json_helpers.h>

#include <type_traits>


TTimeInfo::TTimeInfo(const TProfileResults& profileResults)
: IterationTime(profileResults.CurrentTime)
, PassedTime(profileResults.PassedTime)
, RemainingTime(profileResults.RemainingTime)
{
}

NJson::TJsonValue TMetricsAndTimeLeftHistory::SaveMetrics() const {
auto saveToJson = [&](const auto& field) {
NJson::TJsonValue dst;
TJsonFieldHelper<std::remove_const_t<std::remove_reference_t<decltype(field)>>>::Write(field, &dst);
return dst;
};

NJson::TJsonValue result(NJson::JSON_MAP);
result["learn_metrics_history"] = saveToJson(LearnMetricsHistory);
result["test_metrics_history"] = saveToJson(TestMetricsHistory);
if (BestIteration) {
result["best_iteration"] = saveToJson(*BestIteration);
}
result["learn_best_error"] = saveToJson(LearnBestError);
result["test_best_error"] = saveToJson(TestBestError);

return result;
}

TMetricsAndTimeLeftHistory TMetricsAndTimeLeftHistory::LoadMetrics(const NJson::TJsonValue& rhs) {
const auto& rhsMap = rhs.GetMap();

auto loadFromJson = [&] (TStringBuf name, auto* field) {
TJsonFieldHelper<std::remove_reference_t<decltype(*field)>>::Read(
rhsMap.at(name),
field
);
};

TMetricsAndTimeLeftHistory result;
loadFromJson("learn_metrics_history", &result.LearnMetricsHistory);
loadFromJson("test_metrics_history", &result.TestMetricsHistory);
if (rhsMap.contains("best_iteration")) {
result.BestIteration = rhsMap.at("best_iteration").GetUIntegerSafe();
}
loadFromJson("learn_best_error", &result.LearnBestError);
loadFromJson("test_best_error", &result.TestBestError);

return result;
}


void TMetricsAndTimeLeftHistory::TryUpdateBestError(const IMetric& metric, double error, THashMap<TString, double>& bestError, bool updateBestIteration) {
TString metricDescription = metric.GetDescription();
bool shouldUpdate = false;
Expand Down
6 changes: 6 additions & 0 deletions catboost/libs/loggers/catboost_logger_helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#include <catboost/private/libs/options/enums.h>
#include <catboost/libs/metrics/metric.h>

#include <library/cpp/json/json_value.h>

#include <util/generic/maybe.h>

class TLogger;
Expand Down Expand Up @@ -32,6 +34,10 @@ struct TMetricsAndTimeLeftHistory {

Y_SAVELOAD_DEFINE(LearnMetricsHistory, TestMetricsHistory, TimeHistory, BestIteration, LearnBestError, TestBestError);

// Serialization for model metadata without TimeHistory
NJson::TJsonValue SaveMetrics() const;
static TMetricsAndTimeLeftHistory LoadMetrics(const NJson::TJsonValue& rhs);

void AddLearnError(const IMetric& metric, double error);
void AddTestError(size_t testIdx, const IMetric& metric, double error, bool updateBestIteration);

Expand Down
2 changes: 2 additions & 0 deletions catboost/libs/train_lib/train_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -717,6 +717,8 @@ static void SaveModel(
trainingDataForCpu.Learn->ObjectsData
).WithFeatureEstimators(
trainingDataForCpu.FeatureEstimators
).WithMetrics(
ctx.LearnProgress->MetricsAndTimeHistory
);

const TVector<TTargetClassifier>* targetClassifiers = &ctx.CtrsHelper.GetTargetClassifiers();
Expand Down
16 changes: 16 additions & 0 deletions catboost/private/libs/algo/full_model_saver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,13 @@ namespace NCB {
return *this;
}

TCoreModelToFullModelConverter& TCoreModelToFullModelConverter::WithMetrics(
const TMetricsAndTimeLeftHistory& metrics
) {
MetricsAndTimeHistory = &metrics;
return *this;
}

void TCoreModelToFullModelConverter::Do(
bool requiresStaticCtrProvider,
TFullModel* dstModel,
Expand Down Expand Up @@ -561,6 +568,15 @@ namespace NCB {
dstModel->ModelInfo["class_params"] = ClassificationTargetHelper.Serialize();
}

{
NJson::TJsonValue trainingJson(NJson::EJsonValueType::JSON_MAP);
if (MetricsAndTimeHistory) {
trainingJson["metrics"] = MetricsAndTimeHistory->SaveMetrics();
}
dstModel->ModelInfo["training"] = WriteTJsonValue(trainingJson);
}


if (
FinalFeatureCalcerComputationMode == EFinalFeatureCalcersComputationMode::Default &&
!dstModel->ModelTrees->GetEstimatedFeatures().empty()
Expand Down
5 changes: 5 additions & 0 deletions catboost/private/libs/algo/full_model_saver.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "projection.h"

#include <catboost/libs/data/data_provider.h>
#include <catboost/libs/loggers/catboost_logger_helpers.h>
#include <catboost/libs/model/fwd.h>
#include <catboost/libs/model/ctr_data.h>
#include <catboost/libs/model/online_ctr.h>
Expand All @@ -24,6 +25,7 @@


struct TDatasetDataForFinalCtrs;
struct TMetricsAndTimeLeftHistory;

namespace NCatboostOptions {
class TCatBoostOptions;
Expand Down Expand Up @@ -79,6 +81,8 @@ namespace NCB {
TFeatureEstimatorsPtr featureEstimators
);

TCoreModelToFullModelConverter& WithMetrics(const TMetricsAndTimeLeftHistory& metrics);

void Do(
bool requiresStaticCtrProvider,
TFullModel* dstModel,
Expand Down Expand Up @@ -127,6 +131,7 @@ namespace NCB {

TFullModel* CoreModel = nullptr;
const NCB::TPerfectHashedToHashedCatValuesMap* PerfectHashedToHashedCatValuesMap = nullptr;
const TMetricsAndTimeLeftHistory* MetricsAndTimeHistory = nullptr;
TFeatureEstimatorsPtr FeatureEstimators = nullptr;

TGetBinarizedDataFunc GetBinarizedDataFunc;
Expand Down
32 changes: 32 additions & 0 deletions catboost/pytest/lib/common_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
'permute_dataset_columns',
'remove_time_from_json',
'test_output_path',
'compare_with_limited_precision'
]

try:
Expand Down Expand Up @@ -425,3 +426,34 @@ def get_limited_precision_numpy_diff_tool(rtol=None, atol=None):
if atol is not None:
diff_tool += ['--atol', str(atol)]
return diff_tool


# arguments can be JSON-like simple data structures
def compare_with_limited_precision(lhs, rhs, rtol=1e-6, atol=1e-8):
if isinstance(lhs, dict):
if not isinstance(rhs, dict):
return False
if len(lhs) != len(rhs):
return False
for k in lhs.keys():
if k not in rhs:
return False
if not compare_with_limited_precision(lhs[k], rhs[k], rtol, atol):
return False
return True
elif isinstance(lhs, list):
if not isinstance(rhs, list):
return False
if len(lhs) != len(rhs):
return False
return all((compare_with_limited_precision(lhs[i], rhs[i], rtol, atol) for i in range(len(lhs))))
elif isinstance(lhs, np.ndarray):
if not isinstance(rhs, np.ndarray):
return False
return np.allclose(lhs, rhs, rtol=rtol, atol=atol, equal_nan=True)
elif isinstance(lhs, (float, np.floating)):
if not isinstance(rhs, (float, np.floating)):
return False
return abs(lhs - rhs) <= atol + rtol * abs(rhs)
else:
return lhs == rhs
4 changes: 4 additions & 0 deletions catboost/python-package/catboost/_catboost.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1032,6 +1032,7 @@ cdef extern from "catboost/python-package/catboost/helpers.h":
const TVector[TString]& sampleIdsVector,
TVector[TArrayRef[float]]* numFeaturesColumns
) except +ProcessException
TMetricsAndTimeLeftHistory GetTrainingMetrics(const TFullModel& model) except +ProcessException


cdef extern from "catboost/python-package/catboost/helpers.h":
Expand Down Expand Up @@ -5198,13 +5199,15 @@ cdef class _CatBoost:
tmp_model.Load(wrapper.Get())
self.model_blob = None
self.__model.Swap(tmp_model)
self.__metrics_history = GetTrainingMetrics(self.__model[0])

cpdef _load_model(self, model_file, format):
cdef TFullModel tmp_model
cdef EModelType modelType = string_to_model_type(format)
tmp_model = ReadModel(to_arcadia_string(fspath(model_file)), modelType)
self.model_blob = None
self.__model.Swap(tmp_model)
self.__metrics_history = GetTrainingMetrics(self.__model[0])

cpdef _save_model(self, output_file, format, export_parameters, _PoolBase pool):
cdef EModelType modelType = string_to_model_type(format)
Expand Down Expand Up @@ -5238,6 +5241,7 @@ cdef class _CatBoost:
self.model_blob = serialized_model_str
cdef TFullModel tmp_model = ReadZeroCopyModel(<char*>serialized_model_str, len(serialized_model_str))
self.__model.Swap(tmp_model)
self.__metrics_history = GetTrainingMetrics(self.__model[0])

cpdef _get_params(self):
try:
Expand Down
15 changes: 15 additions & 0 deletions catboost/python-package/catboost/helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
#include <catboost/private/libs/options/split_params.h>
#include <catboost/private/libs/target/data_providers.h>

#include <library/cpp/json/json_reader.h>

#include <util/system/guard.h>
#include <util/system/info.h>
#include <util/system/mutex.h>
Expand Down Expand Up @@ -377,3 +379,16 @@ void GetNumFeatureValuesSample(
Copy(values.begin(), values.end(), dst.begin());
}
}

TMetricsAndTimeLeftHistory GetTrainingMetrics(const TFullModel& model) {
if (model.ModelInfo.contains("training"sv)) {
NJson::TJsonValue trainingJson;
ReadJsonTree(model.ModelInfo.at("training"sv), &trainingJson, /*throwOnError*/ true);
const auto& trainingMap = trainingJson.GetMap();
if (trainingMap.contains("metrics"sv)) {
return TMetricsAndTimeLeftHistory::LoadMetrics(trainingMap.at("metrics"sv));
}
}

return TMetricsAndTimeLeftHistory();
}
4 changes: 4 additions & 0 deletions catboost/python-package/catboost/helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <catboost/private/libs/options/loss_description.h>
#include <catboost/private/libs/options/plain_options_helper.h>
#include <catboost/private/libs/target/data_providers.h>
#include <catboost/libs/loggers/catboost_logger_helpers.h>
#include <catboost/libs/train_lib/options_helper.h>

#include <library/cpp/json/json_value.h>
Expand Down Expand Up @@ -377,3 +378,6 @@ void GetNumFeatureValuesSample(
const TVector<TString>& sampleIdsVector,
TVector<TArrayRef<float>>* numFeaturesColumns
);


TMetricsAndTimeLeftHistory GetTrainingMetrics(const TFullModel& model);
Binary file not shown.

0 comments on commit 1df1f72

Please sign in to comment.