Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Emit predicted category using an appropriate JSON type. #877

Merged
merged 7 commits into from
Dec 6, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/CHANGELOG.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ tree which is trained for both regression and classification. (See {ml-pull}811[
(See {ml-pull}818[#818].)
* Reduce memory usage of {ml} native processes on Windows. (See {ml-pull}844[#844].)
* Reduce runtime of classification and regression. (See {ml-pull}863[#863].)
* Emit `prediction_field_name` in ml results using the type provided as
`prediction_field_type` parameter. (See {ml-pull}877[#877].)

=== Bug Fixes
* Fixes potential memory corruption when determining seasonality. (See {ml-pull}852[#852].)
Expand Down
11 changes: 11 additions & 0 deletions include/api/CDataFrameTrainBoostedTreeClassifierRunner.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ namespace api {
class API_EXPORT CDataFrameTrainBoostedTreeClassifierRunner final
: public CDataFrameTrainBoostedTreeRunner {
public:
enum EPredictionFieldType {
E_PredictionFieldTypeString,
E_PredictionFieldTypeInt,
E_PredictionFieldTypeBool
};

static const CDataFrameAnalysisConfigReader& parameterReader();

//! This is not intended to be called directly: use CDataFrameTrainBoostedTreeClassifierRunnerFactory.
Expand All @@ -44,6 +50,10 @@ class API_EXPORT CDataFrameTrainBoostedTreeClassifierRunner final
const TRowRef& row,
core::CRapidJsonConcurrentLineWriter& writer) const;

//! Write the predicted category value as string, int or bool.
void writePredictedCategoryValue(const std::string& categoryValue,
core::CRapidJsonConcurrentLineWriter& writer) const;

//! \return A serialisable definition of the trained classification model.
TInferenceModelDefinitionUPtr
inferenceModelDefinition(const TStrVec& fieldNames,
Expand All @@ -55,6 +65,7 @@ class API_EXPORT CDataFrameTrainBoostedTreeClassifierRunner final

private:
std::size_t m_NumTopClasses;
EPredictionFieldType m_PredictionFieldType;
};

//! \brief Makes a core::CDataFrame boosted tree classification runner.
Expand Down
41 changes: 39 additions & 2 deletions lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ using TSizeVec = std::vector<std::size_t>;

// Configuration
const std::string NUM_TOP_CLASSES{"num_top_classes"};
const std::string PREDICTION_FIELD_TYPE{"prediction_field_type"};
const std::string BALANCED_CLASS_LOSS{"balanced_class_loss"};

// Output
Expand All @@ -45,8 +46,16 @@ const std::string CLASS_PROBABILITY_FIELD_NAME{"class_probability"};
const CDataFrameAnalysisConfigReader&
CDataFrameTrainBoostedTreeClassifierRunner::parameterReader() {
static const CDataFrameAnalysisConfigReader PARAMETER_READER{[] {
const std::string typeString{"string"};
const std::string typeInt{"int"};
const std::string typeBool{"bool"};
auto theReader = CDataFrameTrainBoostedTreeRunner::parameterReader();
theReader.addParameter(NUM_TOP_CLASSES, CDataFrameAnalysisConfigReader::E_OptionalParameter);
theReader.addParameter(PREDICTION_FIELD_TYPE,
CDataFrameAnalysisConfigReader::E_OptionalParameter,
{{typeString, int{E_PredictionFieldTypeString}},
{typeInt, int{E_PredictionFieldTypeInt}},
{typeBool, int{E_PredictionFieldTypeBool}}});
theReader.addParameter(BALANCED_CLASS_LOSS,
CDataFrameAnalysisConfigReader::E_OptionalParameter);
return theReader;
Expand All @@ -60,6 +69,8 @@ CDataFrameTrainBoostedTreeClassifierRunner::CDataFrameTrainBoostedTreeClassifier
: CDataFrameTrainBoostedTreeRunner{spec, parameters} {

m_NumTopClasses = parameters[NUM_TOP_CLASSES].fallback(std::size_t{0});
m_PredictionFieldType =
parameters[PREDICTION_FIELD_TYPE].fallback(E_PredictionFieldTypeString);
this->boostedTreeFactory().balanceClassTrainingLoss(
parameters[BALANCED_CLASS_LOSS].fallback(true));

Expand Down Expand Up @@ -119,7 +130,7 @@ void CDataFrameTrainBoostedTreeClassifierRunner::writeOneRow(

writer.StartObject();
writer.Key(this->predictionFieldName());
writer.String(categoryValues[predictedCategoryId]);
writePredictedCategoryValue(categoryValues[predictedCategoryId], writer);
writer.Key(PREDICTION_PROBABILITY_FIELD_NAME);
writer.Double(probabilityOfCategory[predictedCategoryId]);
writer.Key(IS_TRAINING_FIELD_NAME);
Expand All @@ -135,7 +146,7 @@ void CDataFrameTrainBoostedTreeClassifierRunner::writeOneRow(
for (std::size_t i = 0; i < std::min(categoryIds.size(), m_NumTopClasses); ++i) {
writer.StartObject();
writer.Key(CLASS_NAME_FIELD_NAME);
writer.String(categoryValues[categoryIds[i]]);
writePredictedCategoryValue(categoryValues[categoryIds[i]], writer);
writer.Key(CLASS_PROBABILITY_FIELD_NAME);
writer.Double(probabilityOfCategory[i]);
writer.EndObject();
Expand All @@ -158,6 +169,32 @@ void CDataFrameTrainBoostedTreeClassifierRunner::writeOneRow(
columnHoldingPrediction, row, writer);
}

void CDataFrameTrainBoostedTreeClassifierRunner::writePredictedCategoryValue(
const std::string& categoryValue,
core::CRapidJsonConcurrentLineWriter& writer) const {

double doubleValue;
switch (m_PredictionFieldType) {
case E_PredictionFieldTypeString:
writer.String(categoryValue);
break;
case E_PredictionFieldTypeInt:
if (core::CStringUtils::stringToType(categoryValue, doubleValue)) {
writer.Int64(static_cast<std::int64_t>(doubleValue));
} else {
writer.String(categoryValue);
}
break;
case E_PredictionFieldTypeBool:
if (core::CStringUtils::stringToType(categoryValue, doubleValue)) {
writer.Bool(doubleValue != 0.0);
} else {
writer.String(categoryValue);
}
break;
}
}

CDataFrameTrainBoostedTreeClassifierRunner::TLossFunctionUPtr
CDataFrameTrainBoostedTreeClassifierRunner::chooseLossFunction(const core::CDataFrame& frame,
std::size_t dependentVariableColumn) const {
Expand Down
65 changes: 50 additions & 15 deletions lib/api/unittest/CDataFrameTrainBoostedTreeClassifierRunnerTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,20 @@ BOOST_AUTO_TEST_CASE(testPredictionFieldNameClash) {
BOOST_TEST_REQUIRE(errors[0] == "Input error: prediction_field_name must not be equal to any of [is_training, prediction_probability, top_classes].");
}

BOOST_AUTO_TEST_CASE(testWriteOneRow) {
template<typename T>
void testWriteOneRow(const std::string& dependentVariableField,
const std::string& predictionFieldType,
T (rapidjson::Value::*extract)() const,
const std::vector<T>& expectedPredictions) {
// Prepare input data frame
const TStrVec columnNames{"x1", "x2", "x3", "x4", "x5", "x5_prediction"};
const TStrVec categoricalColumns{"x1", "x2", "x5"};
const std::string predictionField = dependentVariableField + "_prediction";
const TStrVec columnNames{"x1", "x2", "x3", "x4", "x5", predictionField};
const TStrVec categoricalColumns{"x1", "x2", "x3", "x4", "x5"};
const TStrVecVec rows{{"a", "b", "1.0", "1.0", "cat", "-1.0"},
{"a", "b", "2.0", "2.0", "cat", "-0.5"},
{"a", "b", "5.0", "5.0", "dog", "-0.1"},
{"c", "d", "5.0", "5.0", "dog", "1.0"},
{"e", "f", "5.0", "5.0", "dog", "1.5"}};
{"a", "b", "1.0", "1.0", "cat", "-0.5"},
{"a", "b", "5.0", "0.0", "dog", "-0.1"},
{"c", "d", "5.0", "0.0", "dog", "1.0"},
{"e", "f", "5.0", "0.0", "dog", "1.5"}};
std::unique_ptr<core::CDataFrame> frame =
core::makeMainStorageDataFrame(columnNames.size()).first;
frame->columnNames(columnNames);
Expand All @@ -67,10 +72,21 @@ BOOST_AUTO_TEST_CASE(testWriteOneRow) {

// Create classification analysis runner object
const auto spec{test::CDataFrameAnalysisSpecificationFactory::predictionSpec(
"classification", "x5", rows.size(), columnNames.size(), 13000000, 0, 0,
categoricalColumns)};
"classification", dependentVariableField, rows.size(),
columnNames.size(), 13000000, 0, 0, categoricalColumns)};
rapidjson::Document jsonParameters;
jsonParameters.Parse("{\"dependent_variable\": \"x5\"}");
if (predictionFieldType.empty()) {
jsonParameters.Parse("{\"dependent_variable\": \"" + dependentVariableField + "\"}");
} else {
jsonParameters.Parse("{"
" \"dependent_variable\": \"" +
dependentVariableField +
"\","
" \"prediction_field_type\": \"" +
predictionFieldType +
"\""
"}");
}
const auto parameters{
api::CDataFrameTrainBoostedTreeClassifierRunner::parameterReader().read(jsonParameters)};
api::CDataFrameTrainBoostedTreeClassifierRunner runner(*spec, parameters);
Expand All @@ -83,10 +99,10 @@ BOOST_AUTO_TEST_CASE(testWriteOneRow) {

frame->readRows(1, [&](TRowItr beginRows, TRowItr endRows) {
const auto columnHoldingDependentVariable{
std::find(columnNames.begin(), columnNames.end(), "x5") -
std::find(columnNames.begin(), columnNames.end(), dependentVariableField) -
columnNames.begin()};
const auto columnHoldingPrediction{
std::find(columnNames.begin(), columnNames.end(), "x5_prediction") -
std::find(columnNames.begin(), columnNames.end(), predictionField) -
columnNames.begin()};
for (auto row = beginRows; row != endRows; ++row) {
runner.writeOneRow(*frame, columnHoldingDependentVariable,
Expand All @@ -95,17 +111,17 @@ BOOST_AUTO_TEST_CASE(testWriteOneRow) {
});
}
// Verify results
const TStrVec expectedPredictions{"cat", "cat", "cat", "dog", "dog"};
rapidjson::Document arrayDoc;
arrayDoc.Parse<rapidjson::kParseDefaultFlags>(output.str().c_str());
BOOST_TEST_REQUIRE(arrayDoc.IsArray());
BOOST_TEST_REQUIRE(arrayDoc.Size() == rows.size());
BOOST_TEST_REQUIRE(arrayDoc.Size() == expectedPredictions.size());
for (std::size_t i = 0; i < arrayDoc.Size(); ++i) {
BOOST_TEST_CONTEXT("Result for row " << i) {
const rapidjson::Value& object = arrayDoc[rapidjson::SizeType(i)];
BOOST_TEST_REQUIRE(object.IsObject());
BOOST_TEST_REQUIRE(object.HasMember("x5_prediction"));
BOOST_TEST_REQUIRE(object["x5_prediction"].GetString() ==
BOOST_TEST_REQUIRE(object.HasMember(predictionField));
BOOST_TEST_REQUIRE((object[predictionField].*extract)() ==
expectedPredictions[i]);
BOOST_TEST_REQUIRE(object.HasMember("prediction_probability"));
BOOST_TEST_REQUIRE(object["prediction_probability"].GetDouble() > 0.5);
Expand All @@ -115,4 +131,23 @@ BOOST_AUTO_TEST_CASE(testWriteOneRow) {
}
}

BOOST_AUTO_TEST_CASE(testWriteOneRowPredictionFieldTypeIsInt) {
testWriteOneRow("x3", "int", &rapidjson::Value::GetInt, {1, 1, 1, 5, 5});
}

BOOST_AUTO_TEST_CASE(testWriteOneRowPredictionFieldTypeIsBool) {
testWriteOneRow("x4", "bool", &rapidjson::Value::GetBool,
{true, true, true, false, false});
}

BOOST_AUTO_TEST_CASE(testWriteOneRowPredictionFieldTypeIsString) {
testWriteOneRow("x5", "string", &rapidjson::Value::GetString,
{"cat", "cat", "cat", "dog", "dog"});
}

BOOST_AUTO_TEST_CASE(testWriteOneRowPredictionFieldTypeIsMissing) {
testWriteOneRow("x5", "", &rapidjson::Value::GetString,
{"cat", "cat", "cat", "dog", "dog"});
}

BOOST_AUTO_TEST_SUITE_END()