diff --git a/js/web/docs/operators.md b/js/web/docs/operators.md index 212937df995d4..1e9e6eef87b8b 100644 --- a/js/web/docs/operators.md +++ b/js/web/docs/operators.md @@ -139,7 +139,7 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat | [SequenceErase](https://github.com/onnx/onnx/blob/master/docs/Operators.md#SequenceErase) | | | [SequenceInsert](https://github.com/onnx/onnx/blob/master/docs/Operators.md#SequenceInsert) | | | [SequenceLength](https://github.com/onnx/onnx/blob/master/docs/Operators.md#SequenceLength) | | -| [Shape](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Shape) | | +| [Shape](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Shape) | [1-12](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Shape-1), [13+](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Shape-13) | | [Shrink](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Shrink) | | | [Sigmoid](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Sigmoid) | [6-12](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Sigmoid-6), [13+](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Sigmoid-13) | | [Sign](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Sign) | | diff --git a/js/web/lib/onnxjs/backends/webgl/op-resolve-rules.ts b/js/web/lib/onnxjs/backends/webgl/op-resolve-rules.ts index 6dd11c91b9022..78460768637d0 100644 --- a/js/web/lib/onnxjs/backends/webgl/op-resolve-rules.ts +++ b/js/web/lib/onnxjs/backends/webgl/op-resolve-rules.ts @@ -24,6 +24,7 @@ import {WebGLAveragePool, WebGLGlobalAveragePool, WebGLGlobalMaxPool, WebGLMaxPo import * as reduceOps from './ops/reduce'; import {WebGLReshape} from './ops/reshape'; import {WebGLResizePacked} from './ops/resize-packed'; +import {WebGLShape} from './ops/shape'; import {WebGLSlice, WebGLSliceV10} from './ops/slice'; import {WebGLSoftmax} from './ops/softmax'; import {WebGLSplit} from './ops/split'; @@ -89,6 +90,7 @@ export const WEBGL_OP_RESOLVE_RULES: readonly OpSet.ResolveRule[] = [ ['Reshape', '', '5+', () => new WebGLReshape()], ['Resize', '', '10', () => new WebGLResizePacked(10)], ['Resize', '', '11+', () => new WebGLResizePacked(11)], + ['Shape', '', '1+', () => new WebGLShape()], ['Sigmoid', '', '6+', () => new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslSigmoid())], ['Sin', '', '7+', () => new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslSin())], ['Slice', '', '10+', () => new WebGLSliceV10()], // TODO: support 'steps' for Slice-10 diff --git a/js/web/lib/onnxjs/backends/webgl/ops/shape.ts b/js/web/lib/onnxjs/backends/webgl/ops/shape.ts new file mode 100644 index 0000000000000..538cb2601602f --- /dev/null +++ b/js/web/lib/onnxjs/backends/webgl/ops/shape.ts @@ -0,0 +1,13 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {Shape} from '../../../ops/shape'; +import {Tensor} from '../../../tensor'; +import {WebGLInferenceHandler} from '../inference-handler'; + + +export class WebGLShape extends Shape { + run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] { + return [new Tensor([inputs[0].dims.length], 'int32', undefined, undefined, new Int32Array(inputs[0].dims))]; + } +} diff --git a/js/web/test/data/ops/shape.jsonc b/js/web/test/data/ops/shape.jsonc new file mode 100644 index 0000000000000..c7595b5de1876 --- /dev/null +++ b/js/web/test/data/ops/shape.jsonc @@ -0,0 +1,26 @@ +[ + { + "name": "Shape op test", + "operator": "Shape", + "attributes": [], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [1, 1, 1, 1], + "dims": [2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [2, 2], + "dims": [2], + "type": "int32" + } + ] + } + ] + } +]