diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td index d69e215e05205..61e8ea2c9ebc8 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td @@ -524,8 +524,7 @@ def SPIRV_TosaMaxPool2DOp : SPIRV_TosaOpWithResult<"MaxPool2D", 7, [Pure, let description = [{ Performs a max pooling over the given input tensor. A sliding window of size given by is passed over the input tensor, with the - maximum value being placed in the - output tensor. + maximum value being placed in the output tensor. References: * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_max_pool2d @@ -574,7 +573,7 @@ def SPIRV_TosaRFFT2DOp : SPIRV_TosaOpWithComplexResult<"RFFT2D", 8, [Pure]> { Performs a batched 2D real-valued Fast Fourier Transform over the input where the input tensor consists of real values producing complex valued output. The complex output values will be split into the output_real and output_imag - tensor arguments. RFFT2D takes advantage of Hermitian symmetry to only + tensor arguments. This operator takes advantage of Hermitian symmetry to only calculate the first half of the final output axis. Implementations may choose to skip calculation of the imaginary values at (0,0), (0,W/2), (H/2,0), and (H/2, W/2). If the calculation is skipped, the result at that location must be @@ -694,4 +693,174 @@ def SPIRV_TosaTransposeConv2DOp : SPIRV_TosaOpWithResult<"TransposeConv2D", 9, [ } +def SPIRV_TosaClampOp : SPIRV_TosaOpWithResult<"Clamp", 10, [Pure, + AllTypesMatch<["input", "output"]>, + AllElementTypesMatch<["input", "output", "min_val", "max_val"]>]> { + let summary = "Computes Clamp(min, max)."; + + let description = [{ + Clamp to an arbitrary minimum and maximum value. + Maximum and minimum values are specified as values in the range of the + input type. + No zero point subtraction is done to the values, thus to clamp to the zero + point value, the zero point itself should be supplied as the minimum value. + + References: + * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_clamp + * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_clamp + + #### Example: + ```mlir + %3 = spirv.Tosa.Clamp min_val = -102 : i8, max_val = -100 : i8, nan_mode = , %arg0 : !spirv.arm.tensor<27x44x55xi8> -> !spirv.arm.tensor<27x44x55xi8> + %3 = spirv.Tosa.Clamp min_val = -1.19339396E+38 : f32, max_val = 2.38255944E+38 : f32, nan_mode = , %arg0 : !spirv.arm.tensor<18x5x17x6xf32> -> !spirv.arm.tensor<18x5x17x6xf32> + ``` + }]; + + let arguments = (ins + SPIRV_TosaNumericalAttr: $min_val, + SPIRV_TosaNumericalAttr: $max_val, + SPIRV_TosaExtNaNPropagationModeAttr: $nan_mode, + SPIRV_TosaNumerical_TensorArm: $input + ); + + let results = (outs + SPIRV_TosaNumerical_TensorArm: $output + ); + + let assemblyFormat = [{ + `min_val` `=` $min_val `,` + `max_val` `=` $max_val `,` + `nan_mode` `=` $nan_mode `,` + $input + attr-dict `:` type(operands) `->` type(results) + }]; + + let extraClassDeclaration = extraBaseClassDeclaration#[{ + ::mlir::spirv::TensorArmType getInputType() { + return cast<::mlir::spirv::TensorArmType>(getInput().getType()); + } + }]; +} + + +def SPIRV_TosaErfOp : SPIRV_TosaOpWithResult<"Erf", 11, [Pure, + AllTypesMatch<["input", "output"]>]> { + let summary = "Gauss Error Function."; + + let description = [{ + Gauss Error Function: $ erf(x) = \frac{2}{\sqrt{\pi}} \int_{0}^{x} e^{-t^2} dt $ + For quantized integer data types, the `spirv.Tosa.Table` operator should be used instead. + + References: + * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_erf + * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_erf + + #### Example: + ```mlir + %0 = spirv.Tosa.Erf %arg0 : !spirv.arm.tensor<47x38x51xf32> -> !spirv.arm.tensor<47x38x51xf32> + ``` + }]; + + let arguments = (ins + SPIRV_TosaFloat_TensorArm: $input + ); + + let results = (outs + SPIRV_TosaFloat_TensorArm: $output + ); + + let assemblyFormat = [{ + $input + attr-dict `:` type(operands) `->` type(results) + }]; + + let extraClassDeclaration = extraBaseClassDeclaration#[{ + ::mlir::spirv::TensorArmType getInputType() { + return cast<::mlir::spirv::TensorArmType>(getInput().getType()); + } + }]; +} + + +def SPIRV_TosaSigmoidOp : SPIRV_TosaOpWithResult<"Sigmoid", 12, [Pure, + AllTypesMatch<["input", "output"]>]> { + let summary = "Sigmoid operator."; + + let description = [{ + Applies the sigmoid logistic function to each element of the input tensor: + $ sigmoid(x) = \frac{1}{1 + e^{-x}} $. + + For quantized integer data types, the `spirv.Tosa.Table` operator should be used instead. + + References: + * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_sigmoid + * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_sigmoid + + #### Example: + ```mlir + %0 = spirv.Tosa.Sigmoid %arg0 : !spirv.arm.tensor<28x43x45xf32> -> !spirv.arm.tensor<28x43x45xf32> + ``` + }]; + + let arguments = (ins + SPIRV_TosaFloat_TensorArm: $input + ); + + let results = (outs + SPIRV_TosaFloat_TensorArm: $output + ); + + let assemblyFormat = [{ + $input + attr-dict `:` type(operands) `->` type(results) + }]; + + let extraClassDeclaration = extraBaseClassDeclaration#[{ + ::mlir::spirv::TensorArmType getInputType() { + return cast<::mlir::spirv::TensorArmType>(getInput().getType()); + } + }]; +} + + +def SPIRV_TosaTanhOp : SPIRV_TosaOpWithResult<"Tanh", 13, [Pure, + AllTypesMatch<["input", "output"]>]> { + let summary = "Hyperbolic Tangent operator."; + + let description = [{ + Elementwise Parameterized Hyperbolic Tangent: $ tanh(x) = \frac{1 - e^{-2x}}{1 + e^{-2x}} $. + + For quantized integer data types, the `spirv.Tosa.Table` operator should be used instead. + + References: + * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_tanh + * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_tanh + + #### Example: + ```mlir + %0 = spirv.Tosa.Tanh %arg0 : !spirv.arm.tensor<46x50x36xf16> -> !spirv.arm.tensor<46x50x36xf16> + ``` + }]; + + let arguments = (ins + SPIRV_TosaFloat_TensorArm: $input + ); + + let results = (outs + SPIRV_TosaFloat_TensorArm: $output + ); + + let assemblyFormat = [{ + $input + attr-dict `:` type(operands) `->` type(results) + }]; + + let extraClassDeclaration = extraBaseClassDeclaration#[{ + ::mlir::spirv::TensorArmType getInputType() { + return cast<::mlir::spirv::TensorArmType>(getInput().getType()); + } + }]; +} + + #endif // MLIR_DIALECT_SPIRV_IR_TOSA_OPS diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td index db4ad8064fc11..5fe3bc53618f4 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td @@ -23,6 +23,7 @@ def SPIRV_TosaAny : AnyTypeOf<[SPIRV_TosaNumerical, SPIRV_Bool]>; def SPIRV_TensorArmAxisAttr : ConfinedAttr]>; def SPIRV_BoolConstAttr : ConfinedAttr; +def SPIRV_TosaNumericalAttr: AnyAttrOf<[I8Attr, I16Attr, I32Attr, I64Attr, F16Attr, F32Attr, BF16Attr]>; // TensorARM Types @@ -44,6 +45,7 @@ def SPIRV_TosaNumerical_TensorArm4D : TensorArmRankOf<[SPIRV_TosaNumerical], [4] def SPIRV_TosaNumerical_TensorArm5D : TensorArmRankOf<[SPIRV_TosaNumerical], [5]>; def SPIRV_TosaNumerical_TensorArm : TensorArmRankOf<[SPIRV_TosaNumerical], [1, 2, 3, 4, 5, 6]>; +def SPIRV_TosaFloat_TensorArm : TensorArmRankOf<[SPIRV_TosaFloat], [1, 2, 3, 4, 5, 6]>; def SPIRV_Int32_TensorArmUpTo5D : TensorArmRankOf<[SPIRV_Int32], [1, 2, 3, 4, 5]>; class Is1DTensorArmOfLength allowedLengths> : diff --git a/mlir/include/mlir/IR/CommonAttrConstraints.td b/mlir/include/mlir/IR/CommonAttrConstraints.td index 8ac1a2ea21422..ba6cf55a8fb9e 100644 --- a/mlir/include/mlir/IR/CommonAttrConstraints.td +++ b/mlir/include/mlir/IR/CommonAttrConstraints.td @@ -334,9 +334,17 @@ class FloatAttrBase : let returnType = [{ ::llvm::APFloat }]; } +def F16Attr : FloatAttrBase; def F32Attr : FloatAttrBase; def F64Attr : FloatAttrBase; +def BF16Attr : TypedAttrBase($_self)">, + CPred<"::llvm::cast<::mlir::FloatAttr>($_self).getType().isBF16()">]>, + "16-bit bfloat attribute"> { + let returnType = [{ ::llvm::APFloat }]; +} + // An attribute backed by a string type. class StringBasedAttr : Attr { let constBuilderCall = "$_builder.getStringAttr($0)"; diff --git a/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir b/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir index 56cd6d6900fdb..dd18a3a2ae788 100644 --- a/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir +++ b/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir @@ -410,8 +410,24 @@ spirv.ARM.Graph @matmul_invalid_input_output_element_type_combination(%arg0: !sp // spirv.TOSA.MaxPool2D //===----------------------------------------------------------------------===// -spirv.ARM.Graph @maxpool2d_int(%arg0: !spirv.arm.tensor<1x3x65537x1xi8>) -> (!spirv.arm.tensor<1x2x32769x1xi16>) { +spirv.ARM.Graph @maxpool2d_input_output_different_element_types(%arg0: !spirv.arm.tensor<1x3x65537x1xi8>) -> (!spirv.arm.tensor<1x2x32769x1xi16>) { // expected-error @+1 {{op failed to verify that all of {input, output} have same element type}} %4 = spirv.Tosa.MaxPool2D kernel = [3, 2], stride = [1, 2], pad = [1, 0, 0, 1], nan_mode = , %arg0 : !spirv.arm.tensor<1x3x65537x1xi8> -> !spirv.arm.tensor<1x2x32769x1xi16> spirv.ARM.GraphOutputs %4 : !spirv.arm.tensor<1x2x32769x1xi16> } + +//===----------------------------------------------------------------------===// +// spirv.TOSA.Clamp +//===----------------------------------------------------------------------===// + +spirv.ARM.Graph @clamp_min_val_different_element_type_wrt_input_output(%arg0: !spirv.arm.tensor<27x44x55xi8>) -> (!spirv.arm.tensor<27x44x55xi8>) { + // expected-error @+1 {{op failed to verify that all of {input, output, min_val, max_val} have same element type}} + %3 = spirv.Tosa.Clamp min_val = -102 : i16, max_val = -100 : i8, nan_mode = , %arg0 : !spirv.arm.tensor<27x44x55xi8> -> !spirv.arm.tensor<27x44x55xi8> + spirv.ARM.GraphOutputs %3 : !spirv.arm.tensor<27x44x55xi8> +} + +spirv.ARM.Graph @clamp_max_val_different_element_type_wrt_input_output(%arg0: !spirv.arm.tensor<27x44x55xi8>) -> (!spirv.arm.tensor<27x44x55xi8>) { + // expected-error @+1 {{op failed to verify that all of {input, output, min_val, max_val} have same element type}} + %3 = spirv.Tosa.Clamp min_val = -102 : i8, max_val = -100 : i16, nan_mode = , %arg0 : !spirv.arm.tensor<27x44x55xi8> -> !spirv.arm.tensor<27x44x55xi8> + spirv.ARM.GraphOutputs %3 : !spirv.arm.tensor<27x44x55xi8> +} diff --git a/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir b/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir index 1a43e2c95c530..a9f7bc2b8ef7d 100644 --- a/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir @@ -229,3 +229,58 @@ spirv.ARM.Graph @transposeconv2d_fp(%arg0: !spirv.arm.tensor<10x24x9x13xf16>, %a // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<10x25x65x14xf16> spirv.ARM.GraphOutputs %6 : !spirv.arm.tensor<10x25x65x14xf16> } + +//===----------------------------------------------------------------------===// +// spirv.TOSA.Clamp - PRO-INT +//===----------------------------------------------------------------------===// + +spirv.ARM.Graph @clamp_int(%arg0: !spirv.arm.tensor<27x44x55xi8>) -> (!spirv.arm.tensor<27x44x55xi8>) { + // CHECK: {{%.*}} = spirv.Tosa.Clamp min_val = -102 : i8, max_val = -100 : i8, nan_mode = , %arg0 : !spirv.arm.tensor<27x44x55xi8> -> !spirv.arm.tensor<27x44x55xi8> + %3 = spirv.Tosa.Clamp min_val = -102 : i8, max_val = -100 : i8, nan_mode = , %arg0 : !spirv.arm.tensor<27x44x55xi8> -> !spirv.arm.tensor<27x44x55xi8> + // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<27x44x55xi8> + spirv.ARM.GraphOutputs %3 : !spirv.arm.tensor<27x44x55xi8> +} + +//===----------------------------------------------------------------------===// +// spirv.TOSA.Clamp - PRO-FP +//===----------------------------------------------------------------------===// + +spirv.ARM.Graph @clamp_fp(%arg0: !spirv.arm.tensor<18x5x17x6xf32>) -> (!spirv.arm.tensor<18x5x17x6xf32>) { + // CHECK: {{%.*}} = spirv.Tosa.Clamp min_val = -1.19339396E+38 : f32, max_val = 2.38255944E+38 : f32, nan_mode = , %arg0 : !spirv.arm.tensor<18x5x17x6xf32> -> !spirv.arm.tensor<18x5x17x6xf32> + %3 = spirv.Tosa.Clamp min_val = -1.19339396E+38 : f32, max_val = 2.38255944E+38 : f32, nan_mode = , %arg0 : !spirv.arm.tensor<18x5x17x6xf32> -> !spirv.arm.tensor<18x5x17x6xf32> + // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<18x5x17x6xf32> + spirv.ARM.GraphOutputs %3 : !spirv.arm.tensor<18x5x17x6xf32> +} + +//===----------------------------------------------------------------------===// +// spirv.TOSA.Erf - PRO-FP +//===----------------------------------------------------------------------===// + +spirv.ARM.Graph @erf_fp(%arg0: !spirv.arm.tensor<47x38x51xf32>) -> (!spirv.arm.tensor<47x38x51xf32>) { + // CHECK: {{%.*}} = spirv.Tosa.Erf %arg0 : !spirv.arm.tensor<47x38x51xf32> -> !spirv.arm.tensor<47x38x51xf32> + %0 = spirv.Tosa.Erf %arg0 : !spirv.arm.tensor<47x38x51xf32> -> !spirv.arm.tensor<47x38x51xf32> + // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<47x38x51xf32> + spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<47x38x51xf32> +} + +//===----------------------------------------------------------------------===// +// spirv.TOSA.Sigmoid - PRO-FP +//===----------------------------------------------------------------------===// + +spirv.ARM.Graph @sigmoid_fp(%arg0: !spirv.arm.tensor<28x43x45xf32>) -> (!spirv.arm.tensor<28x43x45xf32>) { + // CHECK: {{%.*}} = spirv.Tosa.Sigmoid %arg0 : !spirv.arm.tensor<28x43x45xf32> -> !spirv.arm.tensor<28x43x45xf32> + %0 = spirv.Tosa.Sigmoid %arg0 : !spirv.arm.tensor<28x43x45xf32> -> !spirv.arm.tensor<28x43x45xf32> + // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<28x43x45xf32> + spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<28x43x45xf32> +} + +//===----------------------------------------------------------------------===// +// spirv.TOSA.Tanh - PRO-FP +//===----------------------------------------------------------------------===// + +spirv.ARM.Graph @tanh_fp(%arg0: !spirv.arm.tensor<46x50x36xf16>) -> (!spirv.arm.tensor<46x50x36xf16>) { + // CHECK: {{%.*}} = spirv.Tosa.Tanh %arg0 : !spirv.arm.tensor<46x50x36xf16> -> !spirv.arm.tensor<46x50x36xf16> + %0 = spirv.Tosa.Tanh %arg0 : !spirv.arm.tensor<46x50x36xf16> -> !spirv.arm.tensor<46x50x36xf16> + // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<46x50x36xf16> + spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<46x50x36xf16> +} diff --git a/mlir/test/Target/SPIRV/tosa-ops.mlir b/mlir/test/Target/SPIRV/tosa-ops.mlir index 1d219b855bec1..9f2ff1c31cbc5 100644 --- a/mlir/test/Target/SPIRV/tosa-ops.mlir +++ b/mlir/test/Target/SPIRV/tosa-ops.mlir @@ -396,3 +396,98 @@ spirv.module Logical Vulkan requires #spirv.vce } } + +// ----- + +//===----------------------------------------------------------------------===// +// spirv.TOSA.Clamp - PRO-INT +//===----------------------------------------------------------------------===// + +// CHECK: spirv.module Logical Vulkan requires #spirv.vce +spirv.module Logical Vulkan requires #spirv.vce { + spirv.GlobalVariable @clamp_int_arg_0 bind(0, 0) : !spirv.ptr, UniformConstant> + spirv.GlobalVariable @clamp_int_res_0 bind(1, 0) : !spirv.ptr, UniformConstant> + spirv.ARM.GraphEntryPoint @clamp_int, @clamp_int_arg_0, @clamp_int_res_0 + spirv.ARM.Graph @clamp_int(%arg0: !spirv.arm.tensor<27x44x55xi8>) -> (!spirv.arm.tensor<27x44x55xi8>) { + // CHECK: {{%.*}} = spirv.Tosa.Clamp min_val = -102 : i8, max_val = -100 : i8, nan_mode = , %arg0 : !spirv.arm.tensor<27x44x55xi8> -> !spirv.arm.tensor<27x44x55xi8> + %3 = spirv.Tosa.Clamp min_val = -102 : i8, max_val = -100 : i8, nan_mode = , %arg0 : !spirv.arm.tensor<27x44x55xi8> -> !spirv.arm.tensor<27x44x55xi8> + // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<27x44x55xi8> + spirv.ARM.GraphOutputs %3 : !spirv.arm.tensor<27x44x55xi8> + } +} + +// ----- + +//===----------------------------------------------------------------------===// +// spirv.TOSA.Clamp - PRO-FP +//===----------------------------------------------------------------------===// + +// CHECK: spirv.module Logical Vulkan requires #spirv.vce +spirv.module Logical Vulkan requires #spirv.vce { + spirv.GlobalVariable @clamp_fp_arg_0 bind(0, 0) : !spirv.ptr, UniformConstant> + spirv.GlobalVariable @clamp_fp_res_0 bind(1, 0) : !spirv.ptr, UniformConstant> + spirv.ARM.GraphEntryPoint @clamp_fp, @clamp_fp_arg_0, @clamp_fp_res_0 + spirv.ARM.Graph @clamp_fp(%arg0: !spirv.arm.tensor<18x5x17x6xf32>) -> (!spirv.arm.tensor<18x5x17x6xf32>) { + // CHECK: {{%.*}} = spirv.Tosa.Clamp min_val = -1.19339396E+38 : f32, max_val = 2.38255944E+38 : f32, nan_mode = , %arg0 : !spirv.arm.tensor<18x5x17x6xf32> -> !spirv.arm.tensor<18x5x17x6xf32> + %3 = spirv.Tosa.Clamp min_val = -1.19339396E+38 : f32, max_val = 2.38255944E+38 : f32, nan_mode = , %arg0 : !spirv.arm.tensor<18x5x17x6xf32> -> !spirv.arm.tensor<18x5x17x6xf32> + // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<18x5x17x6xf32> + spirv.ARM.GraphOutputs %3 : !spirv.arm.tensor<18x5x17x6xf32> + } +} + +// ----- + +//===----------------------------------------------------------------------===// +// spirv.TOSA.Erf - PRO-FP +//===----------------------------------------------------------------------===// + +// CHECK: spirv.module Logical Vulkan requires #spirv.vce +spirv.module Logical Vulkan requires #spirv.vce { + spirv.GlobalVariable @erf_fp_arg_0 bind(0, 0) : !spirv.ptr, UniformConstant> + spirv.GlobalVariable @erf_fp_res_0 bind(1, 0) : !spirv.ptr, UniformConstant> + spirv.ARM.GraphEntryPoint @erf_fp, @erf_fp_arg_0, @erf_fp_res_0 + spirv.ARM.Graph @erf_fp(%arg0: !spirv.arm.tensor<47x38x51xf32>) -> (!spirv.arm.tensor<47x38x51xf32>) { + // CHECK: {{%.*}} = spirv.Tosa.Erf %arg0 : !spirv.arm.tensor<47x38x51xf32> -> !spirv.arm.tensor<47x38x51xf32> + %0 = spirv.Tosa.Erf %arg0 : !spirv.arm.tensor<47x38x51xf32> -> !spirv.arm.tensor<47x38x51xf32> + // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<47x38x51xf32> + spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<47x38x51xf32> + } +} + +// ----- + +//===----------------------------------------------------------------------===// +// spirv.TOSA.Sigmoid - PRO-FP +//===----------------------------------------------------------------------===// + +// CHECK: spirv.module Logical Vulkan requires #spirv.vce +spirv.module Logical Vulkan requires #spirv.vce { + spirv.GlobalVariable @sigmoid_fp_arg_0 bind(0, 0) : !spirv.ptr, UniformConstant> + spirv.GlobalVariable @sigmoid_fp_res_0 bind(1, 0) : !spirv.ptr, UniformConstant> + spirv.ARM.GraphEntryPoint @sigmoid_fp, @sigmoid_fp_arg_0, @sigmoid_fp_res_0 + spirv.ARM.Graph @sigmoid_fp(%arg0: !spirv.arm.tensor<28x43x45xf32>) -> (!spirv.arm.tensor<28x43x45xf32>) { + // CHECK: {{%.*}} = spirv.Tosa.Sigmoid %arg0 : !spirv.arm.tensor<28x43x45xf32> -> !spirv.arm.tensor<28x43x45xf32> + %0 = spirv.Tosa.Sigmoid %arg0 : !spirv.arm.tensor<28x43x45xf32> -> !spirv.arm.tensor<28x43x45xf32> + // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<28x43x45xf32> + spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<28x43x45xf32> + } +} + +// ----- + +//===----------------------------------------------------------------------===// +// spirv.TOSA.Tanh - PRO-FP +//===----------------------------------------------------------------------===// + +// CHECK: spirv.module Logical Vulkan requires #spirv.vce +spirv.module Logical Vulkan requires #spirv.vce { + spirv.GlobalVariable @tanh_fp_arg_0 bind(0, 0) : !spirv.ptr, UniformConstant> + spirv.GlobalVariable @tanh_fp_res_0 bind(1, 0) : !spirv.ptr, UniformConstant> + spirv.ARM.GraphEntryPoint @tanh_fp, @tanh_fp_arg_0, @tanh_fp_res_0 + spirv.ARM.Graph @tanh_fp(%arg0: !spirv.arm.tensor<46x50x36xf16>) -> (!spirv.arm.tensor<46x50x36xf16>) { + // CHECK: {{%.*}} = spirv.Tosa.Tanh %arg0 : !spirv.arm.tensor<46x50x36xf16> -> !spirv.arm.tensor<46x50x36xf16> + %0 = spirv.Tosa.Tanh %arg0 : !spirv.arm.tensor<46x50x36xf16> -> !spirv.arm.tensor<46x50x36xf16> + // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<46x50x36xf16> + spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<46x50x36xf16> + } +} diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp index 9cb48934b2c10..34edd4df49d8e 100644 --- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp +++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp @@ -558,9 +558,10 @@ static void emitAttributeSerialization(const Attribute &attr, os << tabs << " return failure();\n"; os << tabs << " }\n"; os << tabs << formatv(" {0}.push_back(attrTypeID);\n", operandList); - } else if (llvm::is_contained( - {"SPIRV_BoolConstAttr", "SPIRV_TensorArmAxisAttr"}, - attr.getAttrDefName())) { + } else if (llvm::is_contained({"SPIRV_BoolConstAttr", + "SPIRV_TensorArmAxisAttr", + "SPIRV_TosaNumericalAttr"}, + attr.getAttrDefName())) { os << tabs << formatv( " {0}.push_back(prepareConstantScalar({1}.getLoc(), attr));\n", @@ -864,7 +865,9 @@ static void emitAttributeDeserialization(const Attribute &attr, << formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", " "TypeAttr::get(getType({2}[{3}++]))));\n", attrList, attrName, words, wordIndex); - } else if (attr.getAttrDefName() == "SPIRV_BoolConstAttr" || + } else if (llvm::is_contained( + {"SPIRV_BoolConstAttr", "SPIRV_TosaNumericalAttr"}, + attr.getAttrDefName()) || attr.getAttrDefName().contains("TensorArm")) { os << tabs << formatv("std::optional> c = "