Skip to content

Commit

Permalink
[Feature Processor] Add bind values processing to SQL processor
Browse files Browse the repository at this point in the history
Added usage of custom input processor within SQL processor.

Added validation of SQL features.

Added validation of indexed tensors to be used to validate callback
result of feature processors.

Unit tests are added in the next CL.

Bug: 1302140
Change-Id: I06d839f936b8f54363cb733b2fd287adad094d25
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/3526181
Reviewed-by: Siddhartha S <ssid@chromium.org>
Commit-Queue: Hailey Wang <haileywang@google.com>
Cr-Commit-Position: refs/heads/main@{#990217}
  • Loading branch information
Hailey Wang authored and Chromium LUCI CQ committed Apr 8, 2022
1 parent 0f4141d commit 8048d82
Show file tree
Hide file tree
Showing 7 changed files with 258 additions and 25 deletions.
Expand Up @@ -48,12 +48,10 @@ uint64_t GetExpectedTensorLength(const proto::UMAFeature& feature) {

std::string FeatureToString(const proto::UMAFeature& feature) {
std::string result;
if (feature.has_type()) {
if (feature.has_type())
result = "type:" + proto::SignalType_Name(feature.type()) + ", ";
}
if (feature.has_name()) {
if (feature.has_name())
result.append("name:" + feature.name() + ", ");
}
if (feature.has_name_hash()) {
result.append(
base::StringPrintf("name_hash:0x%" PRIx64 ", ", feature.name_hash()));
Expand Down Expand Up @@ -134,22 +132,37 @@ ValidationResult ValidateMetadataUmaFeature(const proto::UMAFeature& feature) {
return ValidationResult::kValidationSuccess;
}

ValidationResult ValidateMetadataSqlFeature(const proto::SqlFeature& feature) {
if (feature.sql().empty())
return ValidationResult::kFeatureSqlQueryEmpty;

for (int i = 0; i < feature.bind_values_size(); ++i) {
const auto& bind_value = feature.bind_values(i);
if (!bind_value.has_value() ||
bind_value.param_type() == proto::SqlFeature::BindValue::UNKNOWN ||
ValidateMetadataCustomInput(bind_value.value()) !=
ValidationResult::kValidationSuccess) {
return ValidationResult::kFeatureBindValuesInvalid;
}
}

return ValidationResult::kValidationSuccess;
}

ValidationResult ValidateMetadataCustomInput(
const proto::CustomInput& custom_input) {
if (custom_input.fill_policy() == proto::CustomInput::UNKNOWN_FILL_POLICY) {
// If the current fill policy is not supported or not filled, we must use
// the given default value list, therefore the default value list must
// provide enough input values as specified by tensor length.
if (custom_input.tensor_length() > custom_input.default_value_size()) {
if (custom_input.tensor_length() > custom_input.default_value_size())
return ValidationResult::kCustomInputInvalid;
}
} else if (custom_input.fill_policy() ==
proto::CustomInput::FILL_PREDICTION_TIME) {
// Current time can only provide up to one input tensor value, so column
// weight must not exceed 1.
if (custom_input.tensor_length() > 1) {
if (custom_input.tensor_length() > 1)
return ValidationResult::kCustomInputInvalid;
}
}
return ValidationResult::kValidationSuccess;
}
Expand Down Expand Up @@ -189,6 +202,18 @@ ValidationResult ValidateMetadataAndFeatures(
return ValidationResult::kValidationSuccess;
}

ValidationResult ValidateIndexedTensors(
const QueryProcessor::IndexedTensors& tensor,
size_t expected_size) {
if (tensor.size() != expected_size)
return ValidationResult::kIndexedTensorsInvalid;
for (size_t i = 0; i < tensor.size(); ++i) {
if (tensor.count(i) != 1)
return ValidationResult::kIndexedTensorsInvalid;
}
return ValidationResult::kValidationSuccess;
}

ValidationResult ValidateSegmentInfoMetadataAndFeatures(
const proto::SegmentInfo& segment_info) {
auto segment_info_result = ValidateSegmentInfo(segment_info);
Expand Down Expand Up @@ -350,9 +375,8 @@ std::vector<proto::UMAFeature> GetAllUmaFeatures(
// Add training/inference inputs.
for (int i = 0; i < model_metadata.input_features_size(); ++i) {
auto feature = model_metadata.input_features(i);
if (feature.has_uma_feature()) {
if (feature.has_uma_feature())
features.push_back(feature.uma_feature());
}
}

// Add training/inference outputs.
Expand Down
Expand Up @@ -8,6 +8,7 @@
#include "base/time/time.h"
#include "components/optimization_guide/proto/models.pb.h"
#include "components/segmentation_platform/internal/database/signal_key.h"
#include "components/segmentation_platform/internal/execution/query_processor.h"
#include "components/segmentation_platform/internal/proto/model_metadata.pb.h"
#include "components/segmentation_platform/internal/proto/model_prediction.pb.h"
#include "components/segmentation_platform/internal/proto/types.pb.h"
Expand All @@ -34,7 +35,10 @@ enum class ValidationResult {
kVersionNotSupported = 10,
kFeatureListInvalid = 11,
kCustomInputInvalid = 12,
kMaxValue = kCustomInputInvalid,
kFeatureSqlQueryEmpty = 13,
kFeatureBindValuesInvalid = 14,
kIndexedTensorsInvalid = 15,
kMaxValue = kIndexedTensorsInvalid,
};

// Whether the given SegmentInfo and its metadata is valid to be used for the
Expand All @@ -46,10 +50,14 @@ ValidationResult ValidateSegmentInfo(const proto::SegmentInfo& segment_info);
ValidationResult ValidateMetadata(
const proto::SegmentationModelMetadata& model_metadata);

// Whether the given feature metadata is valid to be used for the current
// Whether the given UMA feature metadata is valid to be used for the current
// segmentation platform.
ValidationResult ValidateMetadataUmaFeature(const proto::UMAFeature& feature);

// Whether the given SQL feature metadata is valid to be used for the current
// segmentation platform.
ValidationResult ValidateMetadataSqlFeature(const proto::SqlFeature& feature);

// Whether the given custom input metadata is valid to be used for the current
// segmentation platform.
ValidationResult ValidateMetadataCustomInput(
Expand All @@ -60,6 +68,12 @@ ValidationResult ValidateMetadataCustomInput(
ValidationResult ValidateMetadataAndFeatures(
const proto::SegmentationModelMetadata& model_metadata);

// Whether the given indexed tensor is valid to be used for the current
// segmentation platform.
ValidationResult ValidateIndexedTensors(
const QueryProcessor::IndexedTensors& tensor,
size_t expected_size);

// Whether the given SegmentInfo, metadata and feature metadata is valid to be
// used for the current segmentation platform.
ValidationResult ValidateSegmentInfoMetadataAndFeatures(
Expand Down
Expand Up @@ -6,6 +6,8 @@

#include "base/metrics/metrics_hashes.h"
#include "components/optimization_guide/proto/models.pb.h"
#include "components/segmentation_platform/internal/database/ukm_types.h"
#include "components/segmentation_platform/internal/execution/query_processor.h"
#include "components/segmentation_platform/internal/proto/aggregation.pb.h"
#include "components/segmentation_platform/internal/proto/model_metadata.pb.h"
#include "testing/gtest/include/gtest/gtest.h"
Expand Down Expand Up @@ -192,6 +194,32 @@ TEST_F(MetadataUtilsTest, MetadataUmaFeatureValidation) {
}
}

TEST_F(MetadataUtilsTest, MetadataSqlFeatureValidation) {
// Sql feature with no sql query string is invalid.
proto::SqlFeature sql_feature;
EXPECT_EQ(metadata_utils::ValidationResult::kFeatureSqlQueryEmpty,
metadata_utils::ValidateMetadataSqlFeature(sql_feature));

sql_feature.set_sql("sql query");
EXPECT_EQ(metadata_utils::ValidationResult::kValidationSuccess,
metadata_utils::ValidateMetadataSqlFeature(sql_feature));

// Sql feature with a bind value with no value is invalid.
auto* bind_value = sql_feature.add_bind_values();
bind_value->set_param_type(proto::SqlFeature::BindValue::BOOL);
EXPECT_EQ(metadata_utils::ValidationResult::kFeatureBindValuesInvalid,
metadata_utils::ValidateMetadataSqlFeature(sql_feature));

bind_value->mutable_value();
EXPECT_EQ(metadata_utils::ValidationResult::kValidationSuccess,
metadata_utils::ValidateMetadataSqlFeature(sql_feature));

// Sql feature with a bind value of type unknown is invalid.
bind_value->set_param_type(proto::SqlFeature::BindValue::UNKNOWN);
EXPECT_EQ(metadata_utils::ValidationResult::kFeatureBindValuesInvalid,
metadata_utils::ValidateMetadataSqlFeature(sql_feature));
}

TEST_F(MetadataUtilsTest, MetadataCustomInputValidation) {
// Empty custom input has tensor length of 0 and result in a valid input
// tensor of length 0.
Expand Down Expand Up @@ -320,6 +348,33 @@ TEST_F(MetadataUtilsTest, ValidateMetadataAndInputFeatures) {
metadata_utils::ValidateMetadataAndFeatures(metadata));
}

TEST_F(MetadataUtilsTest, MetadataIndexedTensorsValidation) {
// Empty indexed tensors are valid.
QueryProcessor::IndexedTensors tensor;
EXPECT_EQ(
metadata_utils::ValidationResult::kValidationSuccess,
metadata_utils::ValidateIndexedTensors(tensor, /* expected_size */ 0));

// Not continuously indexed tensors are invalid.
const std::vector<ProcessedValue> value;
tensor[0] = value;
tensor[1] = value;
tensor[3] = value;
EXPECT_EQ(metadata_utils::ValidationResult::kIndexedTensorsInvalid,
metadata_utils::ValidateIndexedTensors(
tensor, /*expected_size*/ tensor.size()));

tensor[2] = value;
EXPECT_EQ(metadata_utils::ValidationResult::kValidationSuccess,
metadata_utils::ValidateIndexedTensors(
tensor, /*expected_size*/ tensor.size()));

// The tensor size should match the expected tensor size.
EXPECT_EQ(metadata_utils::ValidationResult::kIndexedTensorsInvalid,
metadata_utils::ValidateIndexedTensors(
tensor, /*expected_size*/ tensor.size() - 1));
}

TEST_F(MetadataUtilsTest, ValidateSegementInfoMetadataAndFeatures) {
proto::SegmentInfo segment_info;
EXPECT_EQ(
Expand Down
Expand Up @@ -92,9 +92,10 @@ void FeatureListQueryProcessor::ProcessNextInputFeature(
processor = std::make_unique<CustomInputProcessor>(
std::move(queries), feature_processor_state->prediction_time());
} else if (input_feature.has_sql_feature()) {
std::map<QueryProcessor::FeatureIndex, proto::SqlFeature> queries = {
SqlFeatureProcessor::QueryList queries = {
{kIndexNotUsed, input_feature.sql_feature()}};
processor = std::make_unique<SqlFeatureProcessor>(queries);
processor = std::make_unique<SqlFeatureProcessor>(
std::move(queries), feature_processor_state->prediction_time());
}

auto* processor_ptr = processor.get();
Expand Down
Expand Up @@ -3,26 +3,125 @@
// found in the LICENSE file.

#include "components/segmentation_platform/internal/execution/sql_feature_processor.h"
#include <utility>

#include "base/containers/flat_map.h"
#include "base/threading/sequenced_task_runner_handle.h"
#include "components/segmentation_platform/internal/database/metadata_utils.h"
#include "components/segmentation_platform/internal/execution/custom_input_processor.h"
#include "components/segmentation_platform/internal/execution/feature_processor_state.h"
#include "components/segmentation_platform/internal/proto/model_metadata.pb.h"

namespace segmentation_platform {

SqlFeatureProcessor::SqlFeatureProcessor(
std::map<FeatureIndex, proto::SqlFeature> queries)
: queries_(std::move(queries)) {}
SqlFeatureProcessor::SqlFeatureProcessor(QueryList&& queries,
base::Time prediction_time)
: queries_(std::move(queries)), prediction_time_(prediction_time) {}
SqlFeatureProcessor::~SqlFeatureProcessor() = default;

void SqlFeatureProcessor::Process(
std::unique_ptr<FeatureProcessorState> feature_processor_state,
QueryProcessorCallback callback) {
// TODO(haileywang): Implement usage of custom input processor for bind
// values.
queries_.clear();
DCHECK(!is_processed_);
is_processed_ = true;
callback_ = std::move(callback);
feature_processor_state_ = std::move(feature_processor_state);

// Prepare the sql queries for indexed custom inputs processing.
base::flat_map<SqlFeatureAndBindValueIndices, proto::CustomInput> bind_values;
for (const auto& query : queries_) {
const proto::SqlFeature& feature = query.second;
FeatureIndex sql_feature_index = query.first;

// Validate the proto::SqlFeature metadata.
if (metadata_utils::ValidateMetadataSqlFeature(feature) !=
metadata_utils::ValidationResult::kValidationSuccess) {
RunErrorCallback();
return;
}

// Process bind values.
// TODO(haileywang): bind_field_index is not currently being used.
for (int i = 0; i < feature.bind_values_size(); ++i) {
// The index is a pair of int constructed from:
// 1. The index of the sql query, and
// 2. The index of the bind value within the sql query.
bind_values[std::make_pair(sql_feature_index, i)] =
feature.bind_values(i).value();
}
}

// Process the indexed custom inputs
auto custom_input_processor =
std::make_unique<CustomInputProcessor>(prediction_time_);
auto* custom_input_processor_ptr = custom_input_processor.get();
custom_input_processor_ptr->ProcessIndexType<SqlFeatureAndBindValueIndices>(
std::move(bind_values), std::move(feature_processor_state_),
base::BindOnce(&SqlFeatureProcessor::OnCustomInputProcessed,
weak_ptr_factory_.GetWeakPtr(),
std::move(custom_input_processor)));
}

// TODO(haileywang): Move this structure to ukm_types.
struct SqlFeatureProcessor::CustomSqlQuery {
CustomSqlQuery() = default;
~CustomSqlQuery() = default;
std::string query;
std::vector<ProcessedValue> bind_values;
};

void SqlFeatureProcessor::OnCustomInputProcessed(
std::unique_ptr<CustomInputProcessor> custom_input_processor,
std::unique_ptr<FeatureProcessorState> feature_processor_state,
base::flat_map<SqlFeatureAndBindValueIndices, Tensor> result) {
// Validate the total number of bind values needed.
size_t total_bind_values = 0;
for (const auto& query : queries_) {
const proto::SqlFeature& feature = query.second;
total_bind_values += feature.bind_values_size();
}

if (total_bind_values != result.size()) {
RunErrorCallback();
return;
}

// Assemble the sql queries and the corresponding bind values.
for (const auto& query : queries_) {
const proto::SqlFeature& feature = query.second;
FeatureIndex sql_feature_index = query.first;

for (int i = 0; i < feature.bind_values_size(); ++i) {
int bind_value_index = i;

// Validate the result tensor.
if (result.count(std::make_pair(sql_feature_index, bind_value_index)) !=
1) {
RunErrorCallback();
return;
}

// Append query and query params to the list.
const auto& custom_input_tensors =
result[std::make_pair(sql_feature_index, bind_value_index)];
CustomSqlQuery current;
current.query = feature.sql();
current.bind_values.insert(current.bind_values.end(),
custom_input_tensors.begin(),
custom_input_tensors.end());
processed_queries_[sql_feature_index] = std::move(current);
}
}

// TODO(haileywang): Custom inputs have been processed and sql queries are
// ready to be sent to the ukm database.
}

void SqlFeatureProcessor::RunErrorCallback() {
feature_processor_state_->SetError();
base::SequencedTaskRunnerHandle::Get()->PostTask(
FROM_HERE,
base::BindOnce(std::move(callback), std::move(feature_processor_state),
base::BindOnce(std::move(callback_), std::move(feature_processor_state_),
std::move(result_)));
}

Expand Down

0 comments on commit 8048d82

Please sign in to comment.