From f73dd122b4494b9993a98961c120f36467ab2cf0 Mon Sep 17 00:00:00 2001 From: Du Li Date: Sun, 6 Jun 2021 21:16:10 -0700 Subject: [PATCH 1/3] Adding webgl shape kernel --- .../onnxjs/backends/webgl/op-resolve-rules.ts | 2 ++ js/web/lib/onnxjs/backends/webgl/ops/shape.ts | 14 ++++++++++ js/web/test/data/ops/shape.jsonc | 26 +++++++++++++++++++ 3 files changed, 42 insertions(+) create mode 100644 js/web/lib/onnxjs/backends/webgl/ops/shape.ts create mode 100644 js/web/test/data/ops/shape.jsonc diff --git a/js/web/lib/onnxjs/backends/webgl/op-resolve-rules.ts b/js/web/lib/onnxjs/backends/webgl/op-resolve-rules.ts index 6dd11c91b9022..78460768637d0 100644 --- a/js/web/lib/onnxjs/backends/webgl/op-resolve-rules.ts +++ b/js/web/lib/onnxjs/backends/webgl/op-resolve-rules.ts @@ -24,6 +24,7 @@ import {WebGLAveragePool, WebGLGlobalAveragePool, WebGLGlobalMaxPool, WebGLMaxPo import * as reduceOps from './ops/reduce'; import {WebGLReshape} from './ops/reshape'; import {WebGLResizePacked} from './ops/resize-packed'; +import {WebGLShape} from './ops/shape'; import {WebGLSlice, WebGLSliceV10} from './ops/slice'; import {WebGLSoftmax} from './ops/softmax'; import {WebGLSplit} from './ops/split'; @@ -89,6 +90,7 @@ export const WEBGL_OP_RESOLVE_RULES: readonly OpSet.ResolveRule[] = [ ['Reshape', '', '5+', () => new WebGLReshape()], ['Resize', '', '10', () => new WebGLResizePacked(10)], ['Resize', '', '11+', () => new WebGLResizePacked(11)], + ['Shape', '', '1+', () => new WebGLShape()], ['Sigmoid', '', '6+', () => new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslSigmoid())], ['Sin', '', '7+', () => new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslSin())], ['Slice', '', '10+', () => new WebGLSliceV10()], // TODO: support 'steps' for Slice-10 diff --git a/js/web/lib/onnxjs/backends/webgl/ops/shape.ts b/js/web/lib/onnxjs/backends/webgl/ops/shape.ts new file mode 100644 index 0000000000000..22c1d08254bc1 --- /dev/null +++ b/js/web/lib/onnxjs/backends/webgl/ops/shape.ts @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {Shape} from '../../../ops/shape'; +import {Tensor} from '../../../tensor'; +import {WebGLInferenceHandler} from '../inference-handler'; + + +export class WebGLShape extends Shape { + run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] { + const shape = inputs[0].dims.slice(0); + return [new Tensor([shape.length], 'int32', undefined, undefined, new Int32Array(shape))]; + } +} diff --git a/js/web/test/data/ops/shape.jsonc b/js/web/test/data/ops/shape.jsonc new file mode 100644 index 0000000000000..eb8026134035c --- /dev/null +++ b/js/web/test/data/ops/shape.jsonc @@ -0,0 +1,26 @@ +[ + { + "name": "Reshape with '0' and '-1' in the shape tensor input", + "operator": "Shape", + "attributes": [], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [1, 1, 1, 1], + "dims": [2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [2, 2], + "dims": [2], + "type": "int32" + } + ] + } + ] + } +] From 83818166e05305c56bede62c62aa71ded6cc70d0 Mon Sep 17 00:00:00 2001 From: Du Li Date: Mon, 7 Jun 2021 15:30:10 -0700 Subject: [PATCH 2/3] updating operators.md --- js/web/docs/operators.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/js/web/docs/operators.md b/js/web/docs/operators.md index 212937df995d4..1e9e6eef87b8b 100644 --- a/js/web/docs/operators.md +++ b/js/web/docs/operators.md @@ -139,7 +139,7 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat | [SequenceErase](https://github.com/onnx/onnx/blob/master/docs/Operators.md#SequenceErase) | | | [SequenceInsert](https://github.com/onnx/onnx/blob/master/docs/Operators.md#SequenceInsert) | | | [SequenceLength](https://github.com/onnx/onnx/blob/master/docs/Operators.md#SequenceLength) | | -| [Shape](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Shape) | | +| [Shape](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Shape) | [1-12](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Shape-1), [13+](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Shape-13) | | [Shrink](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Shrink) | | | [Sigmoid](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Sigmoid) | [6-12](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Sigmoid-6), [13+](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Sigmoid-13) | | [Sign](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Sign) | | From 45980552cbd84207547477f42d4f58c1ad8ac0ae Mon Sep 17 00:00:00 2001 From: Du Li Date: Mon, 7 Jun 2021 18:05:08 -0700 Subject: [PATCH 3/3] addring PR comments --- js/web/lib/onnxjs/backends/webgl/ops/shape.ts | 3 +-- js/web/test/data/ops/shape.jsonc | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/js/web/lib/onnxjs/backends/webgl/ops/shape.ts b/js/web/lib/onnxjs/backends/webgl/ops/shape.ts index 22c1d08254bc1..538cb2601602f 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/shape.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/shape.ts @@ -8,7 +8,6 @@ import {WebGLInferenceHandler} from '../inference-handler'; export class WebGLShape extends Shape { run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] { - const shape = inputs[0].dims.slice(0); - return [new Tensor([shape.length], 'int32', undefined, undefined, new Int32Array(shape))]; + return [new Tensor([inputs[0].dims.length], 'int32', undefined, undefined, new Int32Array(inputs[0].dims))]; } } diff --git a/js/web/test/data/ops/shape.jsonc b/js/web/test/data/ops/shape.jsonc index eb8026134035c..c7595b5de1876 100644 --- a/js/web/test/data/ops/shape.jsonc +++ b/js/web/test/data/ops/shape.jsonc @@ -1,6 +1,6 @@ [ { - "name": "Reshape with '0' and '-1' in the shape tensor input", + "name": "Shape op test", "operator": "Shape", "attributes": [], "cases": [