Skip to content

Commit

Permalink
[Omnibox][ML] Add signal normalization.
Browse files Browse the repository at this point in the history
(cherry picked from commit 1ac78e2)

Bug: 1457678,b/287323698
Change-Id: Ifbcb7301176a128be1886265a31bdb35dc90143c
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/4633630
Reviewed-by: Moe Ahmadi <mahmadi@chromium.org>
Reviewed-by: Sophie Chang <sophiechang@chromium.org>
Commit-Queue: Jun Zou <junzou@chromium.org>
Cr-Original-Commit-Position: refs/heads/main@{#1161315}
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/4646671
Bot-Commit: Rubber Stamper <rubber-stamper@appspot.gserviceaccount.com>
Cr-Commit-Position: refs/branch-heads/5845@{#110}
Cr-Branched-From: 5a5dff6-refs/heads/main@{#1160321}
  • Loading branch information
Jun Zou authored and Chromium LUCI CQ committed Jun 26, 2023
1 parent af8eb34 commit 63eb96b
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,13 @@ AutocompleteScoringModelHandler::ExtractInputFromScoringSignals(
: kDefaultMissingValue;
}

// Normalize signal if configured.
if (scoring_signal_spec.has_norm_upper_boundary()) {
float upper_boundary = scoring_signal_spec.norm_upper_boundary();
DCHECK_GT(upper_boundary, 0);
val = std::clamp(*val, -upper_boundary, upper_boundary) / upper_boundary;
}

model_input.push_back(*val);
}
DCHECK_EQ(static_cast<size_t>(model_input.size()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,11 @@ class AutocompleteScoringModelHandlerTest : public testing::Test {

TEST_F(AutocompleteScoringModelHandlerTest,
ExtractInputFromScoringSignalsTest) {
// Metadata with three scoring signal specifications.
// Metadata with scoring signal specifications.
AutocompleteScoringModelMetadata model_metadata;
*model_metadata.add_scoring_signal_specs() = CreateScoringSignalSpec(
optimization_guide::proto::SCORING_SIGNAL_TYPE_LENGTH_OF_URL);
model_metadata.mutable_scoring_signal_specs(0)->set_norm_upper_boundary(50);
// Signal with log2 transformation.
*model_metadata.add_scoring_signal_specs() = CreateScoringSignalSpec(
optimization_guide::proto::
Expand All @@ -160,20 +161,26 @@ TEST_F(AutocompleteScoringModelHandlerTest,
SCORING_SIGNAL_TYPE_ELAPSED_TIME_LAST_SHORTCUT_VISIT_SEC,
/*transformation=*/absl::nullopt,
/*min_val=*/0, /*max_val=*/absl::nullopt, /*missing_val=*/-2);
// Clamped by upper boundary.
*model_metadata.add_scoring_signal_specs() = CreateScoringSignalSpec(
optimization_guide::proto::SCORING_SIGNAL_TYPE_TYPED_COUNT);
model_metadata.mutable_scoring_signal_specs(4)->set_norm_upper_boundary(100);

// Scoring signals.
ScoringSignals scoring_signals;
scoring_signals.set_length_of_url(10);
scoring_signals.set_elapsed_time_last_visit_secs(32767);
scoring_signals.set_elapsed_time_last_shortcut_visit_sec(-200);
scoring_signals.set_typed_count(150);

const auto input_signals = model_handler_->ExtractInputFromScoringSignals(
scoring_signals, model_metadata);
ASSERT_EQ(input_signals.size(), 4u);
EXPECT_THAT(input_signals[0], 10);
ASSERT_EQ(input_signals.size(), 5u);
EXPECT_THAT(input_signals[0], 0.2); // Normalized signal.
EXPECT_THAT(input_signals[1], 15);
EXPECT_NEAR(input_signals[2], 0.3792, 0.0001);
EXPECT_THAT(input_signals[3], -2);
EXPECT_NEAR(input_signals[4], 1.0f, 0.0001); // Clamped and normalized.
}

TEST_F(AutocompleteScoringModelHandlerTest, GetBatchModelInputTest) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ message ScoringSignalSpec {
optional float min_value = 4;
// Maximum value of valid signals.
optional float max_value = 5;
// The upper boundary for normalization.
optional float norm_upper_boundary = 6;
}

// The message contains a set of params to run a specific autocomplete scoring
Expand Down

0 comments on commit 63eb96b

Please sign in to comment.