From 8048d8287eb78f5806540ac95a1e274cc9bc7984 Mon Sep 17 00:00:00 2001 From: Hailey Wang Date: Fri, 8 Apr 2022 00:59:50 +0000 Subject: [PATCH] [Feature Processor] Add bind values processing to SQL processor Added usage of custom input processor within SQL processor. Added validation of SQL features. Added validation of indexed tensors to be used to validate callback result of feature processors. Unit tests are added in the next CL. Bug: 1302140 Change-Id: I06d839f936b8f54363cb733b2fd287adad094d25 Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/3526181 Reviewed-by: Siddhartha S Commit-Queue: Hailey Wang Cr-Commit-Position: refs/heads/main@{#990217} --- .../internal/database/metadata_utils.cc | 44 +++++-- .../internal/database/metadata_utils.h | 18 ++- .../database/metadata_utils_unittest.cc | 55 +++++++++ .../execution/feature_list_query_processor.cc | 5 +- .../execution/sql_feature_processor.cc | 113 ++++++++++++++++-- .../execution/sql_feature_processor.h | 45 ++++++- tools/metrics/histograms/enums.xml | 3 + 7 files changed, 258 insertions(+), 25 deletions(-) diff --git a/components/segmentation_platform/internal/database/metadata_utils.cc b/components/segmentation_platform/internal/database/metadata_utils.cc index 7cd82092900cee..191485161088ed 100644 --- a/components/segmentation_platform/internal/database/metadata_utils.cc +++ b/components/segmentation_platform/internal/database/metadata_utils.cc @@ -48,12 +48,10 @@ uint64_t GetExpectedTensorLength(const proto::UMAFeature& feature) { std::string FeatureToString(const proto::UMAFeature& feature) { std::string result; - if (feature.has_type()) { + if (feature.has_type()) result = "type:" + proto::SignalType_Name(feature.type()) + ", "; - } - if (feature.has_name()) { + if (feature.has_name()) result.append("name:" + feature.name() + ", "); - } if (feature.has_name_hash()) { result.append( base::StringPrintf("name_hash:0x%" PRIx64 ", ", feature.name_hash())); @@ -134,22 +132,37 @@ ValidationResult ValidateMetadataUmaFeature(const proto::UMAFeature& feature) { return ValidationResult::kValidationSuccess; } +ValidationResult ValidateMetadataSqlFeature(const proto::SqlFeature& feature) { + if (feature.sql().empty()) + return ValidationResult::kFeatureSqlQueryEmpty; + + for (int i = 0; i < feature.bind_values_size(); ++i) { + const auto& bind_value = feature.bind_values(i); + if (!bind_value.has_value() || + bind_value.param_type() == proto::SqlFeature::BindValue::UNKNOWN || + ValidateMetadataCustomInput(bind_value.value()) != + ValidationResult::kValidationSuccess) { + return ValidationResult::kFeatureBindValuesInvalid; + } + } + + return ValidationResult::kValidationSuccess; +} + ValidationResult ValidateMetadataCustomInput( const proto::CustomInput& custom_input) { if (custom_input.fill_policy() == proto::CustomInput::UNKNOWN_FILL_POLICY) { // If the current fill policy is not supported or not filled, we must use // the given default value list, therefore the default value list must // provide enough input values as specified by tensor length. - if (custom_input.tensor_length() > custom_input.default_value_size()) { + if (custom_input.tensor_length() > custom_input.default_value_size()) return ValidationResult::kCustomInputInvalid; - } } else if (custom_input.fill_policy() == proto::CustomInput::FILL_PREDICTION_TIME) { // Current time can only provide up to one input tensor value, so column // weight must not exceed 1. - if (custom_input.tensor_length() > 1) { + if (custom_input.tensor_length() > 1) return ValidationResult::kCustomInputInvalid; - } } return ValidationResult::kValidationSuccess; } @@ -189,6 +202,18 @@ ValidationResult ValidateMetadataAndFeatures( return ValidationResult::kValidationSuccess; } +ValidationResult ValidateIndexedTensors( + const QueryProcessor::IndexedTensors& tensor, + size_t expected_size) { + if (tensor.size() != expected_size) + return ValidationResult::kIndexedTensorsInvalid; + for (size_t i = 0; i < tensor.size(); ++i) { + if (tensor.count(i) != 1) + return ValidationResult::kIndexedTensorsInvalid; + } + return ValidationResult::kValidationSuccess; +} + ValidationResult ValidateSegmentInfoMetadataAndFeatures( const proto::SegmentInfo& segment_info) { auto segment_info_result = ValidateSegmentInfo(segment_info); @@ -350,9 +375,8 @@ std::vector GetAllUmaFeatures( // Add training/inference inputs. for (int i = 0; i < model_metadata.input_features_size(); ++i) { auto feature = model_metadata.input_features(i); - if (feature.has_uma_feature()) { + if (feature.has_uma_feature()) features.push_back(feature.uma_feature()); - } } // Add training/inference outputs. diff --git a/components/segmentation_platform/internal/database/metadata_utils.h b/components/segmentation_platform/internal/database/metadata_utils.h index 3fcb89f9fa56ce..3aa1a71b6b143c 100644 --- a/components/segmentation_platform/internal/database/metadata_utils.h +++ b/components/segmentation_platform/internal/database/metadata_utils.h @@ -8,6 +8,7 @@ #include "base/time/time.h" #include "components/optimization_guide/proto/models.pb.h" #include "components/segmentation_platform/internal/database/signal_key.h" +#include "components/segmentation_platform/internal/execution/query_processor.h" #include "components/segmentation_platform/internal/proto/model_metadata.pb.h" #include "components/segmentation_platform/internal/proto/model_prediction.pb.h" #include "components/segmentation_platform/internal/proto/types.pb.h" @@ -34,7 +35,10 @@ enum class ValidationResult { kVersionNotSupported = 10, kFeatureListInvalid = 11, kCustomInputInvalid = 12, - kMaxValue = kCustomInputInvalid, + kFeatureSqlQueryEmpty = 13, + kFeatureBindValuesInvalid = 14, + kIndexedTensorsInvalid = 15, + kMaxValue = kIndexedTensorsInvalid, }; // Whether the given SegmentInfo and its metadata is valid to be used for the @@ -46,10 +50,14 @@ ValidationResult ValidateSegmentInfo(const proto::SegmentInfo& segment_info); ValidationResult ValidateMetadata( const proto::SegmentationModelMetadata& model_metadata); -// Whether the given feature metadata is valid to be used for the current +// Whether the given UMA feature metadata is valid to be used for the current // segmentation platform. ValidationResult ValidateMetadataUmaFeature(const proto::UMAFeature& feature); +// Whether the given SQL feature metadata is valid to be used for the current +// segmentation platform. +ValidationResult ValidateMetadataSqlFeature(const proto::SqlFeature& feature); + // Whether the given custom input metadata is valid to be used for the current // segmentation platform. ValidationResult ValidateMetadataCustomInput( @@ -60,6 +68,12 @@ ValidationResult ValidateMetadataCustomInput( ValidationResult ValidateMetadataAndFeatures( const proto::SegmentationModelMetadata& model_metadata); +// Whether the given indexed tensor is valid to be used for the current +// segmentation platform. +ValidationResult ValidateIndexedTensors( + const QueryProcessor::IndexedTensors& tensor, + size_t expected_size); + // Whether the given SegmentInfo, metadata and feature metadata is valid to be // used for the current segmentation platform. ValidationResult ValidateSegmentInfoMetadataAndFeatures( diff --git a/components/segmentation_platform/internal/database/metadata_utils_unittest.cc b/components/segmentation_platform/internal/database/metadata_utils_unittest.cc index d419de8de92752..9dcd1424c9f60e 100644 --- a/components/segmentation_platform/internal/database/metadata_utils_unittest.cc +++ b/components/segmentation_platform/internal/database/metadata_utils_unittest.cc @@ -6,6 +6,8 @@ #include "base/metrics/metrics_hashes.h" #include "components/optimization_guide/proto/models.pb.h" +#include "components/segmentation_platform/internal/database/ukm_types.h" +#include "components/segmentation_platform/internal/execution/query_processor.h" #include "components/segmentation_platform/internal/proto/aggregation.pb.h" #include "components/segmentation_platform/internal/proto/model_metadata.pb.h" #include "testing/gtest/include/gtest/gtest.h" @@ -192,6 +194,32 @@ TEST_F(MetadataUtilsTest, MetadataUmaFeatureValidation) { } } +TEST_F(MetadataUtilsTest, MetadataSqlFeatureValidation) { + // Sql feature with no sql query string is invalid. + proto::SqlFeature sql_feature; + EXPECT_EQ(metadata_utils::ValidationResult::kFeatureSqlQueryEmpty, + metadata_utils::ValidateMetadataSqlFeature(sql_feature)); + + sql_feature.set_sql("sql query"); + EXPECT_EQ(metadata_utils::ValidationResult::kValidationSuccess, + metadata_utils::ValidateMetadataSqlFeature(sql_feature)); + + // Sql feature with a bind value with no value is invalid. + auto* bind_value = sql_feature.add_bind_values(); + bind_value->set_param_type(proto::SqlFeature::BindValue::BOOL); + EXPECT_EQ(metadata_utils::ValidationResult::kFeatureBindValuesInvalid, + metadata_utils::ValidateMetadataSqlFeature(sql_feature)); + + bind_value->mutable_value(); + EXPECT_EQ(metadata_utils::ValidationResult::kValidationSuccess, + metadata_utils::ValidateMetadataSqlFeature(sql_feature)); + + // Sql feature with a bind value of type unknown is invalid. + bind_value->set_param_type(proto::SqlFeature::BindValue::UNKNOWN); + EXPECT_EQ(metadata_utils::ValidationResult::kFeatureBindValuesInvalid, + metadata_utils::ValidateMetadataSqlFeature(sql_feature)); +} + TEST_F(MetadataUtilsTest, MetadataCustomInputValidation) { // Empty custom input has tensor length of 0 and result in a valid input // tensor of length 0. @@ -320,6 +348,33 @@ TEST_F(MetadataUtilsTest, ValidateMetadataAndInputFeatures) { metadata_utils::ValidateMetadataAndFeatures(metadata)); } +TEST_F(MetadataUtilsTest, MetadataIndexedTensorsValidation) { + // Empty indexed tensors are valid. + QueryProcessor::IndexedTensors tensor; + EXPECT_EQ( + metadata_utils::ValidationResult::kValidationSuccess, + metadata_utils::ValidateIndexedTensors(tensor, /* expected_size */ 0)); + + // Not continuously indexed tensors are invalid. + const std::vector value; + tensor[0] = value; + tensor[1] = value; + tensor[3] = value; + EXPECT_EQ(metadata_utils::ValidationResult::kIndexedTensorsInvalid, + metadata_utils::ValidateIndexedTensors( + tensor, /*expected_size*/ tensor.size())); + + tensor[2] = value; + EXPECT_EQ(metadata_utils::ValidationResult::kValidationSuccess, + metadata_utils::ValidateIndexedTensors( + tensor, /*expected_size*/ tensor.size())); + + // The tensor size should match the expected tensor size. + EXPECT_EQ(metadata_utils::ValidationResult::kIndexedTensorsInvalid, + metadata_utils::ValidateIndexedTensors( + tensor, /*expected_size*/ tensor.size() - 1)); +} + TEST_F(MetadataUtilsTest, ValidateSegementInfoMetadataAndFeatures) { proto::SegmentInfo segment_info; EXPECT_EQ( diff --git a/components/segmentation_platform/internal/execution/feature_list_query_processor.cc b/components/segmentation_platform/internal/execution/feature_list_query_processor.cc index a750c13196f404..cb9271ae6f662d 100644 --- a/components/segmentation_platform/internal/execution/feature_list_query_processor.cc +++ b/components/segmentation_platform/internal/execution/feature_list_query_processor.cc @@ -92,9 +92,10 @@ void FeatureListQueryProcessor::ProcessNextInputFeature( processor = std::make_unique( std::move(queries), feature_processor_state->prediction_time()); } else if (input_feature.has_sql_feature()) { - std::map queries = { + SqlFeatureProcessor::QueryList queries = { {kIndexNotUsed, input_feature.sql_feature()}}; - processor = std::make_unique(queries); + processor = std::make_unique( + std::move(queries), feature_processor_state->prediction_time()); } auto* processor_ptr = processor.get(); diff --git a/components/segmentation_platform/internal/execution/sql_feature_processor.cc b/components/segmentation_platform/internal/execution/sql_feature_processor.cc index 14c50f5b28fa94..75682ec768b8a5 100644 --- a/components/segmentation_platform/internal/execution/sql_feature_processor.cc +++ b/components/segmentation_platform/internal/execution/sql_feature_processor.cc @@ -3,26 +3,125 @@ // found in the LICENSE file. #include "components/segmentation_platform/internal/execution/sql_feature_processor.h" +#include +#include "base/containers/flat_map.h" #include "base/threading/sequenced_task_runner_handle.h" +#include "components/segmentation_platform/internal/database/metadata_utils.h" +#include "components/segmentation_platform/internal/execution/custom_input_processor.h" #include "components/segmentation_platform/internal/execution/feature_processor_state.h" +#include "components/segmentation_platform/internal/proto/model_metadata.pb.h" namespace segmentation_platform { -SqlFeatureProcessor::SqlFeatureProcessor( - std::map queries) - : queries_(std::move(queries)) {} +SqlFeatureProcessor::SqlFeatureProcessor(QueryList&& queries, + base::Time prediction_time) + : queries_(std::move(queries)), prediction_time_(prediction_time) {} SqlFeatureProcessor::~SqlFeatureProcessor() = default; void SqlFeatureProcessor::Process( std::unique_ptr feature_processor_state, QueryProcessorCallback callback) { - // TODO(haileywang): Implement usage of custom input processor for bind - // values. - queries_.clear(); + DCHECK(!is_processed_); + is_processed_ = true; + callback_ = std::move(callback); + feature_processor_state_ = std::move(feature_processor_state); + + // Prepare the sql queries for indexed custom inputs processing. + base::flat_map bind_values; + for (const auto& query : queries_) { + const proto::SqlFeature& feature = query.second; + FeatureIndex sql_feature_index = query.first; + + // Validate the proto::SqlFeature metadata. + if (metadata_utils::ValidateMetadataSqlFeature(feature) != + metadata_utils::ValidationResult::kValidationSuccess) { + RunErrorCallback(); + return; + } + + // Process bind values. + // TODO(haileywang): bind_field_index is not currently being used. + for (int i = 0; i < feature.bind_values_size(); ++i) { + // The index is a pair of int constructed from: + // 1. The index of the sql query, and + // 2. The index of the bind value within the sql query. + bind_values[std::make_pair(sql_feature_index, i)] = + feature.bind_values(i).value(); + } + } + + // Process the indexed custom inputs + auto custom_input_processor = + std::make_unique(prediction_time_); + auto* custom_input_processor_ptr = custom_input_processor.get(); + custom_input_processor_ptr->ProcessIndexType( + std::move(bind_values), std::move(feature_processor_state_), + base::BindOnce(&SqlFeatureProcessor::OnCustomInputProcessed, + weak_ptr_factory_.GetWeakPtr(), + std::move(custom_input_processor))); +} + +// TODO(haileywang): Move this structure to ukm_types. +struct SqlFeatureProcessor::CustomSqlQuery { + CustomSqlQuery() = default; + ~CustomSqlQuery() = default; + std::string query; + std::vector bind_values; +}; + +void SqlFeatureProcessor::OnCustomInputProcessed( + std::unique_ptr custom_input_processor, + std::unique_ptr feature_processor_state, + base::flat_map result) { + // Validate the total number of bind values needed. + size_t total_bind_values = 0; + for (const auto& query : queries_) { + const proto::SqlFeature& feature = query.second; + total_bind_values += feature.bind_values_size(); + } + + if (total_bind_values != result.size()) { + RunErrorCallback(); + return; + } + + // Assemble the sql queries and the corresponding bind values. + for (const auto& query : queries_) { + const proto::SqlFeature& feature = query.second; + FeatureIndex sql_feature_index = query.first; + + for (int i = 0; i < feature.bind_values_size(); ++i) { + int bind_value_index = i; + + // Validate the result tensor. + if (result.count(std::make_pair(sql_feature_index, bind_value_index)) != + 1) { + RunErrorCallback(); + return; + } + + // Append query and query params to the list. + const auto& custom_input_tensors = + result[std::make_pair(sql_feature_index, bind_value_index)]; + CustomSqlQuery current; + current.query = feature.sql(); + current.bind_values.insert(current.bind_values.end(), + custom_input_tensors.begin(), + custom_input_tensors.end()); + processed_queries_[sql_feature_index] = std::move(current); + } + } + + // TODO(haileywang): Custom inputs have been processed and sql queries are + // ready to be sent to the ukm database. +} + +void SqlFeatureProcessor::RunErrorCallback() { + feature_processor_state_->SetError(); base::SequencedTaskRunnerHandle::Get()->PostTask( FROM_HERE, - base::BindOnce(std::move(callback), std::move(feature_processor_state), + base::BindOnce(std::move(callback_), std::move(feature_processor_state_), std::move(result_))); } diff --git a/components/segmentation_platform/internal/execution/sql_feature_processor.h b/components/segmentation_platform/internal/execution/sql_feature_processor.h index e0b8a8cbd7887a..cae81e85f31ff7 100644 --- a/components/segmentation_platform/internal/execution/sql_feature_processor.h +++ b/components/segmentation_platform/internal/execution/sql_feature_processor.h @@ -5,15 +5,17 @@ #ifndef COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_EXECUTION_SQL_FEATURE_PROCESSOR_H_ #define COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_EXECUTION_SQL_FEATURE_PROCESSOR_H_ -#include #include #include #include "base/callback_forward.h" +#include "base/containers/flat_map.h" +#include "components/segmentation_platform/internal/database/ukm_types.h" #include "components/segmentation_platform/internal/execution/query_processor.h" #include "components/segmentation_platform/internal/proto/model_metadata.pb.h" namespace segmentation_platform { +class CustomInputProcessor; class FeatureProcessorState; // SqlFeatureProcessor takes a list of SqlFeature type of input, fetches samples @@ -21,8 +23,9 @@ class FeatureProcessorState; // ML model. class SqlFeatureProcessor : public QueryProcessor { public: - explicit SqlFeatureProcessor( - std::map queries); + using QueryList = base::flat_map; + + explicit SqlFeatureProcessor(QueryList&& queries, base::Time prediction_time); ~SqlFeatureProcessor() override; // QueryProcessor implementation. @@ -30,11 +33,45 @@ class SqlFeatureProcessor : public QueryProcessor { QueryProcessorCallback callback) override; private: + using SqlFeatureAndBindValueIndices = + std::pair; + + // Struct responsible for storing a sql query and its bind values. + struct CustomSqlQuery; + + // Callback method for when all relevant bind values have been processed. + void OnCustomInputProcessed( + std::unique_ptr custom_input_processor, + std::unique_ptr feature_processor_state, + base::flat_map result); + + // Helper method for setting the error state and returning result to the + // feature processor. + void RunErrorCallback(); + // List of sql features to process into input tensors. - std::map queries_; + QueryList queries_; + + // Time at which we expect the model execution to run. + const base::Time prediction_time_; + + // Temporary storage of the processing state object. + std::unique_ptr feature_processor_state_; + + // Callback for sending the resulting indexed tensors to the feature list + // processor. + QueryProcessorCallback callback_; + + bool is_processed_{false}; + + // List of sql queries and bind values ready to be sent to the ukm database + // for processing. + base::flat_map processed_queries_; // List of resulting input tensors. IndexedTensors result_; + + base::WeakPtrFactory weak_ptr_factory_{this}; }; } // namespace segmentation_platform diff --git a/tools/metrics/histograms/enums.xml b/tools/metrics/histograms/enums.xml index 8d23e7727c67f7..655a71a18e23a3 100644 --- a/tools/metrics/histograms/enums.xml +++ b/tools/metrics/histograms/enums.xml @@ -81409,6 +81409,9 @@ https://www.dmtf.org/sites/default/files/standards/documents/DSP0134_2.7.1.pdf + + +