diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td index db91e529dc623..483ce4348971b 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td @@ -285,7 +285,9 @@ def SPIRV_TosaAvgPool2DOp : SPIRV_TosaOpWithResult<"AvgPool2D", 1, [NoMemoryEffe TypeImpliesAccType<"input", F32, ["FP32"]>, TypeImpliesAccType<"input", F8E4M3FN, ["FP16"]>, TypeImpliesAccType<"input", F8E5M2, ["FP16"]>, - AllElementTypesMatch<["input", "input_zp", "output", "output_zp"]>]> { + AllElementTypesMatch<["input", "input_zp", "output", "output_zp"]>, + NHWCInputOutputShapeMatch<"input", "output">, + Pool2DPadValuesLessThanKernel<"pad", "kernel">]> { let summary = "Performs average pooling on the input."; let description = [{ @@ -308,9 +310,9 @@ def SPIRV_TosaAvgPool2DOp : SPIRV_TosaOpWithResult<"AvgPool2D", 1, [NoMemoryEffe }]; let arguments = (ins - SPIRV_I32_1DTensorArmOfLength2Attr: $kernel, - SPIRV_I32_1DTensorArmOfLength2Attr: $stride, - SPIRV_I32_1DTensorArmOfLength4Attr: $pad, + SPIRV_PositiveInt32_1DTensorArmOfLength2Attr: $kernel, + SPIRV_PositiveInt32_1DTensorArmOfLength2Attr: $stride, + SPIRV_NonNegativeInt32_1DTensorArmOfLength4Attr: $pad, SPIRV_TosaExtAccTypeAttr: $acc_type, SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_TensorArm4D: $input, SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_1DTensorArmOfLength1: $input_zp, @@ -337,6 +339,8 @@ def SPIRV_TosaAvgPool2DOp : SPIRV_TosaOpWithResult<"AvgPool2D", 1, [NoMemoryEffe return cast<::mlir::spirv::TensorArmType>(getInput().getType()); } }]; + + let hasVerifier = 1; } @@ -619,7 +623,9 @@ def SPIRV_TosaMatMulOp : SPIRV_TosaOpWithResult<"MatMul", 6, [NoMemoryEffect, def SPIRV_TosaMaxPool2DOp : SPIRV_TosaOpWithResult<"MaxPool2D", 7, [Pure, - AllElementTypesMatch<["input", "output"]>]> { + AllElementTypesMatch<["input", "output"]>, + NHWCInputOutputShapeMatch<"input", "output">, + Pool2DPadValuesLessThanKernel<"pad", "kernel">]> { let summary = "Performs max pooling on the input."; let description = [{ @@ -640,9 +646,9 @@ def SPIRV_TosaMaxPool2DOp : SPIRV_TosaOpWithResult<"MaxPool2D", 7, [Pure, }]; let arguments = (ins - SPIRV_I32_1DTensorArmOfLength2Attr: $kernel, - SPIRV_I32_1DTensorArmOfLength2Attr: $stride, - SPIRV_I32_1DTensorArmOfLength4Attr: $pad, + SPIRV_PositiveInt32_1DTensorArmOfLength2Attr: $kernel, + SPIRV_PositiveInt32_1DTensorArmOfLength2Attr: $stride, + SPIRV_NonNegativeInt32_1DTensorArmOfLength4Attr: $pad, SPIRV_TosaExtNaNPropagationModeAttr: $nan_mode, SPIRV_I8OrI16OrF16OrF32OrBF16OrFP8_TensorArm4D: $input ); @@ -665,6 +671,8 @@ def SPIRV_TosaMaxPool2DOp : SPIRV_TosaOpWithResult<"MaxPool2D", 7, [Pure, return cast<::mlir::spirv::TensorArmType>(getInput().getType()); } }]; + + let hasVerifier = 1; } diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td index 5704911d7f53d..7064930c5864a 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td @@ -115,6 +115,14 @@ def SPIRV_I32_1DTensorArmOfLength3Attr : ConfinedAttr def SPIRV_I32_1DTensorArmOfLength4Attr : ConfinedAttr, [SPIRV_DenseElementAttrsWithTensorArmType]>; def SPIRV_I32_1DTensorArmOfLength5Attr : ConfinedAttr, [SPIRV_DenseElementAttrsWithTensorArmType]>; def SPIRV_I32_1DTensorArmOfLength6Attr : ConfinedAttr, [SPIRV_DenseElementAttrsWithTensorArmType]>; +class IntElementsAttrAllValuesAtLeast : AttrConstraint< + CPred<"::llvm::all_of(::llvm::cast<::mlir::DenseElementsAttr>($_self).getValues<::llvm::APInt>(), " + "[](const ::llvm::APInt &value) { return value.getSExtValue() >= " # + minValue # "; })">, + "all values must be >= " # minValue>; + +def SPIRV_PositiveInt32_1DTensorArmOfLength2Attr : ConfinedAttr, [SPIRV_DenseElementAttrsWithTensorArmType, IntElementsAttrAllValuesAtLeast<1>]>; +def SPIRV_NonNegativeInt32_1DTensorArmOfLength4Attr : ConfinedAttr, [SPIRV_DenseElementAttrsWithTensorArmType, IntElementsAttrAllValuesAtLeast<0>]>; class Is1DTensorArmAttrOfLength allowedLengths> : AttrConstraint(::llvm::cast<::mlir::DenseElementsAttr>($_self).getType()).getShape().size() == 1 }]>, @@ -217,6 +225,29 @@ class ValuesIndicesShapesMatch: SameDimsOrDynamicPred ]>>; +// The tensor shapes are [N,H,W,C] where N,H,W,C are the dimension values. +class NHWCInputOutputShapeMatch: + PredOpTrait<"shapes of " # input # " and " # output # + " must satisfy [N,*,*,C] and [N,*,*,C]", + And<[ + SameDimsOrDynamicPred, + SameDimsOrDynamicPred + ]>>; + +class FetchNthIntElementsAttr : + StrFunc<"get" # snakeCaseToCamelCase.ret # "().getValues()[" # idx # "].getSExtValue()">; + +class ElementsAttrValueLessThan : + CPred.result # " < " # FetchNthIntElementsAttr.result>; + +class Pool2DPadValuesLessThanKernel : + PredOpTrait<"op pad values must satisfy pad_top/pad_bottom < kernel_y and pad_left/pad_right < kernel_x", + And<[ElementsAttrValueLessThan, + ElementsAttrValueLessThan, + ElementsAttrValueLessThan, + ElementsAttrValueLessThan] + >>; + class TableSizeConstraint: PredOpTrait<"table must have size " # size # " if " # input # " has element type " # type.summary, Implies, [CPred<"::llvm::cast<::mlir::ShapedType>(getTable().getType()).getShape()[0] == " # size>]> diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp index a0591ee31acf8..a2cc8be54e4f3 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTosaOps.cpp @@ -54,6 +54,75 @@ void printSPIRV_I32_1DArmTensor(OpAsmPrinter &printer, Operation *, // SPIRV Tosa Custom verifiers //===----------------------------------------------------------------------===// +namespace { + +int64_t getIntValue(DenseIntElementsAttr attr, size_t idx) { + return attr.getValues()[idx].getSExtValue(); +} + +LogicalResult verifyPool2DOutputDim(Operation *op, int64_t inputSize, + int64_t outputSize, int64_t kernelSize, + int64_t strideSize, int64_t padBefore, + int64_t padAfter, StringRef dimName, + StringRef dimAxis, StringRef padBeforeName, + StringRef padAfterName) { + if (ShapedType::isDynamic(inputSize)) + return success(); + + const int64_t numerator = inputSize + padBefore + padAfter - kernelSize; + if (numerator % strideSize != 0) + return op->emitOpError("expected input_") + << dimName << " + pad_" << padBeforeName << " + pad_" << padAfterName + << " - kernel_" << dimAxis << " to be wholly divisible by stride_" + << dimAxis << ", got (" << inputSize << " + " << padBefore << " + " + << padAfter << " - " << kernelSize << ") / " << strideSize; + + const int64_t calculatedOutput = numerator / strideSize + 1; + if (!ShapedType::isDynamic(outputSize) && outputSize != calculatedOutput) + return op->emitOpError("failed to verify that shapes of input and output " + "must satisfy [N,IH,IW,C] and [N,OH,OW,C], with " + "OH = ((IH + pad_top + pad_bottom - kernel_y) / " + "stride_y) + 1 and OW = ((IW + pad_left + " + "pad_right - kernel_x) / stride_x) + 1"); + + return success(); +} + +LogicalResult verifyPool2DOp(Operation *op, DenseIntElementsAttr kernel, + DenseIntElementsAttr stride, + DenseIntElementsAttr pad, TensorArmType inputType, + TensorArmType outputType) { + + if (!inputType.hasRank() || !outputType.hasRank()) + return success(); + + if (failed(verifyPool2DOutputDim( + op, inputType.getDimSize(1), outputType.getDimSize(1), + getIntValue(kernel, 0), getIntValue(stride, 0), getIntValue(pad, 0), + getIntValue(pad, 1), "height", "y", "top", "bottom"))) + return failure(); + + if (failed(verifyPool2DOutputDim( + op, inputType.getDimSize(2), outputType.getDimSize(2), + getIntValue(kernel, 1), getIntValue(stride, 1), getIntValue(pad, 2), + getIntValue(pad, 3), "width", "x", "left", "right"))) + return failure(); + + return success(); +} + +} // namespace + +LogicalResult TosaAvgPool2DOp::verify() { + return verifyPool2DOp(getOperation(), getKernel(), getStride(), getPad(), + getInputType(), getResultType()); +} + +LogicalResult TosaMaxPool2DOp::verify() { + return verifyPool2DOp(getOperation(), getKernel(), getStride(), getPad(), + getInputType(), getResultType()); +} + LogicalResult TosaSelectOp::verify() { TensorArmType condType = getConditionType(); TensorArmType trueValType = getTrueValueType(); diff --git a/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir b/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir index 78ae3b4586004..0489629b98f2f 100644 --- a/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir +++ b/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir @@ -91,6 +91,38 @@ spirv.ARM.Graph @avgpool2d_accumulator_must_be_FP16_for_f8e5m2_element_type(%arg spirv.ARM.GraphOutputs %6 : !spirv.arm.tensor<1x2x2x2xf8E5M2> } +spirv.ARM.Graph @avgpool2d_input_output_batch_or_channel_mismatch(%arg0: !spirv.arm.tensor<1x3x65537x2xi8>) -> (!spirv.arm.tensor<2x2x32768x1xi8>) { + %4 = spirv.Constant dense<125> : !spirv.arm.tensor<1xi8> + %5 = spirv.Constant dense<-90> : !spirv.arm.tensor<1xi8> + // expected-error @+1 {{op failed to verify that shapes of input and output must satisfy [N,*,*,C] and [N,*,*,C]}} + %6 = spirv.Tosa.AvgPool2D kernel = [3, 3], stride = [1, 2], pad = [0, 1, 0, 0], acc_type = , %arg0, %4, %5 : !spirv.arm.tensor<1x3x65537x2xi8>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<2x2x32768x1xi8> + spirv.ARM.GraphOutputs %6 : !spirv.arm.tensor<2x2x32768x1xi8> +} + +spirv.ARM.Graph @avgpool2d_input_shape_not_wholly_divisible_by_stride(%arg0: !spirv.arm.tensor<1x4x4x1xi8>) -> (!spirv.arm.tensor<1x1x1x1xi8>) { + %4 = spirv.Constant dense<125> : !spirv.arm.tensor<1xi8> + %5 = spirv.Constant dense<-90> : !spirv.arm.tensor<1xi8> + // expected-error @+1 {{op expected input_height + pad_top + pad_bottom - kernel_y to be wholly divisible by stride_y}} + %6 = spirv.Tosa.AvgPool2D kernel = [3, 3], stride = [2, 2], pad = [0, 0, 0, 0], acc_type = , %arg0, %4, %5 : !spirv.arm.tensor<1x4x4x1xi8>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x1x1x1xi8> + spirv.ARM.GraphOutputs %6 : !spirv.arm.tensor<1x1x1x1xi8> +} + +spirv.ARM.Graph @avgpool2d_pad_values_must_be_less_than_kernel(%arg0: !spirv.arm.tensor<1x4x4x1xi8>) -> (!spirv.arm.tensor<1x2x1x1xi8>) { + %4 = spirv.Constant dense<125> : !spirv.arm.tensor<1xi8> + %5 = spirv.Constant dense<-90> : !spirv.arm.tensor<1xi8> + // expected-error @+1 {{op pad values must satisfy pad_top/pad_bottom < kernel_y and pad_left/pad_right < kernel_x}} + %6 = spirv.Tosa.AvgPool2D kernel = [2, 3], stride = [1, 2], pad = [2, 0, 0, 0], acc_type = , %arg0, %4, %5 : !spirv.arm.tensor<1x4x4x1xi8>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x2x1x1xi8> + spirv.ARM.GraphOutputs %6 : !spirv.arm.tensor<1x2x1x1xi8> +} + +spirv.ARM.Graph @avgpool2d_input_output_height_width_mismatch(%arg0: !spirv.arm.tensor<1x3x65537x1xi8>) -> (!spirv.arm.tensor<1x2x32769x1xi8>) { + %4 = spirv.Constant dense<125> : !spirv.arm.tensor<1xi8> + %5 = spirv.Constant dense<-90> : !spirv.arm.tensor<1xi8> + // expected-error @+1 {{op failed to verify that shapes of input and output must satisfy [N,IH,IW,C] and [N,OH,OW,C], with OH = ((IH + pad_top + pad_bottom - kernel_y) / stride_y) + 1 and OW = ((IW + pad_left + pad_right - kernel_x) / stride_x) + 1}} + %6 = spirv.Tosa.AvgPool2D kernel = [3, 3], stride = [1, 2], pad = [0, 1, 0, 0], acc_type = , %arg0, %4, %5 : !spirv.arm.tensor<1x3x65537x1xi8>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x2x32769x1xi8> + spirv.ARM.GraphOutputs %6 : !spirv.arm.tensor<1x2x32769x1xi8> +} + //===----------------------------------------------------------------------===// // spirv.TOSA.Conv2D //===----------------------------------------------------------------------===// @@ -537,6 +569,30 @@ spirv.ARM.Graph @maxpool2d_input_output_different_element_types(%arg0: !spirv.ar spirv.ARM.GraphOutputs %4 : !spirv.arm.tensor<1x2x32769x1xi16> } +spirv.ARM.Graph @maxpool2d_input_output_batch_or_channel_mismatch(%arg0: !spirv.arm.tensor<1x3x65537x2xi8>) -> (!spirv.arm.tensor<2x2x32769x1xi8>) { + // expected-error @+1 {{op failed to verify that shapes of input and output must satisfy [N,*,*,C] and [N,*,*,C]}} + %4 = spirv.Tosa.MaxPool2D kernel = [3, 2], stride = [1, 2], pad = [1, 0, 0, 1], nan_mode = , %arg0 : !spirv.arm.tensor<1x3x65537x2xi8> -> !spirv.arm.tensor<2x2x32769x1xi8> + spirv.ARM.GraphOutputs %4 : !spirv.arm.tensor<2x2x32769x1xi8> +} + +spirv.ARM.Graph @maxpool2d_input_shape_not_wholly_divisible_by_stride(%arg0: !spirv.arm.tensor<1x4x4x1xi8>) -> (!spirv.arm.tensor<1x1x1x1xi8>) { + // expected-error @+1 {{op expected input_height + pad_top + pad_bottom - kernel_y to be wholly divisible by stride_y}} + %4 = spirv.Tosa.MaxPool2D kernel = [3, 3], stride = [2, 2], pad = [0, 0, 0, 0], nan_mode = , %arg0 : !spirv.arm.tensor<1x4x4x1xi8> -> !spirv.arm.tensor<1x1x1x1xi8> + spirv.ARM.GraphOutputs %4 : !spirv.arm.tensor<1x1x1x1xi8> +} + +spirv.ARM.Graph @maxpool2d_pad_values_must_be_less_than_kernel(%arg0: !spirv.arm.tensor<1x4x4x1xi8>) -> (!spirv.arm.tensor<1x2x1x1xi8>) { + // expected-error @+1 {{op pad values must satisfy pad_top/pad_bottom < kernel_y and pad_left/pad_right < kernel_x}} + %4 = spirv.Tosa.MaxPool2D kernel = [2, 3], stride = [1, 2], pad = [2, 0, 0, 0], nan_mode = , %arg0 : !spirv.arm.tensor<1x4x4x1xi8> -> !spirv.arm.tensor<1x2x1x1xi8> + spirv.ARM.GraphOutputs %4 : !spirv.arm.tensor<1x2x1x1xi8> +} + +spirv.ARM.Graph @maxpool2d_input_output_height_width_mismatch(%arg0: !spirv.arm.tensor<1x3x65537x1xi8>) -> (!spirv.arm.tensor<1x2x32768x1xi8>) { + // expected-error @+1 {{op failed to verify that shapes of input and output must satisfy [N,IH,IW,C] and [N,OH,OW,C], with OH = ((IH + pad_top + pad_bottom - kernel_y) / stride_y) + 1 and OW = ((IW + pad_left + pad_right - kernel_x) / stride_x) + 1}} + %4 = spirv.Tosa.MaxPool2D kernel = [3, 2], stride = [1, 2], pad = [1, 0, 0, 1], nan_mode = , %arg0 : !spirv.arm.tensor<1x3x65537x1xi8> -> !spirv.arm.tensor<1x2x32768x1xi8> + spirv.ARM.GraphOutputs %4 : !spirv.arm.tensor<1x2x32768x1xi8> +} + //===----------------------------------------------------------------------===// // spirv.TOSA.TransposeConv2D //===----------------------------------------------------------------------===//