[mlir][spirv] Add Activation operators to TOSA Extended Instruction S…#178620
Conversation
|
@llvm/pr-subscribers-mlir-ods @llvm/pr-subscribers-mlir Author: Davide Grohmann (davidegrohmann) Changes…et (001000.1) This patch adds the Activation operators to the TOSA Extended Instruction Set (001000.1) to the SPIR-V dialect in MLIR. The TOSA extended instruction set provides a standardized set of machine learning operations designed to be used within The change introduces:
All these operations from TOSA 001000.1 extended instructions are introduced: Parser, printer, verifier, and round-trip tests using MLIR’s SPIR-V serialization/deserialization infrastructure are included. This work completes support for expressing TOSA extended instructions inside SPIR-V graphs in MLIR, aligning with Khronos SPIR-V TOSA specifications. Specification: Patch is 21.97 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/178620.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
index d69e215e05205..c4651f695617f 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
@@ -694,4 +694,182 @@ def SPIRV_TosaTransposeConv2DOp : SPIRV_TosaOpWithResult<"TransposeConv2D", 9, [
}
+def SPIRV_TosaClampOp : SPIRV_TosaOpWithResult<"Clamp", 10, [Pure,
+ AllTypesMatch<["input", "output"]>,
+ AllElementTypesMatch<["input", "output", "min_val", "max_val"]>]> {
+ let summary = "Computes Clamp(min, max).";
+
+ let description = [{
+ Clamp to an arbitrary minimum and maximum value.
+ Maximum and minimum values are specified as values in the range of the
+ input type.
+ No zero point subtraction is done to the values, thus to clamp to the zero
+ point value, the zero point itself should be supplied as the minimum value.
+
+ References:
+ * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_clamp
+ * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_clamp
+
+ #### Example:
+ ```mlir
+ %3 = spirv.Tosa.Clamp min_val = -102 : i8, max_val = -100 : i8, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<27x44x55xi8> -> !spirv.arm.tensor<27x44x55xi8>
+ %3 = spirv.Tosa.Clamp min_val = -1.19339396E+38 : f32, max_val = 2.38255944E+38 : f32, nan_mode = <Ignore>, %arg0 : !spirv.arm.tensor<18x5x17x6xf32> -> !spirv.arm.tensor<18x5x17x6xf32>
+ ```
+ }];
+
+ let arguments = (ins
+ SPIRV_TosaNumericalAttr: $min_val,
+ SPIRV_TosaNumericalAttr: $max_val,
+ SPIRV_TosaExtNaNPropagationModeAttr: $nan_mode,
+ SPIRV_TosaNumerical_TensorArm: $input
+ );
+
+ let results = (outs
+ SPIRV_TosaNumerical_TensorArm: $output
+ );
+
+ let assemblyFormat = [{
+ `min_val` `=` $min_val `,`
+ `max_val` `=` $max_val `,`
+ `nan_mode` `=` $nan_mode `,`
+ $input
+ attr-dict `:` type(operands) `->` type(results)
+ }];
+
+ let extraClassDeclaration = extraBaseClassDeclaration#[{
+ ::mlir::spirv::TensorArmType getInputType() {
+ return cast<::mlir::spirv::TensorArmType>(getInput().getType());
+ }
+ }];
+}
+
+
+def SPIRV_TosaErfOp : SPIRV_TosaOpWithResult<"Erf", 11, [Pure,
+ AllTypesMatch<["input", "output"]>]> {
+ let summary = "Computes Gauss Error Function of input.";
+
+ let description = [{
+ Gauss Error Function: $ erf(x) = \frac{2}{\sqrt{\pi}} \int_{0}^{x} e^{-t^2} dt $
+ For quantized integer data types, the table operator should be used instead
+ with the following definition. The ERF table has 513 entries each of
+ 16-bit precision and covering the input range -4.0 to +4.0 in steps of 1/64.
+
+ References:
+ * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_erf
+ * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_erf
+
+ #### Example:
+ ```mlir
+ %0 = spirv.Tosa.Erf %arg0 : !spirv.arm.tensor<47x38x51xf32> -> !spirv.arm.tensor<47x38x51xf32>
+ ```
+ }];
+
+ let arguments = (ins
+ SPIRV_TosaFloat_TensorArm: $input
+ );
+
+ let results = (outs
+ SPIRV_TosaFloat_TensorArm: $output
+ );
+
+ let assemblyFormat = [{
+ $input
+ attr-dict `:` type(operands) `->` type(results)
+ }];
+
+ let extraClassDeclaration = extraBaseClassDeclaration#[{
+ ::mlir::spirv::TensorArmType getInputType() {
+ return cast<::mlir::spirv::TensorArmType>(getInput().getType());
+ }
+ }];
+}
+
+
+def SPIRV_TosaSigmoidOp : SPIRV_TosaOpWithResult<"Sigmoid", 12, [Pure,
+ AllTypesMatch<["input", "output"]>]> {
+ let summary = "Computes elementwise sigmoid of input.";
+
+ let description = [{
+ Applies the sigmoid logistic function to each element of the input tensor:
+ $ sigmoid(x) = rac{1}{1 + e^{-x}} $.
+
+ For quantized integer data types, the table operator should be used instead.
+ Each implementation may choose an appropriate table given the scale and zero
+ point of the input data. Eight or sixteen bit precision tables may be used
+ based on the input tensor to the sigmoid function.
+
+ References:
+ * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_sigmoid
+ * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_sigmoid
+
+ #### Example:
+ ```mlir
+ %0 = spirv.Tosa.Sigmoid %arg0 : !spirv.arm.tensor<28x43x45xf32> -> !spirv.arm.tensor<28x43x45xf32>
+ ```
+ }];
+
+ let arguments = (ins
+ SPIRV_TosaFloat_TensorArm: $input
+ );
+
+ let results = (outs
+ SPIRV_TosaFloat_TensorArm: $output
+ );
+
+ let assemblyFormat = [{
+ $input
+ attr-dict `:` type(operands) `->` type(results)
+ }];
+
+ let extraClassDeclaration = extraBaseClassDeclaration#[{
+ ::mlir::spirv::TensorArmType getInputType() {
+ return cast<::mlir::spirv::TensorArmType>(getInput().getType());
+ }
+ }];
+}
+
+
+def SPIRV_TosaTanhOp : SPIRV_TosaOpWithResult<"Tanh", 13, [Pure,
+ AllTypesMatch<["input", "output"]>]> {
+ let summary = "Computes elementwise Hyperbolic Tangent of input.";
+
+ let description = [{
+ Parameterized Hyperbolic Tangent: $ tanh(x) = rac{1 - e^{-2x}}{1 + e^{-2x}} $.
+
+ For quantized integer data types, the table operator should be used instead.
+ Each implementation may choose an appropriate table given the scale and zero
+ point of the input data. Eight or sixteen bit precision tables may be used
+ based on the input tensor to the tanh function.
+
+ References:
+ * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_tanh
+ * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_tanh
+
+ #### Example:
+ ```mlir
+ %0 = spirv.Tosa.Tanh %arg0 : !spirv.arm.tensor<46x50x36xf16> -> !spirv.arm.tensor<46x50x36xf16>
+ ```
+ }];
+
+ let arguments = (ins
+ SPIRV_TosaFloat_TensorArm: $input
+ );
+
+ let results = (outs
+ SPIRV_TosaFloat_TensorArm: $output
+ );
+
+ let assemblyFormat = [{
+ $input
+ attr-dict `:` type(operands) `->` type(results)
+ }];
+
+ let extraClassDeclaration = extraBaseClassDeclaration#[{
+ ::mlir::spirv::TensorArmType getInputType() {
+ return cast<::mlir::spirv::TensorArmType>(getInput().getType());
+ }
+ }];
+}
+
+
#endif // MLIR_DIALECT_SPIRV_IR_TOSA_OPS
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
index db4ad8064fc11..67d94201bf713 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
@@ -23,6 +23,7 @@ def SPIRV_TosaAny : AnyTypeOf<[SPIRV_TosaNumerical, SPIRV_Bool]>;
def SPIRV_TensorArmAxisAttr : ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<5>]>;
def SPIRV_BoolConstAttr : ConfinedAttr<BoolAttr, []>;
+def SPIRV_TosaNumericalAttr: AnyAttrOf<[I8Attr, I16Attr, I32Attr, I64Attr, F16Attr, F32Attr]>;
// TensorARM Types
@@ -44,6 +45,7 @@ def SPIRV_TosaNumerical_TensorArm4D : TensorArmRankOf<[SPIRV_TosaNumerical], [4]
def SPIRV_TosaNumerical_TensorArm5D : TensorArmRankOf<[SPIRV_TosaNumerical], [5]>;
def SPIRV_TosaNumerical_TensorArm : TensorArmRankOf<[SPIRV_TosaNumerical], [1, 2, 3, 4, 5, 6]>;
+def SPIRV_TosaFloat_TensorArm : TensorArmRankOf<[SPIRV_TosaFloat], [1, 2, 3, 4, 5, 6]>;
def SPIRV_Int32_TensorArmUpTo5D : TensorArmRankOf<[SPIRV_Int32], [1, 2, 3, 4, 5]>;
class Is1DTensorArmOfLength<list<int> allowedLengths> :
@@ -62,6 +64,12 @@ def SPIRV_DenseElementAttrsWithTensorArmType : AttrConstraint<
CPred<"::llvm::isa<::mlir::spirv::TensorArmType>(::llvm::cast<::mlir::DenseElementsAttr>($_self).getType())">,
"Attr with type = spirv::TensorArmType">;
+class Is1DTensorArmAttrOfLength<list<int> allowedLengths> :
+ AttrConstraint<And<[CPred<[{::llvm::cast<::mlir::spirv::TensorArmType>(::llvm::cast<::mlir::DenseElementsAttr>($_self).getType()).getShape().size() == 1 }]>,
+ Or<!foreach(allowedlength, allowedLengths,
+ CPred<[{::llvm::cast<::mlir::spirv::TensorArmType>(::llvm::cast<::mlir::DenseElementsAttr>($_self).getType()).getShape()[0] == }]
+ # allowedlength>)>]>>;
+
def SPIRV_Int32_1DTensorArmOfLength2Attr : ConfinedAttr<RankedI32ElementsAttr<[2]>, [SPIRV_DenseElementAttrsWithTensorArmType]>;
def SPIRV_Int32_1DTensorArmOfLength3Attr : ConfinedAttr<RankedI32ElementsAttr<[3]>, [SPIRV_DenseElementAttrsWithTensorArmType]>;
def SPIRV_Int32_1DTensorArmOfLength4Attr : ConfinedAttr<RankedI32ElementsAttr<[4]>, [SPIRV_DenseElementAttrsWithTensorArmType]>;
diff --git a/mlir/include/mlir/IR/CommonAttrConstraints.td b/mlir/include/mlir/IR/CommonAttrConstraints.td
index 8ac1a2ea21422..ad6c355155161 100644
--- a/mlir/include/mlir/IR/CommonAttrConstraints.td
+++ b/mlir/include/mlir/IR/CommonAttrConstraints.td
@@ -334,6 +334,7 @@ class FloatAttrBase<F attrValType, string descr> :
let returnType = [{ ::llvm::APFloat }];
}
+def F16Attr : FloatAttrBase<F16, "16-bit float attribute">;
def F32Attr : FloatAttrBase<F32, "32-bit float attribute">;
def F64Attr : FloatAttrBase<F64, "64-bit float attribute">;
diff --git a/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir b/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir
index 1a43e2c95c530..a9f7bc2b8ef7d 100644
--- a/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir
@@ -229,3 +229,58 @@ spirv.ARM.Graph @transposeconv2d_fp(%arg0: !spirv.arm.tensor<10x24x9x13xf16>, %a
// CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<10x25x65x14xf16>
spirv.ARM.GraphOutputs %6 : !spirv.arm.tensor<10x25x65x14xf16>
}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Clamp - PRO-INT
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @clamp_int(%arg0: !spirv.arm.tensor<27x44x55xi8>) -> (!spirv.arm.tensor<27x44x55xi8>) {
+ // CHECK: {{%.*}} = spirv.Tosa.Clamp min_val = -102 : i8, max_val = -100 : i8, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<27x44x55xi8> -> !spirv.arm.tensor<27x44x55xi8>
+ %3 = spirv.Tosa.Clamp min_val = -102 : i8, max_val = -100 : i8, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<27x44x55xi8> -> !spirv.arm.tensor<27x44x55xi8>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<27x44x55xi8>
+ spirv.ARM.GraphOutputs %3 : !spirv.arm.tensor<27x44x55xi8>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Clamp - PRO-FP
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @clamp_fp(%arg0: !spirv.arm.tensor<18x5x17x6xf32>) -> (!spirv.arm.tensor<18x5x17x6xf32>) {
+ // CHECK: {{%.*}} = spirv.Tosa.Clamp min_val = -1.19339396E+38 : f32, max_val = 2.38255944E+38 : f32, nan_mode = <Ignore>, %arg0 : !spirv.arm.tensor<18x5x17x6xf32> -> !spirv.arm.tensor<18x5x17x6xf32>
+ %3 = spirv.Tosa.Clamp min_val = -1.19339396E+38 : f32, max_val = 2.38255944E+38 : f32, nan_mode = <Ignore>, %arg0 : !spirv.arm.tensor<18x5x17x6xf32> -> !spirv.arm.tensor<18x5x17x6xf32>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<18x5x17x6xf32>
+ spirv.ARM.GraphOutputs %3 : !spirv.arm.tensor<18x5x17x6xf32>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Erf - PRO-FP
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @erf_fp(%arg0: !spirv.arm.tensor<47x38x51xf32>) -> (!spirv.arm.tensor<47x38x51xf32>) {
+ // CHECK: {{%.*}} = spirv.Tosa.Erf %arg0 : !spirv.arm.tensor<47x38x51xf32> -> !spirv.arm.tensor<47x38x51xf32>
+ %0 = spirv.Tosa.Erf %arg0 : !spirv.arm.tensor<47x38x51xf32> -> !spirv.arm.tensor<47x38x51xf32>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<47x38x51xf32>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<47x38x51xf32>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Sigmoid - PRO-FP
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @sigmoid_fp(%arg0: !spirv.arm.tensor<28x43x45xf32>) -> (!spirv.arm.tensor<28x43x45xf32>) {
+ // CHECK: {{%.*}} = spirv.Tosa.Sigmoid %arg0 : !spirv.arm.tensor<28x43x45xf32> -> !spirv.arm.tensor<28x43x45xf32>
+ %0 = spirv.Tosa.Sigmoid %arg0 : !spirv.arm.tensor<28x43x45xf32> -> !spirv.arm.tensor<28x43x45xf32>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<28x43x45xf32>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<28x43x45xf32>
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Tanh - PRO-FP
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @tanh_fp(%arg0: !spirv.arm.tensor<46x50x36xf16>) -> (!spirv.arm.tensor<46x50x36xf16>) {
+ // CHECK: {{%.*}} = spirv.Tosa.Tanh %arg0 : !spirv.arm.tensor<46x50x36xf16> -> !spirv.arm.tensor<46x50x36xf16>
+ %0 = spirv.Tosa.Tanh %arg0 : !spirv.arm.tensor<46x50x36xf16> -> !spirv.arm.tensor<46x50x36xf16>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<46x50x36xf16>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<46x50x36xf16>
+}
diff --git a/mlir/test/Target/SPIRV/tosa-ops.mlir b/mlir/test/Target/SPIRV/tosa-ops.mlir
index 1d219b855bec1..9f2ff1c31cbc5 100644
--- a/mlir/test/Target/SPIRV/tosa-ops.mlir
+++ b/mlir/test/Target/SPIRV/tosa-ops.mlir
@@ -396,3 +396,98 @@ spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader
spirv.ARM.GraphOutputs %6 : !spirv.arm.tensor<10x25x65x14xf16>
}
}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Clamp - PRO-INT
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+ spirv.GlobalVariable @clamp_int_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<27x44x55xi8>, UniformConstant>
+ spirv.GlobalVariable @clamp_int_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<27x44x55xi8>, UniformConstant>
+ spirv.ARM.GraphEntryPoint @clamp_int, @clamp_int_arg_0, @clamp_int_res_0
+ spirv.ARM.Graph @clamp_int(%arg0: !spirv.arm.tensor<27x44x55xi8>) -> (!spirv.arm.tensor<27x44x55xi8>) {
+ // CHECK: {{%.*}} = spirv.Tosa.Clamp min_val = -102 : i8, max_val = -100 : i8, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<27x44x55xi8> -> !spirv.arm.tensor<27x44x55xi8>
+ %3 = spirv.Tosa.Clamp min_val = -102 : i8, max_val = -100 : i8, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<27x44x55xi8> -> !spirv.arm.tensor<27x44x55xi8>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<27x44x55xi8>
+ spirv.ARM.GraphOutputs %3 : !spirv.arm.tensor<27x44x55xi8>
+ }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Clamp - PRO-FP
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+ spirv.GlobalVariable @clamp_fp_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<18x5x17x6xf32>, UniformConstant>
+ spirv.GlobalVariable @clamp_fp_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<18x5x17x6xf32>, UniformConstant>
+ spirv.ARM.GraphEntryPoint @clamp_fp, @clamp_fp_arg_0, @clamp_fp_res_0
+ spirv.ARM.Graph @clamp_fp(%arg0: !spirv.arm.tensor<18x5x17x6xf32>) -> (!spirv.arm.tensor<18x5x17x6xf32>) {
+ // CHECK: {{%.*}} = spirv.Tosa.Clamp min_val = -1.19339396E+38 : f32, max_val = 2.38255944E+38 : f32, nan_mode = <Ignore>, %arg0 : !spirv.arm.tensor<18x5x17x6xf32> -> !spirv.arm.tensor<18x5x17x6xf32>
+ %3 = spirv.Tosa.Clamp min_val = -1.19339396E+38 : f32, max_val = 2.38255944E+38 : f32, nan_mode = <Ignore>, %arg0 : !spirv.arm.tensor<18x5x17x6xf32> -> !spirv.arm.tensor<18x5x17x6xf32>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<18x5x17x6xf32>
+ spirv.ARM.GraphOutputs %3 : !spirv.arm.tensor<18x5x17x6xf32>
+ }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Erf - PRO-FP
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+ spirv.GlobalVariable @erf_fp_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<47x38x51xf32>, UniformConstant>
+ spirv.GlobalVariable @erf_fp_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<47x38x51xf32>, UniformConstant>
+ spirv.ARM.GraphEntryPoint @erf_fp, @erf_fp_arg_0, @erf_fp_res_0
+ spirv.ARM.Graph @erf_fp(%arg0: !spirv.arm.tensor<47x38x51xf32>) -> (!spirv.arm.tensor<47x38x51xf32>) {
+ // CHECK: {{%.*}} = spirv.Tosa.Erf %arg0 : !spirv.arm.tensor<47x38x51xf32> -> !spirv.arm.tensor<47x38x51xf32>
+ %0 = spirv.Tosa.Erf %arg0 : !spirv.arm.tensor<47x38x51xf32> -> !spirv.arm.tensor<47x38x51xf32>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<47x38x51xf32>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<47x38x51xf32>
+ }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Sigmoid - PRO-FP
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+ spirv.GlobalVariable @sigmoid_fp_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<28x43x45xf32>, UniformConstant>
+ spirv.GlobalVariable @sigmoid_fp_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<28x43x45xf32>, UniformConstant>
+ spirv.ARM.GraphEntryPoint @sigmoid_fp, @sigmoid_fp_arg_0, @sigmoid_fp_res_0
+ spirv.ARM.Graph @sigmoid_fp(%arg0: !spirv.arm.tensor<28x43x45xf32>) -> (!spirv.arm.tensor<28x43x45xf32>) {
+ // CHECK: {{%.*}} = spirv.Tosa.Sigmoid %arg0 : !spirv.arm.tensor<28x43x45xf32> -> !spirv.arm.tensor<28x43x45xf32>
+ %0 = spirv.Tosa.Sigmoid %arg0 : !spirv.arm.tensor<28x43x45xf32> -> !spirv.arm.tensor<28x43x45xf32>
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<28x43x45xf32>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<28x43x45xf32>
+ }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.TOSA.Tanh - PRO-FP
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+ spirv.GlobalVariable @tanh_fp_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<46x50x36xf16>, UniformConstant>
+ spirv.GlobalVariable @tanh_fp_res_0 bind(1, 0) : !spirv.ptr<!spirv.arm.tensor<46x50x36xf16>, UniformConstant>
+ spirv.ARM.GraphEntryPoint @tanh_fp, @tanh_fp_arg_0, @tanh_fp_res_0
+ spirv.ARM.Graph @tanh_fp(%arg0: !spirv.arm.tensor<46...
[truncated]
|
81a3b86 to
47b06d6
Compare
ghost
left a comment
There was a problem hiding this comment.
I think the PR description that becomes the commit message could be compressed. There is a lot of boilerplate at this point. I think the whole description of the change could be compressed into a single paragraph without losing much information.
kuhar
left a comment
There was a problem hiding this comment.
Code looks good, just some comments on the documentation
47b06d6 to
1c1731a
Compare
…et (001000.1)
In details the Activation operators introduced are:
spirv.Tosa.{Clamp,Erf,Sigmoid,Tanh}, along with dialect and
serialization round-trip tests.
Change-Id: I0e7bed3f5a2b0098a4d532ba1d577f1d82507aa0
Signed-off-by: Davide Grohmann <davide.grohmann@arm.com>
1c1731a to
eb11391
Compare
ghost
left a comment
There was a problem hiding this comment.
Thank you for clarifying the description.
LGTM
|
@kuhar any more comments on this PR from your side? |
llvm#178620) …et (001000.1) In details the Activation operators introduced are: spirv.Tosa.{Clamp,Erf,Sigmoid,Tanh}, along with dialect and serialization round-trip tests. Signed-off-by: Davide Grohmann <davide.grohmann@arm.com>
…et (001000.1)
In details the Activation operators introduced are: spirv.Tosa.{Clamp,Erf,Sigmoid,Tanh}, along with dialect and serialization round-trip tests.