Skip to content

Commit

Permalink
[WebNN EP] Support WebNN async API with Asyncify (#19145)
Browse files Browse the repository at this point in the history
  • Loading branch information
Honry committed Jan 24, 2024
1 parent c456f19 commit 7252c6e
Show file tree
Hide file tree
Showing 10 changed files with 30 additions and 49 deletions.
4 changes: 0 additions & 4 deletions js/web/lib/build-def.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,6 @@ interface BuildDefinitions {
/**
* defines whether to disable the whole WebNN backend in the build.
*/
readonly DISABLE_WEBNN: boolean;
/**
* defines whether to disable the whole WebAssembly backend in the build.
*/
readonly DISABLE_WASM: boolean;
/**
* defines whether to disable proxy feature in WebAssembly backend in the build.
Expand Down
4 changes: 1 addition & 3 deletions js/web/lib/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,10 @@ if (!BUILD_DEFS.DISABLE_WASM) {
require('./backend-wasm-training').wasmBackend;
if (!BUILD_DEFS.DISABLE_WEBGPU) {
registerBackend('webgpu', wasmBackend, 5);
registerBackend('webnn', wasmBackend, 5);
}
registerBackend('cpu', wasmBackend, 10);
registerBackend('wasm', wasmBackend, 10);
if (!BUILD_DEFS.DISABLE_WEBNN) {
registerBackend('webnn', wasmBackend, 9);
}
}

Object.defineProperty(env.versions, 'web', {value: version, enumerable: true});
2 changes: 1 addition & 1 deletion js/web/lib/wasm/binding/ort-wasm.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ export interface OrtWasmModule extends EmscriptenModule {

_OrtGetLastError(errorCodeOffset: number, errorMessageOffset: number): void;

_OrtCreateSession(dataOffset: number, dataLength: number, sessionOptionsHandle: number): number;
_OrtCreateSession(dataOffset: number, dataLength: number, sessionOptionsHandle: number): Promise<number>;
_OrtReleaseSession(sessionHandle: number): void;
_OrtGetInputOutputCount(sessionHandle: number, inputCountOffset: number, outputCountOffset: number): number;
_OrtGetInputName(sessionHandle: number, index: number): number;
Expand Down
4 changes: 2 additions & 2 deletions js/web/lib/wasm/wasm-core-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ export const initRuntime = async(env: Env): Promise<void> => {
* @param epName
*/
export const initEp = async(env: Env, epName: string): Promise<void> => {
if (!BUILD_DEFS.DISABLE_WEBGPU && epName === 'webgpu') {
if (!BUILD_DEFS.DISABLE_WEBGPU && (epName === 'webgpu' || epName === 'webnn')) {
// perform WebGPU availability check
if (typeof navigator === 'undefined' || !navigator.gpu) {
throw new Error('WebGPU is not supported in current environment');
Expand Down Expand Up @@ -228,7 +228,7 @@ export const createSession = async(
await Promise.all(loadingPromises);
}

sessionHandle = wasm._OrtCreateSession(modelDataOffset, modelDataLength, sessionOptionsHandle);
sessionHandle = await wasm._OrtCreateSession(modelDataOffset, modelDataLength, sessionOptionsHandle);
if (sessionHandle === 0) {
checkLastError('Can\'t create a session.');
}
Expand Down
7 changes: 1 addition & 6 deletions js/web/script/build.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ const SOURCE_ROOT_FOLDER = path.join(__dirname, '../..'); // <ORT_ROOT>/js/
const DEFAULT_DEFINE = {
'BUILD_DEFS.DISABLE_WEBGL': 'false',
'BUILD_DEFS.DISABLE_WEBGPU': 'false',
'BUILD_DEFS.DISABLE_WEBNN': 'false',
'BUILD_DEFS.DISABLE_WASM': 'false',
'BUILD_DEFS.DISABLE_WASM_PROXY': 'false',
'BUILD_DEFS.DISABLE_WASM_THREAD': 'false',
Expand Down Expand Up @@ -364,7 +363,6 @@ async function main() {
...DEFAULT_DEFINE,
'BUILD_DEFS.DISABLE_WEBGPU': 'true',
'BUILD_DEFS.DISABLE_WEBGL': 'true',
'BUILD_DEFS.DISABLE_WEBNN': 'true',
'BUILD_DEFS.DISABLE_WASM_PROXY': 'true',
'BUILD_DEFS.DISABLE_WASM_THREAD': 'true',
},
Expand Down Expand Up @@ -397,7 +395,7 @@ async function main() {
// ort.webgpu[.min].js
await addAllWebBuildTasks({
outputBundleName: 'ort.webgpu',
define: {...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_WEBGL': 'true', 'BUILD_DEFS.DISABLE_WEBNN': 'true'},
define: {...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_WEBGL': 'true'},
});
// ort.wasm[.min].js
await addAllWebBuildTasks({
Expand All @@ -411,7 +409,6 @@ async function main() {
...DEFAULT_DEFINE,
'BUILD_DEFS.DISABLE_WEBGPU': 'true',
'BUILD_DEFS.DISABLE_WASM': 'true',
'BUILD_DEFS.DISABLE_WEBNN': 'true',
},
});
// ort.wasm-core[.min].js
Expand All @@ -421,7 +418,6 @@ async function main() {
...DEFAULT_DEFINE,
'BUILD_DEFS.DISABLE_WEBGPU': 'true',
'BUILD_DEFS.DISABLE_WEBGL': 'true',
'BUILD_DEFS.DISABLE_WEBNN': 'true',
'BUILD_DEFS.DISABLE_WASM_PROXY': 'true',
'BUILD_DEFS.DISABLE_WASM_THREAD': 'true',
},
Expand All @@ -434,7 +430,6 @@ async function main() {
'BUILD_DEFS.DISABLE_TRAINING': 'false',
'BUILD_DEFS.DISABLE_WEBGPU': 'true',
'BUILD_DEFS.DISABLE_WEBGL': 'true',
'BUILD_DEFS.DISABLE_WEBNN': 'true',
},
});
}
Expand Down
4 changes: 0 additions & 4 deletions js/web/script/test-runner-cli-args.ts
Original file line number Diff line number Diff line change
Expand Up @@ -396,10 +396,6 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs

const globalEnvFlags = parseGlobalEnvFlags(args);

if (backend.includes('webnn') && !globalEnvFlags.wasm!.proxy) {
throw new Error('Backend webnn requires flag "wasm-enable-proxy" to be set to true.');
}

// Options:
// --log-verbose=<...>
// --log-info=<...>
Expand Down
35 changes: 14 additions & 21 deletions onnxruntime/core/providers/webnn/builders/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,22 +70,13 @@ Status Model::Predict(const InlinedHashMap<std::string, OnnxTensorData>& inputs,
"The input of graph has unsupported type, name: ",
name, " type: ", tensor.tensor_info.data_type);
}
#ifdef ENABLE_WEBASSEMBLY_THREADS
// Copy the inputs from Wasm SharedArrayBuffer to the pre-allocated ArrayBuffers.
// Copy the inputs from Wasm ArrayBuffer to the WebNN inputs ArrayBuffer.
// As Wasm ArrayBuffer is not detachable.
wnn_inputs_[name].call<void>("set", view);
#else
wnn_inputs_.set(name, view);
#endif
}

#ifdef ENABLE_WEBASSEMBLY_THREADS
// This vector uses for recording output buffers from WebNN graph compution when WebAssembly
// multi-threads is enabled, since WebNN API only accepts non-shared ArrayBufferView,
// https://www.w3.org/TR/webnn/#typedefdef-mlnamedarraybufferviews
// and at this time the 'view' defined by Emscripten is shared ArrayBufferView, the memory
// address is different from the non-shared one, additional memory copy is required here.
InlinedHashMap<std::string, emscripten::val> output_views;
#endif

for (const auto& output : outputs) {
const std::string& name = output.first;
const struct OnnxTensorData tensor = output.second;
Expand Down Expand Up @@ -131,21 +122,23 @@ Status Model::Predict(const InlinedHashMap<std::string, OnnxTensorData>& inputs,
name, " type: ", tensor.tensor_info.data_type);
}

#ifdef ENABLE_WEBASSEMBLY_THREADS
output_views.insert({name, view});
#else
wnn_outputs_.set(name, view);
#endif
}
wnn_context_.call<emscripten::val>("computeSync", wnn_graph_, wnn_inputs_, wnn_outputs_);
#ifdef ENABLE_WEBASSEMBLY_THREADS
// Copy the outputs from pre-allocated ArrayBuffers back to the Wasm SharedArrayBuffer.
emscripten::val results = wnn_context_.call<emscripten::val>(
"compute", wnn_graph_, wnn_inputs_, wnn_outputs_)
.await();

// Copy the outputs from pre-allocated ArrayBuffers back to the Wasm ArrayBuffer.
for (const auto& output : outputs) {
const std::string& name = output.first;
emscripten::val view = output_views.at(name);
view.call<void>("set", wnn_outputs_[name]);
view.call<void>("set", results["outputs"][name]);
}
#endif
// WebNN compute() method would return the input and output buffers via the promise
// resolution. Reuse the buffers to avoid additional allocation.
wnn_inputs_ = results["inputs"];
wnn_outputs_ = results["outputs"];

return Status::OK();
}

Expand Down
12 changes: 5 additions & 7 deletions onnxruntime/core/providers/webnn/builders/model_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,8 @@ Status ModelBuilder::Compile(std::unique_ptr<Model>& model) {
for (auto& name : output_names_) {
named_operands.set(name, wnn_operands_.at(name));
}
emscripten::val wnn_graph = wnn_builder_.call<emscripten::val>("buildSync", named_operands);

emscripten::val wnn_graph = wnn_builder_.call<emscripten::val>("build", named_operands).await();
if (!wnn_graph.as<bool>()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to build WebNN graph.");
}
Expand All @@ -395,13 +396,10 @@ Status ModelBuilder::Compile(std::unique_ptr<Model>& model) {
model->SetOutputs(std::move(output_names_));
model->SetScalarOutputs(std::move(scalar_outputs_));
model->SetInputOutputInfo(std::move(input_output_info_));
#ifdef ENABLE_WEBASSEMBLY_THREADS
// Pre-allocate the input and output tensors for the WebNN graph
// when WebAssembly multi-threads is enabled since WebNN API only
// accepts non-shared ArrayBufferView.
// https://www.w3.org/TR/webnn/#typedefdef-mlnamedarraybufferviews
// Wasm heap is not transferrable, we have to pre-allocate the MLNamedArrayBufferViews
// for inputs and outputs because they will be transferred after compute() done.
// https://webmachinelearning.github.io/webnn/#api-mlcontext-async-execution
model->AllocateInputOutputBuffers();
#endif
return Status::OK();
}

Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/providers/webnn/webnn_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_f
if (webnn_power_flags.compare("default") != 0) {
context_options.set("powerPreference", emscripten::val(webnn_power_flags));
}
wnn_context_ = ml.call<emscripten::val>("createContextSync", context_options);

wnn_context_ = ml.call<emscripten::val>("createContext", context_options).await();
if (!wnn_context_.as<bool>()) {
ORT_THROW("Failed to create WebNN context.");
}
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/wasm/js_internal_api.js
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,10 @@ Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, relea
};

// replace the original functions with asyncified versions
Module['_OrtCreateSession'] = jsepWrapAsync(
Module['_OrtCreateSession'],
() => Module['_OrtCreateSession'],
v => Module['_OrtCreateSession'] = v);
Module['_OrtRun'] = runAsync(jsepWrapAsync(
Module['_OrtRun'],
() => Module['_OrtRun'],
Expand Down

0 comments on commit 7252c6e

Please sign in to comment.