Skip to content

Commit

Permalink
Add opt guide model for speculation-rule preloading heuristics.
Browse files Browse the repository at this point in the history
In this CL a new `PreloadingModelHandler` class and other required
classes for implementing preloading heuristics ML model are introduced.

Bug: 1471245
Change-Id: I77e6aad2ba4611e43e7d80f4cc77b427abf3758d
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/4762904
Reviewed-by: Dominic Farolino <dom@chromium.org>
Commit-Queue: Iman Saboori <isaboori@google.com>
Reviewed-by: Ryan Sturm <ryansturm@chromium.org>
Reviewed-by: Sophie Chang <sophiechang@chromium.org>
Reviewed-by: Tommy Nyquist <nyquist@chromium.org>
Cr-Commit-Position: refs/heads/main@{#1182546}
  • Loading branch information
isaboori authored and Chromium LUCI CQ committed Aug 11, 2023
1 parent bb69be7 commit 7852d02
Show file tree
Hide file tree
Showing 18 changed files with 427 additions and 1 deletion.
8 changes: 8 additions & 0 deletions chrome/browser/BUILD.gn
Expand Up @@ -868,6 +868,14 @@ static_library("browser") {
"navigation_predictor/navigation_predictor_metrics_document_data.h",
"navigation_predictor/navigation_predictor_preconnect_client.cc",
"navigation_predictor/navigation_predictor_preconnect_client.h",
"navigation_predictor/preloading_model_executor.cc",
"navigation_predictor/preloading_model_executor.h",
"navigation_predictor/preloading_model_handler.cc",
"navigation_predictor/preloading_model_handler.h",
"navigation_predictor/preloading_model_keyed_service.cc",
"navigation_predictor/preloading_model_keyed_service.h",
"navigation_predictor/preloading_model_keyed_service_factory.cc",
"navigation_predictor/preloading_model_keyed_service_factory.h",
"navigation_predictor/search_engine_preconnector.cc",
"navigation_predictor/search_engine_preconnector.h",
"net/chrome_mojo_proxy_resolver_factory.cc",
Expand Down
10 changes: 10 additions & 0 deletions chrome/browser/navigation_predictor/navigation_predictor.cc
Expand Up @@ -15,6 +15,8 @@
#include "base/system/sys_info.h"
#include "chrome/browser/navigation_predictor/navigation_predictor_keyed_service.h"
#include "chrome/browser/navigation_predictor/navigation_predictor_keyed_service_factory.h"
#include "chrome/browser/navigation_predictor/preloading_model_keyed_service.h"
#include "chrome/browser/navigation_predictor/preloading_model_keyed_service_factory.h"
#include "chrome/browser/preloading/prefetch/no_state_prefetch/no_state_prefetch_manager_factory.h"
#include "chrome/browser/profiles/profile.h"
#include "components/no_state_prefetch/browser/no_state_prefetch_manager.h"
Expand Down Expand Up @@ -180,6 +182,14 @@ void NavigationPredictor::ReportNewAnchorElements(
kAnchorElementsParsedFromWebPage,
new_predictions);
}

PreloadingModelKeyedService* model_service =
PreloadingModelKeyedServiceFactory::GetForProfile(
Profile::FromBrowserContext(render_frame_host().GetBrowserContext()));
if (!model_service) {
return;
}
// TODO(isaboori): use the ML model to predict the next use click.
}

