diff --git a/components/segmentation_platform/internal/database/metadata_utils.cc b/components/segmentation_platform/internal/database/metadata_utils.cc index 7cd82092900cee..191485161088ed 100644 --- a/components/segmentation_platform/internal/database/metadata_utils.cc +++ b/components/segmentation_platform/internal/database/metadata_utils.cc @@ -48,12 +48,10 @@ uint64_t GetExpectedTensorLength(const proto::UMAFeature& feature) { std::string FeatureToString(const proto::UMAFeature& feature) { std::string result; - if (feature.has_type()) { + if (feature.has_type()) result = "type:" + proto::SignalType_Name(feature.type()) + ", "; - } - if (feature.has_name()) { + if (feature.has_name()) result.append("name:" + feature.name() + ", "); - } if (feature.has_name_hash()) { result.append( base::StringPrintf("name_hash:0x%" PRIx64 ", ", feature.name_hash())); @@ -134,22 +132,37 @@ ValidationResult ValidateMetadataUmaFeature(const proto::UMAFeature& feature) { return ValidationResult::kValidationSuccess; } +ValidationResult ValidateMetadataSqlFeature(const proto::SqlFeature& feature) { + if (feature.sql().empty()) + return ValidationResult::kFeatureSqlQueryEmpty; + + for (int i = 0; i < feature.bind_values_size(); ++i) { + const auto& bind_value = feature.bind_values(i); + if (!bind_value.has_value() || + bind_value.param_type() == proto::SqlFeature::BindValue::UNKNOWN || + ValidateMetadataCustomInput(bind_value.value()) != + ValidationResult::kValidationSuccess) { + return ValidationResult::kFeatureBindValuesInvalid; + } + } + + return ValidationResult::kValidationSuccess; +} + ValidationResult ValidateMetadataCustomInput( const proto::CustomInput& custom_input) { if (custom_input.fill_policy() == proto::CustomInput::UNKNOWN_FILL_POLICY) { // If the current fill policy is not supported or not filled, we must use // the given default value list, therefore the default value list must // provide enough input values as specified by tensor length. - if (custom_input.tensor_length() > custom_input.default_value_size()) { + if (custom_input.tensor_length() > custom_input.default_value_size()) return ValidationResult::kCustomInputInvalid; - } } else if (custom_input.fill_policy() == proto::CustomInput::FILL_PREDICTION_TIME) { // Current time can only provide up to one input tensor value, so column // weight must not exceed 1. - if (custom_input.tensor_length() > 1) { + if (custom_input.tensor_length() > 1) return ValidationResult::kCustomInputInvalid; - } } return ValidationResult::kValidationSuccess; } @@ -189,6 +202,18 @@ ValidationResult ValidateMetadataAndFeatures( return ValidationResult::kValidationSuccess; } +ValidationResult ValidateIndexedTensors( + const QueryProcessor::IndexedTensors& tensor, + size_t expected_size) { + if (tensor.size() != expected_size) + return ValidationResult::kIndexedTensorsInvalid; + for (size_t i = 0; i < tensor.size(); ++i) { + if (tensor.count(i) != 1) + return ValidationResult::kIndexedTensorsInvalid; + } + return ValidationResult::kValidationSuccess; +} + ValidationResult ValidateSegmentInfoMetadataAndFeatures( const proto::SegmentInfo& segment_info) { auto segment_info_result = ValidateSegmentInfo(segment_info); @@ -350,9 +375,8 @@ std::vector GetAllUmaFeatures( // Add training/inference inputs. for (int i = 0; i < model_metadata.input_features_size(); ++i) { auto feature = model_metadata.input_features(i); - if (feature.has_uma_feature()) { + if (feature.has_uma_feature()) features.push_back(feature.uma_feature()); - } } // Add training/inference outputs. diff --git a/components/segmentation_platform/internal/database/metadata_utils.h b/components/segmentation_platform/internal/database/metadata_utils.h index 3fcb89f9fa56ce..3aa1a71b6b143c 100644 --- a/components/segmentation_platform/internal/database/metadata_utils.h +++ b/components/segmentation_platform/internal/database/metadata_utils.h @@ -8,6 +8,7 @@ #include "base/time/time.h" #include "components/optimization_guide/proto/models.pb.h" #include "components/segmentation_platform/internal/database/signal_key.h" +#include "components/segmentation_platform/internal/execution/query_processor.h" #include "components/segmentation_platform/internal/proto/model_metadata.pb.h" #include "components/segmentation_platform/internal/proto/model_prediction.pb.h" #include "components/segmentation_platform/internal/proto/types.pb.h" @@ -34,7 +35,10 @@ enum class ValidationResult { kVersionNotSupported = 10, kFeatureListInvalid = 11, kCustomInputInvalid = 12, - kMaxValue = kCustomInputInvalid, + kFeatureSqlQueryEmpty = 13, + kFeatureBindValuesInvalid = 14, + kIndexedTensorsInvalid = 15, + kMaxValue = kIndexedTensorsInvalid, }; // Whether the given SegmentInfo and its metadata is valid to be used for the @@ -46,10 +50,14 @@ ValidationResult ValidateSegmentInfo(const proto::SegmentInfo& segment_info); ValidationResult ValidateMetadata( const proto::SegmentationModelMetadata& model_metadata); -// Whether the given feature metadata is valid to be used for the current +// Whether the given UMA feature metadata is valid to be used for the current // segmentation platform. ValidationResult ValidateMetadataUmaFeature(const proto::UMAFeature& feature); +// Whether the given SQL feature metadata is valid to be used for the current +// segmentation platform. +ValidationResult ValidateMetadataSqlFeature(const proto::SqlFeature& feature); + // Whether the given custom input metadata is valid to be used for the current // segmentation platform. ValidationResult ValidateMetadataCustomInput( @@ -60,6 +68,12 @@ ValidationResult ValidateMetadataCustomInput( ValidationResult ValidateMetadataAndFeatures( const proto::SegmentationModelMetadata& model_metadata); +// Whether the given indexed tensor is valid to be used for the current +// segmentation platform. +ValidationResult ValidateIndexedTensors( + const QueryProcessor::IndexedTensors& tensor, + size_t expected_size); + // Whether the given SegmentInfo, metadata and feature metadata is valid to be // used for the current segmentation platform. ValidationResult ValidateSegmentInfoMetadataAndFeatures( diff --git a/components/segmentation_platform/internal/database/metadata_utils_unittest.cc b/components/segmentation_platform/internal/database/metadata_utils_unittest.cc index d419de8de92752..9dcd1424c9f60e 100644 --- a/components/segmentation_platform/internal/database/metadata_utils_unittest.cc +++ b/components/segmentation_platform/internal/database/metadata_utils_unittest.cc @@ -6,6 +6,8 @@ #include "base/metrics/metrics_hashes.h" #include "components/optimization_guide/proto/models.pb.h" +#include "components/segmentation_platform/internal/database/ukm_types.h" +#include "components/segmentation_platform/internal/execution/query_processor.h" #include "components/segmentation_platform/internal/proto/aggregation.pb.h" #include "components/segmentation_platform/internal/proto/model_metadata.pb.h" #include "testing/gtest/include/gtest/gtest.h" @@ -192,6 +194,32 @@ TEST_F(MetadataUtilsTest, MetadataUmaFeatureValidation) { } } +TEST_F(MetadataUtilsTest, MetadataSqlFeatureValidation) { + // Sql feature with no sql query string is invalid. + proto::SqlFeature sql_feature; + EXPECT_EQ(metadata_utils::ValidationResult::kFeatureSqlQueryEmpty, + metadata_utils::ValidateMetadataSqlFeature(sql_feature)); + + sql_feature.set_sql("sql query"); + EXPECT_EQ(metadata_utils::ValidationResult::kValidationSuccess, + metadata_utils::ValidateMetadataSqlFeature(sql_feature)); + + // Sql feature with a bind value with no value is invalid. + auto* bind_value = sql_feature.add_bind_values(); + bind_value->set_param_type(proto::SqlFeature::BindValue::BOOL); + EXPECT_EQ(metadata_utils::ValidationResult::kFeatureBindValuesInvalid, + metadata_utils::ValidateMetadataSqlFeature(sql_feature)); + + bind_value->mutable_value(); + EXPECT_EQ(metadata_utils::ValidationResult::kValidationSuccess, + metadata_utils::ValidateMetadataSqlFeature(sql_feature)); + + // Sql feature with a bind value of type unknown is invalid. + bind_value->set_param_type(proto::SqlFeature::BindValue::UNKNOWN); + EXPECT_EQ(metadata_utils::ValidationResult::kFeatureBindValuesInvalid, + metadata_utils::ValidateMetadataSqlFeature(sql_feature)); +} + TEST_F(MetadataUtilsTest, MetadataCustomInputValidation) { // Empty custom input has tensor length of 0 and result in a valid input // tensor of length 0. @@ -320,6 +348,33 @@ TEST_F(MetadataUtilsTest, ValidateMetadataAndInputFeatures) { metadata_utils::ValidateMetadataAndFeatures(metadata)); } +TEST_F(MetadataUtilsTest, MetadataIndexedTensorsValidation) { + // Empty indexed tensors are valid. + QueryProcessor::IndexedTensors tensor; + EXPECT_EQ( + metadata_utils::ValidationResult::kValidationSuccess, + metadata_utils::ValidateIndexedTensors(tensor, /* expected_size */ 0)); + + // Not continuously indexed tensors are invalid. + const std::vector value; + tensor[0] = value; + tensor[1] = value; + tensor[3] = value; + EXPECT_EQ(metadata_utils::ValidationResult::kIndexedTensorsInvalid, + metadata_utils::ValidateIndexedTensors( + tensor, /*expected_size*/ tensor.size())); + + tensor[2] = value; + EXPECT_EQ(metadata_utils::ValidationResult::kValidationSuccess, + metadata_utils::ValidateIndexedTensors( + tensor, /*expected_size*/ tensor.size())); + + // The tensor size should match the expected tensor size. + EXPECT_EQ(metadata_utils::ValidationResult::kIndexedTensorsInvalid, + metadata_utils::ValidateIndexedTensors( + tensor, /*expected_size*/ tensor.size() - 1)); +} + TEST_F(MetadataUtilsTest, ValidateSegementInfoMetadataAndFeatures) { proto::SegmentInfo segment_info; EXPECT_EQ( diff --git a/components/segmentation_platform/internal/execution/feature_list_query_processor.cc b/components/segmentation_platform/internal/execution/feature_list_query_processor.cc index a750c13196f404..cb9271ae6f662d 100644 --- a/components/segmentation_platform/internal/execution/feature_list_query_processor.cc +++ b/components/segmentation_platform/internal/execution/feature_list_query_processor.cc @@ -92,9 +92,10 @@ void FeatureListQueryProcessor::ProcessNextInputFeature( processor = std::make_unique( std::move(queries), feature_processor_state->prediction_time()); } else if (input_feature.has_sql_feature()) { - std::map queries = { + SqlFeatureProcessor::QueryList queries = { {kIndexNotUsed, input_feature.sql_feature()}}; - processor = std::make_unique(queries); + processor = std::make_unique( + std::move(queries), feature_processor_state->prediction_time()); } auto* processor_ptr = processor.get(); diff --git a/components/segmentation_platform/internal/execution/sql_feature_processor.cc b/components/segmentation_platform/internal/execution/sql_feature_processor.cc index 14c50f5b28fa94..75682ec768b8a5 100644 --- a/components/segmentation_platform/internal/execution/sql_feature_processor.cc +++ b/components/segmentation_platform/internal/execution/sql_feature_processor.cc @@ -3,26 +3,125 @@ // found in the LICENSE file. #include "components/segmentation_platform/internal/execution/sql_feature_processor.h" +#include +#include "base/containers/flat_map.h" #include "base/threading/sequenced_task_runner_handle.h" +#include "components/segmentation_platform/internal/database/metadata_utils.h" +#include "components/segmentation_platform/internal/execution/custom_input_processor.h" #include "components/segmentation_platform/internal/execution/feature_processor_state.h" +#include "components/segmentation_platform/internal/proto/model_metadata.pb.h" namespace segmentation_platform { -SqlFeatureProcessor::SqlFeatureProcessor( - std::map queries) - : queries_(std::move(queries)) {} +SqlFeatureProcessor::SqlFeatureProcessor(QueryList&& queries, + base::Time prediction_time) + : queries_(std::move(queries)), prediction_time_(prediction_time) {} SqlFeatureProcessor::~SqlFeatureProcessor() = default; void SqlFeatureProcessor::Process( std::unique_ptr feature_processor_state, QueryProcessorCallback callback) { - // TODO(haileywang): Implement usage of custom input processor for bind - // values. - queries_.clear(); + DCHECK(!is_processed_); + is_processed_ = true; + callback_ = std::move(callback); + feature_processor_state_ = std::move(feature_processor_state); + + // Prepare the sql queries for indexed custom inputs processing. + base::flat_map bind_values; + for (const auto& query : queries_) { + const proto::SqlFeature& feature = query.second; + FeatureIndex sql_feature_index = query.first; + + // Validate the proto::SqlFeature metadata. + if (metadata_utils::ValidateMetadataSqlFeature(feature) != + metadata_utils::ValidationResult::kValidationSuccess) { + RunErrorCallback(); + return; + } + + // Process bind values. + // TODO(haileywang): bind_field_index is not currently being used. + for (int i = 0; i < feature.bind_values_size(); ++i) { + // The index is a pair of int constructed from: + // 1. The index of the sql query, and + // 2. The index of the bind value within the sql query. + bind_values[std::make_pair(sql_feature_index, i)] = + feature.bind_values(i).value(); + } + } + + // Process the indexed custom inputs + auto custom_input_processor = + std::make_unique(prediction_time_); + auto* custom_input_processor_ptr = custom_input_processor.get(); + custom_input_processor_ptr->ProcessIndexType( + std::move(bind_values), std::move(feature_processor_state_), + base::BindOnce(&SqlFeatureProcessor::OnCustomInputProcessed, + weak_ptr_factory_.GetWeakPtr(), + std::move(custom_input_processor))); +} + +// TODO(haileywang): Move this structure to ukm_types. +struct SqlFeatureProcessor::CustomSqlQuery { + CustomSqlQuery() = default; + ~CustomSqlQuery() = default; + std::string query; + std::vector bind_values; +}; + +void SqlFeatureProcessor::OnCustomInputProcessed( + std::unique_ptr custom_input_processor, + std::unique_ptr feature_processor_state, + base::flat_map result) { + // Validate the total number of bind values needed. + size_t total_bind_values = 0; + for (const auto& query : queries_) { + const proto::SqlFeature& feature = query.second; + total_bind_values += feature.bind_values_size(); + } + + if (total_bind_values != result.size()) { + RunErrorCallback(); + return; + } + + // Assemble the sql queries and the corresponding bind values. + for (const auto& query : queries_) { + const proto::SqlFeature& feature = query.second; + FeatureIndex sql_feature_index = query.first; + + for (int i = 0; i < feature.bind_values_size(); ++i) { + int bind_value_index = i; + + // Validate the result tensor. + if (result.count(std::make_pair(sql_feature_index, bind_value_index)) != + 1) { + RunErrorCallback(); + return; + } + + // Append query and query params to the list. + const auto& custom_input_tensors = + result[std::make_pair(sql_feature_index, bind_value_index)]; + CustomSqlQuery current; + current.query = feature.sql(); + current.bind_values.insert(current.bind_values.end(), + custom_input_tensors.begin(), + custom_input_tensors.end()); + processed_queries_[sql_feature_index] = std::move(current); + } + } + + // TODO(haileywang): Custom inputs have been processed and sql queries are + // ready to be sent to the ukm database. +} + +void SqlFeatureProcessor::RunErrorCallback() { + feature_processor_state_->SetError(); base::SequencedTaskRunnerHandle::Get()->PostTask( FROM_HERE, - base::BindOnce(std::move(callback), std::move(feature_processor_state), + base::BindOnce(std::move(callback_), std::move(feature_processor_state_), std::move(result_))); } diff --git a/components/segmentation_platform/internal/execution/sql_feature_processor.h b/components/segmentation_platform/internal/execution/sql_feature_processor.h index e0b8a8cbd7887a..cae81e85f31ff7 100644 --- a/components/segmentation_platform/internal/execution/sql_feature_processor.h +++ b/components/segmentation_platform/internal/execution/sql_feature_processor.h @@ -5,15 +5,17 @@ #ifndef COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_EXECUTION_SQL_FEATURE_PROCESSOR_H_ #define COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_EXECUTION_SQL_FEATURE_PROCESSOR_H_ -#include #include #include #include "base/callback_forward.h" +#include "base/containers/flat_map.h" +#include "components/segmentation_platform/internal/database/ukm_types.h" #include "components/segmentation_platform/internal/execution/query_processor.h" #include "components/segmentation_platform/internal/proto/model_metadata.pb.h" namespace segmentation_platform { +class CustomInputProcessor; class FeatureProcessorState; // SqlFeatureProcessor takes a list of SqlFeature type of input, fetches samples @@ -21,8 +23,9 @@ class FeatureProcessorState; // ML model. class SqlFeatureProcessor : public QueryProcessor { public: - explicit SqlFeatureProcessor( - std::map queries); + using QueryList = base::flat_map; + + explicit SqlFeatureProcessor(QueryList&& queries, base::Time prediction_time); ~SqlFeatureProcessor() override; // QueryProcessor implementation. @@ -30,11 +33,45 @@ class SqlFeatureProcessor : public QueryProcessor { QueryProcessorCallback callback) override; private: + using SqlFeatureAndBindValueIndices = + std::pair; + + // Struct responsible for storing a sql query and its bind values. + struct CustomSqlQuery; + + // Callback method for when all relevant bind values have been processed. + void OnCustomInputProcessed( + std::unique_ptr custom_input_processor, + std::unique_ptr feature_processor_state, + base::flat_map result); + + // Helper method for setting the error state and returning result to the + // feature processor. + void RunErrorCallback(); + // List of sql features to process into input tensors. - std::map queries_; + QueryList queries_; + + // Time at which we expect the model execution to run. + const base::Time prediction_time_; + + // Temporary storage of the processing state object. + std::unique_ptr feature_processor_state_; + + // Callback for sending the resulting indexed tensors to the feature list + // processor. + QueryProcessorCallback callback_; + + bool is_processed_{false}; + + // List of sql queries and bind values ready to be sent to the ukm database + // for processing. + base::flat_map processed_queries_; // List of resulting input tensors. IndexedTensors result_; + + base::WeakPtrFactory weak_ptr_factory_{this}; }; } // namespace segmentation_platform diff --git a/tools/metrics/histograms/enums.xml b/tools/metrics/histograms/enums.xml index 8d23e7727c67f7..655a71a18e23a3 100644 --- a/tools/metrics/histograms/enums.xml +++ b/tools/metrics/histograms/enums.xml @@ -81409,6 +81409,9 @@ https://www.dmtf.org/sites/default/files/standards/documents/DSP0134_2.7.1.pdf + + +