diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td index 21010d91dc47c..4ea6d784dd88f 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -4915,6 +4915,17 @@ def SPIRV_FPFastMathModeAttr : // SPIR-V TOSA enum definitions. //===----------------------------------------------------------------------===// +// NOTE: This is an attribute in the SPIR-V *dialect* but a constant () in +// SPIR-V proper. +def SPIRV_TosaExtAccTypeAttr : SPIRV_I32EnumAttr< + "TosaExtAccType", "Tosa Ext Acculumator Type", "tosa_ext_acc_type", + [ + I32EnumAttrCase<"INT32", 1>, + I32EnumAttrCase<"FP16", 2>, + I32EnumAttrCase<"FP32", 3>, + I32EnumAttrCase<"INT48", 4>, + ]>; + // NOTE: This is an attribute in the SPIR-V *dialect* but a constant () in // SPIR-V proper. def SPIRV_TosaExtNaNPropagationModeAttr : SPIRV_I32EnumAttr< diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.h index 0e1f6e79a3670..4d43c7d7066ed 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.h +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.h @@ -16,6 +16,7 @@ #include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVTosaOps.h" #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" #include "mlir/Dialect/SPIRV/Interfaces/SPIRVImageInterfaces.h" #include "mlir/IR/BuiltinOps.h" diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.h new file mode 100644 index 0000000000000..175edaafbb857 --- /dev/null +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.h @@ -0,0 +1,19 @@ +//===- SPIRVTosaOps.h - MLIR SPIR-V Tosa operations -------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/BuiltinTypes.h" + +namespace mlir::spirv { + +ParseResult parseSPIRV_I32_1DArmTensor(OpAsmParser &parser, + DenseIntElementsAttr &attr); + +void printSPIRV_I32_1DArmTensor(OpAsmPrinter &printer, Operation *, + DenseIntElementsAttr attr); + +} // namespace mlir::spirv diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td index 6c6a318db4827..1b2f6a923d01b 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td @@ -51,7 +51,7 @@ def SPIRV_TosaArgMaxOp : SPIRV_TosaOp<"ArgMax", 0, [Pure]> { #### Example: ```mlir %2 = spirv.Tosa.ArgMax axis = 3, nan_mode = , %arg0 : !spirv.arm.tensor<3x28x17x17xi8> -> !spirv.arm.tensor<3x28x17xi32> - %2 = spirv.Tosa.ArgMax axis = 3, nan_mode = , %arg0 : !spirv.arm.tensor<2x2x7x14xf32> -> !spirv.arm.tensor<2x2x14xi32> + %2 = spirv.Tosa.ArgMax axis = 2, nan_mode = , %arg0 : !spirv.arm.tensor<2x2x7x14xf32> -> !spirv.arm.tensor<2x2x14xi32> ``` }]; @@ -83,4 +83,290 @@ def SPIRV_TosaArgMaxOp : SPIRV_TosaOp<"ArgMax", 0, [Pure]> { }]; } + +def SPIRV_TosaConv2DOp : SPIRV_TosaOp<"Conv2D", 2, [Pure, + AllElementTypesMatch<["bias", "output"]>, + AllElementTypesMatch<["input", "input_zp"]>, + AllElementTypesMatch<["weight", "weight_zp"]>]> { + let summary = "2D Convolution operator."; + + let description = [{ + Performs a 2D convolution over the given tensor input, using the weight + tensor. Implementations may choose to skip calculation of multiplies in + the padding area. + + References: + * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_conv2d + * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_conv2d + + #### Example: + ```mlir + %7 = spirv.Tosa.Conv2D pad = [1, 0, 0, 0], stride = [1, 2], dilation = [7, 1], acc_type = , local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1xi8>, !spirv.arm.tensor<7x1x1x1xi8>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x65536x2x7xi32> + %7 = spirv.Tosa.Conv2D pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], acc_type = , local_bound = true, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x34x18x27xf16>, !spirv.arm.tensor<11x1x1x27xf16>, !spirv.arm.tensor<11xf16>, !spirv.arm.tensor<1xf16>, !spirv.arm.tensor<1xf16> -> !spirv.arm.tensor<1x34x18x11xf16> + ``` + }]; + + let arguments = (ins + SPIRV_Int32_1DTensorArmOfLength4Attr: $pad, + SPIRV_Int32_1DTensorArmOfLength2Attr: $stride, + SPIRV_Int32_1DTensorArmOfLength2Attr: $dilation, + SPIRV_TosaExtAccTypeAttr: $acc_type, + SPIRV_BoolConstAttr: $local_bound, + SPIRV_TosaNumerical_TensorArm4D: $input, + SPIRV_TosaNumerical_TensorArm4D: $weight, + SPIRV_TosaNumerical_TensorArm1D: $bias, + SPIRV_TosaNumerical_1DTensorArmOfLength1: $input_zp, + SPIRV_TosaNumerical_1DTensorArmOfLength1: $weight_zp + ); + + let results = (outs + SPIRV_TosaNumerical_TensorArm4D: $output + ); + + let hasVerifier = 1; + + let assemblyFormat = [{ + `pad` `=` custom($pad) `,` + `stride` `=` custom($stride) `,` + `dilation` `=` custom($dilation) `,` + `acc_type` `=` $acc_type `,` + `local_bound` `=` $local_bound `,` + $input `,` + $weight `,` + $bias `,` + $input_zp `,` + $weight_zp + attr-dict `:` type(operands) `->` type(results) + }]; + + let extraClassDeclaration = [{ + ::mlir::spirv::TensorArmType getInputType() { + return cast<::mlir::spirv::TensorArmType>(getInput().getType()); + } + ::mlir::spirv::TensorArmType getWeightType() { + return cast<::mlir::spirv::TensorArmType>(getWeight().getType()); + } + ::mlir::spirv::TensorArmType getBiasType() { + return cast<::mlir::spirv::TensorArmType>(getBias().getType()); + } + ::mlir::spirv::TensorArmType getResultType() { + return cast<::mlir::spirv::TensorArmType>(getType()); + } + }]; +} + + +def SPIRV_TosaConv3DOp : SPIRV_TosaOp<"Conv3D", 3, [Pure, + AllElementTypesMatch<["bias", "output"]>, + AllElementTypesMatch<["input", "input_zp"]>, + AllElementTypesMatch<["weight", "weight_zp"]>]> { + let summary = "3D Convolution operator."; + + let description = [{ + Performs a 3D convolution over the given input tensor. Implementations + may choose to skip calculation of multiplies in the padding area. + + References: + * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_conv3d + * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_conv3d + + #### Example: + ```mlir + %7 = spirv.Tosa.Conv3D pad = [0, 0, 0, 0, 0, 0], stride = [1, 1, 1], dilation = [1, 1, 1], acc_type = , local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x9x21x14x1xi8>, !spirv.arm.tensor<2x1x2x1x1xi8>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x9x20x14x2xi32> + %7 = spirv.Tosa.Conv3D pad = [0, 1, 1, 0, 0, 1], stride = [1, 1, 1], dilation = [1, 1, 7], acc_type = , local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x2x65539x1x2xf32>, !spirv.arm.tensor<1x1x1x1x2xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32> -> !spirv.arm.tensor<1x3x65540x2x1xf32> + ``` + }]; + + let arguments = (ins + SPIRV_Int32_1DTensorArmOfLength6Attr: $pad, + SPIRV_Int32_1DTensorArmOfLength3Attr: $stride, + SPIRV_Int32_1DTensorArmOfLength3Attr: $dilation, + SPIRV_TosaExtAccTypeAttr: $acc_type, + SPIRV_BoolConstAttr: $local_bound, + SPIRV_TosaNumerical_TensorArm5D: $input, + SPIRV_TosaNumerical_TensorArm5D: $weight, + SPIRV_TosaNumerical_TensorArm1D: $bias, + SPIRV_TosaNumerical_1DTensorArmOfLength1: $input_zp, + SPIRV_TosaNumerical_1DTensorArmOfLength1: $weight_zp + ); + + let results = (outs + SPIRV_TosaNumerical_TensorArm5D: $output + ); + + let hasVerifier = 1; + + let assemblyFormat = [{ + `pad` `=` custom($pad) `,` + `stride` `=` custom($stride) `,` + `dilation` `=` custom($dilation) `,` + `acc_type` `=` $acc_type `,` + `local_bound` `=` $local_bound `,` + $input `,` + $weight `,` + $bias `,` + $input_zp `,` + $weight_zp + attr-dict `:` type(operands) `->` type(results) + }]; + + let extraClassDeclaration = [{ + ::mlir::spirv::TensorArmType getInputType() { + return cast<::mlir::spirv::TensorArmType>(getInput().getType()); + } + ::mlir::spirv::TensorArmType getWeightType() { + return cast<::mlir::spirv::TensorArmType>(getWeight().getType()); + } + ::mlir::spirv::TensorArmType getBiasType() { + return cast<::mlir::spirv::TensorArmType>(getBias().getType()); + } + ::mlir::spirv::TensorArmType getResultType() { + return cast<::mlir::spirv::TensorArmType>(getType()); + } + }]; +} + + +def SPIRV_TosaDepthwiseConv2DOp : SPIRV_TosaOp<"DepthwiseConv2D", 4, [Pure, + AllElementTypesMatch<["bias", "output"]>, + AllElementTypesMatch<["input", "input_zp"]>, + AllElementTypesMatch<["weight", "weight_zp"]>]> { + let summary = "Depthwise 2D Convolution operator."; + + let description = [{ + Performs 2D convolutions separately over each channel of the given tensor + input, using the weight tensor. Implementations may choose to skip + calculation of multiplies in the padding area. + + References: + * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_depthwise_conv2d + * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_depthwise_conv2d + + #### Example: + ```mlir + %7 = spirv.Tosa.DepthwiseConv2D pad = [0, 0, 0, 0], stride = [1, 2], dilation = [7, 7], acc_type = , local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x4x65537x1xi8>, !spirv.arm.tensor<1x3x1x4xi8>, !spirv.arm.tensor<4xi32>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x4x32762x4xi32> + %7 = spirv.Tosa.DepthwiseConv2D pad = [0, 1, 1, 1], stride = [1, 2], dilation = [1, 7], acc_type = , local_bound = true, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65540x1x3xf32>, !spirv.arm.tensor<1x1x3x1xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32> -> !spirv.arm.tensor<1x65541x2x3xf32> + ``` + }]; + + let arguments = (ins + SPIRV_Int32_1DTensorArmOfLength4Attr: $pad, + SPIRV_Int32_1DTensorArmOfLength2Attr: $stride, + SPIRV_Int32_1DTensorArmOfLength2Attr: $dilation, + SPIRV_TosaExtAccTypeAttr: $acc_type, + SPIRV_BoolConstAttr: $local_bound, + SPIRV_TosaNumerical_TensorArm4D: $input, + SPIRV_TosaNumerical_TensorArm4D: $weight, + SPIRV_TosaNumerical_TensorArm1D: $bias, + SPIRV_TosaNumerical_1DTensorArmOfLength1: $input_zp, + SPIRV_TosaNumerical_1DTensorArmOfLength1: $weight_zp + ); + + let results = (outs + SPIRV_TosaNumerical_TensorArm4D: $output + ); + + let hasVerifier = 1; + + let assemblyFormat = [{ + `pad` `=` custom($pad) `,` + `stride` `=` custom($stride) `,` + `dilation` `=` custom($dilation) `,` + `acc_type` `=` $acc_type `,` + `local_bound` `=` $local_bound `,` + $input `,` + $weight `,` + $bias `,` + $input_zp `,` + $weight_zp + attr-dict `:` type(operands) `->` type(results) + }]; + + let extraClassDeclaration = [{ + ::mlir::spirv::TensorArmType getInputType() { + return cast<::mlir::spirv::TensorArmType>(getInput().getType()); + } + ::mlir::spirv::TensorArmType getWeightType() { + return cast<::mlir::spirv::TensorArmType>(getWeight().getType()); + } + ::mlir::spirv::TensorArmType getBiasType() { + return cast<::mlir::spirv::TensorArmType>(getBias().getType()); + } + ::mlir::spirv::TensorArmType getResultType() { + return cast<::mlir::spirv::TensorArmType>(getType()); + } + }]; +} + + +def SPIRV_TosaTransposeConv2DOp : SPIRV_TosaOp<"TransposeConv2D", 9, [Pure, + AllElementTypesMatch<["bias", "output"]>, + AllElementTypesMatch<["input", "input_zp"]>, + AllElementTypesMatch<["weight", "weight_zp"]>]> { + let summary = "Transpose 2D Convolution operator."; + + let description = [{ + Performs a 2D transposed convolution over the given tensor input, using the + weights tensor. Implementations may choose to skip calculation of multiplies + by zero at fractional input positions. + + References: + * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_transpose_conv2d + * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_transpose_conv2d + + #### Example: + ```mlir + %6 = spirv.Tosa.TransposeConv2D out_pad = [0, 0, 0, 0], stride = [1, 1], acc_type = , local_bound = false, %arg0, %arg1, %arg2, %4, %5 : !spirv.arm.tensor<1x13x33x3xi16>, !spirv.arm.tensor<11x1x3x3xi8>, !spirv.arm.tensor<1xi64>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x13x35x11xi64> + %6 = spirv.Tosa.TransposeConv2D out_pad = [0, 1, 0, 0], stride = [1, 8], acc_type = , local_bound = true, %arg0, %arg1, %arg2, %4, %5 : !spirv.arm.tensor<10x24x9x13xf16>, !spirv.arm.tensor<14x1x1x13xf16>, !spirv.arm.tensor<14xf16>, !spirv.arm.tensor<1xf16>, !spirv.arm.tensor<1xf16> -> !spirv.arm.tensor<10x25x65x14xf16> + ``` + }]; + + let arguments = (ins + SPIRV_Int32_1DTensorArmOfLength4Attr: $out_pad, + SPIRV_Int32_1DTensorArmOfLength2Attr: $stride, + SPIRV_TosaExtAccTypeAttr: $acc_type, + SPIRV_BoolConstAttr: $local_bound, + SPIRV_TosaNumerical_TensorArm4D: $input, + SPIRV_TosaNumerical_TensorArm4D: $weight, + SPIRV_TosaNumerical_TensorArm1D: $bias, + SPIRV_TosaNumerical_1DTensorArmOfLength1: $input_zp, + SPIRV_TosaNumerical_1DTensorArmOfLength1: $weight_zp + ); + + let results = (outs + SPIRV_TosaNumerical_TensorArm4D: $output + ); + + let hasVerifier = 1; + + let assemblyFormat = [{ + `out_pad` `=` custom($out_pad) `,` + `stride` `=` custom($stride) `,` + `acc_type` `=` $acc_type `,` + `local_bound` `=` $local_bound `,` + $input `,` + $weight `,` + $bias `,` + $input_zp `,` + $weight_zp + attr-dict `:` type(operands) `->` type(results) + }]; + + let extraClassDeclaration = [{ + ::mlir::spirv::TensorArmType getInputType() { + return cast<::mlir::spirv::TensorArmType>(getInput().getType()); + } + ::mlir::spirv::TensorArmType getWeightType() { + return cast<::mlir::spirv::TensorArmType>(getWeight().getType()); + } + ::mlir::spirv::TensorArmType getBiasType() { + return cast<::mlir::spirv::TensorArmType>(getBias().getType()); + } + ::mlir::spirv::TensorArmType getResultType() { + return cast<::mlir::spirv::TensorArmType>(getType()); + } + }]; +} + + #endif // MLIR_DIALECT_SPIRV_IR_TOSA_OPS diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td index e731388182eb4..7e2c37f74b437 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td @@ -13,6 +13,7 @@ #ifndef MLIR_DIALECT_SPIRV_IR_TOSA_TYPES #define MLIR_DIALECT_SPIRV_IR_TOSA_TYPES +include "mlir/IR/CommonAttrConstraints.td" include "mlir/Dialect/SPIRV/IR/SPIRVBase.td" def SPIRV_TosaInteger : AnyIntOfWidths<[8, 16, 32, 64]>; @@ -21,6 +22,7 @@ def SPIRV_TosaNumerical : AnyTypeOf<[SPIRV_TosaInteger, SPIRV_TosaFloat]>; def SPIRV_TosaAny : AnyTypeOf<[SPIRV_TosaNumerical, SPIRV_Bool]>; def SPIRV_TensorArmAxisAttr : ConfinedAttr]>; +def SPIRV_BoolConstAttr : ConfinedAttr; // TensorARM Types @@ -35,7 +37,34 @@ class TensorArmRankOf allowedTypes, list ranks> [HasAnyRankOfPred], !interleave(!foreach(rank, ranks, rank # "D"), "/") # " tensorArm">; +def SPIRV_TosaNumerical_TensorArm1D : TensorArmRankOf<[SPIRV_TosaNumerical], [1]>; +def SPIRV_TosaNumerical_TensorArm4D : TensorArmRankOf<[SPIRV_TosaNumerical], [4]>; +def SPIRV_TosaNumerical_TensorArm5D : TensorArmRankOf<[SPIRV_TosaNumerical], [5]>; + def SPIRV_TosaNumerical_TensorArm : TensorArmRankOf<[SPIRV_TosaNumerical], [1, 2, 3, 4, 5, 6]>; def SPIRV_Int32_TensorArmUpTo5D : TensorArmRankOf<[SPIRV_Int32], [1, 2, 3, 4, 5]>; +class Is1DTensorArmOfLength allowedLengths> : + And<[HasAnyRankOfPred<[1]>, + Or($_self).getShape()[0] == }] + # allowedlength>)>]>; + +class SPIRV_1DTensorArmOfLengthAndType allowedLengths, list allowedTypes> : + ContainerType, Is1DTensorArmOfLength, + "::llvm::cast<::mlir::spirv::TensorArmType>($_self).getElementType()", + "rank 1 tensorArm of length " # !interleave(allowedLengths, "/"), + "::mlir::spirv::TensorArmType">; + +def SPIRV_DenseElementAttrsWithTensorArmType : AttrConstraint< + CPred<"::llvm::isa<::mlir::spirv::TensorArmType>(::llvm::cast<::mlir::DenseElementsAttr>($_self).getType())">, + "Attr with type = spirv::TensorArmType">; + +def SPIRV_Int32_1DTensorArmOfLength2Attr : ConfinedAttr, [SPIRV_DenseElementAttrsWithTensorArmType]>; +def SPIRV_Int32_1DTensorArmOfLength3Attr : ConfinedAttr, [SPIRV_DenseElementAttrsWithTensorArmType]>; +def SPIRV_Int32_1DTensorArmOfLength4Attr : ConfinedAttr, [SPIRV_DenseElementAttrsWithTensorArmType]>; +def SPIRV_Int32_1DTensorArmOfLength6Attr : ConfinedAttr, [SPIRV_DenseElementAttrsWithTensorArmType]>; + +def SPIRV_TosaNumerical_1DTensorArmOfLength1 : SPIRV_1DTensorArmOfLengthAndType<[1], [SPIRV_TosaNumerical]>; + #endif // MLIR_DIALECT_SPIRV_IR_TOSA_TYPES diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp index 4f3c91d4a1c12..1d3e1084d5a9e 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp @@ -9,10 +9,9 @@ // This file defines the Tosa operations in the SPIR-V dialect. // //===----------------------------------------------------------------------===// + #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Matchers.h" -#include "mlir/IR/TypeUtilities.h" +#include "llvm/Support/InterleavedRange.h" namespace mlir::spirv { @@ -20,6 +19,60 @@ namespace mlir::spirv { // TOSA Operator Verifiers. //===----------------------------------------------------------------------===// +namespace { + +LogicalResult verifyConvOp(Operation *op, Type inputETy, Type resultETy, + TosaExtAccType accType) { + if (inputETy.isInteger() && !inputETy.isInteger(8) && + !inputETy.isInteger(16)) { + return op->emitOpError( + "input element type can only be of width 8 or 16 when integer type"); + } + + if (inputETy.isInteger(8) && !resultETy.isInteger(32)) { + return op->emitOpError("expect result type to be i32, got ") << resultETy; + } + + if (inputETy.isInteger(16) && !resultETy.isInteger(64)) { + return op->emitOpError("expect result type to be i64, got ") << resultETy; + } + + if (inputETy.isF16() && !resultETy.isF16()) { + return op->emitOpError("expect result type to be f16, got ") << resultETy; + } + + if (inputETy.isF32() && !resultETy.isF32()) { + return op->emitOpError("expect result type to be f32, got ") << resultETy; + } + + if (inputETy.isInteger(8) && accType != TosaExtAccType::INT32) { + return op->emitOpError("accumulator type for i8 tensorARM is not i32"); + } + + if (inputETy.isInteger(16) && accType != TosaExtAccType::INT48) { + return op->emitOpError("accumulator type for i16 tensorARM is not i48"); + } + + if (inputETy.isF16() && + !llvm::is_contained({TosaExtAccType::FP16, TosaExtAccType::FP32}, + accType)) { + return op->emitOpError( + "accumulator type for f16 tensorARM is not f16 or f32"); + } + + if (inputETy.isBF16() && accType != TosaExtAccType::FP32) { + return op->emitOpError("accumulator type for bf16 tensorARM is not f32"); + } + + if (inputETy.isF32() && accType != TosaExtAccType::FP32) { + return op->emitOpError("accumulator type for f32 tensorARM is not f32"); + } + + return success(); +} + +} // namespace + //===----------------------------------------------------------------------===// // spirv.TosaArgmaxOp //===----------------------------------------------------------------------===// @@ -46,4 +99,81 @@ LogicalResult TosaArgMaxOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// spirv.TosaConv2DOp +//===----------------------------------------------------------------------===// + +LogicalResult TosaConv2DOp::verify() { + Type inputETy = getInputType().getElementType(); + Type resultETy = getResultType().getElementType(); + TosaExtAccType accType = getAccType(); + return verifyConvOp(this->getOperation(), inputETy, resultETy, accType); +} + +//===----------------------------------------------------------------------===// +// spirv.TosaConv3DOp +//===----------------------------------------------------------------------===// + +LogicalResult TosaConv3DOp::verify() { + Type inputETy = getInputType().getElementType(); + Type resultETy = getResultType().getElementType(); + TosaExtAccType accType = getAccType(); + return verifyConvOp(this->getOperation(), inputETy, resultETy, accType); +} + +//===----------------------------------------------------------------------===// +// SPIRV Tosa DepthwiseConv2D Ops: +//===----------------------------------------------------------------------===// + +LogicalResult TosaDepthwiseConv2DOp::verify() { + Type inputETy = getInputType().getElementType(); + Type resultETy = getResultType().getElementType(); + TosaExtAccType accType = getAccType(); + return verifyConvOp(this->getOperation(), inputETy, resultETy, accType); +} + +//===----------------------------------------------------------------------===// +// SPIRV Tosa TransposeConv2D Ops: +//===----------------------------------------------------------------------===// + +LogicalResult TosaTransposeConv2DOp::verify() { + Type inputETy = getInputType().getElementType(); + Type resultETy = getResultType().getElementType(); + TosaExtAccType accType = getAccType(); + return verifyConvOp(this->getOperation(), inputETy, resultETy, accType); +} + +//===----------------------------------------------------------------------===// +// SPIRV Tosa Custom formatters +//===----------------------------------------------------------------------===// + +ParseResult parseSPIRV_I32_1DArmTensor(OpAsmParser &parser, + DenseIntElementsAttr &attr) { + SmallVector elements; + auto f = [&]() { + int32_t value; + ParseResult r = parser.parseInteger(value); + elements.push_back(value); + return r; + }; + if (parser.parseCommaSeparatedList( + OpAsmParser::Delimiter::Square, f, + "parsing values in integer list attribute")) { + return failure(); + } + + auto i32Type = IntegerType::get(parser.getContext(), 32); + auto type = TensorArmType::get( + ArrayRef{static_cast(elements.size())}, i32Type); + attr = DenseIntElementsAttr::get(type, elements); + return success(); +} + +void printSPIRV_I32_1DArmTensor(OpAsmPrinter &printer, Operation *, + DenseIntElementsAttr attr) { + printer << llvm::interleaved_array( + llvm::map_range(attr.getValues(), + [](const APInt &a) { return a.getSExtValue(); })); +} + } // namespace mlir::spirv diff --git a/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir b/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir index a6496316f9881..2099630aff0fb 100644 --- a/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir +++ b/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir @@ -21,3 +21,340 @@ spirv.ARM.Graph @argmax_axis_value_not_in_input_rank_range(%arg0: !spirv.arm.ten %2 = spirv.Tosa.ArgMax axis = 4, nan_mode = , %arg0 : !spirv.arm.tensor<3x28x17x17xi8> -> !spirv.arm.tensor<3x28x17xi32> spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<3x28x17xi32> } + +//===----------------------------------------------------------------------===// +// spirv.TOSA.Conv2D +//===----------------------------------------------------------------------===// + +spirv.ARM.Graph @conv2d_wrong_input_integer_element_type(%arg0: !spirv.arm.tensor<1x65535x3x1xi32>, %arg1: !spirv.arm.tensor<7x1x1x1xi32>, %arg2: !spirv.arm.tensor<1xi64>) -> (!spirv.arm.tensor<1x65536x2x7xi64>) { + %5 = spirv.Constant dense<35> : !spirv.arm.tensor<1xi32> + %6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi32> + // expected-error @+1 {{op input element type can only be of width 8 or 16 when integer type}} + %7 = spirv.Tosa.Conv2D pad = [1, 0, 0, 0], stride = [1, 2], dilation = [7, 1], acc_type = , local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1xi32>, !spirv.arm.tensor<7x1x1x1xi32>, !spirv.arm.tensor<1xi64>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi32> -> !spirv.arm.tensor<1x65536x2x7xi64> + spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7xi64> +} + +spirv.ARM.Graph @conv2d_mismatch_result_element_type_i8_input(%arg0: !spirv.arm.tensor<1x65535x3x1xi8>, %arg1: !spirv.arm.tensor<7x1x1x1xi8>, %arg2: !spirv.arm.tensor<1xi16>) -> (!spirv.arm.tensor<1x65536x2x7xi16>) { + %5 = spirv.Constant dense<35> : !spirv.arm.tensor<1xi8> + %6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi8> + // expected-error @+1 {{op expect result type to be i32, got 'i16'}} + %7 = spirv.Tosa.Conv2D pad = [1, 0, 0, 0], stride = [1, 2], dilation = [7, 1], acc_type = , local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1xi8>, !spirv.arm.tensor<7x1x1x1xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x65536x2x7xi16> + spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7xi16> +} + +spirv.ARM.Graph @conv2d_mismatch_result_element_type_i16_input(%arg0: !spirv.arm.tensor<1x65535x3x1xi16>, %arg1: !spirv.arm.tensor<7x1x1x1xi16>, %arg2: !spirv.arm.tensor<1xi32>) -> (!spirv.arm.tensor<1x65536x2x7xi32>) { + %5 = spirv.Constant dense<35> : !spirv.arm.tensor<1xi16> + %6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi16> + // expected-error @+1 {{op expect result type to be i64, got 'i32'}} + %7 = spirv.Tosa.Conv2D pad = [1, 0, 0, 0], stride = [1, 2], dilation = [7, 1], acc_type = , local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1xi16>, !spirv.arm.tensor<7x1x1x1xi16>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<1x65536x2x7xi32> + spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7xi32> +} + +spirv.ARM.Graph @conv2d_mismatch_result_element_type_f16_input(%arg0: !spirv.arm.tensor<1x34x18x27xf16>, %arg1: !spirv.arm.tensor<11x1x1x27xf16>, %arg2: !spirv.arm.tensor<11xf32>) -> (!spirv.arm.tensor<1x34x18x11xf32>) { + %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf16> + %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf16> + // expected-error @+1 {{op expect result type to be f16, got 'f32'}} + %7 = spirv.Tosa.Conv2D pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], acc_type = , local_bound = true, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x34x18x27xf16>, !spirv.arm.tensor<11x1x1x27xf16>, !spirv.arm.tensor<11xf32>, !spirv.arm.tensor<1xf16>, !spirv.arm.tensor<1xf16> -> !spirv.arm.tensor<1x34x18x11xf32> + spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x34x18x11xf32> +} + +spirv.ARM.Graph @conv2d_mismatch_result_element_type_f32_input(%arg0: !spirv.arm.tensor<1x34x18x27xf32>, %arg1: !spirv.arm.tensor<11x1x1x27xf32>, %arg2: !spirv.arm.tensor<11xf16>) -> (!spirv.arm.tensor<1x34x18x11xf16>) { + %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf32> + %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf32> + // expected-error @+1 {{op expect result type to be f32, got 'f16'}} + %7 = spirv.Tosa.Conv2D pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], acc_type = , local_bound = true, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x34x18x27xf32>, !spirv.arm.tensor<11x1x1x27xf32>, !spirv.arm.tensor<11xf16>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32> -> !spirv.arm.tensor<1x34x18x11xf16> + spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x34x18x11xf16> +} + +spirv.ARM.Graph @conv2d_bias_element_type_must_be_same_as_result_element_type(%arg0: !spirv.arm.tensor<1x65535x3x1xi8>, %arg1: !spirv.arm.tensor<7x1x1x1xi8>, %arg2: !spirv.arm.tensor<1xi16>) -> (!spirv.arm.tensor<1x65536x2x7xi32>) { + %5 = spirv.Constant dense<35> : !spirv.arm.tensor<1xi8> + %6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi8> + // expected-error @+1 {{op failed to verify that all of {bias, output} have same element type}} + %7 = spirv.Tosa.Conv2D pad = [1, 0, 0, 0], stride = [1, 2], dilation = [7, 1], acc_type = , local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1xi8>, !spirv.arm.tensor<7x1x1x1xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x65536x2x7xi32> + spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7xi32> +} + +spirv.ARM.Graph @conv2d_accumulator_must_be_INT32_for_i8_input_element_type(%arg0: !spirv.arm.tensor<1x65535x3x1xi8>, %arg1: !spirv.arm.tensor<7x1x1x1xi8>, %arg2: !spirv.arm.tensor<1xi32>) -> (!spirv.arm.tensor<1x65536x2x7xi32>) { + %5 = spirv.Constant dense<35> : !spirv.arm.tensor<1xi8> + %6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi8> + // expected-error @+1 {{op accumulator type for i8 tensorARM is not i32}} + %7 = spirv.Tosa.Conv2D pad = [1, 0, 0, 0], stride = [1, 2], dilation = [7, 1], acc_type = , local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1xi8>, !spirv.arm.tensor<7x1x1x1xi8>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x65536x2x7xi32> + spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7xi32> +} + +spirv.ARM.Graph @conv2d_accumulator_must_be_INT48_for_i16_input_element_type(%arg0: !spirv.arm.tensor<1x65535x3x1xi16>, %arg1: !spirv.arm.tensor<7x1x1x1xi16>, %arg2: !spirv.arm.tensor<1xi64>) -> (!spirv.arm.tensor<1x65536x2x7xi64>) { + %5 = spirv.Constant dense<35> : !spirv.arm.tensor<1xi16> + %6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi16> + // expected-error @+1 {{op accumulator type for i16 tensorARM is not i48}} + %7 = spirv.Tosa.Conv2D pad = [1, 0, 0, 0], stride = [1, 2], dilation = [7, 1], acc_type = , local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1xi16>, !spirv.arm.tensor<7x1x1x1xi16>, !spirv.arm.tensor<1xi64>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<1x65536x2x7xi64> + spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7xi64> +} + +spirv.ARM.Graph @conv2d_accumulator_must_be_either_FP16_or_FP32_for_f16_input_element_type(%arg0: !spirv.arm.tensor<1x34x18x27xf16>, %arg1: !spirv.arm.tensor<11x1x1x27xf16>, %arg2: !spirv.arm.tensor<11xf16>) -> (!spirv.arm.tensor<1x34x18x11xf16>) { + %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf16> + %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf16> + // expected-error @+1 {{op accumulator type for f16 tensorARM is not f16 or f32}} + %7 = spirv.Tosa.Conv2D pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], acc_type = , local_bound = true, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x34x18x27xf16>, !spirv.arm.tensor<11x1x1x27xf16>, !spirv.arm.tensor<11xf16>, !spirv.arm.tensor<1xf16>, !spirv.arm.tensor<1xf16> -> !spirv.arm.tensor<1x34x18x11xf16> + spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x34x18x11xf16> +} + +spirv.ARM.Graph @conv2d_accumulator_must_be_either_FP32_for_f32_input_element_type(%arg0: !spirv.arm.tensor<1x34x18x27xf32>, %arg1: !spirv.arm.tensor<11x1x1x27xf32>, %arg2: !spirv.arm.tensor<11xf32>) -> (!spirv.arm.tensor<1x34x18x11xf32>) { + %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf32> + %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf32> + // expected-error @+1 {{op accumulator type for f32 tensorARM is not f32}} + %7 = spirv.Tosa.Conv2D pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], acc_type = , local_bound = true, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x34x18x27xf32>, !spirv.arm.tensor<11x1x1x27xf32>, !spirv.arm.tensor<11xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32> -> !spirv.arm.tensor<1x34x18x11xf32> + spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x34x18x11xf32> +} + + +//===----------------------------------------------------------------------===// +// spirv.TOSA.Conv3D +//===----------------------------------------------------------------------===// + +spirv.ARM.Graph @conv3d_wrong_input_integer_element_type(%arg0: !spirv.arm.tensor<1x65535x3x1x1xi32>, %arg1: !spirv.arm.tensor<7x1x1x1x1xi32>, %arg2: !spirv.arm.tensor<1xi64>) -> (!spirv.arm.tensor<1x65536x2x7x1xi64>) { + %5 = spirv.Constant dense<35> : !spirv.arm.tensor<1xi32> + %6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi32> + // expected-error @+1 {{op input element type can only be of width 8 or 16 when integer type}} + %7 = spirv.Tosa.Conv3D pad = [1, 0, 0, 0, 0, 0], stride = [1, 2, 3], dilation = [7, 1, 1], acc_type = , local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1x1xi32>, !spirv.arm.tensor<7x1x1x1x1xi32>, !spirv.arm.tensor<1xi64>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi32> -> !spirv.arm.tensor<1x65536x2x7x1xi64> + spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7x1xi64> +} + +spirv.ARM.Graph @conv3d_mismatch_result_element_type_i8_input(%arg0: !spirv.arm.tensor<1x65535x3x1x1xi8>, %arg1: !spirv.arm.tensor<7x1x1x1x1xi8>, %arg2: !spirv.arm.tensor<1xi16>) -> (!spirv.arm.tensor<1x65536x2x7x1xi16>) { + %5 = spirv.Constant dense<35> : !spirv.arm.tensor<1xi8> + %6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi8> + // expected-error @+1 {{op expect result type to be i32, got 'i16'}} + %7 = spirv.Tosa.Conv3D pad = [1, 0, 0, 0, 0, 0], stride = [1, 2, 3], dilation = [7, 1, 1], acc_type = , local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1x1xi8>, !spirv.arm.tensor<7x1x1x1x1xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x65536x2x7x1xi16> + spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7x1xi16> +} + +spirv.ARM.Graph @conv3d_mismatch_result_element_type_i16_input(%arg0: !spirv.arm.tensor<1x65535x3x1x1xi16>, %arg1: !spirv.arm.tensor<7x1x1x1x1xi16>, %arg2: !spirv.arm.tensor<1xi32>) -> (!spirv.arm.tensor<1x65536x2x7x1xi32>) { + %5 = spirv.Constant dense<35> : !spirv.arm.tensor<1xi16> + %6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi16> + // expected-error @+1 {{op expect result type to be i64, got 'i32'}} + %7 = spirv.Tosa.Conv3D pad = [1, 0, 0, 0, 0, 0], stride = [1, 2, 3], dilation = [7, 1, 1], acc_type = , local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1x1xi16>, !spirv.arm.tensor<7x1x1x1x1xi16>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<1x65536x2x7x1xi32> + spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7x1xi32> +} + +spirv.ARM.Graph @conv3d_mismatch_result_element_type_f16_input(%arg0: !spirv.arm.tensor<1x34x18x27x1xf16>, %arg1: !spirv.arm.tensor<11x1x1x27x1xf16>, %arg2: !spirv.arm.tensor<11xf32>) -> (!spirv.arm.tensor<1x34x18x11x1xf32>) { + %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf16> + %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf16> + // expected-error @+1 {{op expect result type to be f16, got 'f32'}} + %7 = spirv.Tosa.Conv3D pad = [0, 0, 0, 0, 0, 0], stride = [1, 1, 1], dilation = [1, 1, 1], acc_type = , local_bound = true, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x34x18x27x1xf16>, !spirv.arm.tensor<11x1x1x27x1xf16>, !spirv.arm.tensor<11xf32>, !spirv.arm.tensor<1xf16>, !spirv.arm.tensor<1xf16> -> !spirv.arm.tensor<1x34x18x11x1xf32> + spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x34x18x11x1xf32> +} + +spirv.ARM.Graph @conv3d_mismatch_result_element_type_f32_input(%arg0: !spirv.arm.tensor<1x34x18x27x1xf32>, %arg1: !spirv.arm.tensor<11x1x1x27x1xf32>, %arg2: !spirv.arm.tensor<11xf16>) -> (!spirv.arm.tensor<1x34x18x11x1xf16>) { + %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf32> + %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf32> + // expected-error @+1 {{op expect result type to be f32, got 'f16'}} + %7 = spirv.Tosa.Conv3D pad = [0, 0, 0, 0, 0, 0], stride = [1, 1, 1], dilation = [1, 1, 1], acc_type = , local_bound = true, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x34x18x27x1xf32>, !spirv.arm.tensor<11x1x1x27x1xf32>, !spirv.arm.tensor<11xf16>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32> -> !spirv.arm.tensor<1x34x18x11x1xf16> + spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x34x18x11x1xf16> +} + +spirv.ARM.Graph @conv3d_bias_element_type_must_be_same_as_result_element_type(%arg0: !spirv.arm.tensor<1x65535x3x1x1xi8>, %arg1: !spirv.arm.tensor<7x1x1x1x1xi8>, %arg2: !spirv.arm.tensor<1xi16>) -> (!spirv.arm.tensor<1x65536x2x7x1xi32>) { + %5 = spirv.Constant dense<35> : !spirv.arm.tensor<1xi8> + %6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi8> + // expected-error @+1 {{op failed to verify that all of {bias, output} have same element type}} + %7 = spirv.Tosa.Conv3D pad = [1, 0, 0, 0, 0, 0], stride = [1, 2, 3], dilation = [7, 1, 1], acc_type = , local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1x1xi8>, !spirv.arm.tensor<7x1x1x1x1xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x65536x2x7x1xi32> + spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7x1xi32> +} + +spirv.ARM.Graph @conv3d_accumulator_must_be_INT32_for_i8_input_element_type(%arg0: !spirv.arm.tensor<1x65535x3x1x1xi8>, %arg1: !spirv.arm.tensor<7x1x1x1x1xi8>, %arg2: !spirv.arm.tensor<1xi32>) -> (!spirv.arm.tensor<1x65536x2x7x1xi32>) { + %5 = spirv.Constant dense<35> : !spirv.arm.tensor<1xi8> + %6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi8> + // expected-error @+1 {{op accumulator type for i8 tensorARM is not i32}} + %7 = spirv.Tosa.Conv3D pad = [1, 0, 0, 0, 0, 0], stride = [1, 2, 3], dilation = [7, 1, 1], acc_type = , local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1x1xi8>, !spirv.arm.tensor<7x1x1x1x1xi8>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x65536x2x7x1xi32> + spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7x1xi32> +} + +spirv.ARM.Graph @conv3d_accumulator_must_be_INT48_for_i16_input_element_type(%arg0: !spirv.arm.tensor<1x65535x3x1x1xi16>, %arg1: !spirv.arm.tensor<7x1x1x1x1xi16>, %arg2: !spirv.arm.tensor<1xi64>) -> (!spirv.arm.tensor<1x65536x2x7x1xi64>) { + %5 = spirv.Constant dense<35> : !spirv.arm.tensor<1xi16> + %6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi16> + // expected-error @+1 {{op accumulator type for i16 tensorARM is not i48}} + %7 = spirv.Tosa.Conv3D pad = [1, 0, 0, 0, 0, 0], stride = [1, 2, 3], dilation = [7, 1, 1], acc_type = , local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1x1xi16>, !spirv.arm.tensor<7x1x1x1x1xi16>, !spirv.arm.tensor<1xi64>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<1x65536x2x7x1xi64> + spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7x1xi64> +} + +spirv.ARM.Graph @conv3d_accumulator_must_be_either_FP16_or_FP32_for_f16_input_element_type(%arg0: !spirv.arm.tensor<1x34x18x27x1xf16>, %arg1: !spirv.arm.tensor<11x1x1x27x1xf16>, %arg2: !spirv.arm.tensor<11xf16>) -> (!spirv.arm.tensor<1x34x18x11x1xf16>) { + %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf16> + %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf16> + // expected-error @+1 {{op accumulator type for f16 tensorARM is not f16 or f32}} + %7 = spirv.Tosa.Conv3D pad = [0, 0, 0, 0, 0, 0], stride = [1, 1, 1], dilation = [1, 1, 1], acc_type = , local_bound = true, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x34x18x27x1xf16>, !spirv.arm.tensor<11x1x1x27x1xf16>, !spirv.arm.tensor<11xf16>, !spirv.arm.tensor<1xf16>, !spirv.arm.tensor<1xf16> -> !spirv.arm.tensor<1x34x18x11x1xf16> + spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x34x18x11x1xf16> +} + +spirv.ARM.Graph @conv3d_accumulator_must_be_either_FP32_for_f32_input_element_type(%arg0: !spirv.arm.tensor<1x34x18x27x1xf32>, %arg1: !spirv.arm.tensor<11x1x1x27x1xf32>, %arg2: !spirv.arm.tensor<11xf32>) -> (!spirv.arm.tensor<1x34x18x11x1xf32>) { + %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf32> + %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf32> + // expected-error @+1 {{op accumulator type for f32 tensorARM is not f32}} + %7 = spirv.Tosa.Conv3D pad = [0, 0, 0, 0, 0, 0], stride = [1, 1, 1], dilation = [1, 1, 1], acc_type = , local_bound = true, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x34x18x27x1xf32>, !spirv.arm.tensor<11x1x1x27x1xf32>, !spirv.arm.tensor<11xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32> -> !spirv.arm.tensor<1x34x18x11x1xf32> + spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x34x18x11x1xf32> +} + +//===----------------------------------------------------------------------===// +// spirv.TOSA.DepthwiseConv2D +//===----------------------------------------------------------------------===// + +spirv.ARM.Graph @depthwise_conv2d_wrong_input_integer_element_type(%arg0: !spirv.arm.tensor<1x65535x3x1xi32>, %arg1: !spirv.arm.tensor<7x1x1x1xi32>, %arg2: !spirv.arm.tensor<1xi64>) -> (!spirv.arm.tensor<1x65536x2x7xi64>) { + %5 = spirv.Constant dense<35> : !spirv.arm.tensor<1xi32> + %6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi32> + // expected-error @+1 {{op input element type can only be of width 8 or 16 when integer type}} + %7 = spirv.Tosa.DepthwiseConv2D pad = [1, 0, 0, 0], stride = [1, 2], dilation = [7, 1], acc_type = , local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1xi32>, !spirv.arm.tensor<7x1x1x1xi32>, !spirv.arm.tensor<1xi64>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi32> -> !spirv.arm.tensor<1x65536x2x7xi64> + spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7xi64> +} + +spirv.ARM.Graph @depthwise_conv2d_mismatch_result_element_type_i8_input(%arg0: !spirv.arm.tensor<1x65535x3x1xi8>, %arg1: !spirv.arm.tensor<7x1x1x1xi8>, %arg2: !spirv.arm.tensor<1xi16>) -> (!spirv.arm.tensor<1x65536x2x7xi16>) { + %5 = spirv.Constant dense<35> : !spirv.arm.tensor<1xi8> + %6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi8> + // expected-error @+1 {{op expect result type to be i32, got 'i16'}} + %7 = spirv.Tosa.DepthwiseConv2D pad = [1, 0, 0, 0], stride = [1, 2], dilation = [7, 1], acc_type = , local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1xi8>, !spirv.arm.tensor<7x1x1x1xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x65536x2x7xi16> + spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7xi16> +} + +spirv.ARM.Graph @depthwise_conv2d_mismatch_result_element_type_i16_input(%arg0: !spirv.arm.tensor<1x65535x3x1xi16>, %arg1: !spirv.arm.tensor<7x1x1x1xi16>, %arg2: !spirv.arm.tensor<1xi32>) -> (!spirv.arm.tensor<1x65536x2x7xi32>) { + %5 = spirv.Constant dense<35> : !spirv.arm.tensor<1xi16> + %6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi16> + // expected-error @+1 {{op expect result type to be i64, got 'i32'}} + %7 = spirv.Tosa.DepthwiseConv2D pad = [1, 0, 0, 0], stride = [1, 2], dilation = [7, 1], acc_type = , local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1xi16>, !spirv.arm.tensor<7x1x1x1xi16>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<1x65536x2x7xi32> + spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7xi32> +} + +spirv.ARM.Graph @depthwise_conv2d_mismatch_result_element_type_f16_input(%arg0: !spirv.arm.tensor<1x34x18x27xf16>, %arg1: !spirv.arm.tensor<11x1x1x27xf16>, %arg2: !spirv.arm.tensor<11xf32>) -> (!spirv.arm.tensor<1x34x18x11xf32>) { + %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf16> + %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf16> + // expected-error @+1 {{op expect result type to be f16, got 'f32'}} + %7 = spirv.Tosa.DepthwiseConv2D pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], acc_type = , local_bound = true, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x34x18x27xf16>, !spirv.arm.tensor<11x1x1x27xf16>, !spirv.arm.tensor<11xf32>, !spirv.arm.tensor<1xf16>, !spirv.arm.tensor<1xf16> -> !spirv.arm.tensor<1x34x18x11xf32> + spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x34x18x11xf32> +} + +spirv.ARM.Graph @depthwise_conv2d_mismatch_result_element_type_f32_input(%arg0: !spirv.arm.tensor<1x34x18x27xf32>, %arg1: !spirv.arm.tensor<11x1x1x27xf32>, %arg2: !spirv.arm.tensor<11xf16>) -> (!spirv.arm.tensor<1x34x18x11xf16>) { + %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf32> + %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf32> + // expected-error @+1 {{op expect result type to be f32, got 'f16'}} + %7 = spirv.Tosa.DepthwiseConv2D pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], acc_type = , local_bound = true, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x34x18x27xf32>, !spirv.arm.tensor<11x1x1x27xf32>, !spirv.arm.tensor<11xf16>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32> -> !spirv.arm.tensor<1x34x18x11xf16> + spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x34x18x11xf16> +} + +spirv.ARM.Graph @depthwise_conv2d_bias_element_type_must_be_same_as_result_element_type(%arg0: !spirv.arm.tensor<1x65535x3x1xi8>, %arg1: !spirv.arm.tensor<7x1x1x1xi8>, %arg2: !spirv.arm.tensor<1xi16>) -> (!spirv.arm.tensor<1x65536x2x7xi32>) { + %5 = spirv.Constant dense<35> : !spirv.arm.tensor<1xi8> + %6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi8> + // expected-error @+1 {{op failed to verify that all of {bias, output} have same element type}} + %7 = spirv.Tosa.DepthwiseConv2D pad = [1, 0, 0, 0], stride = [1, 2], dilation = [7, 1], acc_type = , local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1xi8>, !spirv.arm.tensor<7x1x1x1xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x65536x2x7xi32> + spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7xi32> +} + +spirv.ARM.Graph @depthwise_conv2d_accumulator_must_be_INT32_for_i8_input_element_type(%arg0: !spirv.arm.tensor<1x65535x3x1xi8>, %arg1: !spirv.arm.tensor<7x1x1x1xi8>, %arg2: !spirv.arm.tensor<1xi32>) -> (!spirv.arm.tensor<1x65536x2x7xi32>) { + %5 = spirv.Constant dense<35> : !spirv.arm.tensor<1xi8> + %6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi8> + // expected-error @+1 {{op accumulator type for i8 tensorARM is not i32}} + %7 = spirv.Tosa.DepthwiseConv2D pad = [1, 0, 0, 0], stride = [1, 2], dilation = [7, 1], acc_type = , local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1xi8>, !spirv.arm.tensor<7x1x1x1xi8>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x65536x2x7xi32> + spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7xi32> +} + +spirv.ARM.Graph @depthwise_conv2d_accumulator_must_be_INT48_for_i16_input_element_type(%arg0: !spirv.arm.tensor<1x65535x3x1xi16>, %arg1: !spirv.arm.tensor<7x1x1x1xi16>, %arg2: !spirv.arm.tensor<1xi64>) -> (!spirv.arm.tensor<1x65536x2x7xi64>) { + %5 = spirv.Constant dense<35> : !spirv.arm.tensor<1xi16> + %6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi16> + // expected-error @+1 {{op accumulator type for i16 tensorARM is not i48}} + %7 = spirv.Tosa.DepthwiseConv2D pad = [1, 0, 0, 0], stride = [1, 2], dilation = [7, 1], acc_type = , local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1xi16>, !spirv.arm.tensor<7x1x1x1xi16>, !spirv.arm.tensor<1xi64>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<1x65536x2x7xi64> + spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7xi64> +} + +spirv.ARM.Graph @depthwise_conv2d_accumulator_must_be_either_FP16_or_FP32_for_f16_input_element_type(%arg0: !spirv.arm.tensor<1x34x18x27xf16>, %arg1: !spirv.arm.tensor<11x1x1x27xf16>, %arg2: !spirv.arm.tensor<11xf16>) -> (!spirv.arm.tensor<1x34x18x11xf16>) { + %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf16> + %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf16> + // expected-error @+1 {{op accumulator type for f16 tensorARM is not f16 or f32}} + %7 = spirv.Tosa.DepthwiseConv2D pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], acc_type = , local_bound = true, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x34x18x27xf16>, !spirv.arm.tensor<11x1x1x27xf16>, !spirv.arm.tensor<11xf16>, !spirv.arm.tensor<1xf16>, !spirv.arm.tensor<1xf16> -> !spirv.arm.tensor<1x34x18x11xf16> + spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x34x18x11xf16> +} + +spirv.ARM.Graph @depthwise_conv2d_accumulator_must_be_either_FP32_for_f32_input_element_type(%arg0: !spirv.arm.tensor<1x34x18x27xf32>, %arg1: !spirv.arm.tensor<11x1x1x27xf32>, %arg2: !spirv.arm.tensor<11xf32>) -> (!spirv.arm.tensor<1x34x18x11xf32>) { + %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf32> + %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf32> + // expected-error @+1 {{op accumulator type for f32 tensorARM is not f32}} + %7 = spirv.Tosa.DepthwiseConv2D pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], acc_type = , local_bound = true, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x34x18x27xf32>, !spirv.arm.tensor<11x1x1x27xf32>, !spirv.arm.tensor<11xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32> -> !spirv.arm.tensor<1x34x18x11xf32> + spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x34x18x11xf32> +} + +//===----------------------------------------------------------------------===// +// spirv.TOSA.TransposeConv2D +//===----------------------------------------------------------------------===// + +spirv.ARM.Graph @transpose_conv2d_wrong_input_integer_element_type(%arg0: !spirv.arm.tensor<1x65535x3x1xi32>, %arg1: !spirv.arm.tensor<7x1x1x1xi32>, %arg2: !spirv.arm.tensor<1xi64>) -> (!spirv.arm.tensor<1x65536x2x7xi64>) { + %5 = spirv.Constant dense<35> : !spirv.arm.tensor<1xi32> + %6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi32> + // expected-error @+1 {{op input element type can only be of width 8 or 16 when integer type}} + %7 = spirv.Tosa.TransposeConv2D out_pad = [1, 0, 0, 0], stride = [1, 2], acc_type = , local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1xi32>, !spirv.arm.tensor<7x1x1x1xi32>, !spirv.arm.tensor<1xi64>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi32> -> !spirv.arm.tensor<1x65536x2x7xi64> + spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7xi64> +} + +spirv.ARM.Graph @transpose_conv2d_mismatch_result_element_type_i8_input(%arg0: !spirv.arm.tensor<1x65535x3x1xi8>, %arg1: !spirv.arm.tensor<7x1x1x1xi8>, %arg2: !spirv.arm.tensor<1xi16>) -> (!spirv.arm.tensor<1x65536x2x7xi16>) { + %5 = spirv.Constant dense<35> : !spirv.arm.tensor<1xi8> + %6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi8> + // expected-error @+1 {{op expect result type to be i32, got 'i16'}} + %7 = spirv.Tosa.TransposeConv2D out_pad = [1, 0, 0, 0], stride = [1, 2], acc_type = , local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1xi8>, !spirv.arm.tensor<7x1x1x1xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x65536x2x7xi16> + spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7xi16> +} + +spirv.ARM.Graph @transpose_conv2d_mismatch_result_element_type_i16_input(%arg0: !spirv.arm.tensor<1x65535x3x1xi16>, %arg1: !spirv.arm.tensor<7x1x1x1xi16>, %arg2: !spirv.arm.tensor<1xi32>) -> (!spirv.arm.tensor<1x65536x2x7xi32>) { + %5 = spirv.Constant dense<35> : !spirv.arm.tensor<1xi16> + %6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi16> + // expected-error @+1 {{op expect result type to be i64, got 'i32'}} + %7 = spirv.Tosa.TransposeConv2D out_pad = [1, 0, 0, 0], stride = [1, 2], acc_type = , local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1xi16>, !spirv.arm.tensor<7x1x1x1xi16>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<1x65536x2x7xi32> + spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7xi32> +} + +spirv.ARM.Graph @transpose_conv2d_mismatch_result_element_type_f16_input(%arg0: !spirv.arm.tensor<1x34x18x27xf16>, %arg1: !spirv.arm.tensor<11x1x1x27xf16>, %arg2: !spirv.arm.tensor<11xf32>) -> (!spirv.arm.tensor<1x34x18x11xf32>) { + %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf16> + %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf16> + // expected-error @+1 {{op expect result type to be f16, got 'f32'}} + %7 = spirv.Tosa.TransposeConv2D out_pad = [0, 0, 0, 0], stride = [1, 1], acc_type = , local_bound = true, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x34x18x27xf16>, !spirv.arm.tensor<11x1x1x27xf16>, !spirv.arm.tensor<11xf32>, !spirv.arm.tensor<1xf16>, !spirv.arm.tensor<1xf16> -> !spirv.arm.tensor<1x34x18x11xf32> + spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x34x18x11xf32> +} + +spirv.ARM.Graph @transpose_conv2d_mismatch_result_element_type_f32_input(%arg0: !spirv.arm.tensor<1x34x18x27xf32>, %arg1: !spirv.arm.tensor<11x1x1x27xf32>, %arg2: !spirv.arm.tensor<11xf16>) -> (!spirv.arm.tensor<1x34x18x11xf16>) { + %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf32> + %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf32> + // expected-error @+1 {{op expect result type to be f32, got 'f16'}} + %7 = spirv.Tosa.TransposeConv2D out_pad = [0, 0, 0, 0], stride = [1, 1], acc_type = , local_bound = true, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x34x18x27xf32>, !spirv.arm.tensor<11x1x1x27xf32>, !spirv.arm.tensor<11xf16>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32> -> !spirv.arm.tensor<1x34x18x11xf16> + spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x34x18x11xf16> +} + +spirv.ARM.Graph @transpose_conv2d_bias_element_type_must_be_same_as_result_element_type(%arg0: !spirv.arm.tensor<1x65535x3x1xi8>, %arg1: !spirv.arm.tensor<7x1x1x1xi8>, %arg2: !spirv.arm.tensor<1xi16>) -> (!spirv.arm.tensor<1x65536x2x7xi32>) { + %5 = spirv.Constant dense<35> : !spirv.arm.tensor<1xi8> + %6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi8> + // expected-error @+1 {{op failed to verify that all of {bias, output} have same element type}} + %7 = spirv.Tosa.TransposeConv2D out_pad = [1, 0, 0, 0], stride = [1, 2], acc_type = , local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1xi8>, !spirv.arm.tensor<7x1x1x1xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x65536x2x7xi32> + spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7xi32> +} + +spirv.ARM.Graph @transpose_conv2d_accumulator_must_be_INT32_for_i8_input_element_type(%arg0: !spirv.arm.tensor<1x65535x3x1xi8>, %arg1: !spirv.arm.tensor<7x1x1x1xi8>, %arg2: !spirv.arm.tensor<1xi32>) -> (!spirv.arm.tensor<1x65536x2x7xi32>) { + %5 = spirv.Constant dense<35> : !spirv.arm.tensor<1xi8> + %6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi8> + // expected-error @+1 {{op accumulator type for i8 tensorARM is not i32}} + %7 = spirv.Tosa.TransposeConv2D out_pad = [1, 0, 0, 0], stride = [1, 2], acc_type = , local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1xi8>, !spirv.arm.tensor<7x1x1x1xi8>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x65536x2x7xi32> + spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7xi32> +} + +spirv.ARM.Graph @transpose_conv2d_accumulator_must_be_INT48_for_i16_input_element_type(%arg0: !spirv.arm.tensor<1x65535x3x1xi16>, %arg1: !spirv.arm.tensor<7x1x1x1xi16>, %arg2: !spirv.arm.tensor<1xi64>) -> (!spirv.arm.tensor<1x65536x2x7xi64>) { + %5 = spirv.Constant dense<35> : !spirv.arm.tensor<1xi16> + %6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi16> + // expected-error @+1 {{op accumulator type for i16 tensorARM is not i48}} + %7 = spirv.Tosa.TransposeConv2D out_pad = [1, 0, 0, 0], stride = [1, 2], acc_type = , local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1xi16>, !spirv.arm.tensor<7x1x1x1xi16>, !spirv.arm.tensor<1xi64>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<1x65536x2x7xi64> + spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7xi64> +} + +spirv.ARM.Graph @transpose_conv2d_accumulator_must_be_either_FP16_or_FP32_for_f16_input_element_type(%arg0: !spirv.arm.tensor<1x34x18x27xf16>, %arg1: !spirv.arm.tensor<11x1x1x27xf16>, %arg2: !spirv.arm.tensor<11xf16>) -> (!spirv.arm.tensor<1x34x18x11xf16>) { + %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf16> + %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf16> + // expected-error @+1 {{op accumulator type for f16 tensorARM is not f16 or f32}} + %7 = spirv.Tosa.TransposeConv2D out_pad = [0, 0, 0, 0], stride = [1, 1], acc_type = , local_bound = true, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x34x18x27xf16>, !spirv.arm.tensor<11x1x1x27xf16>, !spirv.arm.tensor<11xf16>, !spirv.arm.tensor<1xf16>, !spirv.arm.tensor<1xf16> -> !spirv.arm.tensor<1x34x18x11xf16> + spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x34x18x11xf16> +} + +spirv.ARM.Graph @transpose_conv2d_accumulator_must_be_either_FP32_for_f32_input_element_type(%arg0: !spirv.arm.tensor<1x34x18x27xf32>, %arg1: !spirv.arm.tensor<11x1x1x27xf32>, %arg2: !spirv.arm.tensor<11xf32>) -> (!spirv.arm.tensor<1x34x18x11xf32>) { + %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf32> + %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf32> + // expected-error @+1 {{op accumulator type for f32 tensorARM is not f32}} + %7 = spirv.Tosa.TransposeConv2D out_pad = [0, 0, 0, 0], stride = [1, 1], acc_type = , local_bound = true, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x34x18x27xf32>, !spirv.arm.tensor<11x1x1x27xf32>, !spirv.arm.tensor<11xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32> -> !spirv.arm.tensor<1x34x18x11xf32> + spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x34x18x11xf32> +} diff --git a/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir b/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir index c9832b903b79e..45243a7553c56 100644 --- a/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir @@ -21,3 +21,107 @@ spirv.ARM.Graph @argmax_fp(%arg0: !spirv.arm.tensor<2x2x7x14xf32>) -> (!spirv.ar // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<2x2x14xi32> spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<2x2x14xi32> } + +//===----------------------------------------------------------------------===// +// spirv.TOSA.Conv2D - PRO-INT +//===----------------------------------------------------------------------===// + +spirv.ARM.Graph @conv2d_int(%arg0: !spirv.arm.tensor<1x65535x3x1xi8>, %arg1: !spirv.arm.tensor<7x1x1x1xi8>, %arg2: !spirv.arm.tensor<1xi32>) -> (!spirv.arm.tensor<1x65536x2x7xi32>) { + %5 = spirv.Constant dense<35> : !spirv.arm.tensor<1xi8> + %6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi8> + // CHECK: {{%.*}} = spirv.Tosa.Conv2D pad = [1, 0, 0, 0], stride = [1, 2], dilation = [7, 1], acc_type = , local_bound = false, %arg0, %arg1, %arg2, {{%.*}}, {{%.*}} : !spirv.arm.tensor<1x65535x3x1xi8>, !spirv.arm.tensor<7x1x1x1xi8>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x65536x2x7xi32> + %7 = spirv.Tosa.Conv2D pad = [1, 0, 0, 0], stride = [1, 2], dilation = [7, 1], acc_type = , local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1xi8>, !spirv.arm.tensor<7x1x1x1xi8>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x65536x2x7xi32> + // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x65536x2x7xi32> + spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7xi32> +} + +//===----------------------------------------------------------------------===// +// spirv.TOSA.Conv2D - PRO-FP +//===----------------------------------------------------------------------===// + +spirv.ARM.Graph @conv2d_fp(%arg0: !spirv.arm.tensor<1x34x18x27xf16>, %arg1: !spirv.arm.tensor<11x1x1x27xf16>, %arg2: !spirv.arm.tensor<11xf16>) -> (!spirv.arm.tensor<1x34x18x11xf16>) { + %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf16> + %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf16> + // CHECK: {{%.*}} = spirv.Tosa.Conv2D pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], acc_type = , local_bound = true, %arg0, %arg1, %arg2, {{%.*}}, {{%.*}} : !spirv.arm.tensor<1x34x18x27xf16>, !spirv.arm.tensor<11x1x1x27xf16>, !spirv.arm.tensor<11xf16>, !spirv.arm.tensor<1xf16>, !spirv.arm.tensor<1xf16> -> !spirv.arm.tensor<1x34x18x11xf16> + %7 = spirv.Tosa.Conv2D pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], acc_type = , local_bound = true, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x34x18x27xf16>, !spirv.arm.tensor<11x1x1x27xf16>, !spirv.arm.tensor<11xf16>, !spirv.arm.tensor<1xf16>, !spirv.arm.tensor<1xf16> -> !spirv.arm.tensor<1x34x18x11xf16> + // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x34x18x11xf16> + spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x34x18x11xf16> +} + +//===----------------------------------------------------------------------===// +// spirv.TOSA.Conv3D - PRO-INT +//===----------------------------------------------------------------------===// + +spirv.ARM.Graph @conv3d_int(%arg0: !spirv.arm.tensor<1x9x21x14x1xi8>, %arg1: !spirv.arm.tensor<2x1x2x1x1xi8>, %arg2: !spirv.arm.tensor<1xi32>) -> (!spirv.arm.tensor<1x9x20x14x2xi32>) { + %5 = spirv.Constant dense<123> : !spirv.arm.tensor<1xi8> + %6 = spirv.Constant dense<121> : !spirv.arm.tensor<1xi8> + // CHECK: {{%.*}} = spirv.Tosa.Conv3D pad = [0, 0, 0, 0, 0, 0], stride = [1, 1, 1], dilation = [1, 1, 1], acc_type = , local_bound = false, %arg0, %arg1, %arg2, {{%.*}}, {{%.*}} : !spirv.arm.tensor<1x9x21x14x1xi8>, !spirv.arm.tensor<2x1x2x1x1xi8>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x9x20x14x2xi32> + %7 = spirv.Tosa.Conv3D pad = [0, 0, 0, 0, 0, 0], stride = [1, 1, 1], dilation = [1, 1, 1], acc_type = , local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x9x21x14x1xi8>, !spirv.arm.tensor<2x1x2x1x1xi8>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x9x20x14x2xi32> + // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x9x20x14x2xi32> + spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x9x20x14x2xi32> +} + +//===----------------------------------------------------------------------===// +// spirv.TOSA.Conv3D - PRO-FP +//===----------------------------------------------------------------------===// + +spirv.ARM.Graph @conv3d_fp(%arg0: !spirv.arm.tensor<1x2x65539x1x2xf32>, %arg1: !spirv.arm.tensor<1x1x1x1x2xf32>, %arg2: !spirv.arm.tensor<1xf32>) -> (!spirv.arm.tensor<1x3x65540x2x1xf32>) { + %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf32> + %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf32> + // CHECK: {{%.*}} = spirv.Tosa.Conv3D pad = [0, 1, 1, 0, 0, 1], stride = [1, 1, 1], dilation = [1, 1, 7], acc_type = , local_bound = false, %arg0, %arg1, %arg2, {{%.*}}, {{%.*}} : !spirv.arm.tensor<1x2x65539x1x2xf32>, !spirv.arm.tensor<1x1x1x1x2xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32> -> !spirv.arm.tensor<1x3x65540x2x1xf32> + %7 = spirv.Tosa.Conv3D pad = [0, 1, 1, 0, 0, 1], stride = [1, 1, 1], dilation = [1, 1, 7], acc_type = , local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x2x65539x1x2xf32>, !spirv.arm.tensor<1x1x1x1x2xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32> -> !spirv.arm.tensor<1x3x65540x2x1xf32> + // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x3x65540x2x1xf32> + spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x3x65540x2x1xf32> +} + +//===----------------------------------------------------------------------===// +// spirv.TOSA.DepthwiseConv2D - PRO-INT +//===----------------------------------------------------------------------===// + +spirv.ARM.Graph @depthwiseconv2d_int(%arg0: !spirv.arm.tensor<1x4x65537x1xi8>, %arg1: !spirv.arm.tensor<1x3x1x4xi8>, %arg2: !spirv.arm.tensor<4xi32>) -> (!spirv.arm.tensor<1x4x32762x4xi32>) { + %5 = spirv.Constant dense<58> : !spirv.arm.tensor<1xi8> + %6 = spirv.Constant dense<-106> : !spirv.arm.tensor<1xi8> + // CHECK: {{%.*}} = spirv.Tosa.DepthwiseConv2D pad = [0, 0, 0, 0], stride = [1, 2], dilation = [7, 7], acc_type = , local_bound = false, %arg0, %arg1, %arg2, {{%.*}}, {{%.*}} : !spirv.arm.tensor<1x4x65537x1xi8>, !spirv.arm.tensor<1x3x1x4xi8>, !spirv.arm.tensor<4xi32>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x4x32762x4xi32> + %7 = spirv.Tosa.DepthwiseConv2D pad = [0, 0, 0, 0], stride = [1, 2], dilation = [7, 7], acc_type = , local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x4x65537x1xi8>, !spirv.arm.tensor<1x3x1x4xi8>, !spirv.arm.tensor<4xi32>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x4x32762x4xi32> + // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x4x32762x4xi32> + spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x4x32762x4xi32> +} + +//===----------------------------------------------------------------------===// +// spirv.TOSA.DepthwiseConv2D - PRO-FP +//===----------------------------------------------------------------------===// + +spirv.ARM.Graph @depthwiseconv2d_fp(%arg0: !spirv.arm.tensor<1x65540x1x3xf32>, %arg1: !spirv.arm.tensor<1x1x3x1xf32>, %arg2: !spirv.arm.tensor<1xf32>) -> (!spirv.arm.tensor<1x65541x2x3xf32>) { + %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf32> + %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf32> + // CHECK: {{%.*}} = spirv.Tosa.DepthwiseConv2D pad = [0, 1, 1, 1], stride = [1, 2], dilation = [1, 7], acc_type = , local_bound = true, %arg0, %arg1, %arg2, {{%.*}}, {{%.*}} : !spirv.arm.tensor<1x65540x1x3xf32>, !spirv.arm.tensor<1x1x3x1xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32> -> !spirv.arm.tensor<1x65541x2x3xf32> + %7 = spirv.Tosa.DepthwiseConv2D pad = [0, 1, 1, 1], stride = [1, 2], dilation = [1, 7], acc_type = , local_bound = true, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65540x1x3xf32>, !spirv.arm.tensor<1x1x3x1xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32> -> !spirv.arm.tensor<1x65541x2x3xf32> + // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x65541x2x3xf32> + spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65541x2x3xf32> +} + +//===----------------------------------------------------------------------===// +// spirv.TOSA.TransposeConv2D - PRO-INT +//===----------------------------------------------------------------------===// + +spirv.ARM.Graph @transposeconv2d_int(%arg0: !spirv.arm.tensor<1x13x33x3xi16>, %arg1: !spirv.arm.tensor<11x1x3x3xi8>, %arg2: !spirv.arm.tensor<1xi64>) -> (!spirv.arm.tensor<1x13x35x11xi64>) { + %4 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16> + %5 = spirv.Constant dense<88> : !spirv.arm.tensor<1xi8> + // CHECK: {{%.*}} = spirv.Tosa.TransposeConv2D out_pad = [0, 0, 0, 0], stride = [1, 1], acc_type = , local_bound = false, %arg0, %arg1, %arg2, {{%.*}}, {{%.*}} : !spirv.arm.tensor<1x13x33x3xi16>, !spirv.arm.tensor<11x1x3x3xi8>, !spirv.arm.tensor<1xi64>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x13x35x11xi64> + %6 = spirv.Tosa.TransposeConv2D out_pad = [0, 0, 0, 0], stride = [1, 1], acc_type = , local_bound = false, %arg0, %arg1, %arg2, %4, %5 : !spirv.arm.tensor<1x13x33x3xi16>, !spirv.arm.tensor<11x1x3x3xi8>, !spirv.arm.tensor<1xi64>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x13x35x11xi64> + // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x13x35x11xi64> + spirv.ARM.GraphOutputs %6 : !spirv.arm.tensor<1x13x35x11xi64> +} + +//===----------------------------------------------------------------------===// +// spirv.TOSA.TransposeConv2D - PRO-FP +//===----------------------------------------------------------------------===// + +spirv.ARM.Graph @transposeconv2d_fp(%arg0: !spirv.arm.tensor<10x24x9x13xf16>, %arg1: !spirv.arm.tensor<14x1x1x13xf16>, %arg2: !spirv.arm.tensor<14xf16>) -> (!spirv.arm.tensor<10x25x65x14xf16>) { + %4 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf16> + %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf16> + // CHECK: {{%.*}} = spirv.Tosa.TransposeConv2D out_pad = [0, 1, 0, 0], stride = [1, 8], acc_type = , local_bound = true, %arg0, %arg1, %arg2, {{%.*}}, {{%.*}} : !spirv.arm.tensor<10x24x9x13xf16>, !spirv.arm.tensor<14x1x1x13xf16>, !spirv.arm.tensor<14xf16>, !spirv.arm.tensor<1xf16>, !spirv.arm.tensor<1xf16> -> !spirv.arm.tensor<10x25x65x14xf16> + %6 = spirv.Tosa.TransposeConv2D out_pad = [0, 1, 0, 0], stride = [1, 8], acc_type = , local_bound = true, %arg0, %arg1, %arg2, %4, %5 : !spirv.arm.tensor<10x24x9x13xf16>, !spirv.arm.tensor<14x1x1x13xf16>, !spirv.arm.tensor<14xf16>, !spirv.arm.tensor<1xf16>, !spirv.arm.tensor<1xf16> -> !spirv.arm.tensor<10x25x65x14xf16> + // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<10x25x65x14xf16> + spirv.ARM.GraphOutputs %6 : !spirv.arm.tensor<10x25x65x14xf16> +} diff --git a/mlir/test/Target/SPIRV/tosa-ops.mlir b/mlir/test/Target/SPIRV/tosa-ops.mlir index 8c0429bca68e4..edaa000c183a8 100644 --- a/mlir/test/Target/SPIRV/tosa-ops.mlir +++ b/mlir/test/Target/SPIRV/tosa-ops.mlir @@ -39,3 +39,187 @@ spirv.module Logical Vulkan requires #spirv.vce } } + +// ----- + +//===----------------------------------------------------------------------===// +// spirv.TOSA.Conv2D - PRO-INT +//===----------------------------------------------------------------------===// + +// CHECK: spirv.module Logical Vulkan requires #spirv.vce +spirv.module Logical Vulkan requires #spirv.vce { + spirv.GlobalVariable @conv2d_int_arg_0 bind(0, 0) : !spirv.ptr, UniformConstant> + spirv.GlobalVariable @conv2d_int_arg_1 bind(0, 1) : !spirv.ptr, UniformConstant> + spirv.GlobalVariable @conv2d_int_arg_2 bind(0, 2) : !spirv.ptr, UniformConstant> + spirv.GlobalVariable @conv2d_int_res_0 bind(1, 0) : !spirv.ptr, UniformConstant> + spirv.ARM.GraphEntryPoint @conv2d_int, @conv2d_int_arg_0, @conv2d_int_arg_1, @conv2d_int_arg_2, @conv2d_int_res_0 + spirv.ARM.Graph @conv2d_int(%arg0: !spirv.arm.tensor<1x65535x3x1xi8>, %arg1: !spirv.arm.tensor<7x1x1x1xi8>, %arg2: !spirv.arm.tensor<1xi32>) -> (!spirv.arm.tensor<1x65536x2x7xi32>) { + %5 = spirv.Constant dense<35> : !spirv.arm.tensor<1xi8> + %6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi8> + // CHECK: {{%.*}} = spirv.Tosa.Conv2D pad = [1, 0, 0, 0], stride = [1, 2], dilation = [7, 1], acc_type = , local_bound = false, %arg0, %arg1, %arg2, {{%.*}}, {{%.*}} : !spirv.arm.tensor<1x65535x3x1xi8>, !spirv.arm.tensor<7x1x1x1xi8>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x65536x2x7xi32> + %7 = spirv.Tosa.Conv2D pad = [1, 0, 0, 0], stride = [1, 2], dilation = [7, 1], acc_type = , local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1xi8>, !spirv.arm.tensor<7x1x1x1xi8>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x65536x2x7xi32> + // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x65536x2x7xi32> + spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7xi32> + } +} + +// ----- + +//===----------------------------------------------------------------------===// +// spirv.TOSA.Conv2D - PRO-FP +//===----------------------------------------------------------------------===// + +// CHECK: spirv.module Logical Vulkan requires #spirv.vce +spirv.module Logical Vulkan requires #spirv.vce { + spirv.GlobalVariable @conv2d_fp_arg_0 bind(0, 0) : !spirv.ptr, UniformConstant> + spirv.GlobalVariable @conv2d_fp_arg_1 bind(0, 1) : !spirv.ptr, UniformConstant> + spirv.GlobalVariable @conv2d_fp_arg_2 bind(0, 2) : !spirv.ptr, UniformConstant> + spirv.GlobalVariable @conv2d_fp_res_0 bind(1, 0) : !spirv.ptr, UniformConstant> + spirv.ARM.GraphEntryPoint @conv2d_fp, @conv2d_fp_arg_0, @conv2d_fp_arg_1, @conv2d_fp_arg_2, @conv2d_fp_res_0 + spirv.ARM.Graph @conv2d_fp(%arg0: !spirv.arm.tensor<1x34x18x27xf16>, %arg1: !spirv.arm.tensor<11x1x1x27xf16>, %arg2: !spirv.arm.tensor<11xf16>) -> (!spirv.arm.tensor<1x34x18x11xf16>) { + %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf16> + %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf16> + // CHECK: {{%.*}} = spirv.Tosa.Conv2D pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], acc_type = , local_bound = true, %arg0, %arg1, %arg2, {{%.*}}, {{%.*}} : !spirv.arm.tensor<1x34x18x27xf16>, !spirv.arm.tensor<11x1x1x27xf16>, !spirv.arm.tensor<11xf16>, !spirv.arm.tensor<1xf16>, !spirv.arm.tensor<1xf16> -> !spirv.arm.tensor<1x34x18x11xf16> + %7 = spirv.Tosa.Conv2D pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], acc_type = , local_bound = true, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x34x18x27xf16>, !spirv.arm.tensor<11x1x1x27xf16>, !spirv.arm.tensor<11xf16>, !spirv.arm.tensor<1xf16>, !spirv.arm.tensor<1xf16> -> !spirv.arm.tensor<1x34x18x11xf16> + // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x34x18x11xf16> + spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x34x18x11xf16> + } +} + +// ----- + +//===----------------------------------------------------------------------===// +// spirv.TOSA.Conv3D - PRO-INT +//===----------------------------------------------------------------------===// + +// CHECK: spirv.module Logical Vulkan requires #spirv.vce +spirv.module Logical Vulkan requires #spirv.vce { + spirv.GlobalVariable @conv3d_int_arg_0 bind(0, 0) : !spirv.ptr, UniformConstant> + spirv.GlobalVariable @conv3d_int_arg_1 bind(0, 1) : !spirv.ptr, UniformConstant> + spirv.GlobalVariable @conv3d_int_arg_2 bind(0, 2) : !spirv.ptr, UniformConstant> + spirv.GlobalVariable @conv3d_int_res_0 bind(1, 0) : !spirv.ptr, UniformConstant> + spirv.ARM.GraphEntryPoint @conv3d_int, @conv3d_int_arg_0, @conv3d_int_arg_1, @conv3d_int_arg_2, @conv3d_int_res_0 + spirv.ARM.Graph @conv3d_int(%arg0: !spirv.arm.tensor<1x9x21x14x1xi8>, %arg1: !spirv.arm.tensor<2x1x2x1x1xi8>, %arg2: !spirv.arm.tensor<1xi32>) -> (!spirv.arm.tensor<1x9x20x14x2xi32>) { + %5 = spirv.Constant dense<123> : !spirv.arm.tensor<1xi8> + %6 = spirv.Constant dense<121> : !spirv.arm.tensor<1xi8> + // CHECK: {{%.*}} = spirv.Tosa.Conv3D pad = [0, 0, 0, 0, 0, 0], stride = [1, 1, 1], dilation = [1, 1, 1], acc_type = , local_bound = false, %arg0, %arg1, %arg2, {{%.*}}, {{%.*}} : !spirv.arm.tensor<1x9x21x14x1xi8>, !spirv.arm.tensor<2x1x2x1x1xi8>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x9x20x14x2xi32> + %7 = spirv.Tosa.Conv3D pad = [0, 0, 0, 0, 0, 0], stride = [1, 1, 1], dilation = [1, 1, 1], acc_type = , local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x9x21x14x1xi8>, !spirv.arm.tensor<2x1x2x1x1xi8>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x9x20x14x2xi32> + // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x9x20x14x2xi32> + spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x9x20x14x2xi32> + } +} + +// ----- + +//===----------------------------------------------------------------------===// +// spirv.TOSA.Conv3D - PRO-FP +//===----------------------------------------------------------------------===// + +// CHECK: spirv.module Logical Vulkan requires #spirv.vce +spirv.module Logical Vulkan requires #spirv.vce { + spirv.GlobalVariable @conv3d_fp_arg_0 bind(0, 0) : !spirv.ptr, UniformConstant> + spirv.GlobalVariable @conv3d_fp_arg_1 bind(0, 1) : !spirv.ptr, UniformConstant> + spirv.GlobalVariable @conv3d_fp_arg_2 bind(0, 2) : !spirv.ptr, UniformConstant> + spirv.GlobalVariable @conv3d_fp_res_0 bind(1, 0) : !spirv.ptr, UniformConstant> + spirv.ARM.GraphEntryPoint @conv3d_fp, @conv3d_fp_arg_0, @conv3d_fp_arg_1, @conv3d_fp_arg_2, @conv3d_fp_res_0 + spirv.ARM.Graph @conv3d_fp(%arg0: !spirv.arm.tensor<1x2x65539x1x2xf32>, %arg1: !spirv.arm.tensor<1x1x1x1x2xf32>, %arg2: !spirv.arm.tensor<1xf32>) -> (!spirv.arm.tensor<1x3x65540x2x1xf32>) { + %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf32> + %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf32> + // CHECK: {{%.*}} = spirv.Tosa.Conv3D pad = [0, 1, 1, 0, 0, 1], stride = [1, 1, 1], dilation = [1, 1, 7], acc_type = , local_bound = false, %arg0, %arg1, %arg2, {{%.*}}, {{%.*}} : !spirv.arm.tensor<1x2x65539x1x2xf32>, !spirv.arm.tensor<1x1x1x1x2xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32> -> !spirv.arm.tensor<1x3x65540x2x1xf32> + %7 = spirv.Tosa.Conv3D pad = [0, 1, 1, 0, 0, 1], stride = [1, 1, 1], dilation = [1, 1, 7], acc_type = , local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x2x65539x1x2xf32>, !spirv.arm.tensor<1x1x1x1x2xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32> -> !spirv.arm.tensor<1x3x65540x2x1xf32> + // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x3x65540x2x1xf32> + spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x3x65540x2x1xf32> + } +} + +// ----- + +//===----------------------------------------------------------------------===// +// spirv.TOSA.DepthwiseConv2D - PRO-INT +//===----------------------------------------------------------------------===// + +// CHECK: spirv.module Logical Vulkan requires #spirv.vce +spirv.module Logical Vulkan requires #spirv.vce { + spirv.GlobalVariable @depthwiseconv2d_int_arg_0 bind(0, 0) : !spirv.ptr, UniformConstant> + spirv.GlobalVariable @depthwiseconv2d_int_arg_1 bind(0, 1) : !spirv.ptr, UniformConstant> + spirv.GlobalVariable @depthwiseconv2d_int_arg_2 bind(0, 2) : !spirv.ptr, UniformConstant> + spirv.GlobalVariable @depthwiseconv2d_int_res_0 bind(1, 0) : !spirv.ptr, UniformConstant> + spirv.ARM.GraphEntryPoint @depthwiseconv2d_int, @depthwiseconv2d_int_arg_0, @depthwiseconv2d_int_arg_1, @depthwiseconv2d_int_arg_2, @depthwiseconv2d_int_res_0 + spirv.ARM.Graph @depthwiseconv2d_int(%arg0: !spirv.arm.tensor<1x4x65537x1xi8>, %arg1: !spirv.arm.tensor<1x3x1x4xi8>, %arg2: !spirv.arm.tensor<4xi32>) -> (!spirv.arm.tensor<1x4x32762x4xi32>) { + %5 = spirv.Constant dense<58> : !spirv.arm.tensor<1xi8> + %6 = spirv.Constant dense<-106> : !spirv.arm.tensor<1xi8> + // CHECK: {{%.*}} = spirv.Tosa.DepthwiseConv2D pad = [0, 0, 0, 0], stride = [1, 2], dilation = [7, 7], acc_type = , local_bound = false, %arg0, %arg1, %arg2, {{%.*}}, {{%.*}} : !spirv.arm.tensor<1x4x65537x1xi8>, !spirv.arm.tensor<1x3x1x4xi8>, !spirv.arm.tensor<4xi32>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x4x32762x4xi32> + %7 = spirv.Tosa.DepthwiseConv2D pad = [0, 0, 0, 0], stride = [1, 2], dilation = [7, 7], acc_type = , local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x4x65537x1xi8>, !spirv.arm.tensor<1x3x1x4xi8>, !spirv.arm.tensor<4xi32>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x4x32762x4xi32> + // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x4x32762x4xi32> + spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x4x32762x4xi32> + } +} + +// ----- + +//===----------------------------------------------------------------------===// +// spirv.TOSA.DepthwiseConv2D - PRO-FP +//===----------------------------------------------------------------------===// + +// CHECK: spirv.module Logical Vulkan requires #spirv.vce +spirv.module Logical Vulkan requires #spirv.vce { + spirv.GlobalVariable @depthwiseconv2d_fp_arg_0 bind(0, 0) : !spirv.ptr, UniformConstant> + spirv.GlobalVariable @depthwiseconv2d_fp_arg_1 bind(0, 1) : !spirv.ptr, UniformConstant> + spirv.GlobalVariable @depthwiseconv2d_fp_arg_2 bind(0, 2) : !spirv.ptr, UniformConstant> + spirv.GlobalVariable @depthwiseconv2d_fp_res_0 bind(1, 0) : !spirv.ptr, UniformConstant> + spirv.ARM.GraphEntryPoint @depthwiseconv2d_fp, @depthwiseconv2d_fp_arg_0, @depthwiseconv2d_fp_arg_1, @depthwiseconv2d_fp_arg_2, @depthwiseconv2d_fp_res_0 + spirv.ARM.Graph @depthwiseconv2d_fp(%arg0: !spirv.arm.tensor<1x65540x1x3xf32>, %arg1: !spirv.arm.tensor<1x1x3x1xf32>, %arg2: !spirv.arm.tensor<1xf32>) -> (!spirv.arm.tensor<1x65541x2x3xf32>) { + %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf32> + %6 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf32> + // CHECK: {{%.*}} = spirv.Tosa.DepthwiseConv2D pad = [0, 1, 1, 1], stride = [1, 2], dilation = [1, 7], acc_type = , local_bound = true, %arg0, %arg1, %arg2, {{%.*}}, {{%.*}} : !spirv.arm.tensor<1x65540x1x3xf32>, !spirv.arm.tensor<1x1x3x1xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32> -> !spirv.arm.tensor<1x65541x2x3xf32> + %7 = spirv.Tosa.DepthwiseConv2D pad = [0, 1, 1, 1], stride = [1, 2], dilation = [1, 7], acc_type = , local_bound = true, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65540x1x3xf32>, !spirv.arm.tensor<1x1x3x1xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32>, !spirv.arm.tensor<1xf32> -> !spirv.arm.tensor<1x65541x2x3xf32> + // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x65541x2x3xf32> + spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65541x2x3xf32> + } +} + +// ----- + +//===----------------------------------------------------------------------===// +// spirv.TOSA.TransposeConv2D - PRO-INT +//===----------------------------------------------------------------------===// + +// CHECK: spirv.module Logical Vulkan requires #spirv.vce +spirv.module Logical Vulkan requires #spirv.vce { + spirv.GlobalVariable @transposeconv2d_int_arg_0 bind(0, 0) : !spirv.ptr, UniformConstant> + spirv.GlobalVariable @transposeconv2d_int_arg_1 bind(0, 1) : !spirv.ptr, UniformConstant> + spirv.GlobalVariable @transposeconv2d_int_arg_2 bind(0, 2) : !spirv.ptr, UniformConstant> + spirv.GlobalVariable @transposeconv2d_int_res_0 bind(1, 0) : !spirv.ptr, UniformConstant> + spirv.ARM.GraphEntryPoint @transposeconv2d_int, @transposeconv2d_int_arg_0, @transposeconv2d_int_arg_1, @transposeconv2d_int_arg_2, @transposeconv2d_int_res_0 + spirv.ARM.Graph @transposeconv2d_int(%arg0: !spirv.arm.tensor<1x13x33x3xi16>, %arg1: !spirv.arm.tensor<11x1x3x3xi8>, %arg2: !spirv.arm.tensor<1xi64>) -> (!spirv.arm.tensor<1x13x35x11xi64>) { + %4 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16> + %5 = spirv.Constant dense<88> : !spirv.arm.tensor<1xi8> + // CHECK: {{%.*}} = spirv.Tosa.TransposeConv2D out_pad = [0, 0, 0, 0], stride = [1, 1], acc_type = , local_bound = false, %arg0, %arg1, %arg2, {{%.*}}, {{%.*}} : !spirv.arm.tensor<1x13x33x3xi16>, !spirv.arm.tensor<11x1x3x3xi8>, !spirv.arm.tensor<1xi64>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x13x35x11xi64> + %6 = spirv.Tosa.TransposeConv2D out_pad = [0, 0, 0, 0], stride = [1, 1], acc_type = , local_bound = false, %arg0, %arg1, %arg2, %4, %5 : !spirv.arm.tensor<1x13x33x3xi16>, !spirv.arm.tensor<11x1x3x3xi8>, !spirv.arm.tensor<1xi64>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x13x35x11xi64> + // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x13x35x11xi64> + spirv.ARM.GraphOutputs %6 : !spirv.arm.tensor<1x13x35x11xi64> + } +} + +// ----- + +//===----------------------------------------------------------------------===// +// spirv.TOSA.TransposeConv2D - PRO-FP +//===----------------------------------------------------------------------===// + +// CHECK: spirv.module Logical Vulkan requires #spirv.vce +spirv.module Logical Vulkan requires #spirv.vce { + spirv.GlobalVariable @transposeconv2d_fp_arg_0 bind(0, 0) : !spirv.ptr, UniformConstant> + spirv.GlobalVariable @transposeconv2d_fp_arg_1 bind(0, 1) : !spirv.ptr, UniformConstant> + spirv.GlobalVariable @transposeconv2d_fp_arg_2 bind(0, 2) : !spirv.ptr, UniformConstant> + spirv.GlobalVariable @transposeconv2d_fp_res_0 bind(1, 0) : !spirv.ptr, UniformConstant> + spirv.ARM.GraphEntryPoint @transposeconv2d_fp, @transposeconv2d_fp_arg_0, @transposeconv2d_fp_arg_1, @transposeconv2d_fp_arg_2, @transposeconv2d_fp_res_0 + spirv.ARM.Graph @transposeconv2d_fp(%arg0: !spirv.arm.tensor<10x24x9x13xf16>, %arg1: !spirv.arm.tensor<14x1x1x13xf16>, %arg2: !spirv.arm.tensor<14xf16>) -> (!spirv.arm.tensor<10x25x65x14xf16>) { + %4 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf16> + %5 = spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<1xf16> + // CHECK: {{%.*}} = spirv.Tosa.TransposeConv2D out_pad = [0, 1, 0, 0], stride = [1, 8], acc_type = , local_bound = true, %arg0, %arg1, %arg2, {{%.*}}, {{%.*}} : !spirv.arm.tensor<10x24x9x13xf16>, !spirv.arm.tensor<14x1x1x13xf16>, !spirv.arm.tensor<14xf16>, !spirv.arm.tensor<1xf16>, !spirv.arm.tensor<1xf16> -> !spirv.arm.tensor<10x25x65x14xf16> + %6 = spirv.Tosa.TransposeConv2D out_pad = [0, 1, 0, 0], stride = [1, 8], acc_type = , local_bound = true, %arg0, %arg1, %arg2, %4, %5 : !spirv.arm.tensor<10x24x9x13xf16>, !spirv.arm.tensor<14x1x1x13xf16>, !spirv.arm.tensor<14xf16>, !spirv.arm.tensor<1xf16>, !spirv.arm.tensor<1xf16> -> !spirv.arm.tensor<10x25x65x14xf16> + // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<10x25x65x14xf16> + spirv.ARM.GraphOutputs %6 : !spirv.arm.tensor<10x25x65x14xf16> + } +} diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp index f3327e31aae04..0b1771ffcee71 100644 --- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp +++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp @@ -501,6 +501,7 @@ constexpr llvm::StringLiteral constantIdEnumAttrs[] = { "SPIRV_KHR_CooperativeMatrixLayoutAttr", "SPIRV_MemorySemanticsAttr", "SPIRV_MatrixLayoutAttr", + "SPIRV_TosaExtAccTypeAttr", "SPIRV_TosaExtNaNPropagationModeAttr", }; @@ -556,11 +557,18 @@ static void emitAttributeSerialization(const Attribute &attr, os << tabs << " return failure();\n"; os << tabs << " }\n"; os << tabs << formatv(" {0}.push_back(attrTypeID);\n", operandList); - } else if (attr.getAttrDefName() == "SPIRV_TensorArmAxisAttr") { + } else if (llvm::is_contained( + {"SPIRV_BoolConstAttr", "SPIRV_TensorArmAxisAttr"}, + attr.getAttrDefName())) { os << tabs << formatv( " {0}.push_back(prepareConstantScalar({1}.getLoc(), attr));\n", operandList, opVar); + } else if (attr.getAttrDefName().contains("TensorArm")) { + os << tabs + << formatv(" {0}.push_back(prepareConstant({1}.getLoc(), " + "llvm::cast(attr).getType(), attr));\n", + operandList, opVar); } else { PrintFatalError( loc, @@ -855,7 +863,8 @@ static void emitAttributeDeserialization(const Attribute &attr, << formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", " "TypeAttr::get(getType({2}[{3}++]))));\n", attrList, attrName, words, wordIndex); - } else if (attr.getAttrDefName() == "SPIRV_TensorArmAxisAttr") { + } else if (attr.getAttrDefName() == "SPIRV_BoolConstAttr" || + attr.getAttrDefName().contains("TensorArm")) { os << tabs << formatv("std::optional> c = " "getConstant({0}[{1}++]);\n",