Skip to content

Commit

Permalink
[js/webgpu] Add HardSigmoid support (#19215)
Browse files Browse the repository at this point in the history
### Description
This op is required in mobilenetv3-small-100. With this PR,
mobilenetv3-small-100 model becomes less than 10 ms from over 100 ms on
ADL.
  • Loading branch information
qjia7 committed Jan 22, 2024
1 parent e283cdb commit 2e0a388
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 3 deletions.
1 change: 1 addition & 0 deletions js/web/docs/webgpu-operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ Do not modify directly.*
| GlobalMaxPool | ai.onnx(1+); com.ms.internal.nhwc(1+) | |
| Greater | ai.onnx(7-8,9-12,13+) | |
| GreaterOrEqual | ai.onnx(12-15,16+) | |
| HardSigmoid | ai.onnx(6+) | |
| If | ai.onnx(1-10,11-12,13-18,19+) | |
| InstanceNormalization | ai.onnx(6+); com.ms.internal.nhwc(6+) | |
| LayerNormalization | ai.onnx(17+) | |
Expand Down
1 change: 1 addition & 0 deletions js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
['GlobalMaxPool', [pool.globalMaxPool, pool.parseGlobalMaxPoolAttributes]],
['Greater', [binaryOps.greater]],
['GreaterOrEqual', [binaryOps.greaterOrEqual]],
['HardSigmoid', [unaryOps.hardSigmoid, unaryOps.parseHardSigmoidAttributes]],
['InstanceNormalization', [instanceNorm]],
['LayerNormalization', [layerNorm]],
['LeakyRelu', [unaryOps.leakyRelu, unaryOps.parseAlphaAttributes]],
Expand Down
20 changes: 20 additions & 0 deletions js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,26 @@ export const sigmoid = (context: ComputeContext): void => {
context.compute(createElementwiseProgramInfo(context.inputs[0], 'Sigmoid', a => `(1.0 / (1.0 + exp(-${a})))`));
};

export interface HardSigmoidAttributes extends AttributeWithCacheKey {
readonly alpha: number;
readonly beta: number;
}

export const parseHardSigmoidAttributes = (attributes: Record<string, unknown>): HardSigmoidAttributes =>
createAttributeWithCacheKey(attributes as {
alpha: number;
beta: number;
});

export const hardSigmoid = (context: ComputeContext, attributes: HardSigmoidAttributes): void => {
const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
context.compute(createElementwiseProgramInfo(
context.inputs[0], 'HardSigmoid',
a => `max(vec4<${dataType}>(0.0), min(vec4<${dataType}>(1.0), ${attributes.alpha} * ${a} + vec4<${dataType}>(${
attributes.beta})))`,
undefined, attributes.cacheKey));
};

export const sin = (context: ComputeContext): void => {
context.compute(createElementwiseProgramInfo(context.inputs[0], 'Sin', 'sin'));
};
Expand Down
6 changes: 3 additions & 3 deletions js/web/test/suite-test-list.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -597,9 +597,9 @@
// // "test_hardmax_example",
// // "test_hardmax_negative_axis",
// // "test_hardmax_one_hot",
// // "test_hardsigmoid_default",
// // "test_hardsigmoid_example",
// // "test_hardsigmoid",
"test_hardsigmoid_default",
"test_hardsigmoid_example",
"test_hardsigmoid",
// // "test_hardswish_expanded",
// // "test_hardswish",
"test_if",
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/js/js_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Erf);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, 12, Sigmoid);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Sigmoid);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, HardSigmoid);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, 12, Log);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Log);

Expand Down Expand Up @@ -392,6 +393,7 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
KERNEL_CREATE_INFO(13, Erf),
KERNEL_CREATE_INFO_VERSIONED(6, 12, Sigmoid),
KERNEL_CREATE_INFO(13, Sigmoid),
KERNEL_CREATE_INFO(6, HardSigmoid),
KERNEL_CREATE_INFO_VERSIONED(6, 12, Log),
KERNEL_CREATE_INFO(13, Log),

Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/core/providers/js/operators/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ JSEP_KERNEL_IMPL(Sigmoid, Sigmoid)
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Sigmoid, 6, 12, Sigmoid)
JSEP_ELEMENTWISE_KERNEL(Sigmoid, 13, Sigmoid)

JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_2_DEFAULT(HardSigmoid, HardSigmoid, alpha, 0.2, beta, 0.5)
JSEP_ELEMENTWISE_KERNEL(HardSigmoid, 6, HardSigmoid)

JSEP_KERNEL_IMPL(Log, Log)
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Log, 6, 12, Log)
JSEP_ELEMENTWISE_KERNEL(Log, 13, Log)
Expand Down

0 comments on commit 2e0a388

Please sign in to comment.