Skip to content

Commit

Permalink
Use on-device model service in opt guide
Browse files Browse the repository at this point in the history
Add a controller to use the on-device service. Also adds commandline switch for overriding model path.

Bug: b/302402576
Change-Id: Ife417eea6f3090bcf514d0a99f4e7f5e2341c1b5
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/4950733
Reviewed-by: Sophie Chang <sophiechang@chromium.org>
Reviewed-by: Clark DuVall <cduvall@chromium.org>
Commit-Queue: Raj T <rajendrant@chromium.org>
Cr-Commit-Position: refs/heads/main@{#1215846}
  • Loading branch information
rajendrant authored and Chromium LUCI CQ committed Oct 27, 2023
1 parent 6878592 commit 82e89b6
Show file tree
Hide file tree
Showing 15 changed files with 395 additions and 1 deletion.
3 changes: 3 additions & 0 deletions chrome/browser/BUILD.gn
Original file line number Diff line number Diff line change
Expand Up @@ -985,6 +985,8 @@ static_library("browser") {
"optimization_guide/chrome_browser_main_extra_parts_optimization_guide.h",
"optimization_guide/chrome_hints_manager.cc",
"optimization_guide/chrome_hints_manager.h",
"optimization_guide/model_execution/chrome_on_device_model_service_controller.cc",
"optimization_guide/model_execution/chrome_on_device_model_service_controller.h",
"optimization_guide/model_validator_keyed_service.cc",
"optimization_guide/model_validator_keyed_service.h",
"optimization_guide/model_validator_keyed_service_factory.cc",
Expand Down Expand Up @@ -2545,6 +2547,7 @@ static_library("browser") {
"//services/network/public/mojom",
"//services/network/public/proto",
"//services/on_device_model/public/cpp",
"//services/on_device_model/public/mojom",
"//services/preferences/public/cpp",
"//services/preferences/public/cpp/tracked",
"//services/preferences/public/mojom",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chrome/browser/optimization_guide/model_execution/chrome_on_device_model_service_controller.h"

#include "content/public/browser/service_process_host.h"

namespace optimization_guide {

ChromeOnDeviceModelServiceController::ChromeOnDeviceModelServiceController() =
default;
ChromeOnDeviceModelServiceController::~ChromeOnDeviceModelServiceController() =
default;

void ChromeOnDeviceModelServiceController::LaunchService() {
content::ServiceProcessHost::Launch<
on_device_model::mojom::OnDeviceModelService>(
service_remote_.BindNewPipeAndPassReceiver(),
content::ServiceProcessHost::Options()
.WithDisplayName("On-Device Model Service")
.Pass());
}

} // namespace optimization_guide
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifndef CHROME_BROWSER_OPTIMIZATION_GUIDE_MODEL_EXECUTION_CHROME_ON_DEVICE_MODEL_SERVICE_CONTROLLER_H_
#define CHROME_BROWSER_OPTIMIZATION_GUIDE_MODEL_EXECUTION_CHROME_ON_DEVICE_MODEL_SERVICE_CONTROLLER_H_

#include "components/optimization_guide/core/model_execution/on_device_model_service_controller.h"

namespace optimization_guide {

class ChromeOnDeviceModelServiceController
: public OnDeviceModelServiceController {
public:
ChromeOnDeviceModelServiceController();
~ChromeOnDeviceModelServiceController() override;

private:
// OnDeviceModelServiceController implementation:
void LaunchService() override;
};

} // namespace optimization_guide

#endif // CHROME_BROWSER_OPTIMIZATION_GUIDE_MODEL_EXECUTION_CHROME_ON_DEVICE_MODEL_SERVICE_CONTROLLER_H_
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "chrome/browser/download/background_download_service_factory.h"
#include "chrome/browser/metrics/chrome_metrics_service_accessor.h"
#include "chrome/browser/optimization_guide/chrome_hints_manager.h"
#include "chrome/browser/optimization_guide/model_execution/chrome_on_device_model_service_controller.h"
#include "chrome/browser/optimization_guide/optimization_guide_keyed_service_factory.h"
#include "chrome/browser/profiles/profile.h"
#include "chrome/browser/profiles/profile_key.h"
Expand All @@ -30,6 +31,7 @@
#include "components/optimization_guide/core/command_line_top_host_provider.h"
#include "components/optimization_guide/core/hints_processing_util.h"
#include "components/optimization_guide/core/model_execution/model_execution_manager.h"
#include "components/optimization_guide/core/model_execution/on_device_model_service_controller.h"
#include "components/optimization_guide/core/model_util.h"
#include "components/optimization_guide/core/optimization_guide_constants.h"
#include "components/optimization_guide/core/optimization_guide_features.h"
Expand Down Expand Up @@ -268,6 +270,8 @@ void OptimizationGuideKeyedService::Initialize() {
model_execution_manager_ =
std::make_unique<optimization_guide::ModelExecutionManager>(
url_loader_factory, IdentityManagerFactory::GetForProfile(profile),
std::make_unique<
optimization_guide::ChromeOnDeviceModelServiceController>(),
optimization_guide_logger_.get());
}

Expand Down
6 changes: 6 additions & 0 deletions components/optimization_guide/core/BUILD.gn
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,10 @@ static_library("core") {
"model_execution/model_execution_fetcher.h",
"model_execution/model_execution_manager.cc",
"model_execution/model_execution_manager.h",
"model_execution/on_device_model_service_controller.cc",
"model_execution/on_device_model_service_controller.h",
"model_execution/on_device_model_stream_receiver.cc",
"model_execution/on_device_model_stream_receiver.h",
"model_execution/optimization_guide_model_execution_error.cc",
"model_execution/optimization_guide_model_execution_error.h",
"model_handler.h",
Expand Down Expand Up @@ -291,6 +295,7 @@ static_library("core") {
"//services/metrics/public/cpp:metrics_cpp",
"//services/metrics/public/cpp:ukm_builders",
"//services/network/public/cpp",
"//services/on_device_model/public/mojom",
"//url:url",
]

Expand Down Expand Up @@ -419,6 +424,7 @@ source_set("unit_tests") {
"hints_processing_util_unittest.cc",
"insertion_ordered_set_unittest.cc",
"model_execution/model_execution_fetcher_unittest.cc",
"model_execution/on_device_model_service_controller_unittest.cc",
"model_handler_unittest.cc",
"model_store_metadata_entry_unittest.cc",
"model_util_unittest.cc",
Expand Down
1 change: 1 addition & 0 deletions components/optimization_guide/core/DEPS
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ include_rules = [
"+crypto",
"+mojo/public/cpp",
"+services/metrics/public/cpp",
"+services/on_device_model/public",
"+third_party/mediapipe",
"+third_party/tensorflow_models/src",
"+third_party/zlib/google",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@

#include "base/command_line.h"
#include "components/optimization_guide/core/model_execution/model_execution_fetcher.h"
#include "components/optimization_guide/core/model_execution/on_device_model_service_controller.h"
#include "components/optimization_guide/core/model_execution/on_device_model_stream_receiver.h"
#include "components/optimization_guide/core/model_execution/optimization_guide_model_execution_error.h"
#include "components/optimization_guide/core/model_util.h"
#include "components/optimization_guide/core/optimization_guide_constants.h"
#include "components/optimization_guide/core/optimization_guide_logger.h"
#include "net/base/url_util.h"
Expand Down Expand Up @@ -35,6 +38,8 @@ using ModelExecutionError =
ModelExecutionManager::ModelExecutionManager(
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
signin::IdentityManager* identity_manager,
std::unique_ptr<OnDeviceModelServiceController>
on_device_model_service_controller,
OptimizationGuideLogger* optimization_guide_logger)
: optimization_guide_logger_(optimization_guide_logger),
model_execution_service_url_(net::AppendOrReplaceQueryParameter(
Expand All @@ -43,7 +48,18 @@ ModelExecutionManager::ModelExecutionManager(
features::GetOptimizationGuideServiceAPIKey())),
url_loader_factory_(url_loader_factory),
identity_manager_(identity_manager),
oauth_scopes_(features::GetOAuthScopesForModelExecution()) {}
oauth_scopes_(features::GetOAuthScopesForModelExecution()),
on_device_model_service_controller_(
std::move(on_device_model_service_controller)) {
auto model_path_override_switch =
switches::GetOnDeviceModelExecutionOverride();
if (model_path_override_switch) {
auto file_path = StringToFilePath(*model_path_override_switch);
if (file_path) {
on_device_model_path_ = *file_path;
}
}
}

ModelExecutionManager::~ModelExecutionManager() = default;

Expand All @@ -52,6 +68,7 @@ void ModelExecutionManager::ExecuteModel(
const google::protobuf::MessageLite& request_metadata,
OptimizationGuideModelExecutionResultCallback callback) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);

if (active_model_execution_fetchers_.find(feature) !=
active_model_execution_fetchers_.end()) {
std::move(callback).Run(base::unexpected(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include <map>

#include "base/files/file_path.h"
#include "base/memory/raw_ptr.h"
#include "base/memory/scoped_refptr.h"
#include "base/memory/weak_ptr.h"
Expand All @@ -28,12 +29,15 @@ class IdentityManager;
namespace optimization_guide {

class ModelExecutionFetcher;
class OnDeviceModelServiceController;

class ModelExecutionManager {
public:
ModelExecutionManager(
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
signin::IdentityManager* identity_manager,
std::unique_ptr<OnDeviceModelServiceController>
on_device_model_service_controller,
OptimizationGuideLogger* optimization_guide_logger);

~ModelExecutionManager();
Expand Down Expand Up @@ -73,6 +77,14 @@ class ModelExecutionManager {
// The set of OAuth scopes to use for requesting access token.
std::set<std::string> oauth_scopes_;

// Controller for the on-device service.
std::unique_ptr<OnDeviceModelServiceController>
on_device_model_service_controller_;

// The path for the on-device model. Can be empty when it was not populated
// yet. Can be overridden from command-line.
base::FilePath on_device_model_path_;

SEQUENCE_CHECKER(sequence_checker_);

// Used to get `weak_ptr_` to self.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/optimization_guide/core/model_execution/on_device_model_service_controller.h"
#include "services/on_device_model/public/cpp/model_assets.h"
#include "services/on_device_model/public/mojom/on_device_model.mojom.h"

namespace optimization_guide {

OnDeviceModelServiceController::OnDeviceModelServiceController() = default;
OnDeviceModelServiceController::~OnDeviceModelServiceController() = default;

void OnDeviceModelServiceController::Init(const base::FilePath& model_path) {
CHECK(model_path_.empty());
model_path_ = model_path;
}

void OnDeviceModelServiceController::Execute(
std::string_view input,
mojo::PendingRemote<on_device_model::mojom::StreamingResponder>
streaming_responder) {
if (model_remote_) {
model_remote_->Execute(std::string(input), std::move(streaming_responder));
return;
}
LaunchService();
service_remote_->LoadModel(
on_device_model::LoadModelAssets(model_path_),
base::BindOnce(&OnDeviceModelServiceController::OnLoadModelResult,
weak_ptr_factory_.GetWeakPtr(), input,
std::move(streaming_responder)));
}

void OnDeviceModelServiceController::OnLoadModelResult(
std::string_view input,
mojo::PendingRemote<on_device_model::mojom::StreamingResponder>
streaming_responder,
on_device_model::mojom::LoadModelResultPtr result) {
if (result->is_model()) {
model_remote_ = mojo::Remote<on_device_model::mojom::OnDeviceModel>(
std::move(result->get_model()));
Execute(input, std::move(streaming_responder));
}
}

} // namespace optimization_guide
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef COMPONENTS_OPTIMIZATION_GUIDE_CORE_MODEL_EXECUTION_ON_DEVICE_MODEL_SERVICE_CONTROLLER_H_
#define COMPONENTS_OPTIMIZATION_GUIDE_CORE_MODEL_EXECUTION_ON_DEVICE_MODEL_SERVICE_CONTROLLER_H_

#include <string_view>

#include "base/memory/weak_ptr.h"
#include "base/sequence_checker.h"
#include "mojo/public/cpp/bindings/remote.h"
#include "services/on_device_model/public/mojom/on_device_model.mojom.h"
#include "third_party/abseil-cpp/absl/types/optional.h"

namespace base {
class FilePath;
} // namespace base

namespace optimization_guide {

// Controls the lifetime of the on-device model service, loading and unloading
// of the models, and executing them via the service.
//
// TODO(b/302402576): Handle unloading the model, and stopping the service. The
// StreamingResponder should notify the controller upon completion to accomplish
// this. Also handle multiple requests gracefully and fail the subsequent
// requests, while handling the first one.
class OnDeviceModelServiceController {
public:
OnDeviceModelServiceController();
virtual ~OnDeviceModelServiceController();

// Initializes the on-device model controller with the parameters, to be ready
// to load models and execute.
void Init(const base::FilePath& model_path);

// Executes the model for `input` and the response will be sent to
// `streaming_responder`. This will load the model if needed, before
// execution.
void Execute(std::string_view input,
mojo::PendingRemote<on_device_model::mojom::StreamingResponder>
streaming_responder);

// Launches the on-device model-service.
virtual void LaunchService() = 0;

private:
friend class ChromeOnDeviceModelServiceController;
friend class OnDeviceModelServiceControllerTest;
friend class FakeOnDeviceModelServiceController;

// Invoked at the end of model load, to continue with model execution.
void OnLoadModelResult(
std::string_view input,
mojo::PendingRemote<on_device_model::mojom::StreamingResponder>
streaming_responder,
on_device_model::mojom::LoadModelResultPtr result);

base::FilePath model_path_;
mojo::Remote<on_device_model::mojom::OnDeviceModelService> service_remote_;
mojo::Remote<on_device_model::mojom::OnDeviceModel> model_remote_;

SEQUENCE_CHECKER(sequence_checker_);

// Used to get `weak_ptr_` to self.
base::WeakPtrFactory<OnDeviceModelServiceController> weak_ptr_factory_{this};
};

} // namespace optimization_guide

#endif // COMPONENTS_OPTIMIZATION_GUIDE_CORE_MODEL_EXECUTION_ON_DEVICE_MODEL_SERVICE_CONTROLLER_H_

0 comments on commit 82e89b6

Please sign in to comment.