From 1eb3b44fb75bd16925d6f3326aa82edb41f20341 Mon Sep 17 00:00:00 2001 From: Przemyslaw Witek Date: Thu, 5 Dec 2019 10:01:42 +0100 Subject: [PATCH 1/7] Emit predicted category using an appropriate JSON type. --- docs/CHANGELOG.asciidoc | 2 + ...ataFrameTrainBoostedTreeClassifierRunner.h | 5 ++ ...taFrameTrainBoostedTreeClassifierRunner.cc | 22 ++++++- ...ameTrainBoostedTreeClassifierRunnerTest.cc | 57 ++++++++++++++----- 4 files changed, 69 insertions(+), 17 deletions(-) diff --git a/docs/CHANGELOG.asciidoc b/docs/CHANGELOG.asciidoc index 98864fd0ec..4254e6fb21 100644 --- a/docs/CHANGELOG.asciidoc +++ b/docs/CHANGELOG.asciidoc @@ -52,6 +52,8 @@ tree which is trained for both regression and classification. (See {ml-pull}811[ (See {ml-pull}818[#818].) * Reduce memory usage of {ml} native processes on Windows. (See {ml-pull}844[#844].) * Reduce runtime of classification and regression. (See {ml-pull}863[#863].) +* Emit `prediction_field_name` in ml results using the type of a `dependent_variable`. +(See {ml-pull}877[#877].) === Bug Fixes * Fixes potential memory corruption when determining seasonality. (See {ml-pull}852[#852].) diff --git a/include/api/CDataFrameTrainBoostedTreeClassifierRunner.h b/include/api/CDataFrameTrainBoostedTreeClassifierRunner.h index deed15b7cf..3f0964af98 100644 --- a/include/api/CDataFrameTrainBoostedTreeClassifierRunner.h +++ b/include/api/CDataFrameTrainBoostedTreeClassifierRunner.h @@ -44,6 +44,10 @@ class API_EXPORT CDataFrameTrainBoostedTreeClassifierRunner final const TRowRef& row, core::CRapidJsonConcurrentLineWriter& writer) const; + //! Write the predicted category value as string, int or bool. + void writePredictedCategoryValue(const std::string& categoryValue, + core::CRapidJsonConcurrentLineWriter& writer) const; + //! \return A serialisable definition of the trained classification model. TInferenceModelDefinitionUPtr inferenceModelDefinition(const TStrVec& fieldNames, @@ -55,6 +59,7 @@ class API_EXPORT CDataFrameTrainBoostedTreeClassifierRunner final private: std::size_t m_NumTopClasses; + std::string m_DependentVariableType; }; //! \brief Makes a core::CDataFrame boosted tree classification runner. diff --git a/lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc b/lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc index f4609d595b..6541451ca1 100644 --- a/lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc +++ b/lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc @@ -32,6 +32,7 @@ using TSizeVec = std::vector; // Configuration const std::string NUM_TOP_CLASSES{"num_top_classes"}; +const std::string DEPENDENT_VARIABLE_TYPE{"dependent_variable_type"}; const std::string BALANCED_CLASS_LOSS{"balanced_class_loss"}; // Output @@ -47,6 +48,8 @@ CDataFrameTrainBoostedTreeClassifierRunner::parameterReader() { static const CDataFrameAnalysisConfigReader PARAMETER_READER{[] { auto theReader = CDataFrameTrainBoostedTreeRunner::parameterReader(); theReader.addParameter(NUM_TOP_CLASSES, CDataFrameAnalysisConfigReader::E_OptionalParameter); + theReader.addParameter(DEPENDENT_VARIABLE_TYPE, + CDataFrameAnalysisConfigReader::E_OptionalParameter); theReader.addParameter(BALANCED_CLASS_LOSS, CDataFrameAnalysisConfigReader::E_OptionalParameter); return theReader; @@ -60,6 +63,8 @@ CDataFrameTrainBoostedTreeClassifierRunner::CDataFrameTrainBoostedTreeClassifier : CDataFrameTrainBoostedTreeRunner{spec, parameters} { m_NumTopClasses = parameters[NUM_TOP_CLASSES].fallback(std::size_t{0}); + m_DependentVariableType = + parameters[DEPENDENT_VARIABLE_TYPE].fallback(std::string("string")); this->boostedTreeFactory().balanceClassTrainingLoss( parameters[BALANCED_CLASS_LOSS].fallback(true)); @@ -119,7 +124,7 @@ void CDataFrameTrainBoostedTreeClassifierRunner::writeOneRow( writer.StartObject(); writer.Key(this->predictionFieldName()); - writer.String(categoryValues[predictedCategoryId]); + writePredictedCategoryValue(categoryValues[predictedCategoryId], writer); writer.Key(PREDICTION_PROBABILITY_FIELD_NAME); writer.Double(probabilityOfCategory[predictedCategoryId]); writer.Key(IS_TRAINING_FIELD_NAME); @@ -135,7 +140,7 @@ void CDataFrameTrainBoostedTreeClassifierRunner::writeOneRow( for (std::size_t i = 0; i < std::min(categoryIds.size(), m_NumTopClasses); ++i) { writer.StartObject(); writer.Key(CLASS_NAME_FIELD_NAME); - writer.String(categoryValues[categoryIds[i]]); + writePredictedCategoryValue(categoryValues[categoryIds[i]], writer); writer.Key(CLASS_PROBABILITY_FIELD_NAME); writer.Double(probabilityOfCategory[i]); writer.EndObject(); @@ -158,6 +163,19 @@ void CDataFrameTrainBoostedTreeClassifierRunner::writeOneRow( columnHoldingPrediction, row, writer); } +void CDataFrameTrainBoostedTreeClassifierRunner::writePredictedCategoryValue( + const std::string& categoryValue, + core::CRapidJsonConcurrentLineWriter& writer) const { + + if (m_DependentVariableType == "int") { + writer.Int(std::stoi(categoryValue)); + } else if (m_DependentVariableType == "bool") { + writer.Bool(std::stoi(categoryValue) == 1); + } else { + writer.String(categoryValue); + } +} + CDataFrameTrainBoostedTreeClassifierRunner::TLossFunctionUPtr CDataFrameTrainBoostedTreeClassifierRunner::chooseLossFunction(const core::CDataFrame& frame, std::size_t dependentVariableColumn) const { diff --git a/lib/api/unittest/CDataFrameTrainBoostedTreeClassifierRunnerTest.cc b/lib/api/unittest/CDataFrameTrainBoostedTreeClassifierRunnerTest.cc index dc60bf8bbf..241bba3bc5 100644 --- a/lib/api/unittest/CDataFrameTrainBoostedTreeClassifierRunnerTest.cc +++ b/lib/api/unittest/CDataFrameTrainBoostedTreeClassifierRunnerTest.cc @@ -45,15 +45,20 @@ BOOST_AUTO_TEST_CASE(testPredictionFieldNameClash) { BOOST_TEST_REQUIRE(errors[0] == "Input error: prediction_field_name must not be equal to any of [is_training, prediction_probability, top_classes]."); } -BOOST_AUTO_TEST_CASE(testWriteOneRow) { +template +void testWriteOneRow(const std::string& dependentVariableField, + const std::string& dependentVariableType, + T (rapidjson::Value::*extract)() const, + const std::vector& expectedPredictions) { // Prepare input data frame - const TStrVec columnNames{"x1", "x2", "x3", "x4", "x5", "x5_prediction"}; - const TStrVec categoricalColumns{"x1", "x2", "x5"}; + const std::string predictionField = dependentVariableField + "_prediction"; + const TStrVec columnNames{"x1", "x2", "x3", "x4", "x5", predictionField}; + const TStrVec categoricalColumns{"x1", "x2", "x3", "x4", "x5"}; const TStrVecVec rows{{"a", "b", "1.0", "1.0", "cat", "-1.0"}, - {"a", "b", "2.0", "2.0", "cat", "-0.5"}, - {"a", "b", "5.0", "5.0", "dog", "-0.1"}, - {"c", "d", "5.0", "5.0", "dog", "1.0"}, - {"e", "f", "5.0", "5.0", "dog", "1.5"}}; + {"a", "b", "1.0", "1.0", "cat", "-0.5"}, + {"a", "b", "5.0", "0.0", "dog", "-0.1"}, + {"c", "d", "5.0", "0.0", "dog", "1.0"}, + {"e", "f", "5.0", "0.0", "dog", "1.5"}}; std::unique_ptr frame = core::makeMainStorageDataFrame(columnNames.size()).first; frame->columnNames(columnNames); @@ -67,10 +72,13 @@ BOOST_AUTO_TEST_CASE(testWriteOneRow) { // Create classification analysis runner object const auto spec{test::CDataFrameAnalysisSpecificationFactory::predictionSpec( - "classification", "x5", rows.size(), columnNames.size(), 13000000, 0, 0, - categoricalColumns)}; + "classification", dependentVariableField, rows.size(), + columnNames.size(), 13000000, 0, 0, categoricalColumns)}; rapidjson::Document jsonParameters; - jsonParameters.Parse("{\"dependent_variable\": \"x5\"}"); + jsonParameters.Parse("{" + " \"dependent_variable\": \"" + dependentVariableField + "\"," + " \"dependent_variable_type\": \"" + dependentVariableType + "\"" + "}"); const auto parameters{ api::CDataFrameTrainBoostedTreeClassifierRunner::parameterReader().read(jsonParameters)}; api::CDataFrameTrainBoostedTreeClassifierRunner runner(*spec, parameters); @@ -83,10 +91,10 @@ BOOST_AUTO_TEST_CASE(testWriteOneRow) { frame->readRows(1, [&](TRowItr beginRows, TRowItr endRows) { const auto columnHoldingDependentVariable{ - std::find(columnNames.begin(), columnNames.end(), "x5") - + std::find(columnNames.begin(), columnNames.end(), dependentVariableField) - columnNames.begin()}; const auto columnHoldingPrediction{ - std::find(columnNames.begin(), columnNames.end(), "x5_prediction") - + std::find(columnNames.begin(), columnNames.end(), predictionField) - columnNames.begin()}; for (auto row = beginRows; row != endRows; ++row) { runner.writeOneRow(*frame, columnHoldingDependentVariable, @@ -95,17 +103,17 @@ BOOST_AUTO_TEST_CASE(testWriteOneRow) { }); } // Verify results - const TStrVec expectedPredictions{"cat", "cat", "cat", "dog", "dog"}; rapidjson::Document arrayDoc; arrayDoc.Parse(output.str().c_str()); BOOST_TEST_REQUIRE(arrayDoc.IsArray()); BOOST_TEST_REQUIRE(arrayDoc.Size() == rows.size()); + BOOST_TEST_REQUIRE(arrayDoc.Size() == expectedPredictions.size()); for (std::size_t i = 0; i < arrayDoc.Size(); ++i) { BOOST_TEST_CONTEXT("Result for row " << i) { const rapidjson::Value& object = arrayDoc[rapidjson::SizeType(i)]; BOOST_TEST_REQUIRE(object.IsObject()); - BOOST_TEST_REQUIRE(object.HasMember("x5_prediction")); - BOOST_TEST_REQUIRE(object["x5_prediction"].GetString() == + BOOST_TEST_REQUIRE(object.HasMember(predictionField)); + BOOST_TEST_REQUIRE((object[predictionField].*extract)() == expectedPredictions[i]); BOOST_TEST_REQUIRE(object.HasMember("prediction_probability")); BOOST_TEST_REQUIRE(object["prediction_probability"].GetDouble() > 0.5); @@ -115,4 +123,23 @@ BOOST_AUTO_TEST_CASE(testWriteOneRow) { } } +BOOST_AUTO_TEST_CASE(testWriteOneRow_DependentVariableIsInt) { + testWriteOneRow("x3", "int", &rapidjson::Value::GetInt, {1, 1, 1, 5, 5}); +} + +BOOST_AUTO_TEST_CASE(testWriteOneRow_DependentVariableIsBool) { + testWriteOneRow("x4", "bool", &rapidjson::Value::GetBool, + {true, true, true, false, false}); +} + +BOOST_AUTO_TEST_CASE(testWriteOneRow_DependentVariableIsString) { + testWriteOneRow("x5", "string", &rapidjson::Value::GetString, + {"cat", "cat", "cat", "dog", "dog"}); +} + +BOOST_AUTO_TEST_CASE(testWriteOneRow_DependentVariableTypeMissing) { + testWriteOneRow("x5", "", &rapidjson::Value::GetString, + {"cat", "cat", "cat", "dog", "dog"}); +} + BOOST_AUTO_TEST_SUITE_END() From 8acd4c3d9c4a0ecd51267ccba71641e8efaa7421 Mon Sep 17 00:00:00 2001 From: Przemyslaw Witek Date: Thu, 5 Dec 2019 11:59:18 +0100 Subject: [PATCH 2/7] Apply review comments --- ...taFrameTrainBoostedTreeClassifierRunner.cc | 24 +++++++++++++------ ...ameTrainBoostedTreeClassifierRunnerTest.cc | 8 +++---- 2 files changed, 21 insertions(+), 11 deletions(-) diff --git a/lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc b/lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc index 6541451ca1..3aaab55e32 100644 --- a/lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc +++ b/lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc @@ -33,6 +33,9 @@ using TSizeVec = std::vector; // Configuration const std::string NUM_TOP_CLASSES{"num_top_classes"}; const std::string DEPENDENT_VARIABLE_TYPE{"dependent_variable_type"}; +const std::string DEPENDENT_VARIABLE_TYPE_STRING{"string"}; +const std::string DEPENDENT_VARIABLE_TYPE_INT{"int"}; +const std::string DEPENDENT_VARIABLE_TYPE_BOOL{"bool"}; const std::string BALANCED_CLASS_LOSS{"balanced_class_loss"}; // Output @@ -64,7 +67,7 @@ CDataFrameTrainBoostedTreeClassifierRunner::CDataFrameTrainBoostedTreeClassifier m_NumTopClasses = parameters[NUM_TOP_CLASSES].fallback(std::size_t{0}); m_DependentVariableType = - parameters[DEPENDENT_VARIABLE_TYPE].fallback(std::string("string")); + parameters[DEPENDENT_VARIABLE_TYPE].fallback(DEPENDENT_VARIABLE_TYPE_STRING); this->boostedTreeFactory().balanceClassTrainingLoss( parameters[BALANCED_CLASS_LOSS].fallback(true)); @@ -167,13 +170,20 @@ void CDataFrameTrainBoostedTreeClassifierRunner::writePredictedCategoryValue( const std::string& categoryValue, core::CRapidJsonConcurrentLineWriter& writer) const { - if (m_DependentVariableType == "int") { - writer.Int(std::stoi(categoryValue)); - } else if (m_DependentVariableType == "bool") { - writer.Bool(std::stoi(categoryValue) == 1); - } else { - writer.String(categoryValue); + if (m_DependentVariableType == DEPENDENT_VARIABLE_TYPE_INT) { + double doubleValue; + if (core::CStringUtils::stringToType(categoryValue, doubleValue)) { + writer.Int64(static_cast(doubleValue)); + return; + } + } else if (m_DependentVariableType == DEPENDENT_VARIABLE_TYPE_BOOL) { + double doubleValue; + if (core::CStringUtils::stringToType(categoryValue, doubleValue)) { + writer.Bool(static_cast(doubleValue) == 1); + return; + } } + writer.String(categoryValue); } CDataFrameTrainBoostedTreeClassifierRunner::TLossFunctionUPtr diff --git a/lib/api/unittest/CDataFrameTrainBoostedTreeClassifierRunnerTest.cc b/lib/api/unittest/CDataFrameTrainBoostedTreeClassifierRunnerTest.cc index 241bba3bc5..aade937d2b 100644 --- a/lib/api/unittest/CDataFrameTrainBoostedTreeClassifierRunnerTest.cc +++ b/lib/api/unittest/CDataFrameTrainBoostedTreeClassifierRunnerTest.cc @@ -123,21 +123,21 @@ void testWriteOneRow(const std::string& dependentVariableField, } } -BOOST_AUTO_TEST_CASE(testWriteOneRow_DependentVariableIsInt) { +BOOST_AUTO_TEST_CASE(testWriteOneRowDependentVariableIsInt) { testWriteOneRow("x3", "int", &rapidjson::Value::GetInt, {1, 1, 1, 5, 5}); } -BOOST_AUTO_TEST_CASE(testWriteOneRow_DependentVariableIsBool) { +BOOST_AUTO_TEST_CASE(testWriteOneRowDependentVariableIsBool) { testWriteOneRow("x4", "bool", &rapidjson::Value::GetBool, {true, true, true, false, false}); } -BOOST_AUTO_TEST_CASE(testWriteOneRow_DependentVariableIsString) { +BOOST_AUTO_TEST_CASE(testWriteOneRowDependentVariableIsString) { testWriteOneRow("x5", "string", &rapidjson::Value::GetString, {"cat", "cat", "cat", "dog", "dog"}); } -BOOST_AUTO_TEST_CASE(testWriteOneRow_DependentVariableTypeMissing) { +BOOST_AUTO_TEST_CASE(testWriteOneRowDependentVariableTypeIsMissing) { testWriteOneRow("x5", "", &rapidjson::Value::GetString, {"cat", "cat", "cat", "dog", "dog"}); } From 2d085dc066fef059956c92c6a7ad6f3b360bc0a9 Mon Sep 17 00:00:00 2001 From: Przemyslaw Witek Date: Thu, 5 Dec 2019 12:11:12 +0100 Subject: [PATCH 3/7] Rename dependent_variable_type to prediction_field_type as that's how the field is really used in C++ code --- ...ataFrameTrainBoostedTreeClassifierRunner.h | 2 +- ...taFrameTrainBoostedTreeClassifierRunner.cc | 20 +++++++++---------- ...ameTrainBoostedTreeClassifierRunnerTest.cc | 12 +++++------ 3 files changed, 17 insertions(+), 17 deletions(-) diff --git a/include/api/CDataFrameTrainBoostedTreeClassifierRunner.h b/include/api/CDataFrameTrainBoostedTreeClassifierRunner.h index 3f0964af98..ccf0ccccf1 100644 --- a/include/api/CDataFrameTrainBoostedTreeClassifierRunner.h +++ b/include/api/CDataFrameTrainBoostedTreeClassifierRunner.h @@ -59,7 +59,7 @@ class API_EXPORT CDataFrameTrainBoostedTreeClassifierRunner final private: std::size_t m_NumTopClasses; - std::string m_DependentVariableType; + std::string m_PredictionFieldType; }; //! \brief Makes a core::CDataFrame boosted tree classification runner. diff --git a/lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc b/lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc index 3aaab55e32..e60bcf5814 100644 --- a/lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc +++ b/lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc @@ -32,10 +32,10 @@ using TSizeVec = std::vector; // Configuration const std::string NUM_TOP_CLASSES{"num_top_classes"}; -const std::string DEPENDENT_VARIABLE_TYPE{"dependent_variable_type"}; -const std::string DEPENDENT_VARIABLE_TYPE_STRING{"string"}; -const std::string DEPENDENT_VARIABLE_TYPE_INT{"int"}; -const std::string DEPENDENT_VARIABLE_TYPE_BOOL{"bool"}; +const std::string PREDICTION_FIELD_TYPE{"prediction_field_type"}; +const std::string PREDICTION_FIELD_TYPE_STRING{"string"}; +const std::string PREDICTION_FIELD_TYPE_INT{"int"}; +const std::string PREDICTION_FIELD_TYPE_BOOL{"bool"}; const std::string BALANCED_CLASS_LOSS{"balanced_class_loss"}; // Output @@ -51,7 +51,7 @@ CDataFrameTrainBoostedTreeClassifierRunner::parameterReader() { static const CDataFrameAnalysisConfigReader PARAMETER_READER{[] { auto theReader = CDataFrameTrainBoostedTreeRunner::parameterReader(); theReader.addParameter(NUM_TOP_CLASSES, CDataFrameAnalysisConfigReader::E_OptionalParameter); - theReader.addParameter(DEPENDENT_VARIABLE_TYPE, + theReader.addParameter(PREDICTION_FIELD_TYPE, CDataFrameAnalysisConfigReader::E_OptionalParameter); theReader.addParameter(BALANCED_CLASS_LOSS, CDataFrameAnalysisConfigReader::E_OptionalParameter); @@ -66,8 +66,8 @@ CDataFrameTrainBoostedTreeClassifierRunner::CDataFrameTrainBoostedTreeClassifier : CDataFrameTrainBoostedTreeRunner{spec, parameters} { m_NumTopClasses = parameters[NUM_TOP_CLASSES].fallback(std::size_t{0}); - m_DependentVariableType = - parameters[DEPENDENT_VARIABLE_TYPE].fallback(DEPENDENT_VARIABLE_TYPE_STRING); + m_PredictionFieldType = + parameters[PREDICTION_FIELD_TYPE].fallback(PREDICTION_FIELD_TYPE_STRING); this->boostedTreeFactory().balanceClassTrainingLoss( parameters[BALANCED_CLASS_LOSS].fallback(true)); @@ -170,16 +170,16 @@ void CDataFrameTrainBoostedTreeClassifierRunner::writePredictedCategoryValue( const std::string& categoryValue, core::CRapidJsonConcurrentLineWriter& writer) const { - if (m_DependentVariableType == DEPENDENT_VARIABLE_TYPE_INT) { + if (m_PredictionFieldType == PREDICTION_FIELD_TYPE_INT) { double doubleValue; if (core::CStringUtils::stringToType(categoryValue, doubleValue)) { writer.Int64(static_cast(doubleValue)); return; } - } else if (m_DependentVariableType == DEPENDENT_VARIABLE_TYPE_BOOL) { + } else if (m_PredictionFieldType == PREDICTION_FIELD_TYPE_BOOL) { double doubleValue; if (core::CStringUtils::stringToType(categoryValue, doubleValue)) { - writer.Bool(static_cast(doubleValue) == 1); + writer.Bool(static_cast(doubleValue) == 1.0); return; } } diff --git a/lib/api/unittest/CDataFrameTrainBoostedTreeClassifierRunnerTest.cc b/lib/api/unittest/CDataFrameTrainBoostedTreeClassifierRunnerTest.cc index aade937d2b..1cb338d097 100644 --- a/lib/api/unittest/CDataFrameTrainBoostedTreeClassifierRunnerTest.cc +++ b/lib/api/unittest/CDataFrameTrainBoostedTreeClassifierRunnerTest.cc @@ -47,7 +47,7 @@ BOOST_AUTO_TEST_CASE(testPredictionFieldNameClash) { template void testWriteOneRow(const std::string& dependentVariableField, - const std::string& dependentVariableType, + const std::string& predictionFieldType, T (rapidjson::Value::*extract)() const, const std::vector& expectedPredictions) { // Prepare input data frame @@ -77,7 +77,7 @@ void testWriteOneRow(const std::string& dependentVariableField, rapidjson::Document jsonParameters; jsonParameters.Parse("{" " \"dependent_variable\": \"" + dependentVariableField + "\"," - " \"dependent_variable_type\": \"" + dependentVariableType + "\"" + " \"prediction_field_type\": \"" + predictionFieldType + "\"" "}"); const auto parameters{ api::CDataFrameTrainBoostedTreeClassifierRunner::parameterReader().read(jsonParameters)}; @@ -123,21 +123,21 @@ void testWriteOneRow(const std::string& dependentVariableField, } } -BOOST_AUTO_TEST_CASE(testWriteOneRowDependentVariableIsInt) { +BOOST_AUTO_TEST_CASE(testWriteOneRowPredictionFieldTypeIsInt) { testWriteOneRow("x3", "int", &rapidjson::Value::GetInt, {1, 1, 1, 5, 5}); } -BOOST_AUTO_TEST_CASE(testWriteOneRowDependentVariableIsBool) { +BOOST_AUTO_TEST_CASE(testWriteOneRowPredictionFieldTypeIsBool) { testWriteOneRow("x4", "bool", &rapidjson::Value::GetBool, {true, true, true, false, false}); } -BOOST_AUTO_TEST_CASE(testWriteOneRowDependentVariableIsString) { +BOOST_AUTO_TEST_CASE(testWriteOneRowPredictionFieldTypeIsString) { testWriteOneRow("x5", "string", &rapidjson::Value::GetString, {"cat", "cat", "cat", "dog", "dog"}); } -BOOST_AUTO_TEST_CASE(testWriteOneRowDependentVariableTypeIsMissing) { +BOOST_AUTO_TEST_CASE(testWriteOneRowPredictionFieldTypeIsMissing) { testWriteOneRow("x5", "", &rapidjson::Value::GetString, {"cat", "cat", "cat", "dog", "dog"}); } From becdf5847d457bffd482ee88b6b7aef909fa8efe Mon Sep 17 00:00:00 2001 From: Przemyslaw Witek Date: Thu, 5 Dec 2019 13:31:08 +0100 Subject: [PATCH 4/7] Make `m_PredictionFieldType` field of an enum type --- ...ataFrameTrainBoostedTreeClassifierRunner.h | 8 ++++- ...taFrameTrainBoostedTreeClassifierRunner.cc | 33 ++++++++++++------- ...ameTrainBoostedTreeClassifierRunnerTest.cc | 12 ++++--- 3 files changed, 36 insertions(+), 17 deletions(-) diff --git a/include/api/CDataFrameTrainBoostedTreeClassifierRunner.h b/include/api/CDataFrameTrainBoostedTreeClassifierRunner.h index ccf0ccccf1..27bb972305 100644 --- a/include/api/CDataFrameTrainBoostedTreeClassifierRunner.h +++ b/include/api/CDataFrameTrainBoostedTreeClassifierRunner.h @@ -19,6 +19,12 @@ namespace api { class API_EXPORT CDataFrameTrainBoostedTreeClassifierRunner final : public CDataFrameTrainBoostedTreeRunner { public: + enum EPredictionFieldType { + E_PredictionFieldTypeString, + E_PredictionFieldTypeInt, + E_PredictionFieldTypeBool + }; + static const CDataFrameAnalysisConfigReader& parameterReader(); //! This is not intended to be called directly: use CDataFrameTrainBoostedTreeClassifierRunnerFactory. @@ -59,7 +65,7 @@ class API_EXPORT CDataFrameTrainBoostedTreeClassifierRunner final private: std::size_t m_NumTopClasses; - std::string m_PredictionFieldType; + EPredictionFieldType m_PredictionFieldType; }; //! \brief Makes a core::CDataFrame boosted tree classification runner. diff --git a/lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc b/lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc index e60bcf5814..5b84439878 100644 --- a/lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc +++ b/lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc @@ -33,9 +33,6 @@ using TSizeVec = std::vector; // Configuration const std::string NUM_TOP_CLASSES{"num_top_classes"}; const std::string PREDICTION_FIELD_TYPE{"prediction_field_type"}; -const std::string PREDICTION_FIELD_TYPE_STRING{"string"}; -const std::string PREDICTION_FIELD_TYPE_INT{"int"}; -const std::string PREDICTION_FIELD_TYPE_BOOL{"bool"}; const std::string BALANCED_CLASS_LOSS{"balanced_class_loss"}; // Output @@ -49,10 +46,16 @@ const std::string CLASS_PROBABILITY_FIELD_NAME{"class_probability"}; const CDataFrameAnalysisConfigReader& CDataFrameTrainBoostedTreeClassifierRunner::parameterReader() { static const CDataFrameAnalysisConfigReader PARAMETER_READER{[] { + const std::string typeString{"string"}; + const std::string typeInt{"int"}; + const std::string typeBool{"bool"}; auto theReader = CDataFrameTrainBoostedTreeRunner::parameterReader(); theReader.addParameter(NUM_TOP_CLASSES, CDataFrameAnalysisConfigReader::E_OptionalParameter); theReader.addParameter(PREDICTION_FIELD_TYPE, - CDataFrameAnalysisConfigReader::E_OptionalParameter); + CDataFrameAnalysisConfigReader::E_OptionalParameter, + {{typeString, int{E_PredictionFieldTypeString}}, + {typeInt, int{E_PredictionFieldTypeInt}}, + {typeBool, int{E_PredictionFieldTypeBool}}}); theReader.addParameter(BALANCED_CLASS_LOSS, CDataFrameAnalysisConfigReader::E_OptionalParameter); return theReader; @@ -67,7 +70,7 @@ CDataFrameTrainBoostedTreeClassifierRunner::CDataFrameTrainBoostedTreeClassifier m_NumTopClasses = parameters[NUM_TOP_CLASSES].fallback(std::size_t{0}); m_PredictionFieldType = - parameters[PREDICTION_FIELD_TYPE].fallback(PREDICTION_FIELD_TYPE_STRING); + parameters[PREDICTION_FIELD_TYPE].fallback(E_PredictionFieldTypeString); this->boostedTreeFactory().balanceClassTrainingLoss( parameters[BALANCED_CLASS_LOSS].fallback(true)); @@ -170,20 +173,26 @@ void CDataFrameTrainBoostedTreeClassifierRunner::writePredictedCategoryValue( const std::string& categoryValue, core::CRapidJsonConcurrentLineWriter& writer) const { - if (m_PredictionFieldType == PREDICTION_FIELD_TYPE_INT) { - double doubleValue; + double doubleValue; + switch (m_PredictionFieldType) { + case E_PredictionFieldTypeString: + writer.String(categoryValue); + break; + case E_PredictionFieldTypeInt: if (core::CStringUtils::stringToType(categoryValue, doubleValue)) { writer.Int64(static_cast(doubleValue)); - return; + } else { + writer.String(categoryValue); } - } else if (m_PredictionFieldType == PREDICTION_FIELD_TYPE_BOOL) { - double doubleValue; + break; + case E_PredictionFieldTypeBool: if (core::CStringUtils::stringToType(categoryValue, doubleValue)) { writer.Bool(static_cast(doubleValue) == 1.0); - return; + } else { + writer.String(categoryValue); } + break; } - writer.String(categoryValue); } CDataFrameTrainBoostedTreeClassifierRunner::TLossFunctionUPtr diff --git a/lib/api/unittest/CDataFrameTrainBoostedTreeClassifierRunnerTest.cc b/lib/api/unittest/CDataFrameTrainBoostedTreeClassifierRunnerTest.cc index 1cb338d097..4dfca4c025 100644 --- a/lib/api/unittest/CDataFrameTrainBoostedTreeClassifierRunnerTest.cc +++ b/lib/api/unittest/CDataFrameTrainBoostedTreeClassifierRunnerTest.cc @@ -75,10 +75,14 @@ void testWriteOneRow(const std::string& dependentVariableField, "classification", dependentVariableField, rows.size(), columnNames.size(), 13000000, 0, 0, categoricalColumns)}; rapidjson::Document jsonParameters; - jsonParameters.Parse("{" - " \"dependent_variable\": \"" + dependentVariableField + "\"," - " \"prediction_field_type\": \"" + predictionFieldType + "\"" - "}"); + if (predictionFieldType.empty()) { + jsonParameters.Parse("{\"dependent_variable\": \"" + dependentVariableField + "\"}"); + } else { + jsonParameters.Parse("{" + " \"dependent_variable\": \"" + dependentVariableField + "\"," + " \"prediction_field_type\": \"" + predictionFieldType + "\"" + "}"); + } const auto parameters{ api::CDataFrameTrainBoostedTreeClassifierRunner::parameterReader().read(jsonParameters)}; api::CDataFrameTrainBoostedTreeClassifierRunner runner(*spec, parameters); From 1a09305cff1dbb2a67314f04b6c8e60276ec193a Mon Sep 17 00:00:00 2001 From: Przemyslaw Witek Date: Thu, 5 Dec 2019 13:35:09 +0100 Subject: [PATCH 5/7] Update changelog --- docs/CHANGELOG.asciidoc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/CHANGELOG.asciidoc b/docs/CHANGELOG.asciidoc index 4254e6fb21..80166caa03 100644 --- a/docs/CHANGELOG.asciidoc +++ b/docs/CHANGELOG.asciidoc @@ -52,8 +52,8 @@ tree which is trained for both regression and classification. (See {ml-pull}811[ (See {ml-pull}818[#818].) * Reduce memory usage of {ml} native processes on Windows. (See {ml-pull}844[#844].) * Reduce runtime of classification and regression. (See {ml-pull}863[#863].) -* Emit `prediction_field_name` in ml results using the type of a `dependent_variable`. -(See {ml-pull}877[#877].) +* Emit `prediction_field_name` in ml results using the type provided as +`prediction_field_type` parameter. (See {ml-pull}877[#877].) === Bug Fixes * Fixes potential memory corruption when determining seasonality. (See {ml-pull}852[#852].) From 8f409e27b044d9bf06b83549898f4b39b6c3ad69 Mon Sep 17 00:00:00 2001 From: Przemyslaw Witek Date: Thu, 5 Dec 2019 13:43:04 +0100 Subject: [PATCH 6/7] Apply review comment --- lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc b/lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc index 5b84439878..7ea53f041d 100644 --- a/lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc +++ b/lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc @@ -187,7 +187,7 @@ void CDataFrameTrainBoostedTreeClassifierRunner::writePredictedCategoryValue( break; case E_PredictionFieldTypeBool: if (core::CStringUtils::stringToType(categoryValue, doubleValue)) { - writer.Bool(static_cast(doubleValue) == 1.0); + writer.Bool(doubleValue != 0.0); } else { writer.String(categoryValue); } From d983afbccb6b4d97120797f205768d844b784e51 Mon Sep 17 00:00:00 2001 From: Przemyslaw Witek Date: Fri, 6 Dec 2019 08:27:40 +0100 Subject: [PATCH 7/7] Apply clang-format --- .../CDataFrameTrainBoostedTreeClassifierRunnerTest.cc | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/lib/api/unittest/CDataFrameTrainBoostedTreeClassifierRunnerTest.cc b/lib/api/unittest/CDataFrameTrainBoostedTreeClassifierRunnerTest.cc index 4dfca4c025..41d9777597 100644 --- a/lib/api/unittest/CDataFrameTrainBoostedTreeClassifierRunnerTest.cc +++ b/lib/api/unittest/CDataFrameTrainBoostedTreeClassifierRunnerTest.cc @@ -79,8 +79,12 @@ void testWriteOneRow(const std::string& dependentVariableField, jsonParameters.Parse("{\"dependent_variable\": \"" + dependentVariableField + "\"}"); } else { jsonParameters.Parse("{" - " \"dependent_variable\": \"" + dependentVariableField + "\"," - " \"prediction_field_type\": \"" + predictionFieldType + "\"" + " \"dependent_variable\": \"" + + dependentVariableField + + "\"," + " \"prediction_field_type\": \"" + + predictionFieldType + + "\"" "}"); } const auto parameters{