From 881e5d0f68357e591f8594db878867446b323ebe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Przemys=C5=82aw=20Witek?= Date: Fri, 6 Dec 2019 11:05:16 +0100 Subject: [PATCH] Emit predicted category using an appropriate JSON type. (#877) --- docs/CHANGELOG.asciidoc | 2 + ...ataFrameTrainBoostedTreeClassifierRunner.h | 11 ++++ ...taFrameTrainBoostedTreeClassifierRunner.cc | 41 +++++++++++- ...ameTrainBoostedTreeClassifierRunnerTest.cc | 65 ++++++++++++++----- 4 files changed, 102 insertions(+), 17 deletions(-) diff --git a/docs/CHANGELOG.asciidoc b/docs/CHANGELOG.asciidoc index 98864fd0ec..80166caa03 100644 --- a/docs/CHANGELOG.asciidoc +++ b/docs/CHANGELOG.asciidoc @@ -52,6 +52,8 @@ tree which is trained for both regression and classification. (See {ml-pull}811[ (See {ml-pull}818[#818].) * Reduce memory usage of {ml} native processes on Windows. (See {ml-pull}844[#844].) * Reduce runtime of classification and regression. (See {ml-pull}863[#863].) +* Emit `prediction_field_name` in ml results using the type provided as +`prediction_field_type` parameter. (See {ml-pull}877[#877].) === Bug Fixes * Fixes potential memory corruption when determining seasonality. (See {ml-pull}852[#852].) diff --git a/include/api/CDataFrameTrainBoostedTreeClassifierRunner.h b/include/api/CDataFrameTrainBoostedTreeClassifierRunner.h index deed15b7cf..27bb972305 100644 --- a/include/api/CDataFrameTrainBoostedTreeClassifierRunner.h +++ b/include/api/CDataFrameTrainBoostedTreeClassifierRunner.h @@ -19,6 +19,12 @@ namespace api { class API_EXPORT CDataFrameTrainBoostedTreeClassifierRunner final : public CDataFrameTrainBoostedTreeRunner { public: + enum EPredictionFieldType { + E_PredictionFieldTypeString, + E_PredictionFieldTypeInt, + E_PredictionFieldTypeBool + }; + static const CDataFrameAnalysisConfigReader& parameterReader(); //! This is not intended to be called directly: use CDataFrameTrainBoostedTreeClassifierRunnerFactory. @@ -44,6 +50,10 @@ class API_EXPORT CDataFrameTrainBoostedTreeClassifierRunner final const TRowRef& row, core::CRapidJsonConcurrentLineWriter& writer) const; + //! Write the predicted category value as string, int or bool. + void writePredictedCategoryValue(const std::string& categoryValue, + core::CRapidJsonConcurrentLineWriter& writer) const; + //! \return A serialisable definition of the trained classification model. TInferenceModelDefinitionUPtr inferenceModelDefinition(const TStrVec& fieldNames, @@ -55,6 +65,7 @@ class API_EXPORT CDataFrameTrainBoostedTreeClassifierRunner final private: std::size_t m_NumTopClasses; + EPredictionFieldType m_PredictionFieldType; }; //! \brief Makes a core::CDataFrame boosted tree classification runner. diff --git a/lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc b/lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc index f4609d595b..7ea53f041d 100644 --- a/lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc +++ b/lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc @@ -32,6 +32,7 @@ using TSizeVec = std::vector; // Configuration const std::string NUM_TOP_CLASSES{"num_top_classes"}; +const std::string PREDICTION_FIELD_TYPE{"prediction_field_type"}; const std::string BALANCED_CLASS_LOSS{"balanced_class_loss"}; // Output @@ -45,8 +46,16 @@ const std::string CLASS_PROBABILITY_FIELD_NAME{"class_probability"}; const CDataFrameAnalysisConfigReader& CDataFrameTrainBoostedTreeClassifierRunner::parameterReader() { static const CDataFrameAnalysisConfigReader PARAMETER_READER{[] { + const std::string typeString{"string"}; + const std::string typeInt{"int"}; + const std::string typeBool{"bool"}; auto theReader = CDataFrameTrainBoostedTreeRunner::parameterReader(); theReader.addParameter(NUM_TOP_CLASSES, CDataFrameAnalysisConfigReader::E_OptionalParameter); + theReader.addParameter(PREDICTION_FIELD_TYPE, + CDataFrameAnalysisConfigReader::E_OptionalParameter, + {{typeString, int{E_PredictionFieldTypeString}}, + {typeInt, int{E_PredictionFieldTypeInt}}, + {typeBool, int{E_PredictionFieldTypeBool}}}); theReader.addParameter(BALANCED_CLASS_LOSS, CDataFrameAnalysisConfigReader::E_OptionalParameter); return theReader; @@ -60,6 +69,8 @@ CDataFrameTrainBoostedTreeClassifierRunner::CDataFrameTrainBoostedTreeClassifier : CDataFrameTrainBoostedTreeRunner{spec, parameters} { m_NumTopClasses = parameters[NUM_TOP_CLASSES].fallback(std::size_t{0}); + m_PredictionFieldType = + parameters[PREDICTION_FIELD_TYPE].fallback(E_PredictionFieldTypeString); this->boostedTreeFactory().balanceClassTrainingLoss( parameters[BALANCED_CLASS_LOSS].fallback(true)); @@ -119,7 +130,7 @@ void CDataFrameTrainBoostedTreeClassifierRunner::writeOneRow( writer.StartObject(); writer.Key(this->predictionFieldName()); - writer.String(categoryValues[predictedCategoryId]); + writePredictedCategoryValue(categoryValues[predictedCategoryId], writer); writer.Key(PREDICTION_PROBABILITY_FIELD_NAME); writer.Double(probabilityOfCategory[predictedCategoryId]); writer.Key(IS_TRAINING_FIELD_NAME); @@ -135,7 +146,7 @@ void CDataFrameTrainBoostedTreeClassifierRunner::writeOneRow( for (std::size_t i = 0; i < std::min(categoryIds.size(), m_NumTopClasses); ++i) { writer.StartObject(); writer.Key(CLASS_NAME_FIELD_NAME); - writer.String(categoryValues[categoryIds[i]]); + writePredictedCategoryValue(categoryValues[categoryIds[i]], writer); writer.Key(CLASS_PROBABILITY_FIELD_NAME); writer.Double(probabilityOfCategory[i]); writer.EndObject(); @@ -158,6 +169,32 @@ void CDataFrameTrainBoostedTreeClassifierRunner::writeOneRow( columnHoldingPrediction, row, writer); } +void CDataFrameTrainBoostedTreeClassifierRunner::writePredictedCategoryValue( + const std::string& categoryValue, + core::CRapidJsonConcurrentLineWriter& writer) const { + + double doubleValue; + switch (m_PredictionFieldType) { + case E_PredictionFieldTypeString: + writer.String(categoryValue); + break; + case E_PredictionFieldTypeInt: + if (core::CStringUtils::stringToType(categoryValue, doubleValue)) { + writer.Int64(static_cast(doubleValue)); + } else { + writer.String(categoryValue); + } + break; + case E_PredictionFieldTypeBool: + if (core::CStringUtils::stringToType(categoryValue, doubleValue)) { + writer.Bool(doubleValue != 0.0); + } else { + writer.String(categoryValue); + } + break; + } +} + CDataFrameTrainBoostedTreeClassifierRunner::TLossFunctionUPtr CDataFrameTrainBoostedTreeClassifierRunner::chooseLossFunction(const core::CDataFrame& frame, std::size_t dependentVariableColumn) const { diff --git a/lib/api/unittest/CDataFrameTrainBoostedTreeClassifierRunnerTest.cc b/lib/api/unittest/CDataFrameTrainBoostedTreeClassifierRunnerTest.cc index dc60bf8bbf..41d9777597 100644 --- a/lib/api/unittest/CDataFrameTrainBoostedTreeClassifierRunnerTest.cc +++ b/lib/api/unittest/CDataFrameTrainBoostedTreeClassifierRunnerTest.cc @@ -45,15 +45,20 @@ BOOST_AUTO_TEST_CASE(testPredictionFieldNameClash) { BOOST_TEST_REQUIRE(errors[0] == "Input error: prediction_field_name must not be equal to any of [is_training, prediction_probability, top_classes]."); } -BOOST_AUTO_TEST_CASE(testWriteOneRow) { +template +void testWriteOneRow(const std::string& dependentVariableField, + const std::string& predictionFieldType, + T (rapidjson::Value::*extract)() const, + const std::vector& expectedPredictions) { // Prepare input data frame - const TStrVec columnNames{"x1", "x2", "x3", "x4", "x5", "x5_prediction"}; - const TStrVec categoricalColumns{"x1", "x2", "x5"}; + const std::string predictionField = dependentVariableField + "_prediction"; + const TStrVec columnNames{"x1", "x2", "x3", "x4", "x5", predictionField}; + const TStrVec categoricalColumns{"x1", "x2", "x3", "x4", "x5"}; const TStrVecVec rows{{"a", "b", "1.0", "1.0", "cat", "-1.0"}, - {"a", "b", "2.0", "2.0", "cat", "-0.5"}, - {"a", "b", "5.0", "5.0", "dog", "-0.1"}, - {"c", "d", "5.0", "5.0", "dog", "1.0"}, - {"e", "f", "5.0", "5.0", "dog", "1.5"}}; + {"a", "b", "1.0", "1.0", "cat", "-0.5"}, + {"a", "b", "5.0", "0.0", "dog", "-0.1"}, + {"c", "d", "5.0", "0.0", "dog", "1.0"}, + {"e", "f", "5.0", "0.0", "dog", "1.5"}}; std::unique_ptr frame = core::makeMainStorageDataFrame(columnNames.size()).first; frame->columnNames(columnNames); @@ -67,10 +72,21 @@ BOOST_AUTO_TEST_CASE(testWriteOneRow) { // Create classification analysis runner object const auto spec{test::CDataFrameAnalysisSpecificationFactory::predictionSpec( - "classification", "x5", rows.size(), columnNames.size(), 13000000, 0, 0, - categoricalColumns)}; + "classification", dependentVariableField, rows.size(), + columnNames.size(), 13000000, 0, 0, categoricalColumns)}; rapidjson::Document jsonParameters; - jsonParameters.Parse("{\"dependent_variable\": \"x5\"}"); + if (predictionFieldType.empty()) { + jsonParameters.Parse("{\"dependent_variable\": \"" + dependentVariableField + "\"}"); + } else { + jsonParameters.Parse("{" + " \"dependent_variable\": \"" + + dependentVariableField + + "\"," + " \"prediction_field_type\": \"" + + predictionFieldType + + "\"" + "}"); + } const auto parameters{ api::CDataFrameTrainBoostedTreeClassifierRunner::parameterReader().read(jsonParameters)}; api::CDataFrameTrainBoostedTreeClassifierRunner runner(*spec, parameters); @@ -83,10 +99,10 @@ BOOST_AUTO_TEST_CASE(testWriteOneRow) { frame->readRows(1, [&](TRowItr beginRows, TRowItr endRows) { const auto columnHoldingDependentVariable{ - std::find(columnNames.begin(), columnNames.end(), "x5") - + std::find(columnNames.begin(), columnNames.end(), dependentVariableField) - columnNames.begin()}; const auto columnHoldingPrediction{ - std::find(columnNames.begin(), columnNames.end(), "x5_prediction") - + std::find(columnNames.begin(), columnNames.end(), predictionField) - columnNames.begin()}; for (auto row = beginRows; row != endRows; ++row) { runner.writeOneRow(*frame, columnHoldingDependentVariable, @@ -95,17 +111,17 @@ BOOST_AUTO_TEST_CASE(testWriteOneRow) { }); } // Verify results - const TStrVec expectedPredictions{"cat", "cat", "cat", "dog", "dog"}; rapidjson::Document arrayDoc; arrayDoc.Parse(output.str().c_str()); BOOST_TEST_REQUIRE(arrayDoc.IsArray()); BOOST_TEST_REQUIRE(arrayDoc.Size() == rows.size()); + BOOST_TEST_REQUIRE(arrayDoc.Size() == expectedPredictions.size()); for (std::size_t i = 0; i < arrayDoc.Size(); ++i) { BOOST_TEST_CONTEXT("Result for row " << i) { const rapidjson::Value& object = arrayDoc[rapidjson::SizeType(i)]; BOOST_TEST_REQUIRE(object.IsObject()); - BOOST_TEST_REQUIRE(object.HasMember("x5_prediction")); - BOOST_TEST_REQUIRE(object["x5_prediction"].GetString() == + BOOST_TEST_REQUIRE(object.HasMember(predictionField)); + BOOST_TEST_REQUIRE((object[predictionField].*extract)() == expectedPredictions[i]); BOOST_TEST_REQUIRE(object.HasMember("prediction_probability")); BOOST_TEST_REQUIRE(object["prediction_probability"].GetDouble() > 0.5); @@ -115,4 +131,23 @@ BOOST_AUTO_TEST_CASE(testWriteOneRow) { } } +BOOST_AUTO_TEST_CASE(testWriteOneRowPredictionFieldTypeIsInt) { + testWriteOneRow("x3", "int", &rapidjson::Value::GetInt, {1, 1, 1, 5, 5}); +} + +BOOST_AUTO_TEST_CASE(testWriteOneRowPredictionFieldTypeIsBool) { + testWriteOneRow("x4", "bool", &rapidjson::Value::GetBool, + {true, true, true, false, false}); +} + +BOOST_AUTO_TEST_CASE(testWriteOneRowPredictionFieldTypeIsString) { + testWriteOneRow("x5", "string", &rapidjson::Value::GetString, + {"cat", "cat", "cat", "dog", "dog"}); +} + +BOOST_AUTO_TEST_CASE(testWriteOneRowPredictionFieldTypeIsMissing) { + testWriteOneRow("x5", "", &rapidjson::Value::GetString, + {"cat", "cat", "cat", "dog", "dog"}); +} + BOOST_AUTO_TEST_SUITE_END()