void NavigationPredictor::ReportAnchorElementClick(
Expand Down
29 changes: 29 additions & 0 deletions chrome/browser/navigation_predictor/preloading_model_executor.cc
@@ -0,0 +1,29 @@
// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chrome/browser/navigation_predictor/preloading_model_executor.h"

#include "third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/task_utils.h"

PreloadingModelExecutor::PreloadingModelExecutor() = default;
PreloadingModelExecutor::~PreloadingModelExecutor() = default;

bool PreloadingModelExecutor::Preprocess(
const std::vector<TfLiteTensor*>& input_tensors,
const std::vector<float>& input) {
return tflite::task::core::PopulateTensor<float>(input, input_tensors[0])
.ok();
}

absl::optional<float> PreloadingModelExecutor::Postprocess(
const std::vector<const TfLiteTensor*>& output_tensors) {
std::vector<float> output;
if (!tflite::task::core::PopulateVector<float>(output_tensors[0], &output)
.ok()) {
return absl::nullopt;
}

CHECK_EQ(1u, output.size());
return output[0];
}
32 changes: 32 additions & 0 deletions chrome/browser/navigation_predictor/preloading_model_executor.h
@@ -0,0 +1,32 @@
// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifndef CHROME_BROWSER_NAVIGATION_PREDICTOR_PRELOADING_MODEL_EXECUTOR_H_
#define CHROME_BROWSER_NAVIGATION_PREDICTOR_PRELOADING_MODEL_EXECUTOR_H_

#include "components/optimization_guide/core/base_model_executor.h"

// A model executor to run the history clusters module ranking model.
class PreloadingModelExecutor
: public optimization_guide::BaseModelExecutor<float,
const std::vector<float>&> {
public:
using ModelInput = const std::vector<float>&;
using ModelOutput = float;

PreloadingModelExecutor();
~PreloadingModelExecutor() override;

PreloadingModelExecutor(const PreloadingModelExecutor&) = delete;
PreloadingModelExecutor& operator=(const PreloadingModelExecutor&) = delete;

protected:
// optimization_guide::BaseModelExecutor:
bool Preprocess(const std::vector<TfLiteTensor*>& input_tensors,
ModelInput input) override;
absl::optional<ModelOutput> Postprocess(
const std::vector<const TfLiteTensor*>& output_tensors) override;
};

#endif // CHROME_BROWSER_NAVIGATION_PREDICTOR_PRELOADING_MODEL_EXECUTOR_H_
@@ -0,0 +1,89 @@
// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chrome/browser/navigation_predictor/preloading_model_executor.h"

#include "base/base_paths.h"
#include "base/path_service.h"
#include "base/task/sequenced_task_runner.h"
#include "base/test/scoped_feature_list.h"
#include "base/test/task_environment.h"
#include "testing/gtest/include/gtest/gtest.h"
#include "third_party/blink/public/common/features.h"

using ModelInput = PreloadingModelExecutor::ModelInput;
using ModelOutput = PreloadingModelExecutor::ModelOutput;

class PreloadingModelExecutorTest : public testing::Test {
public:
PreloadingModelExecutorTest() {
scoped_feature_list_.InitAndEnableFeature(
blink::features::kPreloadingHeuristicsMLModel);
}
~PreloadingModelExecutorTest() override = default;

void SetUp() override {
base::FilePath source_root_dir;
base::PathService::Get(base::DIR_SOURCE_ROOT, &source_root_dir);
// A model of `add` operator.
model_file_path_ = source_root_dir.AppendASCII("chrome")
.AppendASCII("browser")
.AppendASCII("navigation_predictor")
.AppendASCII("test")
.AppendASCII("preloading_heuristics.tflite");
execution_task_runner_ = base::ThreadPool::CreateSequencedTaskRunner(
{base::MayBlock(), base::TaskPriority::BEST_EFFORT});
model_executor_ = std::make_unique<PreloadingModelExecutor>();
model_executor_->InitializeAndMoveToExecutionThread(
/*model_inference_timeout=*/absl::nullopt,
optimization_guide::proto::OPTIMIZATION_TARGET_OMNIBOX_URL_SCORING,
execution_task_runner_, base::SequencedTaskRunner::GetCurrentDefault());
}

void TearDown() override {
// Destroy model executor.
execution_task_runner_->DeleteSoon(FROM_HERE, std::move(model_executor_));
RunUntilIdle();
}

void RunUntilIdle() { task_environment_.RunUntilIdle(); }

protected:
base::test::ScopedFeatureList scoped_feature_list_;
base::test::TaskEnvironment task_environment_;
base::FilePath model_file_path_;
scoped_refptr<base::SequencedTaskRunner> execution_task_runner_;
std::unique_ptr<PreloadingModelExecutor> model_executor_;
};

TEST_F(PreloadingModelExecutorTest, ExecuteModel) {
// Update model file.
execution_task_runner_->PostTask(
FROM_HERE,
base::BindOnce(
&optimization_guide::ModelExecutor<ModelOutput,
ModelInput>::UpdateModelFile,
model_executor_->GetWeakPtrForExecutionThread(), model_file_path_));

// Execute model.
std::unique_ptr<base::RunLoop> run_loop = std::make_unique<base::RunLoop>();
base::OnceCallback<void(const absl::optional<ModelOutput>&)>
execution_callback = base::BindOnce(
[](base::RunLoop* run_loop,
const absl::optional<ModelOutput>& output) {
ASSERT_TRUE(output.has_value());
// TODO(isaboori): After the trained model is approved, use
// realistic inputs and check the output value.
run_loop->Quit();
},
run_loop.get());
base::TimeTicks now = base::TimeTicks::Now();
ModelInput input = std::vector<float>(/*count=*/17, /*value=*/0.0);
execution_task_runner_->PostTask(
FROM_HERE, base::BindOnce(&optimization_guide::ModelExecutor<
ModelOutput, ModelInput>::SendForExecution,
model_executor_->GetWeakPtrForExecutionThread(),
std::move(execution_callback), now, input));
run_loop->Run();
}
24 changes: 24 additions & 0 deletions chrome/browser/navigation_predictor/preloading_model_handler.cc
@@ -0,0 +1,24 @@
// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chrome/browser/navigation_predictor/preloading_model_handler.h"

#include "base/task/thread_pool.h"
#include "chrome/browser/navigation_predictor/preloading_model_executor.h"

PreloadingModelHandler::PreloadingModelHandler(
optimization_guide::OptimizationGuideModelProvider* model_provider)
: ModelHandler<float, const std::vector<float>&>(
model_provider,
base::ThreadPool::CreateSequencedTaskRunner(
{base::MayBlock(), base::TaskPriority::USER_VISIBLE}),
std::make_unique<PreloadingModelExecutor>(),
/*model_inference_timeout=*/absl::nullopt,
optimization_guide::proto::OptimizationTarget::
OPTIMIZATION_TARGET_PRELOADING_HEURISTICS,
/*model_metadata=*/absl::nullopt) {
SetShouldUnloadModelOnComplete(false);
}

PreloadingModelHandler::~PreloadingModelHandler() = default;
25 changes: 25 additions & 0 deletions chrome/browser/navigation_predictor/preloading_model_handler.h
@@ -0,0 +1,25 @@
// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifndef CHROME_BROWSER_NAVIGATION_PREDICTOR_PRELOADING_MODEL_HANDLER_H_
#define CHROME_BROWSER_NAVIGATION_PREDICTOR_PRELOADING_MODEL_HANDLER_H_

#include "components/optimization_guide/core/model_handler.h"

// Model handler used to retrieve and eventually execute the model.
class PreloadingModelHandler
: public optimization_guide::ModelHandler<float,
const std::vector<float>&> {
public:
explicit PreloadingModelHandler(
optimization_guide::OptimizationGuideModelProvider* model_provider);
~PreloadingModelHandler() override;
PreloadingModelHandler(const PreloadingModelHandler&) = delete;
PreloadingModelHandler& operator=(const PreloadingModelHandler&) = delete;

private:
base::WeakPtrFactory<PreloadingModelHandler> weak_ptr_factory_{this};
};

#endif // CHROME_BROWSER_NAVIGATION_PREDICTOR_PRELOADING_MODEL_HANDLER_H_
@@ -0,0 +1,20 @@
// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chrome/browser/navigation_predictor/preloading_model_keyed_service.h"

#include "chrome/browser/optimization_guide/optimization_guide_keyed_service.h"

PreloadingModelKeyedService::PreloadingModelKeyedService(
OptimizationGuideKeyedService* optimization_guide_keyed_service) {
auto* model_provider =
static_cast<optimization_guide::OptimizationGuideModelProvider*>(
optimization_guide_keyed_service);

if (model_provider) {
preloading_model_handler_ =
std::make_unique<PreloadingModelHandler>(model_provider);
}
}
PreloadingModelKeyedService::~PreloadingModelKeyedService() = default;
@@ -0,0 +1,29 @@
// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifndef CHROME_BROWSER_NAVIGATION_PREDICTOR_PRELOADING_MODEL_KEYED_SERVICE_H_
#define CHROME_BROWSER_NAVIGATION_PREDICTOR_PRELOADING_MODEL_KEYED_SERVICE_H_

#include "chrome/browser/navigation_predictor/preloading_model_handler.h"
#include "components/keyed_service/core/keyed_service.h"

class OptimizationGuideKeyedService;

class PreloadingModelKeyedService : public KeyedService {
public:
PreloadingModelKeyedService(const PreloadingModelKeyedService&) = delete;
explicit PreloadingModelKeyedService(
OptimizationGuideKeyedService* optimization_guide_keyed_service);
~PreloadingModelKeyedService() override;

PreloadingModelHandler* GetPreloadingModel() const {
return preloading_model_handler_.get();
}

private:
// preloading ML model
std::unique_ptr<PreloadingModelHandler> preloading_model_handler_;
};

#endif // CHROME_BROWSER_NAVIGATION_PREDICTOR_PRELOADING_MODEL_KEYED_SERVICE_H_
@@ -0,0 +1,58 @@
// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chrome/browser/navigation_predictor/preloading_model_keyed_service.h"

#include "chrome/browser/navigation_predictor/preloading_model_keyed_service_factory.h"
#include "chrome/browser/optimization_guide/optimization_guide_keyed_service_factory.h"
#include "chrome/browser/profiles/profile.h"
#include "chrome/browser/ui/browser.h"
#include "chrome/test/base/in_process_browser_test.h"
#include "content/public/test/browser_test.h"
#include "testing/gtest/include/gtest/gtest.h"
#include "third_party/blink/public/common/features.h"

namespace {

class PreloadingModelKeyedServiceTest
: public InProcessBrowserTest,
public testing::WithParamInterface<bool> {
public:
PreloadingModelKeyedServiceTest() {
bool is_enabled = GetParam();
if (is_enabled) {
scoped_feature_list_.InitAndEnableFeature(
blink::features::kPreloadingHeuristicsMLModel);
}
}
~PreloadingModelKeyedServiceTest() override = default;

content::WebContents* GetWebContents() {
return browser()->tab_strip_model()->GetActiveWebContents();
}

private:
base::test::ScopedFeatureList scoped_feature_list_;
};

IN_PROC_BROWSER_TEST_P(PreloadingModelKeyedServiceTest, FeatureFlagIsWorking) {
Profile* profile =
Profile::FromBrowserContext(GetWebContents()->GetBrowserContext());
ASSERT_TRUE(OptimizationGuideKeyedServiceFactory::GetForProfile(profile));

PreloadingModelKeyedService* model_service =
PreloadingModelKeyedServiceFactory::GetForProfile(profile);
bool is_enabled = GetParam();
if (is_enabled) {
EXPECT_TRUE(model_service);
} else {
EXPECT_FALSE(model_service);
}
}

INSTANTIATE_TEST_SUITE_P(ParametrizedTests,
PreloadingModelKeyedServiceTest,
testing::Values(true, false));

} // namespace

0 comments on commit 7852d02

Please sign in to comment.