diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc index e23827f8aabf2..4d359949096fb 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc @@ -512,6 +512,18 @@ extensionComplianceMap = { {{Extension::bf16}, {{{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T}, SpecificationVersion::V_1_0}}}}}, + {"tosa.conv2d_block_scaled", + {{{Extension::mxfp_conv}, + {{{fp4e2m1T, fp8ue8m0T, fp4e2m1T, fp8ue8m0T, fp32T, fp32T}, + SpecificationVersion::V_1_1_DRAFT}, + {{fp6e2m3T, fp8ue8m0T, fp6e2m3T, fp8ue8m0T, fp32T, fp32T}, + SpecificationVersion::V_1_1_DRAFT}, + {{fp6e3m2T, fp8ue8m0T, fp6e3m2T, fp8ue8m0T, fp32T, fp32T}, + SpecificationVersion::V_1_1_DRAFT}, + {{fp8e4m3T, fp8ue8m0T, fp8e4m3T, fp8ue8m0T, fp32T, fp32T}, + SpecificationVersion::V_1_1_DRAFT}, + {{fp8e5m2T, fp8ue8m0T, fp8e5m2T, fp8ue8m0T, fp32T, fp32T}, + SpecificationVersion::V_1_1_DRAFT}}}}}, {"tosa.conv3d", {{{Extension::int4}, {{{i8T, i4T, i32T, i8T, i4T, i32T, i32T}, SpecificationVersion::V_1_0}}}, diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td index cc23955f31f23..421abc939b2e0 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td @@ -241,6 +241,7 @@ class Tosa_I32EnumAttr; @@ -274,6 +275,7 @@ def Tosa_EXT_INEXACTROUND : I32EnumAttrCase<"inexactround", 10>; def Tosa_EXT_DYNAMIC : I32EnumAttrCase<"dynamic", 11>; def Tosa_EXT_MXFP : I32EnumAttrCase<"mxfp", 12>; def Tosa_EXT_INT64 : I32EnumAttrCase<"int64", 13>; +def Tosa_EXT_MXFP_CONV : I32EnumAttrCase<"mxfp_conv", 14>; def Tosa_ExtensionAttr @@ -281,16 +283,16 @@ def Tosa_ExtensionAttr Tosa_EXT_NONE, Tosa_EXT_INT16, Tosa_EXT_INT4, Tosa_EXT_BF16, Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_FFT, Tosa_EXT_VARIABLE, Tosa_EXT_CONTROLFLOW, Tosa_EXT_DOUBLEROUND, Tosa_EXT_INEXACTROUND, - Tosa_EXT_DYNAMIC, Tosa_EXT_MXFP, Tosa_EXT_INT64 + Tosa_EXT_DYNAMIC, Tosa_EXT_MXFP, Tosa_EXT_INT64, Tosa_EXT_MXFP_CONV, ]> { let extraClassDeclaration = [{ - static llvm::SmallVector getAllValues() { + static llvm::SmallVector getAllValues() { return { Extension::int16, Extension::int4, Extension::bf16, Extension::fp8e4m3, Extension::fp8e5m2, Extension::fft, Extension::variable, Extension::controlflow, Extension::doubleround, Extension::inexactround, Extension::dynamic, Extension::mxfp, - Extension::int64 + Extension::int64, Extension::mxfp_conv }; } }]; diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index 370ce8c161d0b..edd8f0fc266bb 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -163,6 +163,43 @@ def Tosa_Conv2DOp : Tosa_ConvOp<"conv2d"> { let hasVerifier = 1; } +//===----------------------------------------------------------------------===// +// Operator: conv2d_block_scaled +//===----------------------------------------------------------------------===// +def Tosa_Conv2DBlockScaledOp : Tosa_InferShapedTypeOp<"conv2d_block_scaled"> { + let summary = "Performs two dimensional convolution using block scaled tensors."; + + let description = [{ + Performs a 2D convolution over the given input data and scales, using + the weight data and scales. Implementations may choose to skip calculation + of multiplies in the padding area. + }]; + + let arguments = (ins + Tosa_MXFPDataTensor4D:$input_data, + Tosa_MXFPScaleTensor4D:$input_scale, + Tosa_MXFPDataTensor4D:$weight_data, + Tosa_MXFPScaleTensor4D:$weight_scale, + Tosa_Tensor1D:$bias, + Rank4TosaShape:$pad, + Rank2TosaShape:$stride, + Rank2TosaShape:$dilation, + Tosa_BlockSizeAttr:$block_size + ); + + let results = (outs + Tosa_Tensor4D:$output + ); + + list availability = [ + Profile<[Tosa_PRO_FP]>, + Extension<[Tosa_EXT_MXFP_CONV]>, + ]; + + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; +} + //===----------------------------------------------------------------------===// // Operator: conv3d //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h index ea58f49b64c44..5c77bd701e416 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h @@ -149,6 +149,7 @@ class TosaProfileCompliance { case Extension::fp8e5m2: case Extension::fft: case Extension::mxfp: + case Extension::mxfp_conv: return {Profile::pro_fp}; case Extension::variable: case Extension::controlflow: diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td index 266a9e3a7d946..0468ca29e10ac 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td @@ -202,6 +202,8 @@ def Tosa_Tensor1Dto6D : AnyTypeOf<[ def Tosa_TensorUpto4D : AnyTypeOf<[ Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [0,1,2,3,4]>]>; +def Tosa_IndexTensor1D : AnyTypeOf<[ + Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_Int32, Tosa_Int64], [1]>]>; def Tosa_IndexTensor2D : AnyTypeOf<[ Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_Int32, Tosa_Int64], [2]>]>; @@ -216,6 +218,14 @@ def Tosa_MXFPScaleTensor3D : AnyTypeOf<[ TosaUnrankedTensorOf<[Tosa_MXFPScaleNumber]>, TosaTensorRankOf<[Tosa_MXFPScaleNumber], [3]> ]>; +def Tosa_MXFPDataTensor4D : AnyTypeOf<[ + TosaUnrankedTensorOf<[Tosa_MXFPNumber]>, + TosaTensorRankOf<[Tosa_MXFPNumber], [4]> +]>; +def Tosa_MXFPScaleTensor4D : AnyTypeOf<[ + TosaUnrankedTensorOf<[Tosa_MXFPScaleNumber]>, + TosaTensorRankOf<[Tosa_MXFPScaleNumber], [4]> +]>; def Tosa_MXFPDataTensorAtLeast1D : AnyTypeOf<[ TosaUnrankedTensorOf<[Tosa_MXFPNumber]>, TosaRankedTensorOf<[Tosa_MXFPNumber], [AtLeastRankOne]>], diff --git a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp index eb47e85cf9b0b..2e0a0d85d7dbe 100644 --- a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp @@ -43,6 +43,7 @@ TosaSpecificationVersion getMinVersion(const Extension &extension) { return TosaSpecificationVersion(1, 0); case Extension::mxfp: case Extension::int64: + case Extension::mxfp_conv: return TosaSpecificationVersion(1, 1); case Extension::none: return TosaSpecificationVersion(0, 0); diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index bead774620a4f..7849ff58d5318 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -550,6 +550,15 @@ void CastToBlockScaledOp::print(OpAsmPrinter &parser) { printWithEnumHandling(parser, *this); } +ParseResult Conv2DBlockScaledOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseWithEnumHandling(parser, result); +} + +void Conv2DBlockScaledOp::print(OpAsmPrinter &parser) { + printWithEnumHandling(parser, *this); +} + //===----------------------------------------------------------------------===// // Tosa utilities. //===----------------------------------------------------------------------===// @@ -612,6 +621,55 @@ unsigned mlir::tosa::getBitWidth(Type type) { return type.getIntOrFloatBitWidth(); } +// Update dim size if current dim is dynamic, otherwise raise an error if sizes +// do not match +LogicalResult tryUpdateDimOrFailure(Operation *op, int64_t &currDim, + const int64_t newDim, + const StringRef operandName, + const StringRef dimName) { + if (ShapedType::isDynamic(currDim)) { + currDim = newDim; + return success(); + } else if (ShapedType::isStatic(newDim) && currDim != newDim) { + return op->emitOpError("expected ") + << dimName << " of " << operandName << " to match size " << currDim + << ", got " << newDim; + } + return success(); +} + +LogicalResult verifyConvOutputSize( + Operation *op, const int64_t inputSize, const int64_t kernelSize, + const int64_t outputSize, const int64_t padBefore, const int64_t padAfter, + const int64_t stride, const int64_t dilation, const llvm::StringRef dimName, + const llvm::StringRef dimAxis, const llvm::StringRef padBeforeName, + const llvm::StringRef padAfterName) { + if (inputSize == ShapedType::kDynamic || kernelSize == ShapedType::kDynamic) + return success(); + + // ERROR_IF: O != idiv_check(I - 1 + pa + pb - (K - 1) * d, s) + 1 + + const std::optional calculatedOutSizeMinusOne = idivCheck( + inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation, + stride); + if (!calculatedOutSizeMinusOne.has_value()) + return op->emitOpError("expected input_") + << dimName << " - 1 + pad_" << padBeforeName << " + pad_" + << padAfterName << " - (kernel_" << dimName << " - 1) * dilation_" + << dimAxis << " to be wholly divisible by stride_" << dimAxis + << ", got (" << inputSize << " - 1 + " << padBefore << " + " + << padAfter << " - (" << kernelSize << " - 1) * " << dilation + << ") / " << stride; + + const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1; + if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize) + return op->emitOpError("calculated output ") + << dimName << " did not match expected: " + << "calculated=" << calculatedOutSize << ", expected=" << outputSize; + + return success(); +} + //===----------------------------------------------------------------------===// // TOSA Operator Verifiers. //===----------------------------------------------------------------------===// @@ -791,53 +849,16 @@ static LogicalResult verifyConvOpErrorIf(T op) { llvm::dyn_cast(op.getWeight().getType()); if (inputType && weightType) { - const auto verifyOutputSize = - [&op](const int64_t inputSize, const int64_t kernelSize, - const int64_t outputSize, const int64_t padBefore, - const int64_t padAfter, const int64_t stride, - const int64_t dilation, const llvm::StringRef dimName, - const llvm::StringRef dimAxis, - const llvm::StringRef padBeforeName, - const llvm::StringRef padAfterName) -> LogicalResult { - if (inputSize == ShapedType::kDynamic || - kernelSize == ShapedType::kDynamic) - return success(); - - // ERROR_IF: O != idiv_check(I - 1 + pa + pb - (K - 1) * d, s) + 1 - - const std::optional calculatedOutSizeMinusOne = idivCheck( - inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation, - stride); - if (!calculatedOutSizeMinusOne.has_value()) - return op.emitOpError("expected input_") - << dimName << " - 1 + pad_" << padBeforeName << " + pad_" - << padAfterName << " - (kernel_" << dimName - << " - 1) * dilation_" << dimAxis - << " to be wholly divisible by stride_" << dimAxis << ", got (" - << inputSize << " - 1 + " << padBefore << " + " << padAfter - << " - (" << kernelSize << " - 1) * " << dilation << ") / " - << stride; - - const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1; - if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize) - return op.emitOpError("calculated output ") - << dimName << " did not match expected: " - << "calculated=" << calculatedOutSize - << ", expected=" << outputSize; - - return success(); - }; - // input = [_,IH,IW,_], weight = [_,KH,KW,_], output = [_,OH,OW,_] if constexpr (std::is_same::value) { - if (failed(verifyOutputSize( - inputType.getDimSize(1), weightType.getDimSize(1), + if (failed(verifyConvOutputSize( + op, inputType.getDimSize(1), weightType.getDimSize(1), outputType.getDimSize(1), padding[0], padding[1], strides[0], dilations[0], "height", "y", "top", "bottom"))) return failure(); - if (failed(verifyOutputSize( - inputType.getDimSize(2), weightType.getDimSize(2), + if (failed(verifyConvOutputSize( + op, inputType.getDimSize(2), weightType.getDimSize(2), outputType.getDimSize(2), padding[2], padding[3], strides[1], dilations[1], "width", "x", "left", "right"))) return failure(); @@ -845,14 +866,14 @@ static LogicalResult verifyConvOpErrorIf(T op) { // input = [_,IH,IW,_], weight = [KH,KW,_,_], output = [_,OH,OW,_] if constexpr (std::is_same::value) { - if (failed(verifyOutputSize( - inputType.getDimSize(1), weightType.getDimSize(0), + if (failed(verifyConvOutputSize( + op, inputType.getDimSize(1), weightType.getDimSize(0), outputType.getDimSize(1), padding[0], padding[1], strides[0], dilations[0], "height", "y", "top", "bottom"))) return failure(); - if (failed(verifyOutputSize( - inputType.getDimSize(2), weightType.getDimSize(1), + if (failed(verifyConvOutputSize( + op, inputType.getDimSize(2), weightType.getDimSize(1), outputType.getDimSize(2), padding[2], padding[3], strides[1], dilations[1], "width", "x", "left", "right"))) return failure(); @@ -860,20 +881,20 @@ static LogicalResult verifyConvOpErrorIf(T op) { // input = [_,ID,IH,IW,_], weight = [_,KD,KH,KW,_], output = [_,OD,OH,OW,_] if constexpr (std::is_same::value) { - if (failed(verifyOutputSize( - inputType.getDimSize(1), weightType.getDimSize(1), + if (failed(verifyConvOutputSize( + op, inputType.getDimSize(1), weightType.getDimSize(1), outputType.getDimSize(1), padding[0], padding[1], strides[0], dilations[0], "depth", "d", "front", "back"))) return failure(); - if (failed(verifyOutputSize( - inputType.getDimSize(2), weightType.getDimSize(2), + if (failed(verifyConvOutputSize( + op, inputType.getDimSize(2), weightType.getDimSize(2), outputType.getDimSize(2), padding[2], padding[3], strides[1], dilations[1], "height", "y", "top", "bottom"))) return failure(); - if (failed(verifyOutputSize( - inputType.getDimSize(3), weightType.getDimSize(3), + if (failed(verifyConvOutputSize( + op, inputType.getDimSize(3), weightType.getDimSize(3), outputType.getDimSize(3), padding[4], padding[5], strides[2], dilations[2], "width", "x", "left", "right"))) return failure(); @@ -1954,20 +1975,6 @@ LogicalResult MatmulTBlockScaledOp::verify() { "B_data"))) return failure(); - auto tryUpdateDimOrFailure = [&](int64_t &currDim, const int64_t newDim, - const StringRef operandName, - const StringRef dimName) -> LogicalResult { - if (ShapedType::isDynamic(currDim)) { - currDim = newDim; - return success(); - } else if (ShapedType::isStatic(newDim) && currDim != newDim) { - return emitOpError("expected ") - << dimName << " of " << operandName << " to match size " << currDim - << ", got " << newDim; - } - return success(); - }; - // Verify input shape compatibility int64_t N = ShapedType::kDynamic; int64_t D = ShapedType::kDynamic; @@ -1985,32 +1992,33 @@ LogicalResult MatmulTBlockScaledOp::verify() { const ShapeAdaptor aScaleShape = ShapeAdaptor(getAScale().getType()); if (aScaleShape.hasRank()) { - if (failed(tryUpdateDimOrFailure(N, aScaleShape.getDimSize(0), "a_scale", - "batch")) || - failed(tryUpdateDimOrFailure(H, aScaleShape.getDimSize(1), "a_scale", - "height"))) + if (failed(tryUpdateDimOrFailure(*this, N, aScaleShape.getDimSize(0), + "a_scale", "batch")) || + failed(tryUpdateDimOrFailure(*this, H, aScaleShape.getDimSize(1), + "a_scale", "height"))) return failure(); multiplesOfC = aScaleShape.getDimSize(2); } const ShapeAdaptor bDataShape = ShapeAdaptor(bDataType); if (bDataShape.hasRank()) { - if (failed(tryUpdateDimOrFailure(D, bDataShape.getDimSize(0), "b_data", - "batch")) || - failed(tryUpdateDimOrFailure(C, bDataShape.getDimSize(2), "b_data", - "channels"))) + if (failed(tryUpdateDimOrFailure(*this, D, bDataShape.getDimSize(0), + "b_data", "batch")) || + failed(tryUpdateDimOrFailure(*this, C, bDataShape.getDimSize(2), + "b_data", "channels"))) return failure(); W = bDataShape.getDimSize(1); } const ShapeAdaptor bScaleShape = ShapeAdaptor(getBScale().getType()); if (bScaleShape.hasRank()) { - if (failed(tryUpdateDimOrFailure(D, bScaleShape.getDimSize(0), "b_scale", - "batch")) || - failed(tryUpdateDimOrFailure(W, bScaleShape.getDimSize(1), "b_scale", - "width")) || - failed(tryUpdateDimOrFailure(multiplesOfC, bScaleShape.getDimSize(2), - "b_scale", "C/block_size"))) + if (failed(tryUpdateDimOrFailure(*this, D, bScaleShape.getDimSize(0), + "b_scale", "batch")) || + failed(tryUpdateDimOrFailure(*this, W, bScaleShape.getDimSize(1), + "b_scale", "width")) || + failed(tryUpdateDimOrFailure(*this, multiplesOfC, + bScaleShape.getDimSize(2), "b_scale", + "C/block_size"))) return failure(); } @@ -3485,6 +3493,232 @@ LogicalResult Conv2DOp::verify() { return success(); } +LogicalResult Conv2DBlockScaledOp::inferReturnTypeComponents( + MLIRContext *context, ::std::optional location, + Conv2DBlockScaledOp::Adaptor adaptor, + SmallVectorImpl &inferredReturnShapes) { + SmallVector outShape(4, ShapedType::kDynamic); + + int64_t inputWidth = ShapedType::kDynamic; + int64_t inputHeight = ShapedType::kDynamic; + int64_t weightWidth = ShapedType::kDynamic; + int64_t weightHeight = ShapedType::kDynamic; + + // Input shape describes input width/height and batch. + const ShapeAdaptor inputDataShape(adaptor.getInputData().getType()); + if (inputDataShape.hasRank()) { + outShape[0] = inputDataShape.getDimSize(0); + inputHeight = inputDataShape.getDimSize(1); + inputWidth = inputDataShape.getDimSize(2); + } + const ShapeAdaptor inputScaleShape(adaptor.getInputScale().getType()); + if (inputScaleShape.hasRank()) { + outShape[0] = ShapedType::isDynamic(outShape[0]) + ? inputScaleShape.getDimSize(0) + : outShape[0]; + inputHeight = ShapedType::isDynamic(inputHeight) + ? inputScaleShape.getDimSize(1) + : inputHeight; + inputWidth = ShapedType::isDynamic(inputWidth) + ? inputScaleShape.getDimSize(2) + : inputWidth; + } + + // Weight shapes describes the filter width/height and the output channels. + const ShapeAdaptor weightDataShape(adaptor.getWeightData().getType()); + if (weightDataShape.hasRank()) { + outShape[3] = weightDataShape.getDimSize(0); + weightHeight = weightDataShape.getDimSize(1); + weightWidth = weightDataShape.getDimSize(2); + } + const ShapeAdaptor weightScaleShape(adaptor.getWeightScale().getType()); + if (weightScaleShape.hasRank()) { + outShape[3] = ShapedType::isDynamic(outShape[3]) + ? weightScaleShape.getDimSize(0) + : outShape[3]; + weightHeight = ShapedType::isDynamic(weightHeight) + ? weightScaleShape.getDimSize(1) + : weightHeight; + weightWidth = ShapedType::isDynamic(weightWidth) + ? weightScaleShape.getDimSize(2) + : weightWidth; + } + + // Bias shape can describe the output channels. + const ShapeAdaptor biasShape(adaptor.getBias().getType()); + if (biasShape.hasRank()) { + const int64_t biasSize = biasShape.getDimSize(0); + // Bias of size 1 may be broadcast + if (biasSize != 1) { + outShape[3] = ShapedType::isDynamic(outShape[3]) ? biasSize : outShape[3]; + } + } + + SmallVector padValues; + SmallVector strideValues; + SmallVector dilationValues; + if (!tosa::getConstShapeValues(adaptor.getPad().getDefiningOp(), padValues) || + !tosa::getConstShapeValues(adaptor.getStride().getDefiningOp(), + strideValues) || + !tosa::getConstShapeValues(adaptor.getDilation().getDefiningOp(), + dilationValues)) { + inferredReturnShapes.push_back(ShapedTypeComponents(outShape)); + return success(); + } + + if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) { + const int64_t inputSize = inputHeight + padValues[0] + padValues[1]; + const int64_t filterSize = (weightHeight - 1) * dilationValues[0] + 1; + const int64_t unstridedResult = inputSize - filterSize + 1; + outShape[1] = (unstridedResult - 1) / strideValues[0] + 1; + } + + if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) { + const int64_t inputSize = inputWidth + padValues[2] + padValues[3]; + const int64_t filterSize = (weightWidth - 1) * dilationValues[1] + 1; + const int64_t unstridedResult = inputSize - filterSize + 1; + outShape[2] = (unstridedResult - 1) / strideValues[1] + 1; + } + + inferredReturnShapes.push_back(ShapedTypeComponents(outShape)); + return success(); +} + +LogicalResult Conv2DBlockScaledOp::verify() { + if (failed(verifySameElementTypes(*this, getInputData().getType(), + getWeightData().getType())) || + failed(verifySameElementTypes(*this, getInputScale().getType(), + getWeightScale().getType())) || + failed(verifySameElementTypes(*this, getBias().getType(), + getOutput().getType()))) + return failure(); + + // Verify input shape compatibility + int64_t N = ShapedType::kDynamic; + int64_t IH = ShapedType::kDynamic; + int64_t IW = ShapedType::kDynamic; + int64_t IC = ShapedType::kDynamic; + int64_t multiplesOfIC = ShapedType::kDynamic; + int64_t OC = ShapedType::kDynamic; + int64_t KH = ShapedType::kDynamic; + int64_t KW = ShapedType::kDynamic; + + const ShapeAdaptor inputDataShape(getInputData().getType()); + if (inputDataShape.hasRank()) { + N = inputDataShape.getDimSize(0); + IH = inputDataShape.getDimSize(1); + IW = inputDataShape.getDimSize(2); + IC = inputDataShape.getDimSize(3); + } + + const ShapeAdaptor inputScaleShape(getInputScale().getType()); + if (inputScaleShape.hasRank()) { + if (failed(tryUpdateDimOrFailure(*this, N, inputScaleShape.getDimSize(0), + "input_scale", "batch size")) || + failed(tryUpdateDimOrFailure(*this, IH, inputScaleShape.getDimSize(1), + "input_scale", "input height")) || + failed(tryUpdateDimOrFailure(*this, IW, inputScaleShape.getDimSize(2), + "input_scale", "input width"))) + return failure(); + multiplesOfIC = inputScaleShape.getDimSize(3); + } + + const ShapeAdaptor weightDataShape(getWeightData().getType()); + if (weightDataShape.hasRank()) { + OC = weightDataShape.getDimSize(0); + KH = weightDataShape.getDimSize(1); + KW = weightDataShape.getDimSize(2); + if (failed(tryUpdateDimOrFailure(*this, IC, weightDataShape.getDimSize(3), + "weight_data", "input channels"))) + return failure(); + } + + const ShapeAdaptor weightScaleShape(getWeightScale().getType()); + if (weightScaleShape.hasRank()) { + if (failed(tryUpdateDimOrFailure(*this, OC, weightScaleShape.getDimSize(0), + "weight_scale", "output channels")) || + failed(tryUpdateDimOrFailure(*this, KH, weightScaleShape.getDimSize(1), + "weight_scale", "kernel height")) || + failed(tryUpdateDimOrFailure(*this, KW, weightScaleShape.getDimSize(2), + "weight_scale", "kernel width")) || + failed(tryUpdateDimOrFailure(*this, multiplesOfIC, + weightScaleShape.getDimSize(3), + "weight_scale", "input channel blocks"))) + return failure(); + } + + // Verify IC is a multiple of block size + const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize()); + if (ShapedType::isStatic(IC) && IC % blockSize != 0) + return emitOpError("expect IC to be a multiple of block size, got IC=") + << IC << ", block_size=" << blockSize; + + // Verify multiplesOfIC is IC / block size + if (ShapedType::isStatic(IC) && ShapedType::isStatic(multiplesOfIC) && + multiplesOfIC != IC / blockSize) + return emitOpError( + "expect scale operands dimension 2 to equal IC/block_size (") + << IC << "/" << blockSize << ")" + << ", got " << multiplesOfIC; + + // Verify pad/stride/dilation values + SmallVector padValues; + if (tosa::getConstShapeValues(getPad().getDefiningOp(), padValues)) { + if (llvm::any_of(padValues, [](int64_t p) { return p < 0; })) + return emitOpError("expect all padding values to be >= 0, got ") + << padValues; + } + + SmallVector strideValues; + if (tosa::getConstShapeValues(getStride().getDefiningOp(), strideValues)) { + if (llvm::any_of(strideValues, [](int64_t s) { return s < 1; })) + return emitOpError("expect all stride values to be >= 1, got ") + << strideValues; + } + + SmallVector dilationValues; + if (tosa::getConstShapeValues(getDilation().getDefiningOp(), + dilationValues)) { + if (llvm::any_of(dilationValues, [](int64_t d) { return d < 1; })) + return emitOpError("expect all dilation values to be >= 1, got ") + << dilationValues; + } + + // Verify output shape compatibility + const ShapeAdaptor outputShape(getOutput().getType()); + if (!padValues.empty() && !strideValues.empty() && !dilationValues.empty() && + outputShape.hasRank()) { + if (failed(verifyConvOutputSize(*this, IH, KH, outputShape.getDimSize(1), + padValues[0], padValues[1], strideValues[0], + dilationValues[0], "height", "y", "top", + "bottom")) || + failed(verifyConvOutputSize(*this, IW, KW, outputShape.getDimSize(2), + padValues[2], padValues[3], strideValues[1], + dilationValues[1], "width", "x", "left", + "right"))) + return failure(); + } + + // Verify bias + const ShapeAdaptor biasShape(getBias().getType()); + if (biasShape.hasRank() && outputShape.hasRank()) { + const int64_t biasChannels = biasShape.getDimSize(0); + const int64_t outputChannels = + outputShape.getDimSize(outputShape.getRank() - 1); + if (biasChannels == ShapedType::kDynamic || + outputChannels == ShapedType::kDynamic) + // Skip following checks if biasChannels or outputChannels is dynamic dim + return success(); + + if (biasChannels != outputChannels && biasChannels != 1) + return emitOpError( + "bias channels expected to be equal to output channels (") + << outputChannels << ") or 1, got " << biasChannels; + } + + return success(); +} + LogicalResult Conv3DOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, Conv3DOp::Adaptor adaptor, diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp index ddd9c70402fdc..4b3beb87e829a 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp @@ -111,6 +111,18 @@ ProfileInfoDepot::populateProfileInfo(tosa::DepthwiseConv2DOp op) { return populateProfileInfoConv(op); } +template <> +LogicalResult +ProfileInfoDepot::populateProfileInfo(tosa::Conv2DBlockScaledOp op) { + addValue(op.getInputData()); + addValue(op.getInputScale()); + addValue(op.getWeightData()); + addValue(op.getWeightScale()); + addValue(op.getBias()); + addValue(op.getOutput()); + return success(); +} + template <> LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::PadOp op) { addValue(op.getInput1()); @@ -239,6 +251,7 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) { POPULATE_PROFILE_INFO_CUSTOM(AvgPool2d) POPULATE_PROFILE_INFO_CUSTOM(TransposeConv2D) POPULATE_PROFILE_INFO_CUSTOM(Conv2D) + POPULATE_PROFILE_INFO_CUSTOM(Conv2DBlockScaled) POPULATE_PROFILE_INFO_CUSTOM(Conv3D) POPULATE_PROFILE_INFO_CUSTOM(DepthwiseConv2D) POPULATE_PROFILE_INFO_CUSTOM(Mul) diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp index b54ed5585d72d..b06792ea695c1 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -354,6 +354,55 @@ struct TosaValidation : public tosa::impl::TosaValidationBase { return success(); } + LogicalResult levelCheckConv2DBlockScaled(Operation *op) { + auto convOp = dyn_cast(op); + if (!convOp) + return success(); + + DenseIntElementsAttr padding; + if (matchPattern(convOp.getPad(), m_Constant(&padding))) { + const SmallVector padValues = convertFromIntAttr(padding, 4); + for (const auto p : padValues) + if (failed(levelCheckKernel(op, p, "pad <= MAX_KERNEL"))) + return failure(); + } + + DenseIntElementsAttr stride; + if (matchPattern(convOp.getStride(), m_Constant(&stride))) { + const SmallVector strideValues = convertFromIntAttr(stride, 4); + for (const auto s : strideValues) + if (failed(levelCheckKernel(op, s, "stride <= MAX_KERNEL"))) + return failure(); + } + + DenseIntElementsAttr dilation; + if (matchPattern(convOp.getDilation(), m_Constant(&dilation))) { + const SmallVector dilationValues = + convertFromIntAttr(dilation, 4); + + int64_t KH = ShapedType::kDynamic; + int64_t KW = ShapedType::kDynamic; + const ShapeAdaptor weightDataShape(convOp.getWeightData().getType()); + KH = weightDataShape.getDimSize(1); + KW = weightDataShape.getDimSize(2); + const ShapeAdaptor weightScaleShape(convOp.getWeightScale().getType()); + KH = ShapedType::isDynamic(KH) ? weightScaleShape.getDimSize(1) : KH; + KW = ShapedType::isDynamic(KW) ? weightScaleShape.getDimSize(2) : KW; + + if (!ShapedType::isDynamic(KH) && + failed(levelCheckKernel(op, dilationValues[0] * KH, + "dilation_y * KH <= MAX_KERNEL)"))) + return failure(); + + if (!ShapedType::isDynamic(KW) && + failed(levelCheckKernel(op, dilationValues[1] * KW, + "dilation_x * KW <= MAX_KERNEL)"))) + return failure(); + } + + return success(); + } + // FFT op: level check H, W in input shape [N,H,W] template LogicalResult levelCheckFFT(Operation *op) { @@ -654,6 +703,7 @@ LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) { // Tensor Operators CHECK_SIZES(AvgPool2d); CHECK_SIZES(Conv2D); + CHECK_SIZES(Conv2DBlockScaled); CHECK_SIZES(Conv3D); CHECK_SIZES(DepthwiseConv2D); CHECK_SIZES(TransposeConv2D); @@ -722,7 +772,6 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) { if (failed(levelCheckRanksAndSizes(op))) return failure(); - // additional level checks from spec 0.70 if (failed(levelCheckPool(op)) || failed(levelCheckConv(op)) || failed(levelCheckConv(op)) || @@ -730,7 +779,8 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) { failed(levelCheckFFT(op)) || failed(levelCheckPool(op)) || failed(levelCheckFFT(op)) || - failed(levelCheckTransposeConv2d(op)) || failed(levelCheckResize(op))) { + failed(levelCheckTransposeConv2d(op)) || failed(levelCheckResize(op)) || + failed(levelCheckConv2DBlockScaled(op))) { return failure(); } diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index 3d24928487ed2..89d249e42d870 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -4,7 +4,7 @@ // validation flow. //-------------------------------------------------------------------------------------------------- -// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.1.draft profiles=pro_int,pro_fp extensions=int16,int4,int64,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment" +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.1.draft profiles=pro_int,pro_fp extensions=int16,int4,int64,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround,mxfp_conv" -tosa-validate="strict-op-spec-alignment" func.func @test_cast(%arg0: tensor) -> tensor<5xi32> { @@ -2067,3 +2067,14 @@ func.func @test_rfft2d(%arg0: tensor<13x8x16xbf16>) -> (tensor<13x8x9xbf16>, ten %0, %1 = tosa.rfft2d %arg0 : (tensor<13x8x16xbf16>) -> (tensor<13x8x9xbf16>, tensor<13x8x9xbf16>) return %0, %1 : tensor<13x8x9xbf16>, tensor<13x8x9xbf16> } + +// ----- + +func.func @test_conv2d_block_scaled(%arg0: tensor<*xf4E2M1FN>, %arg1: tensor<*xf8E8M0FNU>, %arg2: tensor<*xf4E2M1FN>, %arg3: tensor<*xf8E8M0FNU>, %arg4: tensor<*xf16>) -> tensor<*xf16> { + %0 = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4> + %1 = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> + %2 = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> + // expected-error@+1 {{'tosa.conv2d_block_scaled' op illegal: operation operand/result data types did not align with any profile or extension, got (fp4e2m1,fp8e8m0,fp4e2m1,fp8e8m0,f16,f16), did you mean (fp4e2m1,fp8e8m0,fp4e2m1,fp8e8m0,f32,f32)? Otherwise, please refer to the 'supported data types' for 'tosa.conv2d_block_scaled' in the specification.}} + %3 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %0, %1, %2 {block_size = BLOCK_SIZE_32} : (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>, tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>, tensor<*xf16>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xf16> + return %3 : tensor<*xf16> +} diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir index 177192ba5440d..0158969eab6e4 100644 --- a/mlir/test/Dialect/Tosa/invalid_extension.mlir +++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir @@ -564,7 +564,7 @@ func.func @test_const_fp6e3m2(%arg0 : index) -> tensor<4xf6E3M2FN> { // ----- -func.func @test_cast_from_block_scaled_static(%arg0: tensor<4x32xf8E5M2>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> { +func.func @test_cast_from_block_scaled(%arg0: tensor<4x32xf8E5M2>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> { // expected-error@+1 {{'tosa.cast_from_block_scaled' op illegal: requires [mxfp] but not enabled in target}} %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size : i32} : (tensor<4x32xf8E5M2>, tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> return %0 : tensor<4x32xf32> @@ -572,8 +572,19 @@ func.func @test_cast_from_block_scaled_static(%arg0: tensor<4x32xf8E5M2>, %arg1: // ----- -func.func @test_cast_to_block_scaled_static(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf6E3M2FN>, tensor<4x1xf8E8M0FNU>) { +func.func @test_cast_to_block_scaled(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf6E3M2FN>, tensor<4x1xf8E8M0FNU>) { // expected-error@+1 {{'tosa.cast_to_block_scaled' op illegal: requires [mxfp] but not enabled in target}} %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size} : (tensor<4x32xf32>) -> (tensor<4x32xf6E3M2FN>, tensor<4x1xf8E8M0FNU>) return %0#0, %0#1 : tensor<4x32xf6E3M2FN>, tensor<4x1xf8E8M0FNU> } + +// ----- + +func.func @test_conv2d_block_scaled(%arg0: tensor<*xf4E2M1FN>, %arg1: tensor<*xf8E8M0FNU>, %arg2: tensor<*xf4E2M1FN>, %arg3: tensor<*xf8E8M0FNU>, %arg4: tensor<*xf32>) -> tensor<*xf32> { + %0 = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4> + %1 = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> + %2 = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> + // expected-error@+1 {{'tosa.conv2d_block_scaled' op illegal: requires [mxfp_conv] but not enabled in target}} + %3 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %0, %1, %2 {block_size = BLOCK_SIZE_32} : (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>, tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>, tensor<*xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xf32> + return %3 : tensor<*xf32> +} diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir index a7087647e542b..7a94f9db82d94 100644 --- a/mlir/test/Dialect/Tosa/level_check.mlir +++ b/mlir/test/Dialect/Tosa/level_check.mlir @@ -1662,3 +1662,74 @@ func.func @test_cast_to_block_scaled_invalid_rank(%arg0: tensor<1x2x3x4x5x6x7x32 %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size} : (tensor<1x2x3x4x5x6x7x32xf32>) -> (tensor<1x2x3x4x5x6x7x32xf6E2M3FN>, tensor<1x2x3x4x5x6x7x1xf8E8M0FNU>) return %0#0, %0#1 : tensor<1x2x3x4x5x6x7x32xf6E2M3FN>, tensor<1x2x3x4x5x6x7x1xf8E8M0FNU> } + +// ----- + +func.func @test_conv2d_block_scaled_invalid_size(%arg0: tensor<67108864x4x4x64xf4E2M1FN>, %arg1: tensor<67108864x4x4x2xf8E8M0FNU>, %arg2: tensor<8x1x1x64xf4E2M1FN>, %arg3: tensor<8x1x1x2xf8E8M0FNU>, %arg4: tensor<1xf32>) -> tensor<67108864x4x4x8xf32> { + %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4> + %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> + %dilation = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> + // expected-error@+1 {{'tosa.conv2d_block_scaled' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}} + %0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = BLOCK_SIZE_32} : (tensor<67108864x4x4x64xf4E2M1FN>, tensor<67108864x4x4x2xf8E8M0FNU>, tensor<8x1x1x64xf4E2M1FN>, tensor<8x1x1x2xf8E8M0FNU>, tensor<1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<67108864x4x4x8xf32> + return %0 : tensor<67108864x4x4x8xf32> +} + +// ----- + +func.func @test_conv2d_block_scaled_dilation_y(%arg0: tensor<1x8191x8191x32xf8E4M3FN>, %arg1: tensor<1x8191x8191x1xf8E8M0FNU>, %arg2: tensor<16x1025x1024x32xf8E4M3FN>, %arg3: tensor<16x1025x1024x1xf8E8M0FNU>, %arg4: tensor<16xf32>) -> tensor<1x9x7178x16xf32> { + %pad = tosa.const_shape {values = dense<[10, 0, 10, 0]> : tensor<4xindex>} : () -> !tosa.shape<4> + %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> + %dilation = tosa.const_shape {values = dense<[8, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> + // expected-error@+1 {{'tosa.conv2d_block_scaled' op failed level check: dilation_y * KH <= MAX_KERNEL}} + %0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = BLOCK_SIZE_32} : + (tensor<1x8191x8191x32xf8E4M3FN>, tensor<1x8191x8191x1xf8E8M0FNU>, tensor<16x1025x1024x32xf8E4M3FN>, tensor<16x1025x1024x1xf8E8M0FNU>, tensor<16xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x9x7178x16xf32> + return %0 : tensor<1x9x7178x16xf32> +} + +// ----- + +func.func @test_conv2d_block_scaled_dilation_x(%arg0: tensor<1x8191x8191x32xf8E4M3FN>, %arg1: tensor<1x8191x8191x1xf8E8M0FNU>, %arg2: tensor<16x1024x1025x32xf8E4M3FN>, %arg3: tensor<16x1024x1025x1xf8E8M0FNU>, %arg4: tensor<16xf32>) -> tensor<1x7178x9x16xf32> { + %pad = tosa.const_shape {values = dense<[10, 0, 10, 0]> : tensor<4xindex>} : () -> !tosa.shape<4> + %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> + %dilation = tosa.const_shape {values = dense<[1, 8]> : tensor<2xindex>} : () -> !tosa.shape<2> + // expected-error@+1 {{'tosa.conv2d_block_scaled' op failed level check: dilation_x * KW <= MAX_KERNEL}} + %0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = BLOCK_SIZE_32} : + (tensor<1x8191x8191x32xf8E4M3FN>, tensor<1x8191x8191x1xf8E8M0FNU>, tensor<16x1024x1025x32xf8E4M3FN>, tensor<16x1024x1025x1xf8E8M0FNU>, tensor<16xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x7178x9x16xf32> + return %0 : tensor<1x7178x9x16xf32> +} + +// ----- + +func.func @test_conv2d_block_scaled_pad_top(%arg0: tensor<1x32x32x32xf8E4M3FN>, %arg1: tensor<1x32x32x1xf8E8M0FNU>, %arg2: tensor<16x2x2x32xf8E4M3FN>, %arg3: tensor<16x2x2x1xf8E8M0FNU>, %arg4: tensor<16xf32>) -> tensor<1x8224x31x16xf32> { + %pad = tosa.const_shape {values = dense<[8193, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4> + %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> + %dilation = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> + // expected-error@+1 {{'tosa.conv2d_block_scaled' op failed level check: pad <= MAX_KERNEL}} + %0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = #tosa.block_size} : + (tensor<1x32x32x32xf8E4M3FN>, tensor<1x32x32x1xf8E8M0FNU>, tensor<16x2x2x32xf8E4M3FN>, tensor<16x2x2x1xf8E8M0FNU>, tensor<16xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x8224x31x16xf32> + return %0 : tensor<1x8224x31x16xf32> +} + +// ----- + +func.func @test_conv2d_block_scaled_pad_right(%arg0: tensor<1x32x32x32xf8E4M3FN>, %arg1: tensor<1x32x32x1xf8E8M0FNU>, %arg2: tensor<16x2x2x32xf8E4M3FN>, %arg3: tensor<16x2x2x1xf8E8M0FNU>, %arg4: tensor<16xf32>) -> tensor<1x31x8224x16xf32> { + %pad = tosa.const_shape {values = dense<[0, 0, 0, 8193]> : tensor<4xindex>} : () -> !tosa.shape<4> + %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> + %dilation = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> + // expected-error@+1 {{'tosa.conv2d_block_scaled' op failed level check: pad <= MAX_KERNEL}} + %0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = #tosa.block_size} : + (tensor<1x32x32x32xf8E4M3FN>, tensor<1x32x32x1xf8E8M0FNU>, tensor<16x2x2x32xf8E4M3FN>, tensor<16x2x2x1xf8E8M0FNU>, tensor<16xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x31x8224x16xf32> + return %0 : tensor<1x31x8224x16xf32> +} + +// ----- + +func.func @test_conv2d_block_scaled_stride_y(%arg0: tensor<1x8194x33x32xf8E4M3FN>, %arg1: tensor<1x8194x33x1xf8E8M0FNU>, %arg2: tensor<16x2x2x32xf8E4M3FN>, %arg3: tensor<16x2x2x1xf8E8M0FNU>, %arg4: tensor<16xf32>) -> tensor<1x2x32x16xf32> { + %pad = tosa.const_shape {values = dense<[1, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4> + %stride = tosa.const_shape {values = dense<[8193, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> + %dilation = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> + // expected-error@+1 {{'tosa.conv2d_block_scaled' op failed level check: stride <= MAX_KERNEL}} + %0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = #tosa.block_size} : + (tensor<1x8194x33x32xf8E4M3FN>, tensor<1x8194x33x1xf8E8M0FNU>, tensor<16x2x2x32xf8E4M3FN>, tensor<16x2x2x1xf8E8M0FNU>, tensor<16xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x2x32x16xf32> + return %0 : tensor<1x2x32x16xf32> +} diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir index 652447bd6056e..d33d5ebbe992b 100644 --- a/mlir/test/Dialect/Tosa/ops.mlir +++ b/mlir/test/Dialect/Tosa/ops.mlir @@ -1381,3 +1381,23 @@ func.func @test_const_mxint8(%arg0 : index) -> tensor<2x!tosa.mxint8> { %0 = "tosa.const"() {values = dense<"0x007F"> : tensor<2x!tosa.mxint8>} : () -> tensor<2x!tosa.mxint8> return %0 : tensor<2x!tosa.mxint8> } + +// ----- +// CHECK-LABEL: test_conv2d_block_scaled_static +func.func @test_conv2d_block_scaled_static(%arg0: tensor<1x4x4x64xf4E2M1FN>, %arg1: tensor<1x4x4x2xf8E8M0FNU>, %arg2: tensor<8x1x1x64xf4E2M1FN>, %arg3: tensor<8x1x1x2xf8E8M0FNU>, %arg4: tensor<1xf32>) -> tensor<*xf32> { + %0 = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4> + %1 = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> + %2 = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> + %3 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %0, %1, %2 {block_size = BLOCK_SIZE_32} : (tensor<1x4x4x64xf4E2M1FN>, tensor<1x4x4x2xf8E8M0FNU>, tensor<8x1x1x64xf4E2M1FN>, tensor<8x1x1x2xf8E8M0FNU>, tensor<1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xf32> + return %3 : tensor<*xf32> +} + +// ----- +// CHECK-LABEL: test_conv2d_block_scaled_dynamic +func.func @test_conv2d_block_scaled_dynamic(%arg0: tensor<*xf4E2M1FN>, %arg1: tensor<*xf8E8M0FNU>, %arg2: tensor<*xf4E2M1FN>, %arg3: tensor<*xf8E8M0FNU>, %arg4: tensor<*xf32>) -> tensor<*xf32> { + %0 = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4> + %1 = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> + %2 = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> + %3 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %0, %1, %2 {block_size = BLOCK_SIZE_32} : (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>, tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>, tensor<*xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xf32> + return %3 : tensor<*xf32> +} diff --git a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir index 7de7b85bcaedf..15aad410c6f44 100644 --- a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir +++ b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir @@ -2,7 +2,7 @@ // Enable all supported extensions to focus the verification of expected profile requirement errors. //-------------------------------------------------------------------------------------------------- -// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.1.draft profiles=pro_int extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround,mxfp" -tosa-validate="strict-op-spec-alignment" +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.1.draft profiles=pro_int extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround,mxfp,mxfp_conv" -tosa-validate="strict-op-spec-alignment" // ----- func.func @test_const_f16() -> tensor<3x11x11x3xf16> { @@ -334,15 +334,25 @@ func.func @test_matmul_t_block_scaled(%arg0: tensor<4x8x32xf6E3M2FN>, %arg1: ten } // ----- -func.func @test_cast_from_block_scaled_static(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> { +func.func @test_cast_from_block_scaled(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> { // expected-error@+1 {{'tosa.cast_from_block_scaled' op illegal: requires [pro_fp] but not enabled in target}} %0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size : i32} : (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> return %0 : tensor<4x32xf32> } // ----- -func.func @test_cast_to_block_scaled_static(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) { +func.func @test_cast_to_block_scaled(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) { // expected-error@+1 {{'tosa.cast_to_block_scaled' op illegal: requires [pro_fp] but not enabled in target}} %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size} : (tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU>) return %0#0, %0#1 : tensor<4x32xf4E2M1FN>, tensor<4x1xf8E8M0FNU> } + +// ----- +func.func @test_conv2d_block_scaled(%arg0: tensor<*xf4E2M1FN>, %arg1: tensor<*xf8E8M0FNU>, %arg2: tensor<*xf4E2M1FN>, %arg3: tensor<*xf8E8M0FNU>, %arg4: tensor<*xf32>) -> tensor<*xf32> { + %0 = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4> + %1 = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> + %2 = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> + // expected-error@+1 {{'tosa.conv2d_block_scaled' op illegal: requires [pro_fp] but not enabled in target}} + %3 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %0, %1, %2 {block_size = BLOCK_SIZE_32} : (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>, tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>, tensor<*xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xf32> + return %3 : tensor<*xf32> +} diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir index 54556a0eb08e0..b74540f060cfe 100644 --- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir +++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir @@ -1673,3 +1673,51 @@ func.func @test_cast_to_block_scaled_dynamic_scales(%arg0: tensor<4x?xf32>) -> ( %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size} : (tensor<4x?xf32>) -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>) return %0#0, %0#1 : tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU> } + +// ----- + +// CHECK-LABEL: test_conv2d_block_scaled_static +func.func @test_conv2d_block_scaled_static(%arg0: tensor<1x4x4x64xf4E2M1FN>, %arg1: tensor<1x4x4x2xf8E8M0FNU>, %arg2: tensor<8x1x1x64xf4E2M1FN>, %arg3: tensor<8x1x1x2xf8E8M0FNU>, %arg4: tensor<1xf32>, %arg5: tensor<4xi32>, %arg6: tensor<2xi32>, %arg7: tensor<2xi32>) -> tensor<*xf32> { + %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4> + %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> + %dilation = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> + // CHECK: -> tensor<1x4x4x8xf32> + %0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = #tosa.block_size} : (tensor<1x4x4x64xf4E2M1FN>, tensor<1x4x4x2xf8E8M0FNU>, tensor<8x1x1x64xf4E2M1FN>, tensor<8x1x1x2xf8E8M0FNU>, tensor<1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// ----- + +// CHECK-LABEL: test_conv2d_block_scaled_dynamic_scales +func.func @test_conv2d_block_scaled_dynamic_scales(%arg0: tensor, %arg1: tensor<*xf8E8M0FNU>, %arg2: tensor, %arg3: tensor<*xf8E8M0FNU>, %arg4: tensor<1xf32>, %arg5: tensor<4xi32>, %arg6: tensor<2xi32>, %arg7: tensor<2xi32>) -> tensor<*xf32> { + %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4> + %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> + %dilation = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> + // CHECK: -> tensor + %0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = #tosa.block_size} : (tensor, tensor<*xf8E8M0FNU>, tensor, tensor<*xf8E8M0FNU>, tensor<1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// ----- + +// CHECK-LABEL: test_conv2d_block_scaled_dynamic_data +func.func @test_conv2d_block_scaled_dynamic_data(%arg0: tensor<*xf4E2M1FN>, %arg1: tensor<1x4x4x2xf8E8M0FNU>, %arg2: tensor<*xf4E2M1FN>, %arg3: tensor<8x1x1x2xf8E8M0FNU>, %arg4: tensor<1xf32>, %arg5: tensor<4xi32>, %arg6: tensor<2xi32>, %arg7: tensor<2xi32>) -> tensor<*xf32> { + %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4> + %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> + %dilation = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> + // CHECK: -> tensor<1x4x4x8xf32> + %0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = #tosa.block_size} : (tensor<*xf4E2M1FN>, tensor<1x4x4x2xf8E8M0FNU>, tensor<*xf4E2M1FN>, tensor<8x1x1x2xf8E8M0FNU>, tensor<1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// ----- + +// CHECK-LABEL: test_conv2d_block_scaled_dynamic_unranked +func.func @test_conv2d_block_scaled_dynamic_unranked(%arg0: tensor<*xf4E2M1FN>, %arg1: tensor<*xf8E8M0FNU>, %arg2: tensor<*xf4E2M1FN>, %arg3: tensor<*xf8E8M0FNU>, %arg4: tensor<1xf32>, %arg5: tensor<4xi32>, %arg6: tensor<2xi32>, %arg7: tensor<2xi32>) -> tensor<*xf32> { + %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4> + %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> + %dilation = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> + // CHECK: -> tensor + %0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = #tosa.block_size} : (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>, tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>, tensor<1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} diff --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir index f6b1edc21ea5a..5725acbc740ec 100644 --- a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir +++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.1.draft profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround,mxfp,int64" -tosa-validate="strict-op-spec-alignment" | FileCheck %s +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.1.draft profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround,mxfp,int64,mxfp_conv" -tosa-validate="strict-op-spec-alignment" | FileCheck %s // ----- @@ -140,3 +140,14 @@ func.func @test_scatter_const_indices_int64(%arg0: tensor<2x52x3xf32>, %arg2: te %0 = tosa.scatter %arg0, %indices, %arg2 : (tensor<2x52x3xf32>, tensor<2x12xi64>, tensor<2x12x3xf32>) -> tensor<2x52x3xf32> return %0 : tensor<2x52x3xf32> } + +// ----- + +// CHECK-LABEL: test_conv2d_block_scaled +func.func @test_conv2d_block_scaled(%arg0: tensor<1x4x4x64xf4E2M1FN>, %arg1: tensor<1x4x4x2xf8E8M0FNU>, %arg2: tensor<8x1x1x64xf4E2M1FN>, %arg3: tensor<8x1x1x2xf8E8M0FNU>, %arg4: tensor<1xf32>) -> tensor<1x4x4x8xf32> { + %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4> + %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> + %dilation = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> + %0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = BLOCK_SIZE_32} : (tensor<1x4x4x64xf4E2M1FN>, tensor<1x4x4x2xf8E8M0FNU>, tensor<8x1x1x64xf4E2M1FN>, tensor<8x1x1x2xf8E8M0FNU>, tensor<1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x4x4x8xf32> + return %0 : tensor<1x4x4x8xf32> +} diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir index ea64d468f151e..cb7e0f5dd62b9 100644 --- a/mlir/test/Dialect/Tosa/verifier.mlir +++ b/mlir/test/Dialect/Tosa/verifier.mlir @@ -1230,3 +1230,135 @@ func.func @test_clamp_quantized(%arg0:tensor>) -> tensor> return %0 : tensor> } + +// ----- + +func.func @test_conv2d_block_scaled_data_type_mismatch(%arg0: tensor<1x4x4x64xf4E2M1FN>, %arg1: tensor<1x4x4x2xf8E8M0FNU>, %arg2: tensor<8x1x1x64xf8E4M3FN>, %arg3: tensor<8x1x1x2xf8E8M0FNU>, %arg4: tensor<1xf32>) -> tensor<*xf32> { + %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4> + %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> + %dilation = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> + // expected-error@+1 {{'tosa.conv2d_block_scaled' op expect input and output to have same element type, got 'f4E2M1FN' and 'f8E4M3FN'}} + %3 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = BLOCK_SIZE_32} : (tensor<1x4x4x64xf4E2M1FN>, tensor<1x4x4x2xf8E8M0FNU>, tensor<8x1x1x64xf8E4M3FN>, tensor<8x1x1x2xf8E8M0FNU>, tensor<1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xf32> + return %3 : tensor<*xf32> +} + +// ----- + +func.func @test_conv2d_block_scaled_bias_output_type_mismatch(%arg0: tensor<1x4x4x64xf4E2M1FN>, %arg1: tensor<1x4x4x2xf8E8M0FNU>, %arg2: tensor<8x1x1x64xf4E2M1FN>, %arg3: tensor<8x1x1x2xf8E8M0FNU>, %arg4: tensor<1xf16>) -> tensor<*xf32> { + %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4> + %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> + %dilation = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> + // expected-error@+1 {{'tosa.conv2d_block_scaled' op expect input and output to have same element type, got 'f16' and 'f32'}} + %3 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = BLOCK_SIZE_32} : (tensor<1x4x4x64xf4E2M1FN>, tensor<1x4x4x2xf8E8M0FNU>, tensor<8x1x1x64xf4E2M1FN>, tensor<8x1x1x2xf8E8M0FNU>, tensor<1xf16>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xf32> + return %3 : tensor<*xf32> +} + +// ----- + +func.func @test_conv2d_block_scaled_invalid_padding(%arg0: tensor<1x4x4x64xf4E2M1FN>, %arg1: tensor<1x4x4x2xf8E8M0FNU>, %arg2: tensor<8x1x1x64xf4E2M1FN>, %arg3: tensor<8x1x1x2xf8E8M0FNU>, %arg4: tensor<1xf32>, %arg5: tensor<4xi32>, %arg6: tensor<2xi32>, %arg7: tensor<2xi32>) -> tensor<*xf32> { + %pad = tosa.const_shape {values = dense<[0, 0, 0, -1]> : tensor<4xindex>} : () -> !tosa.shape<4> + %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> + %dilation = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> + // expected-error@+1 {{'tosa.conv2d_block_scaled' op expect all padding values to be >= 0, got 0, 0, 0, -1}} + %0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = #tosa.block_size} : (tensor<1x4x4x64xf4E2M1FN>, tensor<1x4x4x2xf8E8M0FNU>, tensor<8x1x1x64xf4E2M1FN>, tensor<8x1x1x2xf8E8M0FNU>, tensor<1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// ----- + +func.func @test_conv2d_block_scaled_invalid_stride(%arg0: tensor<1x4x4x64xf4E2M1FN>, %arg1: tensor<1x4x4x2xf8E8M0FNU>, %arg2: tensor<8x1x1x64xf4E2M1FN>, %arg3: tensor<8x1x1x2xf8E8M0FNU>, %arg4: tensor<1xf32>, %arg5: tensor<4xi32>, %arg6: tensor<2xi32>, %arg7: tensor<2xi32>) -> tensor<*xf32> { + %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4> + %stride = tosa.const_shape {values = dense<[0, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> + %dilation = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> + // expected-error@+1 {{'tosa.conv2d_block_scaled' op expect all stride values to be >= 1, got 0, 1}} + %0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = #tosa.block_size} : (tensor<1x4x4x64xf4E2M1FN>, tensor<1x4x4x2xf8E8M0FNU>, tensor<8x1x1x64xf4E2M1FN>, tensor<8x1x1x2xf8E8M0FNU>, tensor<1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// ----- + +func.func @test_conv2d_block_scaled_invalid_dilation(%arg0: tensor<1x4x4x64xf4E2M1FN>, %arg1: tensor<1x4x4x2xf8E8M0FNU>, %arg2: tensor<8x1x1x64xf4E2M1FN>, %arg3: tensor<8x1x1x2xf8E8M0FNU>, %arg4: tensor<1xf32>, %arg5: tensor<4xi32>, %arg6: tensor<2xi32>, %arg7: tensor<2xi32>) -> tensor<*xf32> { + %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4> + %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> + %dilation = tosa.const_shape {values = dense<[1, 0]> : tensor<2xindex>} : () -> !tosa.shape<2> + // expected-error@+1 {{'tosa.conv2d_block_scaled' op expect all dilation values to be >= 1, got 1, 0}} + %0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = #tosa.block_size} : (tensor<1x4x4x64xf4E2M1FN>, tensor<1x4x4x2xf8E8M0FNU>, tensor<8x1x1x64xf4E2M1FN>, tensor<8x1x1x2xf8E8M0FNU>, tensor<1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// ----- + +func.func @test_conv2d_block_scaled_input_width_mismatch(%arg0: tensor<1x4x4x64xf4E2M1FN>, %arg1: tensor<1x4x5x2xf8E8M0FNU>, %arg2: tensor<8x1x1x64xf4E2M1FN>, %arg3: tensor<8x1x1x2xf8E8M0FNU>, %arg4: tensor<1xf32>, %arg5: tensor<4xi32>, %arg6: tensor<2xi32>, %arg7: tensor<2xi32>) -> tensor<1x4x4x8xf32> { + %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4> + %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> + %dilation = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> + // expected-error@+1 {{'tosa.conv2d_block_scaled' op expected input width of input_scale to match size 4, got 5}} + %0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = #tosa.block_size} : (tensor<1x4x4x64xf4E2M1FN>, tensor<1x4x5x2xf8E8M0FNU>, tensor<8x1x1x64xf4E2M1FN>, tensor<8x1x1x2xf8E8M0FNU>, tensor<1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x4x4x8xf32> + return %0 : tensor<1x4x4x8xf32> +} + +// ----- + +func.func @test_conv2d_block_scaled_kernel_height_mismatch(%arg0: tensor<1x4x4x64xf4E2M1FN>, %arg1: tensor<1x4x4x2xf8E8M0FNU>, %arg2: tensor<8x2x1x64xf4E2M1FN>, %arg3: tensor<8x1x1x2xf8E8M0FNU>, %arg4: tensor<1xf32>, %arg5: tensor<4xi32>, %arg6: tensor<2xi32>, %arg7: tensor<2xi32>) -> tensor<1x4x4x8xf32> { + %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4> + %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> + %dilation = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> + // expected-error@+1 {{'tosa.conv2d_block_scaled' op expected kernel height of weight_scale to match size 2, got 1}} + %0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = #tosa.block_size} : (tensor<1x4x4x64xf4E2M1FN>, tensor<1x4x4x2xf8E8M0FNU>, tensor<8x2x1x64xf4E2M1FN>, tensor<8x1x1x2xf8E8M0FNU>, tensor<1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x4x4x8xf32> + return %0 : tensor<1x4x4x8xf32> +} + +// ----- + +func.func @test_conv2d_block_scaled_output_shape_indivisible(%arg0: tensor<1x4x4x64xf4E2M1FN>, %arg1: tensor<1x4x4x2xf8E8M0FNU>, %arg2: tensor<8x1x1x64xf4E2M1FN>, %arg3: tensor<8x1x1x2xf8E8M0FNU>, %arg4: tensor<1xf32>, %arg5: tensor<4xi32>, %arg6: tensor<2xi32>, %arg7: tensor<2xi32>) -> tensor<1x4x5x8xf32> { + %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4> + %stride = tosa.const_shape {values = dense<[1, 2]> : tensor<2xindex>} : () -> !tosa.shape<2> + %dilation = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> + // expected-error@+1 {{'tosa.conv2d_block_scaled' op expected input_width - 1 + pad_left + pad_right - (kernel_width - 1) * dilation_x to be wholly divisible by stride_x, got (4 - 1 + 0 + 0 - (1 - 1) * 1) / 2}} + %0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = #tosa.block_size} : (tensor<1x4x4x64xf4E2M1FN>, tensor<1x4x4x2xf8E8M0FNU>, tensor<8x1x1x64xf4E2M1FN>, tensor<8x1x1x2xf8E8M0FNU>, tensor<1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x4x5x8xf32> + return %0 : tensor<1x4x5x8xf32> +} + +// ----- + +func.func @test_conv2d_block_scaled_output_shape_mismatch(%arg0: tensor<1x4x4x64xf4E2M1FN>, %arg1: tensor<1x4x4x2xf8E8M0FNU>, %arg2: tensor<8x1x1x64xf4E2M1FN>, %arg3: tensor<8x1x1x2xf8E8M0FNU>, %arg4: tensor<1xf32>, %arg5: tensor<4xi32>, %arg6: tensor<2xi32>, %arg7: tensor<2xi32>) -> tensor<1x4x5x8xf32> { + %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4> + %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> + %dilation = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> + // expected-error@+1 {{'tosa.conv2d_block_scaled' op calculated output width did not match expected: calculated=4, expected=5}} + %0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = #tosa.block_size} : (tensor<1x4x4x64xf4E2M1FN>, tensor<1x4x4x2xf8E8M0FNU>, tensor<8x1x1x64xf4E2M1FN>, tensor<8x1x1x2xf8E8M0FNU>, tensor<1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x4x5x8xf32> + return %0 : tensor<1x4x5x8xf32> +} + +// ----- + +func.func @test_conv2d_block_scaled_invalid_ic(%arg0: tensor<1x4x4x63xf4E2M1FN>, %arg1: tensor<1x4x4x2xf8E8M0FNU>, %arg2: tensor<8x1x1x63xf4E2M1FN>, %arg3: tensor<8x1x1x2xf8E8M0FNU>, %arg4: tensor<1xf32>, %arg5: tensor<4xi32>, %arg6: tensor<2xi32>, %arg7: tensor<2xi32>) -> tensor<1x4x5x8xf32> { + %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4> + %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> + %dilation = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> + // expected-error@+1 {{'tosa.conv2d_block_scaled' op expect IC to be a multiple of block size, got IC=63, block_size=32}} + %0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = #tosa.block_size} : (tensor<1x4x4x63xf4E2M1FN>, tensor<1x4x4x2xf8E8M0FNU>, tensor<8x1x1x63xf4E2M1FN>, tensor<8x1x1x2xf8E8M0FNU>, tensor<1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x4x5x8xf32> + return %0 : tensor<1x4x5x8xf32> +} + +// ----- + +func.func @test_conv2d_block_scaled_invalid_ic_mutiple(%arg0: tensor<1x4x4x64xf4E2M1FN>, %arg1: tensor<1x4x4x3xf8E8M0FNU>, %arg2: tensor<8x1x1x64xf4E2M1FN>, %arg3: tensor<8x1x1x3xf8E8M0FNU>, %arg4: tensor<1xf32>, %arg5: tensor<4xi32>, %arg6: tensor<2xi32>, %arg7: tensor<2xi32>) -> tensor<1x4x5x8xf32> { + %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4> + %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> + %dilation = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> + // expected-error@+1 {{'tosa.conv2d_block_scaled' op expect scale operands dimension 2 to equal IC/block_size (64/32), got 3}} + %0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = #tosa.block_size} : (tensor<1x4x4x64xf4E2M1FN>, tensor<1x4x4x3xf8E8M0FNU>, tensor<8x1x1x64xf4E2M1FN>, tensor<8x1x1x3xf8E8M0FNU>, tensor<1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x4x5x8xf32> + return %0 : tensor<1x4x5x8xf32> +} + +// ----- + +func.func @test_conv2d_block_scaled_invalid_bias_size(%arg0: tensor<1x4x4x64xf4E2M1FN>, %arg1: tensor<1x4x4x2xf8E8M0FNU>, %arg2: tensor<8x1x1x64xf4E2M1FN>, %arg3: tensor<8x1x1x2xf8E8M0FNU>, %arg4: tensor<6xf32>, %arg5: tensor<4xi32>, %arg6: tensor<2xi32>, %arg7: tensor<2xi32>) -> tensor<1x4x4x8xf32> { + %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4> + %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> + %dilation = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> + // expected-error@+1 {{'tosa.conv2d_block_scaled' op bias channels expected to be equal to output channels (8) or 1, got 6}} + %0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = #tosa.block_size} : (tensor<1x4x4x64xf4E2M1FN>, tensor<1x4x4x2xf8E8M0FNU>, tensor<8x1x1x64xf4E2M1FN>, tensor<8x1x1x2xf8E8M0FNU>, tensor<6xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x4x4x8xf32> + return %0 : tensor<1x4x4x8xf32> +}