Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Emit predicted category using an appropriate JSON type. #877

Merged
merged 7 commits into from
Dec 6, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/CHANGELOG.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ tree which is trained for both regression and classification. (See {ml-pull}811[
(See {ml-pull}818[#818].)
* Reduce memory usage of {ml} native processes on Windows. (See {ml-pull}844[#844].)
* Reduce runtime of classification and regression. (See {ml-pull}863[#863].)
* Emit `prediction_field_name` in ml results using the type of a `dependent_variable`.
(See {ml-pull}877[#877].)

=== Bug Fixes
* Fixes potential memory corruption when determining seasonality. (See {ml-pull}852[#852].)
Expand Down
5 changes: 5 additions & 0 deletions include/api/CDataFrameTrainBoostedTreeClassifierRunner.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ class API_EXPORT CDataFrameTrainBoostedTreeClassifierRunner final
const TRowRef& row,
core::CRapidJsonConcurrentLineWriter& writer) const;

//! Write the predicted category value as string, int or bool.
void writePredictedCategoryValue(const std::string& categoryValue,
core::CRapidJsonConcurrentLineWriter& writer) const;

//! \return A serialisable definition of the trained classification model.
TInferenceModelDefinitionUPtr
inferenceModelDefinition(const TStrVec& fieldNames,
Expand All @@ -55,6 +59,7 @@ class API_EXPORT CDataFrameTrainBoostedTreeClassifierRunner final

private:
std::size_t m_NumTopClasses;
std::string m_DependentVariableType;
};

//! \brief Makes a core::CDataFrame boosted tree classification runner.
Expand Down
22 changes: 20 additions & 2 deletions lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ using TSizeVec = std::vector<std::size_t>;

// Configuration
const std::string NUM_TOP_CLASSES{"num_top_classes"};
const std::string DEPENDENT_VARIABLE_TYPE{"dependent_variable_type"};
const std::string BALANCED_CLASS_LOSS{"balanced_class_loss"};

// Output
Expand All @@ -47,6 +48,8 @@ CDataFrameTrainBoostedTreeClassifierRunner::parameterReader() {
static const CDataFrameAnalysisConfigReader PARAMETER_READER{[] {
auto theReader = CDataFrameTrainBoostedTreeRunner::parameterReader();
theReader.addParameter(NUM_TOP_CLASSES, CDataFrameAnalysisConfigReader::E_OptionalParameter);
theReader.addParameter(DEPENDENT_VARIABLE_TYPE,
CDataFrameAnalysisConfigReader::E_OptionalParameter);
theReader.addParameter(BALANCED_CLASS_LOSS,
CDataFrameAnalysisConfigReader::E_OptionalParameter);
return theReader;
Expand All @@ -60,6 +63,8 @@ CDataFrameTrainBoostedTreeClassifierRunner::CDataFrameTrainBoostedTreeClassifier
: CDataFrameTrainBoostedTreeRunner{spec, parameters} {

m_NumTopClasses = parameters[NUM_TOP_CLASSES].fallback(std::size_t{0});
m_DependentVariableType =
parameters[DEPENDENT_VARIABLE_TYPE].fallback(std::string("string"));
this->boostedTreeFactory().balanceClassTrainingLoss(
parameters[BALANCED_CLASS_LOSS].fallback(true));

Expand Down Expand Up @@ -119,7 +124,7 @@ void CDataFrameTrainBoostedTreeClassifierRunner::writeOneRow(

writer.StartObject();
writer.Key(this->predictionFieldName());
writer.String(categoryValues[predictedCategoryId]);
writePredictedCategoryValue(categoryValues[predictedCategoryId], writer);
writer.Key(PREDICTION_PROBABILITY_FIELD_NAME);
writer.Double(probabilityOfCategory[predictedCategoryId]);
writer.Key(IS_TRAINING_FIELD_NAME);
Expand All @@ -135,7 +140,7 @@ void CDataFrameTrainBoostedTreeClassifierRunner::writeOneRow(
for (std::size_t i = 0; i < std::min(categoryIds.size(), m_NumTopClasses); ++i) {
writer.StartObject();
writer.Key(CLASS_NAME_FIELD_NAME);
writer.String(categoryValues[categoryIds[i]]);
writePredictedCategoryValue(categoryValues[categoryIds[i]], writer);
writer.Key(CLASS_PROBABILITY_FIELD_NAME);
writer.Double(probabilityOfCategory[i]);
writer.EndObject();
Expand All @@ -158,6 +163,19 @@ void CDataFrameTrainBoostedTreeClassifierRunner::writeOneRow(
columnHoldingPrediction, row, writer);
}

void CDataFrameTrainBoostedTreeClassifierRunner::writePredictedCategoryValue(
const std::string& categoryValue,
core::CRapidJsonConcurrentLineWriter& writer) const {

if (m_DependentVariableType == "int") {
przemekwitek marked this conversation as resolved.
Show resolved Hide resolved
writer.Int(std::stoi(categoryValue));
przemekwitek marked this conversation as resolved.
Show resolved Hide resolved
} else if (m_DependentVariableType == "bool") {
writer.Bool(std::stoi(categoryValue) == 1);
przemekwitek marked this conversation as resolved.
Show resolved Hide resolved
} else {
writer.String(categoryValue);
}
}

CDataFrameTrainBoostedTreeClassifierRunner::TLossFunctionUPtr
CDataFrameTrainBoostedTreeClassifierRunner::chooseLossFunction(const core::CDataFrame& frame,
std::size_t dependentVariableColumn) const {
Expand Down
57 changes: 42 additions & 15 deletions lib/api/unittest/CDataFrameTrainBoostedTreeClassifierRunnerTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,20 @@ BOOST_AUTO_TEST_CASE(testPredictionFieldNameClash) {
BOOST_TEST_REQUIRE(errors[0] == "Input error: prediction_field_name must not be equal to any of [is_training, prediction_probability, top_classes].");
}

BOOST_AUTO_TEST_CASE(testWriteOneRow) {
template<typename T>
void testWriteOneRow(const std::string& dependentVariableField,
const std::string& dependentVariableType,
T (rapidjson::Value::*extract)() const,
const std::vector<T>& expectedPredictions) {
// Prepare input data frame
const TStrVec columnNames{"x1", "x2", "x3", "x4", "x5", "x5_prediction"};
const TStrVec categoricalColumns{"x1", "x2", "x5"};
const std::string predictionField = dependentVariableField + "_prediction";
const TStrVec columnNames{"x1", "x2", "x3", "x4", "x5", predictionField};
const TStrVec categoricalColumns{"x1", "x2", "x3", "x4", "x5"};
const TStrVecVec rows{{"a", "b", "1.0", "1.0", "cat", "-1.0"},
{"a", "b", "2.0", "2.0", "cat", "-0.5"},
{"a", "b", "5.0", "5.0", "dog", "-0.1"},
{"c", "d", "5.0", "5.0", "dog", "1.0"},
{"e", "f", "5.0", "5.0", "dog", "1.5"}};
{"a", "b", "1.0", "1.0", "cat", "-0.5"},
{"a", "b", "5.0", "0.0", "dog", "-0.1"},
{"c", "d", "5.0", "0.0", "dog", "1.0"},
{"e", "f", "5.0", "0.0", "dog", "1.5"}};
std::unique_ptr<core::CDataFrame> frame =
core::makeMainStorageDataFrame(columnNames.size()).first;
frame->columnNames(columnNames);
Expand All @@ -67,10 +72,13 @@ BOOST_AUTO_TEST_CASE(testWriteOneRow) {

// Create classification analysis runner object
const auto spec{test::CDataFrameAnalysisSpecificationFactory::predictionSpec(
"classification", "x5", rows.size(), columnNames.size(), 13000000, 0, 0,
categoricalColumns)};
"classification", dependentVariableField, rows.size(),
columnNames.size(), 13000000, 0, 0, categoricalColumns)};
rapidjson::Document jsonParameters;
jsonParameters.Parse("{\"dependent_variable\": \"x5\"}");
jsonParameters.Parse("{"
" \"dependent_variable\": \"" + dependentVariableField + "\","
" \"dependent_variable_type\": \"" + dependentVariableType + "\""
"}");
const auto parameters{
api::CDataFrameTrainBoostedTreeClassifierRunner::parameterReader().read(jsonParameters)};
api::CDataFrameTrainBoostedTreeClassifierRunner runner(*spec, parameters);
Expand All @@ -83,10 +91,10 @@ BOOST_AUTO_TEST_CASE(testWriteOneRow) {

frame->readRows(1, [&](TRowItr beginRows, TRowItr endRows) {
const auto columnHoldingDependentVariable{
std::find(columnNames.begin(), columnNames.end(), "x5") -
std::find(columnNames.begin(), columnNames.end(), dependentVariableField) -
columnNames.begin()};
const auto columnHoldingPrediction{
std::find(columnNames.begin(), columnNames.end(), "x5_prediction") -
std::find(columnNames.begin(), columnNames.end(), predictionField) -
columnNames.begin()};
for (auto row = beginRows; row != endRows; ++row) {
runner.writeOneRow(*frame, columnHoldingDependentVariable,
Expand All @@ -95,17 +103,17 @@ BOOST_AUTO_TEST_CASE(testWriteOneRow) {
});
}
// Verify results
const TStrVec expectedPredictions{"cat", "cat", "cat", "dog", "dog"};
rapidjson::Document arrayDoc;
arrayDoc.Parse<rapidjson::kParseDefaultFlags>(output.str().c_str());
BOOST_TEST_REQUIRE(arrayDoc.IsArray());
BOOST_TEST_REQUIRE(arrayDoc.Size() == rows.size());
BOOST_TEST_REQUIRE(arrayDoc.Size() == expectedPredictions.size());
for (std::size_t i = 0; i < arrayDoc.Size(); ++i) {
BOOST_TEST_CONTEXT("Result for row " << i) {
const rapidjson::Value& object = arrayDoc[rapidjson::SizeType(i)];
BOOST_TEST_REQUIRE(object.IsObject());
BOOST_TEST_REQUIRE(object.HasMember("x5_prediction"));
BOOST_TEST_REQUIRE(object["x5_prediction"].GetString() ==
BOOST_TEST_REQUIRE(object.HasMember(predictionField));
BOOST_TEST_REQUIRE((object[predictionField].*extract)() ==
expectedPredictions[i]);
BOOST_TEST_REQUIRE(object.HasMember("prediction_probability"));
BOOST_TEST_REQUIRE(object["prediction_probability"].GetDouble() > 0.5);
Expand All @@ -115,4 +123,23 @@ BOOST_AUTO_TEST_CASE(testWriteOneRow) {
}
}

BOOST_AUTO_TEST_CASE(testWriteOneRow_DependentVariableIsInt) {
przemekwitek marked this conversation as resolved.
Show resolved Hide resolved
testWriteOneRow("x3", "int", &rapidjson::Value::GetInt, {1, 1, 1, 5, 5});
}

BOOST_AUTO_TEST_CASE(testWriteOneRow_DependentVariableIsBool) {
testWriteOneRow("x4", "bool", &rapidjson::Value::GetBool,
{true, true, true, false, false});
}

BOOST_AUTO_TEST_CASE(testWriteOneRow_DependentVariableIsString) {
testWriteOneRow("x5", "string", &rapidjson::Value::GetString,
{"cat", "cat", "cat", "dog", "dog"});
}

BOOST_AUTO_TEST_CASE(testWriteOneRow_DependentVariableTypeMissing) {
testWriteOneRow("x5", "", &rapidjson::Value::GetString,
{"cat", "cat", "cat", "dog", "dog"});
}

BOOST_AUTO_TEST_SUITE_END()