Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add opt guide model for speculation-rule preloading heuristics.
In this CL a new `PreloadingModelHandler` class and other required classes for implementing preloading heuristics ML model are introduced. Bug: 1471245 Change-Id: I77e6aad2ba4611e43e7d80f4cc77b427abf3758d Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/4762904 Reviewed-by: Dominic Farolino <dom@chromium.org> Commit-Queue: Iman Saboori <isaboori@google.com> Reviewed-by: Ryan Sturm <ryansturm@chromium.org> Reviewed-by: Sophie Chang <sophiechang@chromium.org> Reviewed-by: Tommy Nyquist <nyquist@chromium.org> Cr-Commit-Position: refs/heads/main@{#1182546}
- Loading branch information
Showing
18 changed files
with
427 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
29 changes: 29 additions & 0 deletions
29
chrome/browser/navigation_predictor/preloading_model_executor.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
// Copyright 2023 The Chromium Authors | ||
// Use of this source code is governed by a BSD-style license that can be | ||
// found in the LICENSE file. | ||
|
||
#include "chrome/browser/navigation_predictor/preloading_model_executor.h" | ||
|
||
#include "third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/task_utils.h" | ||
|
||
PreloadingModelExecutor::PreloadingModelExecutor() = default; | ||
PreloadingModelExecutor::~PreloadingModelExecutor() = default; | ||
|
||
bool PreloadingModelExecutor::Preprocess( | ||
const std::vector<TfLiteTensor*>& input_tensors, | ||
const std::vector<float>& input) { | ||
return tflite::task::core::PopulateTensor<float>(input, input_tensors[0]) | ||
.ok(); | ||
} | ||
|
||
absl::optional<float> PreloadingModelExecutor::Postprocess( | ||
const std::vector<const TfLiteTensor*>& output_tensors) { | ||
std::vector<float> output; | ||
if (!tflite::task::core::PopulateVector<float>(output_tensors[0], &output) | ||
.ok()) { | ||
return absl::nullopt; | ||
} | ||
|
||
CHECK_EQ(1u, output.size()); | ||
return output[0]; | ||
} |
32 changes: 32 additions & 0 deletions
32
chrome/browser/navigation_predictor/preloading_model_executor.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
// Copyright 2023 The Chromium Authors | ||
// Use of this source code is governed by a BSD-style license that can be | ||
// found in the LICENSE file. | ||
|
||
#ifndef CHROME_BROWSER_NAVIGATION_PREDICTOR_PRELOADING_MODEL_EXECUTOR_H_ | ||
#define CHROME_BROWSER_NAVIGATION_PREDICTOR_PRELOADING_MODEL_EXECUTOR_H_ | ||
|
||
#include "components/optimization_guide/core/base_model_executor.h" | ||
|
||
// A model executor to run the history clusters module ranking model. | ||
class PreloadingModelExecutor | ||
: public optimization_guide::BaseModelExecutor<float, | ||
const std::vector<float>&> { | ||
public: | ||
using ModelInput = const std::vector<float>&; | ||
using ModelOutput = float; | ||
|
||
PreloadingModelExecutor(); | ||
~PreloadingModelExecutor() override; | ||
|
||
PreloadingModelExecutor(const PreloadingModelExecutor&) = delete; | ||
PreloadingModelExecutor& operator=(const PreloadingModelExecutor&) = delete; | ||
|
||
protected: | ||
// optimization_guide::BaseModelExecutor: | ||
bool Preprocess(const std::vector<TfLiteTensor*>& input_tensors, | ||
ModelInput input) override; | ||
absl::optional<ModelOutput> Postprocess( | ||
const std::vector<const TfLiteTensor*>& output_tensors) override; | ||
}; | ||
|
||
#endif // CHROME_BROWSER_NAVIGATION_PREDICTOR_PRELOADING_MODEL_EXECUTOR_H_ |
89 changes: 89 additions & 0 deletions
89
chrome/browser/navigation_predictor/preloading_model_executor_unittest.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
// Copyright 2023 The Chromium Authors | ||
// Use of this source code is governed by a BSD-style license that can be | ||
// found in the LICENSE file. | ||
|
||
#include "chrome/browser/navigation_predictor/preloading_model_executor.h" | ||
|
||
#include "base/base_paths.h" | ||
#include "base/path_service.h" | ||
#include "base/task/sequenced_task_runner.h" | ||
#include "base/test/scoped_feature_list.h" | ||
#include "base/test/task_environment.h" | ||
#include "testing/gtest/include/gtest/gtest.h" | ||
#include "third_party/blink/public/common/features.h" | ||
|
||
using ModelInput = PreloadingModelExecutor::ModelInput; | ||
using ModelOutput = PreloadingModelExecutor::ModelOutput; | ||
|
||
class PreloadingModelExecutorTest : public testing::Test { | ||
public: | ||
PreloadingModelExecutorTest() { | ||
scoped_feature_list_.InitAndEnableFeature( | ||
blink::features::kPreloadingHeuristicsMLModel); | ||
} | ||
~PreloadingModelExecutorTest() override = default; | ||
|
||
void SetUp() override { | ||
base::FilePath source_root_dir; | ||
base::PathService::Get(base::DIR_SOURCE_ROOT, &source_root_dir); | ||
// A model of `add` operator. | ||
model_file_path_ = source_root_dir.AppendASCII("chrome") | ||
.AppendASCII("browser") | ||
.AppendASCII("navigation_predictor") | ||
.AppendASCII("test") | ||
.AppendASCII("preloading_heuristics.tflite"); | ||
execution_task_runner_ = base::ThreadPool::CreateSequencedTaskRunner( | ||
{base::MayBlock(), base::TaskPriority::BEST_EFFORT}); | ||
model_executor_ = std::make_unique<PreloadingModelExecutor>(); | ||
model_executor_->InitializeAndMoveToExecutionThread( | ||
/*model_inference_timeout=*/absl::nullopt, | ||
optimization_guide::proto::OPTIMIZATION_TARGET_OMNIBOX_URL_SCORING, | ||
execution_task_runner_, base::SequencedTaskRunner::GetCurrentDefault()); | ||
} | ||
|
||
void TearDown() override { | ||
// Destroy model executor. | ||
execution_task_runner_->DeleteSoon(FROM_HERE, std::move(model_executor_)); | ||
RunUntilIdle(); | ||
} | ||
|
||
void RunUntilIdle() { task_environment_.RunUntilIdle(); } | ||
|
||
protected: | ||
base::test::ScopedFeatureList scoped_feature_list_; | ||
base::test::TaskEnvironment task_environment_; | ||
base::FilePath model_file_path_; | ||
scoped_refptr<base::SequencedTaskRunner> execution_task_runner_; | ||
std::unique_ptr<PreloadingModelExecutor> model_executor_; | ||
}; | ||
|
||
TEST_F(PreloadingModelExecutorTest, ExecuteModel) { | ||
// Update model file. | ||
execution_task_runner_->PostTask( | ||
FROM_HERE, | ||
base::BindOnce( | ||
&optimization_guide::ModelExecutor<ModelOutput, | ||
ModelInput>::UpdateModelFile, | ||
model_executor_->GetWeakPtrForExecutionThread(), model_file_path_)); | ||
|
||
// Execute model. | ||
std::unique_ptr<base::RunLoop> run_loop = std::make_unique<base::RunLoop>(); | ||
base::OnceCallback<void(const absl::optional<ModelOutput>&)> | ||
execution_callback = base::BindOnce( | ||
[](base::RunLoop* run_loop, | ||
const absl::optional<ModelOutput>& output) { | ||
ASSERT_TRUE(output.has_value()); | ||
// TODO(isaboori): After the trained model is approved, use | ||
// realistic inputs and check the output value. | ||
run_loop->Quit(); | ||
}, | ||
run_loop.get()); | ||
base::TimeTicks now = base::TimeTicks::Now(); | ||
ModelInput input = std::vector<float>(/*count=*/17, /*value=*/0.0); | ||
execution_task_runner_->PostTask( | ||
FROM_HERE, base::BindOnce(&optimization_guide::ModelExecutor< | ||
ModelOutput, ModelInput>::SendForExecution, | ||
model_executor_->GetWeakPtrForExecutionThread(), | ||
std::move(execution_callback), now, input)); | ||
run_loop->Run(); | ||
} |
24 changes: 24 additions & 0 deletions
24
chrome/browser/navigation_predictor/preloading_model_handler.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
// Copyright 2023 The Chromium Authors | ||
// Use of this source code is governed by a BSD-style license that can be | ||
// found in the LICENSE file. | ||
|
||
#include "chrome/browser/navigation_predictor/preloading_model_handler.h" | ||
|
||
#include "base/task/thread_pool.h" | ||
#include "chrome/browser/navigation_predictor/preloading_model_executor.h" | ||
|
||
PreloadingModelHandler::PreloadingModelHandler( | ||
optimization_guide::OptimizationGuideModelProvider* model_provider) | ||
: ModelHandler<float, const std::vector<float>&>( | ||
model_provider, | ||
base::ThreadPool::CreateSequencedTaskRunner( | ||
{base::MayBlock(), base::TaskPriority::USER_VISIBLE}), | ||
std::make_unique<PreloadingModelExecutor>(), | ||
/*model_inference_timeout=*/absl::nullopt, | ||
optimization_guide::proto::OptimizationTarget:: | ||
OPTIMIZATION_TARGET_PRELOADING_HEURISTICS, | ||
/*model_metadata=*/absl::nullopt) { | ||
SetShouldUnloadModelOnComplete(false); | ||
} | ||
|
||
PreloadingModelHandler::~PreloadingModelHandler() = default; |
25 changes: 25 additions & 0 deletions
25
chrome/browser/navigation_predictor/preloading_model_handler.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
// Copyright 2023 The Chromium Authors | ||
// Use of this source code is governed by a BSD-style license that can be | ||
// found in the LICENSE file. | ||
|
||
#ifndef CHROME_BROWSER_NAVIGATION_PREDICTOR_PRELOADING_MODEL_HANDLER_H_ | ||
#define CHROME_BROWSER_NAVIGATION_PREDICTOR_PRELOADING_MODEL_HANDLER_H_ | ||
|
||
#include "components/optimization_guide/core/model_handler.h" | ||
|
||
// Model handler used to retrieve and eventually execute the model. | ||
class PreloadingModelHandler | ||
: public optimization_guide::ModelHandler<float, | ||
const std::vector<float>&> { | ||
public: | ||
explicit PreloadingModelHandler( | ||
optimization_guide::OptimizationGuideModelProvider* model_provider); | ||
~PreloadingModelHandler() override; | ||
PreloadingModelHandler(const PreloadingModelHandler&) = delete; | ||
PreloadingModelHandler& operator=(const PreloadingModelHandler&) = delete; | ||
|
||
private: | ||
base::WeakPtrFactory<PreloadingModelHandler> weak_ptr_factory_{this}; | ||
}; | ||
|
||
#endif // CHROME_BROWSER_NAVIGATION_PREDICTOR_PRELOADING_MODEL_HANDLER_H_ |
20 changes: 20 additions & 0 deletions
20
chrome/browser/navigation_predictor/preloading_model_keyed_service.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
// Copyright 2023 The Chromium Authors | ||
// Use of this source code is governed by a BSD-style license that can be | ||
// found in the LICENSE file. | ||
|
||
#include "chrome/browser/navigation_predictor/preloading_model_keyed_service.h" | ||
|
||
#include "chrome/browser/optimization_guide/optimization_guide_keyed_service.h" | ||
|
||
PreloadingModelKeyedService::PreloadingModelKeyedService( | ||
OptimizationGuideKeyedService* optimization_guide_keyed_service) { | ||
auto* model_provider = | ||
static_cast<optimization_guide::OptimizationGuideModelProvider*>( | ||
optimization_guide_keyed_service); | ||
|
||
if (model_provider) { | ||
preloading_model_handler_ = | ||
std::make_unique<PreloadingModelHandler>(model_provider); | ||
} | ||
} | ||
PreloadingModelKeyedService::~PreloadingModelKeyedService() = default; |
29 changes: 29 additions & 0 deletions
29
chrome/browser/navigation_predictor/preloading_model_keyed_service.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
// Copyright 2023 The Chromium Authors | ||
// Use of this source code is governed by a BSD-style license that can be | ||
// found in the LICENSE file. | ||
|
||
#ifndef CHROME_BROWSER_NAVIGATION_PREDICTOR_PRELOADING_MODEL_KEYED_SERVICE_H_ | ||
#define CHROME_BROWSER_NAVIGATION_PREDICTOR_PRELOADING_MODEL_KEYED_SERVICE_H_ | ||
|
||
#include "chrome/browser/navigation_predictor/preloading_model_handler.h" | ||
#include "components/keyed_service/core/keyed_service.h" | ||
|
||
class OptimizationGuideKeyedService; | ||
|
||
class PreloadingModelKeyedService : public KeyedService { | ||
public: | ||
PreloadingModelKeyedService(const PreloadingModelKeyedService&) = delete; | ||
explicit PreloadingModelKeyedService( | ||
OptimizationGuideKeyedService* optimization_guide_keyed_service); | ||
~PreloadingModelKeyedService() override; | ||
|
||
PreloadingModelHandler* GetPreloadingModel() const { | ||
return preloading_model_handler_.get(); | ||
} | ||
|
||
private: | ||
// preloading ML model | ||
std::unique_ptr<PreloadingModelHandler> preloading_model_handler_; | ||
}; | ||
|
||
#endif // CHROME_BROWSER_NAVIGATION_PREDICTOR_PRELOADING_MODEL_KEYED_SERVICE_H_ |
58 changes: 58 additions & 0 deletions
58
chrome/browser/navigation_predictor/preloading_model_keyed_service_browsertest.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
// Copyright 2023 The Chromium Authors | ||
// Use of this source code is governed by a BSD-style license that can be | ||
// found in the LICENSE file. | ||
|
||
#include "chrome/browser/navigation_predictor/preloading_model_keyed_service.h" | ||
|
||
#include "chrome/browser/navigation_predictor/preloading_model_keyed_service_factory.h" | ||
#include "chrome/browser/optimization_guide/optimization_guide_keyed_service_factory.h" | ||
#include "chrome/browser/profiles/profile.h" | ||
#include "chrome/browser/ui/browser.h" | ||
#include "chrome/test/base/in_process_browser_test.h" | ||
#include "content/public/test/browser_test.h" | ||
#include "testing/gtest/include/gtest/gtest.h" | ||
#include "third_party/blink/public/common/features.h" | ||
|
||
namespace { | ||
|
||
class PreloadingModelKeyedServiceTest | ||
: public InProcessBrowserTest, | ||
public testing::WithParamInterface<bool> { | ||
public: | ||
PreloadingModelKeyedServiceTest() { | ||
bool is_enabled = GetParam(); | ||
if (is_enabled) { | ||
scoped_feature_list_.InitAndEnableFeature( | ||
blink::features::kPreloadingHeuristicsMLModel); | ||
} | ||
} | ||
~PreloadingModelKeyedServiceTest() override = default; | ||
|
||
content::WebContents* GetWebContents() { | ||
return browser()->tab_strip_model()->GetActiveWebContents(); | ||
} | ||
|
||
private: | ||
base::test::ScopedFeatureList scoped_feature_list_; | ||
}; | ||
|
||
IN_PROC_BROWSER_TEST_P(PreloadingModelKeyedServiceTest, FeatureFlagIsWorking) { | ||
Profile* profile = | ||
Profile::FromBrowserContext(GetWebContents()->GetBrowserContext()); | ||
ASSERT_TRUE(OptimizationGuideKeyedServiceFactory::GetForProfile(profile)); | ||
|
||
PreloadingModelKeyedService* model_service = | ||
PreloadingModelKeyedServiceFactory::GetForProfile(profile); | ||
bool is_enabled = GetParam(); | ||
if (is_enabled) { | ||
EXPECT_TRUE(model_service); | ||
} else { | ||
EXPECT_FALSE(model_service); | ||
} | ||
} | ||
|
||
INSTANTIATE_TEST_SUITE_P(ParametrizedTests, | ||
PreloadingModelKeyedServiceTest, | ||
testing::Values(true, false)); | ||
|
||
} // namespace |
Oops, something went wrong.