Skip to content

Commit

Permalink
WebNN: decouple context creation from graph building.
Browse files Browse the repository at this point in the history
To support creation of GPU backed WebNN resources for graph builds.

More specifically this change-list:
* Splits use of Mojo with MLContext, following MLGraphMojo.
* Makes createContext setup up the IPC connection upfront.
* Re-baselines WPT because context creation fails when unsupported.

Bug: 1273291
Change-Id: I662a7cee68ba5e32c1e60c40f164b0975e21b087
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/4814277
Commit-Queue: Bryan Bernhart <bryan.bernhart@intel.com>
Reviewed-by: ningxin hu <ningxin.hu@intel.com>
Reviewed-by: Jiewei Qian <qjw@chromium.org>
Cr-Commit-Position: refs/heads/main@{#1211760}
  • Loading branch information
bbernhar authored and Chromium LUCI CQ committed Oct 18, 2023
1 parent 5c9b864 commit 5685851
Show file tree
Hide file tree
Showing 34 changed files with 662 additions and 195 deletions.
8 changes: 6 additions & 2 deletions third_party/blink/renderer/modules/ml/BUILD.gn
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,14 @@ blink_modules_sources("ml") {

if (!is_chromeos) {
sources += [
# MLGraphMojo is platform independent which packages the model information
# into a neutral struct type, it will be built and executed with hardware
# Platform independent sources which packages the model information
# into a neutral struct types, it will be built and executed with hardware
# accelerated OS machine learning API in WebNN Service which run out of
# renderer process.
"webnn/ml_context_mojo.cc",
"webnn/ml_context_mojo.h",
"webnn/ml_error_mojo.cc",
"webnn/ml_error_mojo.h",
"webnn/ml_graph_mojo.cc",
"webnn/ml_graph_mojo.h",
"webnn/ml_graph_type_converter.cc",
Expand Down
54 changes: 49 additions & 5 deletions third_party/blink/renderer/modules/ml/ml.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,18 @@
#include "third_party/blink/renderer/modules/ml/ml.h"

#include "components/ml/mojom/web_platform_model.mojom-blink.h"
#include "components/ml/webnn/features.h"
#include "third_party/blink/renderer/bindings/core/v8/script_promise_resolver.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_context_options.h"
#include "third_party/blink/renderer/core/dom/dom_exception.h"
#include "third_party/blink/renderer/modules/ml/buildflags.h"
#include "third_party/blink/renderer/platform/bindings/exception_state.h"

#if !BUILDFLAG(IS_CHROMEOS)
#include "third_party/blink/public/common/features.h"
#include "third_party/blink/renderer/modules/ml/webnn/ml_context_mojo.h"
#endif

namespace blink {

namespace {
Expand Down Expand Up @@ -66,6 +73,26 @@ ScriptPromise ML::createContext(ScriptState* script_state,

auto promise = resolver->Promise();

// TODO(crbug.com/1273291): Support async context creation for all contexts.
#if BUILDFLAG(BUILD_WEBNN_WITH_XNNPACK) || BUILDFLAG(BUILD_WEBNN_ON_CROS)
if (option->devicePreference() == V8MLDevicePreference::Enum::kAuto ||
option->devicePreference() == V8MLDevicePreference::Enum::kCpu) {
auto* ml_context = MakeGarbageCollected<MLContext>(
option->devicePreference(), option->powerPreference(),
option->modelFormat(), option->numThreads(), this);
resolver->Resolve(ml_context);
return promise;
}
#endif

#if !BUILDFLAG(IS_CHROMEOS)
if (base::FeatureList::IsEnabled(
webnn::features::kEnableMachineLearningNeuralNetworkService)) {
MLContextMojo::ValidateAndCreateAsync(resolver, option, this);
return promise;
}
#endif

// Notice that currently, we just create the context in the renderer. In the
// future we may add backend query ability to check whether a context is
// supportable or not. At that time, this function will be truly asynced.
Expand All @@ -86,11 +113,28 @@ MLContext* ML::createContextSync(ScriptState* script_state,
return nullptr;
}

// TODO(crbug/1405354): Query browser about whether the given context is
// supported.
return MakeGarbageCollected<MLContext>(
options->devicePreference(), options->powerPreference(),
options->modelFormat(), options->numThreads(), this);
// TODO(crbug.com/1273291): support sync context creation for all contexts.
#if BUILDFLAG(BUILD_WEBNN_WITH_XNNPACK) || BUILDFLAG(BUILD_WEBNN_ON_CROS)
if (options->devicePreference() == V8MLDevicePreference::Enum::kAuto ||
options->devicePreference() == V8MLDevicePreference::Enum::kCpu) {
return MLContext::ValidateAndCreateSync(options, this);
}
#endif

#if !BUILDFLAG(IS_CHROMEOS)
// The runtime enable feature is used to disable the cross process hardware
// acceleration by default.
if (base::FeatureList::IsEnabled(
webnn::features::kEnableMachineLearningNeuralNetworkService) &&
(options->devicePreference() == V8MLDevicePreference::Enum::kAuto ||
options->devicePreference() == V8MLDevicePreference::Enum::kGpu)) {
return MLContextMojo::ValidateAndCreateSync(exception_state, options, this);
}
#endif

// TODO(crbug.com/1273291): throw exception once tests support all context
// types.
return MLContext::ValidateAndCreateSync(options, this);
}

void ML::EnsureModelLoaderServiceConnection(ScriptState* script_state) {
Expand Down
81 changes: 26 additions & 55 deletions third_party/blink/renderer/modules/ml/ml_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,21 @@
#include "third_party/blink/renderer/modules/ml/ml_context.h"

#include "third_party/blink/renderer/bindings/core/v8/script_promise_resolver.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_context_options.h"
#include "third_party/blink/renderer/core/dom/dom_exception.h"
#include "third_party/blink/renderer/modules/ml/ml.h"
#include "third_party/blink/renderer/modules/ml/ml_model_loader.h"
#include "third_party/blink/renderer/platform/bindings/exception_state.h"

namespace blink {

namespace {

namespace blink_mojom = webnn::mojom::blink;

template <typename MojoResultType>
mojo::StructPtr<MojoResultType> ToError(
const blink_mojom::Error::Code& error_code,
const WTF::String& error_message) {
return MojoResultType::NewError(
blink_mojom::Error::New(error_code, error_message));
// static
MLContext* MLContext::ValidateAndCreateSync(MLContextOptions* options, ML* ml) {
return MakeGarbageCollected<MLContext>(
options->devicePreference(), options->powerPreference(),
options->modelFormat(), options->numThreads(), ml);
}

} // namespace

MLContext::MLContext(const V8MLDevicePreference device_preference,
const V8MLPowerPreference power_preference,
const V8MLModelFormat model_format,
Expand All @@ -35,8 +29,7 @@ MLContext::MLContext(const V8MLDevicePreference device_preference,
power_preference_(power_preference),
model_format_(model_format),
num_threads_(num_threads),
ml_(ml),
webnn_context_(ml->GetExecutionContext()) {}
ml_(ml) {}

MLContext::~MLContext() = default;

Expand Down Expand Up @@ -82,7 +75,6 @@ MLModelLoader* MLContext::GetModelLoaderForWebNN(ScriptState* script_state) {
void MLContext::Trace(Visitor* visitor) const {
visitor->Trace(ml_);
visitor->Trace(ml_model_loader_);
visitor->Trace(webnn_context_);

ScriptWrappable::Trace(visitor);
}
Expand Down Expand Up @@ -126,50 +118,29 @@ void MLContext::computeSync(MLGraph* graph,
graph->ComputeSync(inputs, outputs, exception_state);
}

void MLContext::CreateWebNNGraph(
ScriptState* script_state,
blink_mojom::GraphInfoPtr graph_info,
blink_mojom::WebNNContext::CreateGraphCallback callback) {
if (!webnn_context_.is_bound()) {
// Needs to create `WebNNContext` interface first.
auto options = blink_mojom::CreateContextOptions::New();
// TODO(crbug.com/1273291): Set power preference in the context option.
ml_->CreateWebNNContext(
std::move(options),
WTF::BindOnce(&MLContext::OnCreateWebNNContext, WrapPersistent(this),
WrapPersistent(script_state), std::move(graph_info),
std::move(callback)));
} else {
// Directly use `WebNNContext` to create `WebNNGraph` message pipe.
webnn_context_->CreateGraph(std::move(graph_info),
WTF::BindOnce(std::move(callback)));
}
void MLContext::CreateAsync(ScriptPromiseResolver* resolver,
MLContextOptions* options) {
CreateAsyncImpl(resolver, options);
}

void MLContext::OnCreateWebNNContext(
ScriptState* script_state,
blink_mojom::GraphInfoPtr graph_info,
blink_mojom::WebNNContext::CreateGraphCallback callback,
blink_mojom::CreateContextResultPtr result) {
if (!script_state->ContextIsValid()) {
std::move(callback).Run(ToError<blink_mojom::CreateGraphResult>(
blink_mojom::Error::Code::kUnknownError, "Invalid script state."));
return;
}

if (result->is_error()) {
std::move(callback).Run(blink_mojom::CreateGraphResult::NewError(
std::move(result->get_error())));
return;
}
MLContext* MLContext::CreateSync(MLContextOptions* options,
ExceptionState& exception_state) {
return CreateSyncImpl(options, exception_state);
}

auto* execution_context = ExecutionContext::From(script_state);
webnn_context_.Bind(
std::move(result->get_context_remote()),
execution_context->GetTaskRunner(TaskType::kInternalDefault));
void MLContext::CreateAsyncImpl(ScriptPromiseResolver* resolver,
MLContextOptions* options) {
// TODO(crbug.com/1273291): Remove when async creation gets implemented for
// all context types.
NOTIMPLEMENTED();
}

webnn_context_->CreateGraph(std::move(graph_info),
WTF::BindOnce(std::move(callback)));
MLContext* MLContext::CreateSyncImpl(MLContextOptions* options,
ExceptionState& exception_state) {
// TODO(crbug.com/1273291): Remove when sync creation gets implemented for
// all context types.
NOTIMPLEMENTED();
return nullptr;
}

} // namespace blink
46 changes: 29 additions & 17 deletions third_party/blink/renderer/modules/ml/ml_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,21 @@
#include "third_party/blink/renderer/platform/bindings/script_wrappable.h"
#include "third_party/blink/renderer/platform/heap/member.h"
#include "third_party/blink/renderer/platform/heap/visitor.h"
#include "third_party/blink/renderer/platform/mojo/heap_mojo_remote.h"

namespace blink {

class ML;
class MLContextOptions;
class MLModelLoader;

class MODULES_EXPORT MLContext final : public ScriptWrappable {
class MODULES_EXPORT MLContext : public ScriptWrappable {
DEFINE_WRAPPERTYPEINFO();

public:
static MLContext* ValidateAndCreateSync(MLContextOptions* options, ML* ml);

// The constructor shouldn't be called directly. The callers should use
// CreateAsync() or CreateSync() method instead.
MLContext(const V8MLDevicePreference device_preference,
const V8MLPowerPreference power_preference,
const V8MLModelFormat model_format,
Expand Down Expand Up @@ -63,10 +67,29 @@ class MODULES_EXPORT MLContext final : public ScriptWrappable {
const MLNamedArrayBufferViews& outputs,
ExceptionState& exception_state);

void CreateWebNNGraph(
ScriptState* script_state,
webnn::mojom::blink::GraphInfoPtr graph_info,
webnn::mojom::blink::WebNNContext::CreateGraphCallback callback);
protected:
// Create and initialize a MLContext object. Resolve the promise with
// this concrete object if the underlying context gets created
// successfully.
void CreateAsync(ScriptPromiseResolver* resolver, MLContextOptions* options);

// An MLContext backend should implement this method to create and initialize
// a platform specific context asynchronously.
virtual void CreateAsyncImpl(ScriptPromiseResolver* resolver,
MLContextOptions* options);

// CreateSync() has the similar function as CreateAsync(). The difference is
// if there are no validation error, it calls CreateSyncImpl() implemented
// by a MLContext backend that initializes the context synchronously in the
// caller's thread. This method is called by ML to implement
// MLContext.createContextSync() method.
MLContext* CreateSync(MLContextOptions* options,
ExceptionState& exception_state);

// An MLContext backend should implement this method to initialize the
// platform context synchronously in the caller's thread.
virtual MLContext* CreateSyncImpl(MLContextOptions* options,
ExceptionState& exception_state);

private:
V8MLDevicePreference device_preference_;
Expand All @@ -77,17 +100,6 @@ class MODULES_EXPORT MLContext final : public ScriptWrappable {
Member<ML> ml_;
// WebNN uses this MLModelLoader to build a computational graph.
Member<MLModelLoader> ml_model_loader_;

// The callback of creating context called from WebNN server side.
void OnCreateWebNNContext(
ScriptState* script_state,
webnn::mojom::blink::GraphInfoPtr graph_info,
webnn::mojom::blink::WebNNContext::CreateGraphCallback callback,
webnn::mojom::blink::CreateContextResultPtr result);
// WebNN support multiple types of neural network inference hardware
// acceleration, the context of WebNN in server side is used to map different
// device and represent a state of graph execution processes.
HeapMojoRemote<webnn::mojom::blink::WebNNContext> webnn_context_;
};

} // namespace blink
Expand Down
13 changes: 9 additions & 4 deletions third_party/blink/renderer/modules/ml/ml_model_loader_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "third_party/blink/renderer/bindings/core/v8/script_promise_tester.h"
#include "third_party/blink/renderer/bindings/core/v8/v8_binding_for_testing.h"
#include "third_party/blink/renderer/bindings/core/v8/v8_dom_exception.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_context_options.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_data_type.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_model.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_power_preference.h"
Expand Down Expand Up @@ -158,10 +159,14 @@ class MLModelLoaderTest : public testing::Test {

MLModelLoader* CreateTestLoader(V8TestingScope& scope) {
ML* ml = MakeGarbageCollected<ML>(scope.GetExecutionContext());
MLContext* ml_context = MakeGarbageCollected<MLContext>(
V8MLDevicePreference(V8MLDevicePreference::Enum::kCpu),
V8MLPowerPreference(V8MLPowerPreference::Enum::kAuto),
V8MLModelFormat(V8MLModelFormat::Enum::kTflite), 1, ml);

MLContextOptions* options = MLContextOptions::Create();
options->setDevicePreference(V8MLDevicePreference::Enum::kCpu);
options->setPowerPreference(V8MLPowerPreference::Enum::kAuto);
options->setModelFormat(V8MLModelFormat::Enum::kTflite);

MLContext* ml_context = ml->createContextSync(
scope.GetScriptState(), options, scope.GetExceptionState());
return MLModelLoader::Create(scope.GetScriptState(), ml_context,
scope.GetExceptionState());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// found in the LICENSE file.

#include "testing/libfuzzer/proto/lpm_interface.h"
#include "third_party/blink/renderer/bindings/core/v8/v8_binding_for_core.h"
#include "third_party/blink/renderer/core/frame/local_dom_window.h"
#include "third_party/blink/renderer/core/frame/settings.h"
#include "third_party/blink/renderer/core/testing/dummy_page_holder.h"
Expand Down Expand Up @@ -70,10 +71,15 @@ DEFINE_PROTO_FUZZER(const webnn_proto::conv2d& conv2d) {
return page_holder.release();
}();

auto* builder = CreateMLGraphBuilder(
page_holder->GetFrame().DomWindow()->GetExecutionContext());
ScriptState* script_state =
ToScriptStateForMainWorld(&page_holder->GetFrame());

DummyExceptionStateForTesting exception_state;
auto* builder = CreateMLGraphBuilder(
page_holder->GetFrame().DomWindow()->GetExecutionContext(), script_state,
exception_state);
CHECK(builder);

auto* input =
BuildInput(builder, "input", Vector<uint32_t>(conv2d.input_dimensions()),
ToV8MLOperandType(conv2d.input_type()), exception_state);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,15 @@ DEFINE_PROTO_FUZZER(const webnn_proto::pool2d& pool2d) {
return page_holder.release();
}();

auto* builder = CreateMLGraphBuilder(
page_holder->GetFrame().DomWindow()->GetExecutionContext());
ScriptState* script_state =
ToScriptStateForMainWorld(&page_holder->GetFrame());

DummyExceptionStateForTesting exception_state;
auto* builder = CreateMLGraphBuilder(
page_holder->GetFrame().DomWindow()->GetExecutionContext(), script_state,
exception_state);
CHECK(builder);

auto* input =
BuildInput(builder, "input", Vector<uint32_t>(pool2d.input_dimensions()),
ToV8MLOperandType(pool2d.input_type()), exception_state);
Expand Down

0 comments on commit 5685851

Please sign in to comment.