Skip to content

Commit

Permalink
[Omnibox][ML] Add batch scoring for autocomplete urls.
Browse files Browse the repository at this point in the history
- Guarded by a feature param `ml_batch_url_scoring`.

Bug: b/282173802
Change-Id: I5cab1fa2df295514486517a754b0c5e519a7e5f3
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/4548341
Reviewed-by: Moe Ahmadi <mahmadi@chromium.org>
Commit-Queue: Jun Zou <junzou@chromium.org>
Reviewed-by: manuk hovanesian <manukh@chromium.org>
Cr-Commit-Position: refs/heads/main@{#1146808}
  • Loading branch information
Jun Zou authored and Chromium LUCI CQ committed May 20, 2023
1 parent 3de3e3e commit 5c84cf0
Show file tree
Hide file tree
Showing 7 changed files with 114 additions and 57 deletions.
5 changes: 5 additions & 0 deletions chrome/browser/about_flags.cc
Expand Up @@ -1527,6 +1527,9 @@ const FeatureEntry::FeatureParam kOmniboxMlUrlScoringPreserveDefault[] = {
{"MlUrlScoringRerankFinalMatchesOnly", "true"},
{"MlUrlScoringPreserveDefault", "true"},
};
const FeatureEntry::FeatureParam kOmniboxMlBatchUrlScoring[] = {
{"MlBatchUrlScoring", "true"},
};

const FeatureEntry::FeatureVariation kOmniboxMlUrlScoringVariations[] = {
{"Run the model but do not rescore or rerank the matches (counterfactual)",
Expand All @@ -1540,6 +1543,8 @@ const FeatureEntry::FeatureVariation kOmniboxMlUrlScoringVariations[] = {
"match",
kOmniboxMlUrlScoringPreserveDefault,
std::size(kOmniboxMlUrlScoringPreserveDefault), nullptr},
{"Run the model on a batch of matches", kOmniboxMlBatchUrlScoring,
std::size(kOmniboxMlBatchUrlScoring), nullptr},
};
const FeatureEntry::FeatureParam kRealboxTwoPreviousSearchRelatedSuggestions[] =
{
Expand Down
61 changes: 49 additions & 12 deletions components/omnibox/browser/autocomplete_controller.cc
Expand Up @@ -1045,11 +1045,19 @@ void AutocompleteController::UpdateResult(
#if BUILDFLAG(BUILD_WITH_TFLITE_LIB)
// Use a WeakPtr since the model is not owned and `this` may no longer be
// alive. `SortCullAndAnnotateResult()` is called when the model is done.
RunUrlScoringModel(base::BindOnce(
&AutocompleteController::SortCullAndAnnotateResult,
weak_ptr_factory_.GetWeakPtr(), last_default_match,
last_default_associated_keyword, force_notify_default_match_changed,
default_match_to_preserve));
if (OmniboxFieldTrial::IsMlBatchUrlScoringEnabled()) {
RunBatchUrlScoringModel(base::BindOnce(
&AutocompleteController::SortCullAndAnnotateResult,
weak_ptr_factory_.GetWeakPtr(), last_default_match,
last_default_associated_keyword, force_notify_default_match_changed,
default_match_to_preserve));
} else {
RunUrlScoringModel(base::BindOnce(
&AutocompleteController::SortCullAndAnnotateResult,
weak_ptr_factory_.GetWeakPtr(), last_default_match,
last_default_associated_keyword, force_notify_default_match_changed,
default_match_to_preserve));
}
return;
#endif // BUILDFLAG(BUILD_WITH_TFLITE_LIB)
}
Expand Down Expand Up @@ -1575,16 +1583,38 @@ void AutocompleteController::RunUrlScoringModel(
match.stripped_destination_url.spec(), barrier_callback);
}
}
#endif // BUILDFLAG(BUILD_WITH_TFLITE_LIB)

void AutocompleteController::CancelUrlScoringModel() {
// Try to cancel any pending requests to the scoring model and invalidate the
// WeakPtr to prevent its callbacks from being called.
scoring_model_task_tracker_.TryCancelAll();
weak_ptr_factory_.InvalidateWeakPtrs();
void AutocompleteController::RunBatchUrlScoringModel(
base::OnceClosure completion_callback) {
TRACE_EVENT0("omnibox", "AutocompleteController::RunBatchUrlScoringModel");

std::vector<const metrics::OmniboxEventProto::Suggestion::ScoringSignals*>
batch_scoring_signals;
std::vector<std::string> stripped_destination_urls;
// Run the model for the eligible matches.
for (auto& match : result_) {
if (!match.scoring_signals.has_value()) {
continue;
}
batch_scoring_signals.push_back(&match.scoring_signals.value());
stripped_destination_urls.push_back(match.stripped_destination_url.spec());
}

// If no eligible matches to score, call `completion_callback` immediately.
if (batch_scoring_signals.empty()) {
std::move(completion_callback).Run();
return;
}

provider_client_->GetAutocompleteScoringModelService()
->BatchScoreAutocompleteUrlMatches(
&scoring_model_task_tracker_, batch_scoring_signals,
stripped_destination_urls,
base::BindOnce(&AutocompleteController::OnUrlScoringModelDone,
weak_ptr_factory_.GetWeakPtr(), base::ElapsedTimer(),
std::move(completion_callback)));
}

#if BUILDFLAG(BUILD_WITH_TFLITE_LIB)
void AutocompleteController::OnUrlScoringModelDone(
const base::ElapsedTimer elapsed_timer,
base::OnceClosure completion_callback,
Expand Down Expand Up @@ -1669,3 +1699,10 @@ void AutocompleteController::OnUrlScoringModelDone(
std::move(completion_callback).Run();
}
#endif // BUILDFLAG(BUILD_WITH_TFLITE_LIB)

void AutocompleteController::CancelUrlScoringModel() {
// Try to cancel any pending requests to the scoring model and invalidate the
// WeakPtr to prevent its callbacks from being called.
scoring_model_task_tracker_.TryCancelAll();
weak_ptr_factory_.InvalidateWeakPtrs();
}
13 changes: 8 additions & 5 deletions components/omnibox/browser/autocomplete_controller.h
Expand Up @@ -374,13 +374,12 @@ class AutocompleteController : public AutocompleteProviderListener,
// `OnUrlScoringModelDone()` callback which is called once the model is done
// for all the eligible matches, whether successfully or not.
void RunUrlScoringModel(base::OnceClosure completion_callback);
#endif // BUILDFLAG(BUILD_WITH_TFLITE_LIB)

// Tries to cancel any pending requests to the scoring model and prevents
// `OnUrlScoringModelDone()` and its completion callback from being called.
void CancelUrlScoringModel();
// Runs the async batch scoring for all the eligible matches in
// `results_.matches_`. Passes `completion_callback` to
// `OnBatchUrlScoringModelDone()` callback upon finishing scoring.
void RunBatchUrlScoringModel(base::OnceClosure completion_callback);

#if BUILDFLAG(BUILD_WITH_TFLITE_LIB)
// Called when the async scoring model is done running for all the eligible
// matches in `results_.matches_`. Redistributes the existing relevance scores
// to the matches based on the model output (i.e. highest relevance now
Expand All @@ -392,6 +391,10 @@ class AutocompleteController : public AutocompleteProviderListener,
std::vector<AutocompleteScoringModelService::Result> results);
#endif // BUILDFLAG(BUILD_WITH_TFLITE_LIB)

// Tries to cancel any pending requests to the scoring model and prevents
// `OnUrlScoringModelDone()` and its completion callback from being called.
void CancelUrlScoringModel();

base::ObserverList<Observer> observers_;

// The client passed to the providers.
Expand Down
45 changes: 23 additions & 22 deletions components/omnibox/browser/autocomplete_scoring_model_service.cc
Expand Up @@ -72,33 +72,39 @@ bool AutocompleteScoringModelService::UrlScoringModelAvailable() {
void AutocompleteScoringModelService::BatchScoreAutocompleteUrlMatches(
base::CancelableTaskTracker* tracker,
const std::vector<const ScoringSignals*>& batch_scoring_signals,
const std::vector<size_t>& match_indexes,
const std::vector<GURL>& match_destination_urls,
const std::vector<std::string>& stripped_destination_urls,
BatchResultCallback batch_result_callback) {
TRACE_EVENT0(
"omnibox",
"AutocompleteScoringModelService::BatchScoreAutocompleteUrlMatches");

// Function for creating a result vector with null scores.
auto create_null_results = [&]() {
std::vector<Result> results;
for (size_t i = 0; i < batch_scoring_signals.size(); i++) {
results.emplace_back(absl::nullopt, stripped_destination_urls.at(i));
}
return results;
};

if (!UrlScoringModelAvailable()) {
std::move(batch_result_callback)
.Run(absl::nullopt, match_indexes, match_destination_urls);
std::move(batch_result_callback).Run(create_null_results());
return;
}

absl::optional<std::vector<std::vector<float>>> batch_input =
url_scoring_model_handler_->GetBatchModelInput(batch_scoring_signals);
if (!batch_input) {
std::move(batch_result_callback)
.Run(absl::nullopt, match_indexes, match_destination_urls);
std::move(batch_result_callback).Run(create_null_results());
return;
}

url_scoring_model_handler_->BatchExecuteModelWithInput(
tracker,
base::BindOnce(&AutocompleteScoringModelService::ProcessBatchModelOutput,
weak_ptr_factory_.GetWeakPtr(),
std::move(batch_result_callback), match_indexes,
match_destination_urls),
std::move(batch_result_callback),
stripped_destination_urls),
*batch_input);
}

Expand All @@ -124,23 +130,18 @@ void AutocompleteScoringModelService::ProcessModelOutput(

void AutocompleteScoringModelService::ProcessBatchModelOutput(
BatchResultCallback batch_result_callback,
const std::vector<size_t>& match_indexes,
const std::vector<GURL>& match_destination_urls,
const std::vector<
absl::optional<AutocompleteScoringModelExecutor::ModelOutput>>&
batch_model_output) {
const std::vector<std::string>& stripped_destination_urls,
const std::vector<absl::optional<ModelOutput>>& batch_model_output) {
TRACE_EVENT0("omnibox",
"AutocompleteScoringModelService::ProcessBatchModelOutput");

std::vector<absl::optional<float>> batch_output_scores;
for (const auto& output : batch_model_output) {
if (output) {
batch_output_scores.push_back(output.value()[0]);
} else {
batch_output_scores.push_back(absl::nullopt);
}
std::vector<Result> batch_results;
for (size_t i = 0; i < stripped_destination_urls.size(); i++) {
const auto& output = batch_model_output.at(i);
batch_results.emplace_back(
output ? absl::optional<float>(output->at(0)) : absl::nullopt,
stripped_destination_urls.at(i));
}

std::move(batch_result_callback)
.Run(batch_output_scores, match_indexes, match_destination_urls);
std::move(batch_result_callback).Run(std::move(batch_results));
}
32 changes: 14 additions & 18 deletions components/omnibox/browser/autocomplete_scoring_model_service.h
Expand Up @@ -24,10 +24,9 @@ class AutocompleteScoringModelService : public KeyedService {
public:
using Result = std::tuple<absl::optional<float>, std::string>;
using ResultCallback = base::OnceCallback<void(Result)>;
using BatchResultCallback = base::OnceCallback<void(
absl::optional<std::vector<absl::optional<float>>>,
std::vector<size_t>,
std::vector<GURL>)>;
using BatchResult = std::vector<Result>;
using BatchResultCallback = base::OnceCallback<void(BatchResult)>;
using ModelOutput = AutocompleteScoringModelExecutor::ModelOutput;
using ScoringSignals =
::metrics::OmniboxEventProto::Suggestion::ScoringSignals;

Expand All @@ -42,20 +41,22 @@ class AutocompleteScoringModelService : public KeyedService {
const AutocompleteScoringModelService&) = delete;

// Invokes the model to score the given `scoring_signals` and calls
// `result_callback` with an optional relevance score generated by the model.
// `result_callback` with a relevance score generated by the model.
// When the model is not available or the model input is null, calls
// `result_callback` with a null result.
void ScoreAutocompleteUrlMatch(base::CancelableTaskTracker* tracker,
const ScoringSignals& scoring_signals,
const std::string& stripped_destination_url,
ResultCallback result_callback);

// Invokes the model to scores a batch of URL candidates with their signals.
// Calls `batch_result_callback` with a batch of optional prediction scores
// from the model.
// from the model. When the model is not available or any model input is null,
// calls `result_callback` with a vector of null results.
void BatchScoreAutocompleteUrlMatches(
base::CancelableTaskTracker* tracker,
const std::vector<const ScoringSignals*>& batch_scoring_signals,
const std::vector<size_t>& match_indexes,
const std::vector<GURL>& match_destination_urls,
const std::vector<std::string>& stripped_destination_urls,
BatchResultCallback batch_result_callback);

// Returns whether the scoring model is loaded and the pointer to the
Expand All @@ -66,19 +67,14 @@ class AutocompleteScoringModelService : public KeyedService {
// Processes the model output and invokes the callback with the relevance
// score from the model output. Invokes the callback with nullopt if the model
// output is nullopt or an empty vector (which is unexpected).
void ProcessModelOutput(
ResultCallback result_callback,
const std::string& stripped_destination_url,
const absl::optional<AutocompleteScoringModelExecutor::ModelOutput>&
model_output);
void ProcessModelOutput(ResultCallback result_callback,
const std::string& stripped_destination_url,
const absl::optional<ModelOutput>& model_output);

void ProcessBatchModelOutput(
BatchResultCallback batch_result_callback,
const std::vector<size_t>& match_indexes,
const std::vector<GURL>& match_destination_urls,
const std::vector<
absl::optional<AutocompleteScoringModelExecutor::ModelOutput>>&
batch_model_output);
const std::vector<std::string>& stripped_destination_urls,
const std::vector<absl::optional<ModelOutput>>& batch_model_output);

scoped_refptr<base::SequencedTaskRunner> model_executor_task_runner_;

Expand Down
9 changes: 9 additions & 0 deletions components/omnibox/browser/omnibox_field_trial.cc
Expand Up @@ -1053,11 +1053,17 @@ const base::FeatureParam<bool> kMlUrlScoringPreserveDefault(
"MlUrlScoringPreserveDefault",
false);

// If true, the ML model scores a batch of urls.
const base::FeatureParam<bool> kMlBatchUrlScoring(&omnibox::kMlUrlScoring,
"MlBatchUrlScoring",
false);

MLConfig::MLConfig() {
log_url_scoring_signals =
base::FeatureList::IsEnabled(omnibox::kLogUrlScoringSignals);
enable_scoring_signals_annotators = kEnableScoringSignalsAnnotators.Get();
ml_url_scoring = base::FeatureList::IsEnabled(omnibox::kMlUrlScoring);
ml_batch_url_scoring = kMlBatchUrlScoring.Get();
ml_url_scoring_counterfactual = kMlUrlScoringCounterfactual.Get();
ml_url_scoring_increase_num_candidates =
kMlUrlScoringIncreaseNumCandidates.Get();
Expand Down Expand Up @@ -1095,6 +1101,9 @@ bool IsMlUrlScoringEnabled() {
return false;
#endif // BUILDFLAG(BUILD_WITH_TFLITE_LIB)
}
bool IsMlBatchUrlScoringEnabled() {
return IsMlUrlScoringEnabled() && GetMLConfig().ml_batch_url_scoring;
}
bool IsMlUrlScoringCounterfactual() {
return IsMlUrlScoringEnabled() && GetMLConfig().ml_url_scoring_counterfactual;
}
Expand Down
6 changes: 6 additions & 0 deletions components/omnibox/browser/omnibox_field_trial.h
Expand Up @@ -612,6 +612,9 @@ struct MLConfig {
// Equivalent to omnibox::kMlUrlScoring.
bool ml_url_scoring{false};

// If true, runs batch ML scoring of URL candidates.
bool ml_batch_url_scoring{false};

// If true, runs the ML scoring model but does not assign new relevance scores
// to the URL suggestions and does not rerank them.
// Equivalent to OmniboxFieldTrial::kMlUrlScoringCounterfactual.
Expand Down Expand Up @@ -671,6 +674,9 @@ bool AreScoringSignalsAnnotatorsEnabled();
// URL suggestions and reranks them.
bool IsMlUrlScoringEnabled();

// Whether batch ML url scoring is enabled.
bool IsMlBatchUrlScoringEnabled();

// If true, runs the ML scoring model but does not assign new relevance scores
// to URL suggestions.
bool IsMlUrlScoringCounterfactual();
Expand Down

0 comments on commit 5c84cf0

Please sign in to comment.