diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td index f5536927dc251..d3f12c34421b0 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td @@ -126,11 +126,12 @@ def Tosa_ConvOpQuantInfoBuilder : OpBuilder< (ins "::mlir::Type":$outputType, "::mlir::Value":$input, "::mlir::Value":$weight, "::mlir::Value":$bias, "::mlir::DenseI64ArrayAttr":$pad, "::mlir::DenseI64ArrayAttr":$stride, - "::mlir::DenseI64ArrayAttr":$dilation), + "::mlir::DenseI64ArrayAttr":$dilation, + "::mlir::TypeAttr":$acc_type), [{ buildConvOpWithQuantInfo($_builder, $_state, outputType, input, weight, bias, - pad, stride, dilation); + pad, stride, dilation, acc_type); }]>; // Handles tosa.transpose_conv2d which has an outpad and output shape attribute. @@ -139,12 +140,13 @@ def Tosa_TransConvOpQuantInfoBuilder : OpBuilder< "::mlir::Value":$weight, "mlir::Value":$bias, "::mlir::DenseI64ArrayAttr":$outpad, "::mlir::DenseI64ArrayAttr":$stride, - "::mlir::DenseI64ArrayAttr":$outputShape), + "::mlir::DenseI64ArrayAttr":$outputShape, + "::mlir::TypeAttr":$acc_type), [{ buildTransConvOpWithQuantInfo($_builder, $_state, outputType, input, weight, bias, outpad, stride, - outputShape); + outputShape, acc_type); }]>; // The tosa.fully_connected op has its own builder as it does not have diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index e3c725801d162..baab3fef662e1 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -57,7 +57,7 @@ def Tosa_ArgMaxOp : Tosa_InferShapedTypeOp<"argmax"> { // Accumulator types. //===----------------------------------------------------------------------===// -def Tosa_AccType : AnyTypeOf<[I<32>, SI<32>, F16, F32]>; +def Tosa_AccType : AnyTypeOf<[I<32>, I<48>, F16, F32]>; //===----------------------------------------------------------------------===// // Operator: avg_pool2d @@ -106,6 +106,7 @@ def Tosa_Conv2DOp : Tosa_InferShapedTypeOp<"conv2d"> { Tosa_IntArrayAttr4:$pad, Tosa_IntArrayAttr2:$stride, Tosa_IntArrayAttr2:$dilation, + TypeAttrOf:$acc_type, OptionalAttr:$quantization_info, DefaultValuedOptionalAttr:$local_bound ); @@ -135,6 +136,7 @@ def Tosa_Conv3DOp : Tosa_InferShapedTypeOp<"conv3d"> { Tosa_IntArrayAttr6:$pad, Tosa_IntArrayAttr3:$stride, Tosa_IntArrayAttr3:$dilation, + TypeAttrOf:$acc_type, OptionalAttr:$quantization_info, DefaultValuedOptionalAttr:$local_bound ); @@ -165,6 +167,7 @@ def Tosa_DepthwiseConv2DOp : Tosa_InferShapedTypeOp<"depthwise_conv2d"> { Tosa_IntArrayAttr4:$pad, Tosa_IntArrayAttr2:$stride, Tosa_IntArrayAttr2:$dilation, + TypeAttrOf:$acc_type, OptionalAttr:$quantization_info, DefaultValuedOptionalAttr:$local_bound ); @@ -348,6 +351,7 @@ def Tosa_TransposeConv2DOp : Tosa_InferShapedTypeOp<"transpose_conv2d"> { Tosa_IntArrayAttr4:$out_pad, Tosa_IntArrayAttr2:$stride, Tosa_IntArrayAttr4:$out_shape, + TypeAttrOf:$acc_type, OptionalAttr:$quantization_info, DefaultValuedOptionalAttr:$local_bound ); @@ -357,6 +361,7 @@ def Tosa_TransposeConv2DOp : Tosa_InferShapedTypeOp<"transpose_conv2d"> { ); let builders = [Tosa_TransConvOpQuantInfoBuilder]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 631d3c48f2df0..a0cd2194b39e3 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -210,7 +210,12 @@ template static LogicalResult verifyConvOp(T op) { // All TOSA conv ops have an input() and weight(). auto inputType = llvm::dyn_cast(op.getInput().getType()); - auto weightType = llvm::dyn_cast(op.getWeight().getType()); + + RankedTensorType weightType; + if constexpr (std::is_same_v) + weightType = llvm::dyn_cast(op.getFilter().getType()); + else + weightType = llvm::dyn_cast(op.getWeight().getType()); // Must be ranked tensor types if (!inputType) { @@ -218,7 +223,13 @@ static LogicalResult verifyConvOp(T op) { return failure(); } if (!weightType) { - op.emitOpError("expect a ranked tensor for weight, got ") << op.getWeight(); + if constexpr (std::is_same_v) { + op.emitOpError("expect a ranked tensor for filter, got ") + << op.getFilter(); + } else { + op.emitOpError("expect a ranked tensor for weight, got ") + << op.getWeight(); + } return failure(); } @@ -271,6 +282,38 @@ LogicalResult tosa::ConstOp::verify() { return success(); } +template +static LogicalResult verifyConvOpModes(T op) { + auto inputEType = + llvm::cast(op.getInput().getType()).getElementType(); + + if (auto quantType = + llvm::dyn_cast(inputEType)) + inputEType = quantType.getStorageType(); + + auto accType = op.getAccType(); + if (inputEType.isInteger(8) && !accType.isInteger(32)) + return op.emitOpError("accumulator type for i8 tensor is not i32"); + + if (inputEType.isInteger(16) && !accType.isInteger(48)) + return op.emitOpError("accumulator type for i16 tensor is not i48"); + + if ((inputEType.isFloat8E5M2() || inputEType.isFloat8E4M3()) && + !accType.isF16()) + return op.emitOpError("accumulator type for f8 tensor is not f16"); + + if (inputEType.isF16() && !(accType.isF16() || accType.isF32())) + return op.emitOpError("accumulator type for f16 tensor is not f16/f32"); + + if (inputEType.isBF16() && !accType.isF32()) + return op.emitOpError("accumulator type for bf16 tensor is not f32"); + + if (inputEType.isF32() && !accType.isF32()) + return op.emitOpError("accumulator type for f32 tensor is not f32"); + + return success(); +} + LogicalResult tosa::ArgMaxOp::verify() { // Ensure output is of 32-bit integer const auto resultETy = llvm::cast(getType()).getElementType(); @@ -368,12 +411,14 @@ static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr pad, DenseI64ArrayAttr stride, - DenseI64ArrayAttr dilation) { + DenseI64ArrayAttr dilation, + TypeAttr accType) { result.addOperands({input, weight, bias}); result.addAttribute("pad", pad); result.addAttribute("stride", stride); result.addAttribute("dilation", dilation); + result.addAttribute("acc_type", accType); auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight); if (quantAttr) { @@ -390,11 +435,12 @@ static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, static void buildTransConvOpWithQuantInfo( OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr outpad, - DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape) { + DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape, TypeAttr accType) { result.addOperands({input, weight, bias}); result.addAttribute("out_pad", outpad); result.addAttribute("stride", stride); result.addAttribute("out_shape", outputShape); + result.addAttribute("acc_type", accType); auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight); if (quantAttr) { @@ -1595,7 +1641,11 @@ LogicalResult Conv2DOp::inferReturnTypeComponents( return success(); } -LogicalResult Conv2DOp::verify() { return verifyConvOp(*this); } +LogicalResult Conv2DOp::verify() { + if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed()) + return failure(); + return success(); +} LogicalResult Conv3DOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, @@ -1667,7 +1717,11 @@ LogicalResult Conv3DOp::inferReturnTypeComponents( return success(); } -LogicalResult Conv3DOp::verify() { return verifyConvOp(*this); } +LogicalResult Conv3DOp::verify() { + if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed()) + return failure(); + return success(); +} LogicalResult AvgPool2dOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, @@ -1762,7 +1816,11 @@ LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents( return success(); } -LogicalResult DepthwiseConv2DOp::verify() { return verifyConvOp(*this); } +LogicalResult DepthwiseConv2DOp::verify() { + if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed()) + return failure(); + return success(); +} LogicalResult TransposeConv2DOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, @@ -1828,6 +1886,12 @@ LogicalResult TransposeConv2DOp::inferReturnTypeComponents( return success(); } +LogicalResult TransposeConv2DOp::verify() { + if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed()) + return failure(); + return success(); +} + LogicalResult IfOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, IfOp::Adaptor adaptor, diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp index 0779cdb9667a1..275c2f80f7f49 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp @@ -75,13 +75,15 @@ class TransposeConvNonStridedConverter loc, resultTy, input, reverse2, bias, rewriter.getDenseI64ArrayAttr(convPad), rewriter.getDenseI64ArrayAttr(stride), - rewriter.getDenseI64ArrayAttr({1, 1}), *op.getQuantizationInfo()); + rewriter.getDenseI64ArrayAttr({1, 1}), + /* acc_type = */ op.getAccType(), *op.getQuantizationInfo()); } else { conv2d = rewriter.create( loc, resultTy, input, reverse2, bias, rewriter.getDenseI64ArrayAttr(convPad), rewriter.getDenseI64ArrayAttr(stride), - rewriter.getDenseI64ArrayAttr({1, 1})); + rewriter.getDenseI64ArrayAttr({1, 1}), + /* acc_type = */ op.getAccTypeAttr()); } rewriter.replaceOp(op, conv2d); @@ -238,7 +240,7 @@ class TransposeConvStridedConverter /*pad=*/rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}), /*stride=*/rewriter.getDenseI64ArrayAttr({1, 1}), /*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1}), - *op.getQuantizationInfo()) + /* acc_type = */ op.getAccType(), *op.getQuantizationInfo()) .getResult(); } else { conv2d = CreateOpAndInferShape( @@ -246,7 +248,8 @@ class TransposeConvStridedConverter weight, zeroBias, /*pad=*/rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}), /*stride=*/rewriter.getDenseI64ArrayAttr({1, 1}), - /*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1})) + /*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1}), + /* acc_type = */ op.getAccTypeAttr()) .getResult(); } diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir index bfdc72ee07e97..453a8610e7169 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir @@ -510,7 +510,7 @@ func.func @avg_pool_dyn(%arg0: tensor) -> (tensor) func.func @conv2d_scalar_bias_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<1xf32>) -> () { // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x45x40x28xf32> // CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%[[INIT]] : tensor<1x45x40x28xf32>) { - %0 = tosa.conv2d %input, %weights, %bias {pad = array, stride = array, dilation = array} : (tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>, tensor<1xf32>) -> tensor<1x45x40x28xf32> + %0 = tosa.conv2d %input, %weights, %bias {acc_type = f32, pad = array, stride = array, dilation = array} : (tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>, tensor<1xf32>) -> tensor<1x45x40x28xf32> return } @@ -531,7 +531,7 @@ func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi // CHECK: linalg.conv_2d_nhwc_fhwc_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1, %c0_i32, %c0_i32_0 : tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, i32, i32) outs(%[[BROADCAST]] : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32> // HWCF: linalg.conv_2d_nhwc_hwcf_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[TRANSPOSE]], %c0_i32, %c0_i32_0 : tensor<1x49x42x27xi8>, tensor<1x1x27x28xi8>, i32, i32) outs(%{{[a-zA-Z0-9_]*}} : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32> - %0 = tosa.conv2d %input, %weights, %bias {dilation = array, pad = array, quantization_info = #tosa.conv_quant, stride = array} : (tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, tensor<28xi8>) -> tensor<1x45x40x28xi32> + %0 = tosa.conv2d %input, %weights, %bias {acc_type = i32, dilation = array, pad = array, quantization_info = #tosa.conv_quant, stride = array} : (tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, tensor<28xi8>) -> tensor<1x45x40x28xi32> return } @@ -552,7 +552,7 @@ func.func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27 // CHECK: linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>) outs(%1 : tensor<1x45x40x28xf32>) -> tensor<1x45x40x28xf32> // HWCF: linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[TRANSPOSE]] : tensor<1x49x42x27xf32>, tensor<3x3x27x28xf32>) outs(%{{[a-zA-Z0-9_]*}} : tensor<1x45x40x28xf32> - %0 = tosa.conv2d %input, %weights, %bias {pad = array, stride = array, dilation = array} : (tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>, tensor<28xf32>) -> tensor<1x45x40x28xf32> + %0 = tosa.conv2d %input, %weights, %bias {acc_type = f32, pad = array, stride = array, dilation = array} : (tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>, tensor<28xf32>) -> tensor<1x45x40x28xf32> return } @@ -571,7 +571,7 @@ func.func @conv2d_dyn(%input: tensor, %weights: tensor<28x3x3x27 // CHECK: linalg.yield %[[IN]] : f32 // CHECK: } -> tensor // CHECK: %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor, tensor<28x3x3x27xf32>) outs(%[[BROADCAST]] : tensor) -> tensor - %0 = tosa.conv2d %input, %weights, %bias {pad = array, stride = array, dilation = array} : (tensor, tensor<28x3x3x27xf32>, tensor<28xf32>) -> tensor + %0 = tosa.conv2d %input, %weights, %bias {acc_type = f32, pad = array, stride = array, dilation = array} : (tensor, tensor<28x3x3x27xf32>, tensor<28xf32>) -> tensor return } @@ -627,7 +627,7 @@ func.func @conv2d_dyn_w_h(%input: tensor<1x?x?x27xf32>, %weights: tensor<28x3x3x // CHECK: } -> tensor<1x?x?x28xf32> // CHECK: linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x?x?x27xf32>, tensor<28x3x3x27xf32>) outs(%17 : tensor<1x?x?x28xf32>) -> tensor<1x?x?x28xf32> - %0 = tosa.conv2d %input, %weights, %bias {pad = array, stride = array, dilation = array} : (tensor<1x?x?x27xf32>, tensor<28x3x3x27xf32>, tensor<28xf32>) -> tensor<1x?x?x28xf32> + %0 = tosa.conv2d %input, %weights, %bias {acc_type = f32, pad = array, stride = array, dilation = array} : (tensor<1x?x?x27xf32>, tensor<28x3x3x27xf32>, tensor<28xf32>) -> tensor<1x?x?x28xf32> return } @@ -650,7 +650,7 @@ func.func @conv2d_dyn_output(%input: tensor<2x6x5x4xf32>, %weights: tensor<4x3x3 // linalg.yield %[[ADD]] : f32 // } -> tensor - %0 = tosa.conv2d %input, %weights, %bias {dilation = array, pad = array, stride = array} : (tensor<2x6x5x4xf32 >, tensor<4x3x3x4xf32>, tensor<4xf32>) -> tensor + %0 = tosa.conv2d %input, %weights, %bias {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<2x6x5x4xf32 >, tensor<4x3x3x4xf32>, tensor<4xf32>) -> tensor return } @@ -662,7 +662,7 @@ func.func @conv2d_padded_f32(%input: tensor<1x47x40x28xf32>, %weights: tensor<28 // CHECK: tensor.pad %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0] // CHECK: tensor.yield %[[C0]] // CHECK: linalg.conv_2d_nhwc_fhwc - %0 = tosa.conv2d %input, %weights, %bias {pad = array, stride = array, dilation = array} : (tensor<1x47x40x28xf32>, tensor<28x3x3x28xf32>, tensor<28xf32>) -> tensor<1x45x40x28xf32> + %0 = tosa.conv2d %input, %weights, %bias {acc_type = f32, pad = array, stride = array, dilation = array} : (tensor<1x47x40x28xf32>, tensor<28x3x3x28xf32>, tensor<28xf32>) -> tensor<1x45x40x28xf32> return } @@ -674,7 +674,7 @@ func.func @conv2d_quant(%arg0 : tensor<1x12x12x1xi8>, %arg1 : tensor<1024x3x3x1x // CHECK: tensor.pad %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0] // CHECK: tensor.yield %[[C22]] // CHECK: linalg.conv_2d_nhwc_fhwc_q - %0 = tosa.conv2d %arg0, %arg1, %arg2 {dilation = array, pad = array, quantization_info = #tosa.conv_quant, stride = array} : (tensor<1x12x12x1xi8>, tensor<1024x3x3x1xi8>, tensor<1024xi32>) -> tensor<1x12x12x1024xi32> + %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = i32, dilation = array, pad = array, quantization_info = #tosa.conv_quant, stride = array} : (tensor<1x12x12x1xi8>, tensor<1024x3x3x1xi8>, tensor<1024xi32>) -> tensor<1x12x12x1024xi32> return } @@ -696,7 +696,7 @@ func.func @depthwise_conv(%arg0 : tensor<1x7x5x3xf32>, %arg1 : tensor<3x1x3x11xf // CHECK: [[ADD:%.+]] = arith.addf %[[ARG3]], %[[ARG4]] : f32 // CHECK: linalg.yield [[ADD]] : f32 // CHECK: } -> tensor<1x5x5x33xf32> - %2 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 { pad = array, stride = array, dilation = array } : (tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>, tensor<33xf32>) -> tensor<1x5x5x33xf32> + %2 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 { acc_type = f32, pad = array, stride = array, dilation = array } : (tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>, tensor<33xf32>) -> tensor<1x5x5x33xf32> return } @@ -712,7 +712,7 @@ func.func @depthwise_conv_scalar_bias(%arg0 : tensor<1x7x5x3xf32>, %arg1 : tenso // CHECK: [[ADD:%.+]] = arith.addf %[[ARG3]], %[[ARG4]] : f32 // CHECK: linalg.yield [[ADD]] : f32 // CHECK: } -> tensor<1x5x5x33xf32> - %2 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 { pad = array, stride = array, dilation = array } : (tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>, tensor<1xf32>) -> tensor<1x5x5x33xf32> + %2 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = f32, pad = array, stride = array, dilation = array } : (tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>, tensor<1xf32>) -> tensor<1x5x5x33xf32> return } @@ -736,7 +736,7 @@ func.func @depthwise_conv_dyn(%arg0 : tensor, %arg1 : tensor<3x1x3x // CHECK: %[[ADD:.+]] = arith.addf %[[ARG3]], %[[ARG4]] : f32 // CHECK: linalg.yield %[[ADD]] : f32 // CHECK: } -> tensor - %2 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 { pad = array, stride = array, dilation = array } : (tensor, tensor<3x1x3x11xf32>, tensor<33xf32>) -> tensor + %2 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 { acc_type = f32, pad = array, stride = array, dilation = array } : (tensor, tensor<3x1x3x11xf32>, tensor<33xf32>) -> tensor return } @@ -758,7 +758,7 @@ func.func @depthwise_conv_strides(%arg0 : tensor<1x11x9x3xf32>, %arg1 : tensor<3 // CHECK: [[ADD:%.+]] = arith.addf %[[ARG3]], %[[ARG4]] : f32 // CHECK: linalg.yield [[ADD]] : f32 // CHECK: } -> tensor<1x5x5x33xf32> - %2 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 { pad = array, stride = array, dilation = array } : (tensor<1x11x9x3xf32>, tensor<3x1x3x11xf32>, tensor<33xf32>) -> tensor<1x5x5x33xf32> + %2 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 { acc_type = f32, pad = array, stride = array, dilation = array } : (tensor<1x11x9x3xf32>, tensor<3x1x3x11xf32>, tensor<33xf32>) -> tensor<1x5x5x33xf32> return } @@ -786,7 +786,7 @@ func.func @depthwise_conv_quant(%arg0 : tensor<1x12x12x4xi8>, %arg1 : tensor<3x3 // CHECK: [[ADD:%.+]] = arith.addi %[[ARG3]], %[[ARG4]] : i32 // CHECK: linalg.yield [[ADD]] : i32 // CHECK: } -> tensor<1x12x12x512xi32> - %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {pad = array, quantization_info = #tosa.conv_quant, stride = array, dilation = array } : (tensor<1x12x12x4xi8>, tensor<3x3x4x128xi8>, tensor<512xi32>) -> tensor<1x12x12x512xi32> + %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = i32, pad = array, quantization_info = #tosa.conv_quant, stride = array, dilation = array } : (tensor<1x12x12x4xi8>, tensor<3x3x4x128xi8>, tensor<512xi32>) -> tensor<1x12x12x512xi32> return } @@ -810,7 +810,7 @@ func.func @depthwise_conv_quant_dilations(%arg0 : tensor<1x14x14x4xi8>, %arg1 : // CHECK: [[ADD:%.+]] = arith.addi %[[ARG3]], %[[ARG4]] : i32 // CHECK: linalg.yield [[ADD]] : i32 // CHECK: } -> tensor<1x10x10x512xi32> - %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {pad = array, quantization_info = #tosa.conv_quant, stride = array, dilation = array } : (tensor<1x14x14x4xi8>, tensor<3x3x4x128xi8>, tensor<512xi32>) -> tensor<1x10x10x512xi32> + %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = i32, pad = array, quantization_info = #tosa.conv_quant, stride = array, dilation = array } : (tensor<1x14x14x4xi8>, tensor<3x3x4x128xi8>, tensor<512xi32>) -> tensor<1x10x10x512xi32> return } @@ -826,7 +826,7 @@ func.func @depthwise_conv2d_dyn_w_h(%arg0: tensor<2x?x?x3xf32>, %arg1: tensor<3x // CHECK: } : tensor<2x?x?x3xf32> to tensor<2x?x?x3xf32> // CHECK: %[[CONV:.+]] = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>} ins(%[[PADDED]], %arg1 : tensor<2x?x?x3xf32>, tensor<3x6x3x5xf32>) outs(%{{.*}} : tensor<2x?x?x3x5xf32>) -> tensor<2x?x?x3x5xf32> // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[CONV]] {{\[}}[0], [1], [2], [3, 4]] - %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {pad = array, dilation = array, stride = array} : (tensor<2x?x?x3xf32>, tensor<3x6x3x5xf32>, tensor<15xf32>) -> tensor<2x?x?x15xf32> + %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = f32, pad = array, dilation = array, stride = array} : (tensor<2x?x?x3xf32>, tensor<3x6x3x5xf32>, tensor<15xf32>) -> tensor<2x?x?x15xf32> return } @@ -850,7 +850,7 @@ func.func @conv3d_f32(%input: tensor<1x49x48x47x27xf32>, %weights: tensor<28x3x4 // CHECK-SAME: {dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} // CHECK-SAME: ins(%arg0, %[[TRANSPOSE]] : tensor<1x49x48x47x27xf32>, tensor<3x4x5x27x28xf32>) // CHECK-SAME: outs(%[[BROADCAST]] : tensor<1x47x45x43x28xf32>) -> tensor<1x47x45x43x28xf32> - %0 = tosa.conv3d %input, %weights, %bias {pad = array, stride = array, dilation = array} : (tensor<1x49x48x47x27xf32>, tensor<28x3x4x5x27xf32>, tensor<28xf32>) -> tensor<1x47x45x43x28xf32> + %0 = tosa.conv3d %input, %weights, %bias {acc_type = f32, pad = array, stride = array, dilation = array} : (tensor<1x49x48x47x27xf32>, tensor<28x3x4x5x27xf32>, tensor<28xf32>) -> tensor<1x47x45x43x28xf32> return } @@ -864,7 +864,7 @@ func.func @conv3d_scalar_bias_f32(%input: tensor<1x49x48x47x27xf32>, %weights: t // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x47x45x43x28xf32> // CHECK: %[[BROADCAST:.+]] = linalg.generic // CHECK-SAME: {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} - %0 = tosa.conv3d %input, %weights, %bias {pad = array, stride = array, dilation = array} : (tensor<1x49x48x47x27xf32>, tensor<28x3x4x5x27xf32>, tensor<1xf32>) -> tensor<1x47x45x43x28xf32> + %0 = tosa.conv3d %input, %weights, %bias {acc_type = f32, pad = array, stride = array, dilation = array} : (tensor<1x49x48x47x27xf32>, tensor<28x3x4x5x27xf32>, tensor<1xf32>) -> tensor<1x47x45x43x28xf32> return } @@ -892,7 +892,7 @@ func.func @conv3d_i8(%input: tensor<1x49x48x47x27xi8>, %weights: tensor<28x3x4x5 // CHECK-SAME: ins(%arg0, %[[TRANSPOSE]], %[[IZP]], %[[FZP]] : tensor<1x49x48x47x27xi8>, tensor<3x4x5x27x28xi8>, i32, i32) // CHECK-SAME: outs(%[[BROADCAST]] : tensor<1x47x45x43x28xi32>) -> tensor<1x47x45x43x28xi32> - %0 = tosa.conv3d %input, %weights, %bias {pad = array, quantization_info = #tosa.conv_quant, stride = array, dilation = array} : (tensor<1x49x48x47x27xi8>, tensor<28x3x4x5x27xi8>, tensor<28xi32>) -> tensor<1x47x45x43x28xi32> + %0 = tosa.conv3d %input, %weights, %bias {acc_type = i32, pad = array, quantization_info = #tosa.conv_quant, stride = array, dilation = array} : (tensor<1x49x48x47x27xi8>, tensor<28x3x4x5x27xi8>, tensor<28xi32>) -> tensor<1x47x45x43x28xi32> return } diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index 67cd01f62f0bd..8fdc03e5b1aa1 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -162,7 +162,7 @@ func.func @conv2d_stride_2(%arg0: tensor<4x10x10x2xf32>) -> tensor<4x10x10x3xf32 // CHECK: tosa.conv2d %weight = "tosa.const"() {value = dense<[[[[1.0, 1.0]]], [[[1.0, 1.0]]], [[[1.0, 1.0]]]]> : tensor<3x1x1x2xf32>} : ()-> tensor<3x1x1x2xf32> %bias = "tosa.const"() {value = dense<0.0> : tensor<3xf32>} : ()-> tensor<3xf32> - %0 = tosa.conv2d %arg0, %weight, %bias {pad = array, stride = array, dilation = array} : (tensor<4x10x10x2xf32>, tensor<3x1x1x2xf32>, tensor<3xf32>) -> tensor<4x10x10x3xf32> + %0 = tosa.conv2d %arg0, %weight, %bias {acc_type = f32, pad = array, stride = array, dilation = array} : (tensor<4x10x10x2xf32>, tensor<3x1x1x2xf32>, tensor<3xf32>) -> tensor<4x10x10x3xf32> return %0 : tensor<4x10x10x3xf32> } @@ -173,7 +173,7 @@ func.func @conv2d_weight_2x2(%arg0: tensor<4x10x10x1xf32>) -> tensor<4x10x10x1xf // CHECK: tosa.conv2d %weight = "tosa.const"() {value = dense<[[[[1.0], [1.0]], [[1.0], [1.0]]]]> : tensor<1x2x2x1xf32>} : ()-> tensor<1x2x2x1xf32> %bias = "tosa.const"() {value = dense<0.0> : tensor<1xf32>} : ()-> tensor<1xf32> - %0 = tosa.conv2d %arg0, %weight, %bias {pad = array, stride = array, dilation = array} : (tensor<4x10x10x1xf32>, tensor<1x2x2x1xf32>, tensor<1xf32>) -> tensor<4x10x10x1xf32> + %0 = tosa.conv2d %arg0, %weight, %bias {acc_type = f32, pad = array, stride = array, dilation = array} : (tensor<4x10x10x1xf32>, tensor<1x2x2x1xf32>, tensor<1xf32>) -> tensor<4x10x10x1xf32> return %0 : tensor<4x10x10x1xf32> } @@ -182,7 +182,7 @@ func.func @conv2d_weight_2x2(%arg0: tensor<4x10x10x1xf32>) -> tensor<4x10x10x1xf // CHECK-LABEL: @depthwise_conv2d_stride_2 func.func @depthwise_conv2d_stride_2(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x10x10x6xf32> { // CHECK: tosa.depthwise_conv2d - %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {pad = array, stride = array, dilation = array} : (tensor<4x10x10x2xf32>, tensor<1x1x2x3xf32>, tensor<6xf32>) -> tensor<4x10x10x6xf32> + %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = f32, pad = array, stride = array, dilation = array} : (tensor<4x10x10x2xf32>, tensor<1x1x2x3xf32>, tensor<6xf32>) -> tensor<4x10x10x6xf32> return %0 : tensor<4x10x10x6xf32> } @@ -191,7 +191,7 @@ func.func @depthwise_conv2d_stride_2(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor // CHECK-LABEL: @depthwise_conv2d_weight_2x2 func.func @depthwise_conv2d_weight_2x2(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<2x2x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x10x10x6xf32> { // CHECK: tosa.depthwise_conv2d - %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {pad = array, stride = array, dilation = array} : (tensor<4x10x10x2xf32>, tensor<2x2x2x3xf32>, tensor<6xf32>) -> tensor<4x10x10x6xf32> + %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = f32, pad = array, stride = array, dilation = array} : (tensor<4x10x10x2xf32>, tensor<2x2x2x3xf32>, tensor<6xf32>) -> tensor<4x10x10x6xf32> return %0 : tensor<4x10x10x6xf32> } diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index cca50b25d14d6..e6b33fd8eb481 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -25,7 +25,7 @@ func.func @test_const_non_tensor_attr() { func.func @test_conv2d(%arg0: tensor<1x29x29x4xf32>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> { // expected-error@+1 {{expect both input and weight to be float or not together, got 'f32' and 'i8'}} - %0 = tosa.conv2d %arg0, %arg1, %arg2 {dilation = array, pad = array, stride = array} + %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = i32, dilation = array, pad = array, stride = array} : (tensor<1x29x29x4xf32>, tensor<16x3x3x4xi8>, tensor<16xi8>) -> tensor<1x27x27x16xi8> return %0 : tensor<1x27x27x16xi8> } @@ -34,7 +34,7 @@ func.func @test_conv2d(%arg0: tensor<1x29x29x4xf32>, %arg1: tensor<16x3x3x4xi8>, func.func @test_conv2d(%arg0: tensor<*xi8>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> { // expected-error@+1 {{expect a ranked tensor for input, got of type 'tensor<*xi8>' at index: 0}} - %0 = tosa.conv2d %arg0, %arg1, %arg2 {dilation = array, pad = array, stride = array} + %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = i32, dilation = array, pad = array, stride = array} : (tensor<*xi8>, tensor<16x3x3x4xi8>, tensor<16xi8>) -> tensor<1x27x27x16xi8> return %0 : tensor<1x27x27x16xi8> } @@ -43,7 +43,7 @@ func.func @test_conv2d(%arg0: tensor<*xi8>, %arg1: tensor<16x3x3x4xi8>, %arg2: t func.func @test_conv2d(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<*xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> { // expected-error@+1 {{'tosa.conv2d' op operand #1 must be 4D tensor of 4-bit signless integer or 8-bit signless integer or Quint8 type or Qint4 type or Qint8 type or Qint16 type or Qint32 type or floating-point values, but got 'tensor<*xi8>'}} - %0 = tosa.conv2d %arg0, %arg1, %arg2 {dilation = array, pad = array, stride = array} + %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = i32, dilation = array, pad = array, stride = array} : (tensor<1x29x29x4xi8>, tensor<*xi8>, tensor<16xi8>) -> tensor<1x27x27x16xi8> return %0 : tensor<1x27x27x16xi8> } @@ -52,13 +52,101 @@ func.func @test_conv2d(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<*xi8>, %arg2: func.func @test_conv2d(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> { // expected-error@+1 {{'tosa.conv2d' op quantizationattr is required for quantized type, and not allowed for float type}} - %0 = tosa.conv2d %arg0, %arg1, %arg2 {dilation = array, pad = array, stride = array} + %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = f16, dilation = array, pad = array, stride = array} : (tensor<1x29x29x4xi8>, tensor<16x3x3x4xi8>, tensor<16xi8>) -> tensor<1x27x27x16xi8> return %0 : tensor<1x27x27x16xi8> } // ----- +func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> { + // expected-error@+1 {{'tosa.conv2d' op accumulator type for i8 tensor is not i32}} + %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = f16, dilation = array, pad = array, stride = array, quantization_info = #tosa.conv_quant} + : (tensor<1x29x29x4xi8>, tensor<16x3x3x4xi8>, tensor<16xi8>) -> tensor<1x27x27x16xi8> + return %0 : tensor<1x27x27x16xi8> +} + +// ----- + +func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xi16>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi16>) -> tensor<1x27x27x16xi16> { + // expected-error@+1 {{'tosa.conv2d' op accumulator type for i16 tensor is not i48}} + %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = f16, dilation = array, pad = array, stride = array, quantization_info = #tosa.conv_quant} + : (tensor<1x29x29x4xi16>, tensor<16x3x3x4xi8>, tensor<16xi16>) -> tensor<1x27x27x16xi16> + return %0 : tensor<1x27x27x16xi16> +} + +// ----- + +func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xf8E5M2>, %arg1: tensor<16x3x3x4xf8E5M2>, %arg2: tensor<16xf16>) -> tensor<1x27x27x16xf16> { + // expected-error@+1 {{'tosa.conv2d' op accumulator type for f8 tensor is not f16}} + %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = i32, dilation = array, pad = array, stride = array} + : (tensor<1x29x29x4xf8E5M2>, tensor<16x3x3x4xf8E5M2>, tensor<16xf16>) -> tensor<1x27x27x16xf16> + return %0 : tensor<1x27x27x16xf16> +} + +// ----- + +func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xf8E4M3>, %arg1: tensor<16x3x3x4xf8E4M3>, %arg2: tensor<16xf16>) -> tensor<1x27x27x16xf16> { + // expected-error@+1 {{'tosa.conv2d' op accumulator type for f8 tensor is not f16}} + %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = i32, dilation = array, pad = array, stride = array} + : (tensor<1x29x29x4xf8E4M3>, tensor<16x3x3x4xf8E4M3>, tensor<16xf16>) -> tensor<1x27x27x16xf16> + return %0 : tensor<1x27x27x16xf16> +} + +// ----- + +func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xf16>, %arg1: tensor<16x3x3x4xf16>, %arg2: tensor<16xf16>) -> tensor<1x27x27x16xf16> { + // expected-error@+1 {{'tosa.conv2d' op accumulator type for f16 tensor is not f16/f32}} + %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = i32, dilation = array, pad = array, stride = array} + : (tensor<1x29x29x4xf16>, tensor<16x3x3x4xf16>, tensor<16xf16>) -> tensor<1x27x27x16xf16> + return %0 : tensor<1x27x27x16xf16> +} + +// ----- + +func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xbf16>, %arg1: tensor<16x3x3x4xbf16>, %arg2: tensor<16xbf16>) -> tensor<1x27x27x16xbf16> { + // expected-error@+1 {{'tosa.conv2d' op accumulator type for bf16 tensor is not f32}} + %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = i32, dilation = array, pad = array, stride = array} + : (tensor<1x29x29x4xbf16>, tensor<16x3x3x4xbf16>, tensor<16xbf16>) -> tensor<1x27x27x16xbf16> + return %0 : tensor<1x27x27x16xbf16> +} + +// ----- + +func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<16xf32>) -> tensor<1x27x27x16xf32> { + // expected-error@+1 {{'tosa.conv2d' op accumulator type for f32 tensor is not f32}} + %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = i32, dilation = array, pad = array, stride = array} + : (tensor<1x29x29x4xf32>, tensor<16x3x3x4xf32>, tensor<16xf32>) -> tensor<1x27x27x16xf32> + return %0 : tensor<1x27x27x16xf32> +} + +// ----- + +func.func @test_conv3d_acc_type(%arg0: tensor<1x4x8x21x17xi8>, %arg1: tensor<34x1x1x1x17xi8>, %arg2: tensor<34xi8>) -> tensor<1x4x8x21x34xi8> { + // expected-error@+1 {{'tosa.conv3d' op accumulator type for i8 tensor is not i32}} + %0 = tosa.conv3d %arg0, %arg1, %arg2 {acc_type = f16, dilation = array, pad = array, stride = array, quantization_info = #tosa.conv_quant} + : (tensor<1x4x8x21x17xi8>, tensor<34x1x1x1x17xi8>, tensor<34xi8>) -> tensor<1x4x8x21x34xi8> + return %0 : tensor<1x4x8x21x34xi8> +} + +// ----- + +func.func @test_depthwise_conv2d_acc_type(%arg0: tensor<1x4x4x4xi8>, %arg1: tensor<1x1x4x2xi8>, %arg2: tensor<8xi8>) -> tensor<1x4x4x8xi8> { + // expected-error@+1 {{'tosa.depthwise_conv2d' op accumulator type for i8 tensor is not i32}} + %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = f16, dilation = array, pad = array, stride = array, quantization_info = #tosa.conv_quant} : (tensor<1x4x4x4xi8>, tensor<1x1x4x2xi8>, tensor<8xi8>) -> tensor<1x4x4x8xi8> + return %0 : tensor<1x4x4x8xi8> +} + +// ----- + +func.func @test_transpose_conv2d(%arg0: tensor<1x32x32x8xi8>, %arg1: tensor<16x1x1x8xi8>, %arg2: tensor<16xi8>) -> tensor<1x32x32x16xi8> { + // expected-error@+1 {{'tosa.transpose_conv2d' op accumulator type for i8 tensor is not i32}} + %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f16, out_pad = array, out_shape = array, stride = array, quantization_info = #tosa.conv_quant} : (tensor<1x32x32x8xi8>, tensor<16x1x1x8xi8>, tensor<16xi8>) -> tensor<1x32x32x16xi8> + return %0 : tensor<1x32x32x16xi8> +} + +// ----- + func.func @test_concat(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> tensor { // expected-error@+2 {{failed to infer returned types}} // expected-error@+1 {{Cannot concat tensors with different sizes on the non-axis dimension 1}} @@ -416,7 +504,7 @@ func.func @test_const_attribute_type_mismatch() -> tensor<100x100xf32> { func.func @test_conv2d_static_zero_dim_input(%arg0: tensor<1x29x0x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<16xf32>) -> tensor<1x27x27x16xf32> { // expected-error@+1 {{'tosa.conv2d' op operand #0 must be 4-d tosa-conformant tensor, but got 'tensor<1x29x0x4xf32>'}} - %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = array, pad = array, stride = array} + %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<1x29x0x4xf32>, tensor<16x3x3x4xf32>, tensor<16xf32>) -> tensor<1x27x27x16xf32> return %0 : tensor<1x27x27x16xf32> } diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir index 529a16ca48c7e..ba8ed8a1e5f50 100644 --- a/mlir/test/Dialect/Tosa/level_check.mlir +++ b/mlir/test/Dialect/Tosa/level_check.mlir @@ -226,7 +226,7 @@ func.func @test_avgpool2d_pad_right(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32 func.func @test_conv2d_dilation_y(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> { // expected-error@+1 {{'tosa.conv2d' op failed level check: dilation_y * KH <= MAX_KERNEL}} - %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = array, pad = array, stride = array} : + %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<1x32x32x8xf32>, tensor<16x2x2x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32> return %0 : tensor<1x32x32x16xf32> } @@ -235,7 +235,7 @@ func.func @test_conv2d_dilation_y(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16 func.func @test_conv2d_dilation_x(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> { // expected-error@+1 {{'tosa.conv2d' op failed level check: dilation_x * KW <= MAX_KERNEL}} - %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = array, pad = array, stride = array} : + %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<1x32x32x8xf32>, tensor<16x2x2x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32> return %0 : tensor<1x32x32x16xf32> } @@ -244,7 +244,7 @@ func.func @test_conv2d_dilation_x(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16 func.func @test_conv2d_pad_top(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> { // expected-error@+1 {{'tosa.conv2d' op failed level check: pad <= MAX_KERNEL}} - %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = array, pad = array, stride = array} : + %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<1x32x32x8xf32>, tensor<16x2x2x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32> return %0 : tensor<1x32x32x16xf32> } @@ -253,7 +253,7 @@ func.func @test_conv2d_pad_top(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2x func.func @test_conv2d_pad_bottom(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> { // expected-error@+1 {{'tosa.conv2d' op failed level check: pad <= MAX_KERNEL}} - %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = array, pad = array, stride = array} : + %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<1x32x32x8xf32>, tensor<16x2x2x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32> return %0 : tensor<1x32x32x16xf32> } @@ -262,7 +262,7 @@ func.func @test_conv2d_pad_bottom(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16 func.func @test_conv2d_pad_left(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> { // expected-error@+1 {{'tosa.conv2d' op failed level check: pad <= MAX_KERNEL}} - %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = array, pad = array, stride = array} : + %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<1x32x32x8xf32>, tensor<16x2x2x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32> return %0 : tensor<1x32x32x16xf32> } @@ -271,7 +271,7 @@ func.func @test_conv2d_pad_left(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2 func.func @test_conv2d_pad_right(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> { // expected-error@+1 {{'tosa.conv2d' op failed level check: pad <= MAX_KERNEL}} - %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = array, pad = array, stride = array} : + %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<1x32x32x8xf32>, tensor<16x2x2x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32> return %0 : tensor<1x32x32x16xf32> } @@ -280,7 +280,7 @@ func.func @test_conv2d_pad_right(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x func.func @test_conv2d_stride_y(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> { // expected-error@+1 {{'tosa.conv2d' op failed level check: stride <= MAX_STRIDE}} - %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = array, pad = array, stride = array} : + %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<1x32x32x8xf32>, tensor<16x2x2x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32> return %0 : tensor<1x32x32x16xf32> } @@ -289,7 +289,7 @@ func.func @test_conv2d_stride_y(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2 func.func @test_conv2d_stride_x(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> { // expected-error@+1 {{'tosa.conv2d' op failed level check: stride <= MAX_STRIDE}} - %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = array, pad = array, stride = array} : + %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<1x32x32x8xf32>, tensor<16x2x2x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32> return %0 : tensor<1x32x32x16xf32> } @@ -298,7 +298,7 @@ func.func @test_conv2d_stride_x(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2 func.func @test_conv3d_dilation_d(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x2x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x1x32x32x16xf32> { // expected-error@+1 {{'tosa.conv3d' op failed level check: dilation_d * KD <= MAX_KERNEL}} - %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {dilation = array, pad = array, stride = array} : + %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<1x1x32x32x8xf32>, tensor<16x2x2x2x8xf32>, tensor<16xf32>) -> tensor<1x1x32x32x16xf32> return %0 : tensor<1x1x32x32x16xf32> } @@ -307,7 +307,7 @@ func.func @test_conv3d_dilation_d(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor< func.func @test_conv3d_dilation_y(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x2x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x1x32x32x16xf32> { // expected-error@+1 {{'tosa.conv3d' op failed level check: dilation_y * KH <= MAX_KERNEL}} - %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {dilation = array, pad = array, stride = array} : + %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<1x1x32x32x8xf32>, tensor<16x2x2x2x8xf32>, tensor<16xf32>) -> tensor<1x1x32x32x16xf32> return %0 : tensor<1x1x32x32x16xf32> } @@ -316,7 +316,7 @@ func.func @test_conv3d_dilation_y(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor< func.func @test_conv3d_dilation_x(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x2x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x1x32x32x16xf32> { // expected-error@+1 {{'tosa.conv3d' op failed level check: dilation_x * KW <= MAX_KERNEL}} - %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {dilation = array, pad = array, stride = array} : + %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<1x1x32x32x8xf32>, tensor<16x2x2x2x8xf32>, tensor<16xf32>) -> tensor<1x1x32x32x16xf32> return %0 : tensor<1x1x32x32x16xf32> } @@ -325,7 +325,7 @@ func.func @test_conv3d_dilation_x(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor< func.func @test_conv3d_pad_d0(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x2x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x1x32x32x16xf32> { // expected-error@+1 {{'tosa.conv3d' op failed level check: pad <= MAX_KERNEL}} - %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {dilation = array, pad = array, stride = array} : + %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<1x1x32x32x8xf32>, tensor<16x2x2x2x8xf32>, tensor<16xf32>) -> tensor<1x1x32x32x16xf32> return %0 : tensor<1x1x32x32x16xf32> } @@ -334,7 +334,7 @@ func.func @test_conv3d_pad_d0(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x2 func.func @test_conv3d_pad_d1(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x2x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x1x32x32x16xf32> { // expected-error@+1 {{'tosa.conv3d' op failed level check: pad <= MAX_KERNEL}} - %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {dilation = array, pad = array, stride = array} : + %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<1x1x32x32x8xf32>, tensor<16x2x2x2x8xf32>, tensor<16xf32>) -> tensor<1x1x32x32x16xf32> return %0 : tensor<1x1x32x32x16xf32> } @@ -343,7 +343,7 @@ func.func @test_conv3d_pad_d1(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x2 func.func @test_conv3d_pad_top(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x2x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x1x32x32x16xf32> { // expected-error@+1 {{'tosa.conv3d' op failed level check: pad <= MAX_KERNEL}} - %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {dilation = array, pad = array, stride = array} : + %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<1x1x32x32x8xf32>, tensor<16x2x2x2x8xf32>, tensor<16xf32>) -> tensor<1x1x32x32x16xf32> return %0 : tensor<1x1x32x32x16xf32> } @@ -352,7 +352,7 @@ func.func @test_conv3d_pad_top(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x func.func @test_conv3d_pad_bottom(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x2x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x1x32x32x16xf32> { // expected-error@+1 {{'tosa.conv3d' op failed level check: pad <= MAX_KERNEL}} - %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {dilation = array, pad = array, stride = array} : + %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<1x1x32x32x8xf32>, tensor<16x2x2x2x8xf32>, tensor<16xf32>) -> tensor<1x1x32x32x16xf32> return %0 : tensor<1x1x32x32x16xf32> } @@ -361,7 +361,7 @@ func.func @test_conv3d_pad_bottom(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor< func.func @test_conv3d_pad_left(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x2x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x1x32x32x16xf32> { // expected-error@+1 {{'tosa.conv3d' op failed level check: pad <= MAX_KERNEL}} - %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {dilation = array, pad = array, stride = array} : + %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<1x1x32x32x8xf32>, tensor<16x2x2x2x8xf32>, tensor<16xf32>) -> tensor<1x1x32x32x16xf32> return %0 : tensor<1x1x32x32x16xf32> } @@ -370,7 +370,7 @@ func.func @test_conv3d_pad_left(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16 func.func @test_conv3d_pad_right(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x2x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x1x32x32x16xf32> { // expected-error@+1 {{'tosa.conv3d' op failed level check: pad <= MAX_KERNEL}} - %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {dilation = array, pad = array, stride = array} : + %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<1x1x32x32x8xf32>, tensor<16x2x2x2x8xf32>, tensor<16xf32>) -> tensor<1x1x32x32x16xf32> return %0 : tensor<1x1x32x32x16xf32> } @@ -379,7 +379,7 @@ func.func @test_conv3d_pad_right(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<1 func.func @test_conv3d_stride_d(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x2x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x1x32x32x16xf32> { // expected-error@+1 {{'tosa.conv3d' op failed level check: stride <= MAX_STRIDE}} - %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {dilation = array, pad = array, stride = array} : + %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<1x1x32x32x8xf32>, tensor<16x2x2x2x8xf32>, tensor<16xf32>) -> tensor<1x1x32x32x16xf32> return %0 : tensor<1x1x32x32x16xf32> } @@ -388,7 +388,7 @@ func.func @test_conv3d_stride_d(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16 func.func @test_conv3d_stride_y(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x2x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x1x32x32x16xf32> { // expected-error@+1 {{'tosa.conv3d' op failed level check: stride <= MAX_STRIDE}} - %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {dilation = array, pad = array, stride = array} : + %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<1x1x32x32x8xf32>, tensor<16x2x2x2x8xf32>, tensor<16xf32>) -> tensor<1x1x32x32x16xf32> return %0 : tensor<1x1x32x32x16xf32> } @@ -397,7 +397,7 @@ func.func @test_conv3d_stride_y(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16 func.func @test_conv3d_stride_x(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x2x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x1x32x32x16xf32> { // expected-error@+1 {{'tosa.conv3d' op failed level check: stride <= MAX_STRIDE}} - %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {dilation = array, pad = array, stride = array} : + %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<1x1x32x32x8xf32>, tensor<16x2x2x2x8xf32>, tensor<16xf32>) -> tensor<1x1x32x32x16xf32> return %0 : tensor<1x1x32x32x16xf32> } @@ -406,7 +406,7 @@ func.func @test_conv3d_stride_x(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16 func.func @test_depthwise_conv2d_dilation_y(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<2x2x8x8xf32>, %arg2: tensor<64xf32>) -> tensor<1x32x32x64xf32> { // expected-error@+1 {{'tosa.depthwise_conv2d' op failed level check: dilation_y * KH <= MAX_KERNEL}} - %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {dilation = array, pad = array, stride = array} : + %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<1x32x32x8xf32>, tensor<2x2x8x8xf32>, tensor<64xf32>) -> tensor<1x32x32x64xf32> return %0 : tensor<1x32x32x64xf32> } @@ -415,7 +415,7 @@ func.func @test_depthwise_conv2d_dilation_y(%arg0: tensor<1x32x32x8xf32>, %arg1: func.func @test_depthwise_conv2d_dilation_x(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<2x2x8x8xf32>, %arg2: tensor<64xf32>) -> tensor<1x32x32x64xf32> { // expected-error@+1 {{'tosa.depthwise_conv2d' op failed level check: dilation_x * KW <= MAX_KERNEL}} - %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {dilation = array, pad = array, stride = array} : + %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<1x32x32x8xf32>, tensor<2x2x8x8xf32>, tensor<64xf32>) -> tensor<1x32x32x64xf32> return %0 : tensor<1x32x32x64xf32> } @@ -424,7 +424,7 @@ func.func @test_depthwise_conv2d_dilation_x(%arg0: tensor<1x32x32x8xf32>, %arg1: func.func @test_depthwise_conv2d_pad_top(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<2x2x8x8xf32>, %arg2: tensor<64xf32>) -> tensor<1x32x32x64xf32> { // expected-error@+1 {{'tosa.depthwise_conv2d' op failed level check: pad <= MAX_KERNEL}} - %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {dilation = array, pad = array, stride = array} : + %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<1x32x32x8xf32>, tensor<2x2x8x8xf32>, tensor<64xf32>) -> tensor<1x32x32x64xf32> return %0 : tensor<1x32x32x64xf32> } @@ -433,7 +433,7 @@ func.func @test_depthwise_conv2d_pad_top(%arg0: tensor<1x32x32x8xf32>, %arg1: te func.func @test_depthwise_conv2d_pad_bottom(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<2x2x8x8xf32>, %arg2: tensor<64xf32>) -> tensor<1x32x32x64xf32> { // expected-error@+1 {{'tosa.depthwise_conv2d' op failed level check: pad <= MAX_KERNEL}} - %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {dilation = array, pad = array, stride = array} : + %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<1x32x32x8xf32>, tensor<2x2x8x8xf32>, tensor<64xf32>) -> tensor<1x32x32x64xf32> return %0 : tensor<1x32x32x64xf32> } @@ -442,7 +442,7 @@ func.func @test_depthwise_conv2d_pad_bottom(%arg0: tensor<1x32x32x8xf32>, %arg1: func.func @test_depthwise_conv2d_pad_left(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<2x2x8x8xf32>, %arg2: tensor<64xf32>) -> tensor<1x32x32x64xf32> { // expected-error@+1 {{'tosa.depthwise_conv2d' op failed level check: pad <= MAX_KERNEL}} - %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {dilation = array, pad = array, stride = array} : + %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<1x32x32x8xf32>, tensor<2x2x8x8xf32>, tensor<64xf32>) -> tensor<1x32x32x64xf32> return %0 : tensor<1x32x32x64xf32> } @@ -451,7 +451,7 @@ func.func @test_depthwise_conv2d_pad_left(%arg0: tensor<1x32x32x8xf32>, %arg1: t func.func @test_depthwise_conv2d_pad_right(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<2x2x8x8xf32>, %arg2: tensor<64xf32>) -> tensor<1x32x32x64xf32> { // expected-error@+1 {{'tosa.depthwise_conv2d' op failed level check: pad <= MAX_KERNEL}} - %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {dilation = array, pad = array, stride = array} : + %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<1x32x32x8xf32>, tensor<2x2x8x8xf32>, tensor<64xf32>) -> tensor<1x32x32x64xf32> return %0 : tensor<1x32x32x64xf32> } @@ -460,7 +460,7 @@ func.func @test_depthwise_conv2d_pad_right(%arg0: tensor<1x32x32x8xf32>, %arg1: func.func @test_depthwise_conv2d_stride_y(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<2x2x8x8xf32>, %arg2: tensor<64xf32>) -> tensor<1x32x32x64xf32> { // expected-error@+1 {{'tosa.depthwise_conv2d' op failed level check: stride <= MAX_STRIDE}} - %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {dilation = array, pad = array, stride = array} : + %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<1x32x32x8xf32>, tensor<2x2x8x8xf32>, tensor<64xf32>) -> tensor<1x32x32x64xf32> return %0 : tensor<1x32x32x64xf32> } @@ -469,7 +469,7 @@ func.func @test_depthwise_conv2d_stride_y(%arg0: tensor<1x32x32x8xf32>, %arg1: t func.func @test_depthwise_conv2d_stride_x(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<2x2x8x8xf32>, %arg2: tensor<64xf32>) -> tensor<1x32x32x64xf32> { // expected-error@+1 {{'tosa.depthwise_conv2d' op failed level check: stride <= MAX_STRIDE}} - %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {dilation = array, pad = array, stride = array} : + %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<1x32x32x8xf32>, tensor<2x2x8x8xf32>, tensor<64xf32>) -> tensor<1x32x32x64xf32> return %0 : tensor<1x32x32x64xf32> } @@ -603,7 +603,7 @@ func.func @test_rfft2d_input_w(%arg0: tensor<13x8x8193xf32>) -> (tensor<13x8x9xf func.func @test_transpose_conv2d_weight_h(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x8193x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> { // expected-error@+1 {{'tosa.transpose_conv2d' op failed level check: KH <= MAX_KERNEL}} - %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {out_pad = array, out_shape = array, stride = array} : + %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, out_pad = array, out_shape = array, stride = array} : (tensor<1x32x32x8xf32>, tensor<16x8193x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32> return %0 : tensor<1x32x32x16xf32> } @@ -612,7 +612,7 @@ func.func @test_transpose_conv2d_weight_h(%arg0: tensor<1x32x32x8xf32>, %arg1: t func.func @test_transpose_conv2d_weight_w(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x8193x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> { // expected-error@+1 {{'tosa.transpose_conv2d' op failed level check: KW <= MAX_KERNEL}} - %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {out_pad = array, out_shape = array, stride = array} : + %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, out_pad = array, out_shape = array, stride = array} : (tensor<1x32x32x8xf32>, tensor<16x1x8193x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32> return %0 : tensor<1x32x32x16xf32> } @@ -621,7 +621,7 @@ func.func @test_transpose_conv2d_weight_w(%arg0: tensor<1x32x32x8xf32>, %arg1: t func.func @test_transpose_conv2d_pad_top(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> { // expected-error@+1 {{'tosa.transpose_conv2d' op failed level check: pad <= MAX_KERNEL}} - %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {out_pad = array, out_shape = array, stride = array} : + %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, out_pad = array, out_shape = array, stride = array} : (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32> return %0 : tensor<1x32x32x16xf32> } @@ -630,7 +630,7 @@ func.func @test_transpose_conv2d_pad_top(%arg0: tensor<1x32x32x8xf32>, %arg1: te func.func @test_transpose_conv2d_pad_bottom(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> { // expected-error@+1 {{'tosa.transpose_conv2d' op failed level check: pad <= MAX_KERNEL}} - %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {out_pad = array, out_shape = array, stride = array} : + %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, out_pad = array, out_shape = array, stride = array} : (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32> return %0 : tensor<1x32x32x16xf32> } @@ -639,7 +639,7 @@ func.func @test_transpose_conv2d_pad_bottom(%arg0: tensor<1x32x32x8xf32>, %arg1: func.func @test_transpose_conv2d_pad_left(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> { // expected-error@+1 {{'tosa.transpose_conv2d' op failed level check: pad <= MAX_KERNEL}} - %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {out_pad = array, out_shape = array, stride = array} : + %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, out_pad = array, out_shape = array, stride = array} : (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32> return %0 : tensor<1x32x32x16xf32> } @@ -648,7 +648,7 @@ func.func @test_transpose_conv2d_pad_left(%arg0: tensor<1x32x32x8xf32>, %arg1: t func.func @test_transpose_conv2d_pad_right(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> { // expected-error@+1 {{'tosa.transpose_conv2d' op failed level check: pad <= MAX_KERNEL}} - %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {out_pad = array, out_shape = array, stride = array} : + %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, out_pad = array, out_shape = array, stride = array} : (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32> return %0 : tensor<1x32x32x16xf32> } @@ -657,7 +657,7 @@ func.func @test_transpose_conv2d_pad_right(%arg0: tensor<1x32x32x8xf32>, %arg1: func.func @test_transpose_conv2d_stride_y(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> { // expected-error@+1 {{'tosa.transpose_conv2d' op failed level check: stride <= MAX_STRIDE}} - %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {out_pad = array, out_shape = array, stride = array} : + %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, out_pad = array, out_shape = array, stride = array} : (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32> return %0 : tensor<1x32x32x16xf32> } @@ -666,7 +666,7 @@ func.func @test_transpose_conv2d_stride_y(%arg0: tensor<1x32x32x8xf32>, %arg1: t func.func @test_transpose_conv2d_stride_x(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> { // expected-error@+1 {{'tosa.transpose_conv2d' op failed level check: stride <= MAX_STRIDE}} - %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {out_pad = array, out_shape = array, stride = array} : + %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {acc_type = f32, out_pad = array, out_shape = array, stride = array} : (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32> return %0 : tensor<1x32x32x16xf32> } diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir index 88fa94ae90db6..2a066eb9f2770 100644 --- a/mlir/test/Dialect/Tosa/ops.mlir +++ b/mlir/test/Dialect/Tosa/ops.mlir @@ -54,7 +54,7 @@ func.func @test_avg_pool2d_q8(%arg0: tensor<1x7x7x9x!quant.uniform // ----- // CHECK-LABEL: conv2d func.func @test_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<8xf32>) -> tensor<1x4x4x8xf32> { - %0 = tosa.conv2d %arg0, %arg1, %arg2 {dilation = array, pad = array, stride = array, local_bound = true} : (tensor<1x4x4x4xf32>, tensor<8x1x1x4xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32> + %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array, pad = array, stride = array, local_bound = true} : (tensor<1x4x4x4xf32>, tensor<8x1x1x4xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32> return %0 : tensor<1x4x4x8xf32> } @@ -63,7 +63,7 @@ func.func @test_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, % func.func @test_conv2d_q8xi4(%arg0: tensor<1x11x11x3xi8>) -> tensor<1x1x1x3xi8> { %0 = "tosa.const"() {value = dense<0> : tensor<3x11x11x3xi4>} : () -> tensor<3x11x11x3xi4> %1 = "tosa.const"() {value = dense<[12, 23, 55]> : tensor<3xi32>} : () -> tensor<3xi32> - %2 = "tosa.conv2d"(%arg0, %0, %1) {dilation = array, pad = array, quantization_info = #tosa.conv_quant, stride = array} : (tensor<1x11x11x3xi8>, tensor<3x11x11x3xi4>, tensor<3xi32>) -> tensor<1x1x1x3xi32> + %2 = "tosa.conv2d"(%arg0, %0, %1) {acc_type = i32, dilation = array, pad = array, quantization_info = #tosa.conv_quant, stride = array} : (tensor<1x11x11x3xi8>, tensor<3x11x11x3xi4>, tensor<3xi32>) -> tensor<1x1x1x3xi32> %3 = "tosa.rescale"(%2) {double_round = true, input_zp = 0 : i32, multiplier = array, output_zp = 27 : i32, per_channel = true, scale32 = true, shift = array} : (tensor<1x1x1x3xi32>) -> tensor<1x1x1x3xi8> return %3 : tensor<1x1x1x3xi8> } @@ -71,28 +71,28 @@ func.func @test_conv2d_q8xi4(%arg0: tensor<1x11x11x3xi8>) -> tensor<1x1x1x3xi8> // ----- // CHECK-LABEL: conv3d func.func @test_conv3d(%arg0: tensor<1x4x8x21x17xf32>, %arg1: tensor<34x1x1x1x17xf32>, %arg2: tensor<34xf32>) -> tensor<1x4x8x21x34xf32> { - %0 = tosa.conv3d %arg0, %arg1, %arg2 {dilation = array, pad = array, stride = array} : (tensor<1x4x8x21x17xf32>, tensor<34x1x1x1x17xf32>, tensor<34xf32>) -> tensor<1x4x8x21x34xf32> + %0 = tosa.conv3d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<1x4x8x21x17xf32>, tensor<34x1x1x1x17xf32>, tensor<34xf32>) -> tensor<1x4x8x21x34xf32> return %0 : tensor<1x4x8x21x34xf32> } // ----- // CHECK-LABEL: conv3d_with_local_bound func.func @test_conv3d_with_local_bound(%arg0: tensor<1x4x8x21x17xf32>, %arg1: tensor<34x1x1x1x17xf32>, %arg2: tensor<34xf32>) -> tensor<1x4x8x21x34xf32> { - %0 = tosa.conv3d %arg0, %arg1, %arg2 {dilation = array, pad = array, stride = array, local_bound = true} : (tensor<1x4x8x21x17xf32>, tensor<34x1x1x1x17xf32>, tensor<34xf32>) -> tensor<1x4x8x21x34xf32> + %0 = tosa.conv3d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array, pad = array, stride = array, local_bound = true} : (tensor<1x4x8x21x17xf32>, tensor<34x1x1x1x17xf32>, tensor<34xf32>) -> tensor<1x4x8x21x34xf32> return %0 : tensor<1x4x8x21x34xf32> } // ----- // CHECK-LABEL: depthwise_conv2d func.func @test_depthwise_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<1x1x4x2xf32>, %arg2: tensor<8xf32>) -> tensor<1x4x4x8xf32> { - %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {dilation = array, pad = array, stride = array} : (tensor<1x4x4x4xf32>, tensor<1x1x4x2xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32> + %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<1x4x4x4xf32>, tensor<1x1x4x2xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32> return %0 : tensor<1x4x4x8xf32> } // ----- // CHECK-LABEL: depthwise_conv2d_with_local_bound func.func @test_depthwise_conv2d_with_local_bound(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<1x1x4x2xf32>, %arg2: tensor<8xf32>) -> tensor<1x4x4x8xf32> { - %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {dilation = array, pad = array, stride = array, local_bound = true} : (tensor<1x4x4x4xf32>, tensor<1x1x4x2xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32> + %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array, pad = array, stride = array, local_bound = true} : (tensor<1x4x4x4xf32>, tensor<1x1x4x2xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32> return %0 : tensor<1x4x4x8xf32> } @@ -162,14 +162,14 @@ func.func @test_rfft2d_with_local_bound(%arg0: tensor<13x8x16xf32>) -> (tensor<1 // ----- // CHECK-LABEL: transpose_conv2d func.func @test_transpose_conv2d(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> { - %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array, out_shape = array, stride = array} : (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32> + %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f32, out_pad = array, out_shape = array, stride = array} : (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32> return %0 : tensor<1x32x32x16xf32> } // ----- // CHECK-LABEL: transpose_conv2d_with_local_bound func.func @test_transpose_conv2d_with_local_bound(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> { - %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array, out_shape = array, stride = array, local_bound = false} : (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32> + %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f32, out_pad = array, out_shape = array, stride = array, local_bound = false} : (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32> return %0 : tensor<1x32x32x16xf32> } diff --git a/mlir/test/Dialect/Tosa/quant-test.mlir b/mlir/test/Dialect/Tosa/quant-test.mlir index 82a87dbc2494c..6437f12e3ff85 100644 --- a/mlir/test/Dialect/Tosa/quant-test.mlir +++ b/mlir/test/Dialect/Tosa/quant-test.mlir @@ -10,9 +10,9 @@ func.func @test_build_qtype(%arg0 : tensor<16x1x1x8x!quant.uniform:f32 // ----- // CHECK-LABEL: test_build_mult_and_shift -func.func @test_build_mult_and_shift(%arg0: tensor<1x32x32x8x!quant.uniform>, %arg1 : tensor<16x1x1x8x!quant.uniform:f32, 0.015680249780416489>>, %arg2 : tensor<16xi32>) -> tensor<1x32x32x16x!quant.uniform> { +func.func @test_build_mult_and_shift(%arg0: tensor<1x32x32x8x!quant.uniform>, %arg1 : tensor<16x1x1x8x!quant.uniform:f32, 0.015680249780416489>>, %arg2 : tensor<16xi32>) -> tensor<1x32x32x16x!quant.uniform> { // CHECK: tosa.conv2d - %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {pad = array, dilation = array, stride = array, quantization_info = #tosa.conv_quant} : (tensor<1x32x32x8x!quant.uniform>, tensor<16x1x1x8x!quant.uniform:f32, 0.015680249780416489>>, tensor<16xi32>) -> tensor<1x32x32x16x!quant.uniform> - return %0 : tensor<1x32x32x16x!quant.uniform> + %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {acc_type = i32, pad = array, dilation = array, stride = array, quantization_info = #tosa.conv_quant} : (tensor<1x32x32x8x!quant.uniform>, tensor<16x1x1x8x!quant.uniform:f32, 0.015680249780416489>>, tensor<16xi32>) -> tensor<1x32x32x16x!quant.uniform> + return %0 : tensor<1x32x32x16x!quant.uniform> } diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir index d876ccfb3b911..9e53b013cb833 100644 --- a/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir +++ b/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir @@ -14,7 +14,7 @@ func.func @conv2d_as_fully_connected(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor // CHECK: %[[VAR3:.*]] = tosa.reshape %[[VAR2]] {new_shape = array} // CHECK-SAME: -> tensor<4x10x10x3xf32> // CHECK: return %[[VAR3]] - %0 = tosa.conv2d %arg0, %arg1, %arg2 {pad = array, stride = array, dilation = array} : (tensor<4x10x10x2xf32>, tensor<3x1x1x2xf32>, tensor<3xf32>) -> tensor<4x10x10x3xf32> + %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = f32, pad = array, stride = array, dilation = array} : (tensor<4x10x10x2xf32>, tensor<3x1x1x2xf32>, tensor<3xf32>) -> tensor<4x10x10x3xf32> return %0 : tensor<4x10x10x3xf32> } @@ -33,7 +33,7 @@ func.func @conv2d_as_fully_connected_quant(%arg0: tensor<4x10x10x2xi8>, %arg1: t // CHECK: %[[VAR3:.*]] = tosa.reshape %[[VAR2]] {new_shape = array} // CHECK-SAME: -> tensor<4x10x10x3xi32> // CHECK: return %[[VAR3]] - %0 = tosa.conv2d %arg0, %arg1, %arg2 {pad = array, stride = array, dilation = array, quantization_info = #tosa.conv_quant} : (tensor<4x10x10x2xi8>, tensor<3x1x1x2xi8>, tensor<3xi32>) -> tensor<4x10x10x3xi32> + %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = i32, pad = array, stride = array, dilation = array, quantization_info = #tosa.conv_quant} : (tensor<4x10x10x2xi8>, tensor<3x1x1x2xi8>, tensor<3xi32>) -> tensor<4x10x10x3xi32> return %0 : tensor<4x10x10x3xi32> } @@ -50,7 +50,7 @@ func.func @conv_with_dynamic_dim(%arg0: tensor, %arg1: tensor<384 // CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array} : (tensor) -> tensor // CHECK: return %[[VAL_6]] : tensor // CHECK: } - %0 = tosa.conv2d %arg0, %arg1, %arg2 {dilation = array, pad = array, quantization_info = #tosa.conv_quant, stride = array} : (tensor, tensor<384x1x1x64xi8>, tensor<384xi32>) -> tensor + %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = i32, dilation = array, pad = array, quantization_info = #tosa.conv_quant, stride = array} : (tensor, tensor<384x1x1x64xi8>, tensor<384xi32>) -> tensor return %0 : tensor } @@ -65,6 +65,6 @@ func.func @conv2d_as_fully_connected_padded(%arg0: tensor<4x10x10x2xi8>, %arg1: // CHECK-DAG: %[[RESHAPE_FILTER:.+]] = tosa.reshape %arg1 {new_shape = array} // CHECK-DAG: %[[FULLY:.+]] = tosa.fully_connected %[[RESHAPE_INPUT]], %[[RESHAPE_FILTER]], %arg2 {quantization_info = #tosa.conv_quant} // CHECK: %[[RESHAPE:.+]] = tosa.reshape %[[FULLY]] {new_shape = array} - %0 = tosa.conv2d %arg0, %arg1, %arg2 {pad = array, stride = array, dilation = array, quantization_info = #tosa.conv_quant} : (tensor<4x10x10x2xi8>, tensor<3x1x1x2xi8>, tensor<3xi32>) -> tensor<4x12x12x3xi32> + %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = i32, pad = array, stride = array, dilation = array, quantization_info = #tosa.conv_quant} : (tensor<4x10x10x2xi8>, tensor<3x1x1x2xi8>, tensor<3xi32>) -> tensor<4x12x12x3xi32> return %0 : tensor<4x12x12x3xi32> } diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir index 2224bf3f57b25..f35a3f8761bb5 100644 --- a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir +++ b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir @@ -18,7 +18,7 @@ func.func @depthwise_conv2d_as_mul(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1 // CHECK: %[[VAR5:.*]] = tosa.add %[[VAR3]], %[[VAR4]] // CHECK-SAME: -> tensor<4x10x10x6xf32> // CHECK: return %[[VAR5]] - %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {pad = array, stride = array, dilation = array} : (tensor<4x10x10x2xf32>, tensor<1x1x2x3xf32>, tensor<6xf32>) -> tensor<4x10x10x6xf32> + %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = f32, pad = array, stride = array, dilation = array} : (tensor<4x10x10x2xf32>, tensor<1x1x2x3xf32>, tensor<6xf32>) -> tensor<4x10x10x6xf32> return %0 : tensor<4x10x10x6xf32> } @@ -38,7 +38,7 @@ func.func @depthwise_conv2d_as_mul_q(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor< // CHECK: %[[reO:.+]] = tosa.reshape %[[mul]] {new_shape = array} // CHECK: %[[reArg2:.+]] = tosa.reshape %arg2 {new_shape = array} // CHECK: %[[add:.+]] = tosa.add %[[reO]], %[[reArg2]] - %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {pad = array, stride = array, dilation = array, quantization_info = #tosa.conv_quant} : (tensor<4x10x10x2xi8>, tensor<1x1x2x3xi8>, tensor<6xi32>) -> tensor<4x10x10x6xi32> + %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = i32, pad = array, stride = array, dilation = array, quantization_info = #tosa.conv_quant} : (tensor<4x10x10x2xi8>, tensor<1x1x2x3xi8>, tensor<6xi32>) -> tensor<4x10x10x6xi32> return %0 : tensor<4x10x10x6xi32> } @@ -55,6 +55,6 @@ func.func @depthwise_conv2d_as_mul_padded(%arg0: tensor<4x10x10x2xf32>, %arg1: t // CHECK: %[[reOut:.+]] = tosa.reshape %[[mul]] {new_shape = array} // CHECK: %[[reArg2:.+]] = tosa.reshape %arg2 {new_shape = array} // CHECK: %[[add:.+]] = tosa.add %[[reOut]], %[[reArg2]] - %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {pad = array, stride = array, dilation = array} : (tensor<4x10x10x2xf32>, tensor<1x1x2x3xf32>, tensor<6xf32>) -> tensor<4x12x12x6xf32> + %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = f32, pad = array, stride = array, dilation = array} : (tensor<4x10x10x2xf32>, tensor<1x1x2x3xf32>, tensor<6xf32>) -> tensor<4x12x12x6xf32> return %0 : tensor<4x12x12x6xf32> } diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir index 1f2bb3fb9a365..41c146d707d65 100644 --- a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir +++ b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir @@ -6,7 +6,7 @@ func.func @transpose_conv2d(%arg0: tensor<2x16x14x3xf32>, %arg1: tensor<5x3x6x3x // CHECK: %[[REV2:.+]] = tosa.reverse %[[REV1]] {axis = 2 : i32} // CHECK: tosa.conv2d %arg0, %[[REV2]], %arg2 // CHECK-SAME: dilation = array, pad = array, stride = array - %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array, out_shape = array, stride = array} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x18x19x5xf32> + %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f32, out_pad = array, out_shape = array, stride = array} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x18x19x5xf32> return %0 : tensor<2x18x19x5xf32> } @@ -17,8 +17,8 @@ func.func @transpose_conv2d(%arg0: tensor<2x16x14x3xf32>, %arg1: tensor<5x3x6x3x func.func @transpose_conv2d_quantized(%arg0: tensor<2x16x14x3xi8>, %arg1: tensor<5x3x6x3xi8>, %arg2: tensor<5xi32>) -> (tensor<2x18x19x5xi32>) { // CHECK: %[[REV1:.+]] = tosa.reverse %arg1 {axis = 1 : i32} // CHECK: %[[REV2:.+]] = tosa.reverse %[[REV1]] {axis = 2 : i32} - // CHECK: tosa.conv2d %arg0, %[[REV2]], %arg2 {dilation = array, pad = array, quantization_info = #tosa.conv_quant, stride = array} - %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array, quantization_info = #tosa.conv_quant, out_shape = array, stride = array} : (tensor<2x16x14x3xi8>, tensor<5x3x6x3xi8>, tensor<5xi32>) -> tensor<2x18x19x5xi32> + // CHECK: tosa.conv2d %arg0, %[[REV2]], %arg2 {acc_type = i32, dilation = array, pad = array, quantization_info = #tosa.conv_quant, stride = array} + %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = i32, out_pad = array, quantization_info = #tosa.conv_quant, out_shape = array, stride = array} : (tensor<2x16x14x3xi8>, tensor<5x3x6x3xi8>, tensor<5xi32>) -> tensor<2x18x19x5xi32> return %0 : tensor<2x18x19x5xi32> } @@ -32,6 +32,7 @@ func.func @transpose_conv2d_quantized_padded(%arg0: tensor<2x16x14x3xi8>, %arg1: // CHECK-SAME: dilation = array, pad = array, // CHECK-SAME: quantization_info = #tosa.conv_quant, stride = array} %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 { + acc_type = i32, out_pad = array, quantization_info = #tosa.conv_quant, out_shape = array, @@ -60,14 +61,14 @@ func.func @transpose_conv2d_strided(%arg0: tensor<2x17x15x3xf32>, %arg1: tensor< // Manipulate the final shape. // CHECK-DAG: %[[BIAS:.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<30xf32>} - // CHECK-DAG: %[[CONV:.+]] = tosa.conv2d %[[NEWINPUT]], %[[NEWWEIGHT]], %[[BIAS]] {dilation = array, pad = array, stride = array} + // CHECK-DAG: %[[CONV:.+]] = tosa.conv2d %[[NEWINPUT]], %[[NEWWEIGHT]], %[[BIAS]] {acc_type = f32, dilation = array, pad = array, stride = array} // CHECK-DAG: %[[RESHAPE_OUT_1:.+]] = tosa.reshape %[[CONV]] {new_shape = array} // CHECK-DAG: %[[TRANS_OUT:.+]] = tosa.transpose %[[RESHAPE_OUT_1]], %[[TRANS2]] // CHECK-DAG: %[[RESHAPE_OUT_2:.+]] = tosa.reshape %[[TRANS_OUT]] {new_shape = array} // CHECK-DAG: %[[SLICE:.+]] = tosa.slice %[[RESHAPE_OUT_2]] {size = array, start = array} // CHECK-DAG: %[[RESHAPE_ARG2:.+]] = tosa.reshape %arg2 {new_shape = array} // CHECK: %[[ADD:.+]] = tosa.add %[[SLICE]], %[[RESHAPE_ARG2]] - %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array, out_shape = array, stride = array} : (tensor<2x17x15x3xf32>, tensor<5x3x5x3xf32>, tensor<5xf32>) -> tensor<2x35x47x5xf32> + %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f32, out_pad = array, out_shape = array, stride = array} : (tensor<2x17x15x3xf32>, tensor<5x3x5x3xf32>, tensor<5xf32>) -> tensor<2x35x47x5xf32> %1 = tensor.cast %0 : tensor<2x35x47x5xf32> to tensor<2x?x?x5xf32> return %1 : tensor<2x?x?x5xf32> } @@ -93,14 +94,14 @@ func.func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1 // Manipulate the final shape. // CHECK-DAG: %[[BIAS:.+]] = "tosa.const"() <{value = dense<0> : tensor<30xi32>} - // CHECK-DAG: %[[CONV:.+]] = tosa.conv2d %[[NEWINPUT]], %[[NEWWEIGHT]], %[[BIAS]] {dilation = array, pad = array, quantization_info = #tosa.conv_quant, stride = array} + // CHECK-DAG: %[[CONV:.+]] = tosa.conv2d %[[NEWINPUT]], %[[NEWWEIGHT]], %[[BIAS]] {acc_type = i32, dilation = array, pad = array, quantization_info = #tosa.conv_quant, stride = array} // CHECK-DAG: %[[RESHAPE_OUT_1:.+]] = tosa.reshape %[[CONV]] {new_shape = array} // CHECK-DAG: %[[TRANS_OUT:.+]] = tosa.transpose %[[RESHAPE_OUT_1]], %[[TRANS2]] // CHECK-DAG: %[[RESHAPE_OUT_2:.+]] = tosa.reshape %[[TRANS_OUT]] {new_shape = array} // CHECK-DAG: %[[SLICE:.+]] = tosa.slice %[[RESHAPE_OUT_2]] {size = array, start = array} // CHECK-DAG: %[[RESHAPE_ARG2:.+]] = tosa.reshape %arg2 {new_shape = array} // CHECK: %[[ADD:.+]] = tosa.add %[[SLICE]], %[[RESHAPE_ARG2]] - %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array, quantization_info = #tosa.conv_quant, out_shape = array, stride = array} : (tensor<2x17x15x3xi8>, tensor<5x3x5x3xi8>, tensor<5xi32>) -> tensor<2x35x47x5xi32> + %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = i32, out_pad = array, quantization_info = #tosa.conv_quant, out_shape = array, stride = array} : (tensor<2x17x15x3xi8>, tensor<5x3x5x3xi8>, tensor<5xi32>) -> tensor<2x35x47x5xi32> return %0 : tensor<2x35x47x5xi32> } @@ -129,6 +130,7 @@ func.func @transpose_conv2d_strided_overpad(%arg0 : tensor<1x16x1x1xi8>, %arg1 : // CHECK: %[[RESHAPE_ARG2:.+]] = tosa.reshape %arg2 {new_shape = array} // CHECK: %[[ADD:.+]] = tosa.add %[[PAD_RESULT]], %[[RESHAPE_ARG2]] %2 = tosa.transpose_conv2d %arg0, %arg1, %arg2 { + acc_type = i32, out_pad = array, out_shape = array, stride = array, diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir index d46de740800e9..88e9d5a4654de 100644 --- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir +++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir @@ -674,7 +674,7 @@ func.func @test_pool_static(%arg0: tensor<3x5x6x7xf32>) { // CHECK-LABEL: @conv2d_static func.func @conv2d_static(%input: tensor<2x8x9x3xf32>, %weights: tensor<5x3x6x3xf32>, %bias: tensor<5xf32>) -> () { // CHECK: -> tensor<2x6x4x5xf32> - %0 = tosa.conv2d %input, %weights, %bias {pad = array, stride = array, dilation = array} : (tensor<2x8x9x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor + %0 = tosa.conv2d %input, %weights, %bias {acc_type = f32, pad = array, stride = array, dilation = array} : (tensor<2x8x9x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor return } @@ -683,7 +683,7 @@ func.func @conv2d_static(%input: tensor<2x8x9x3xf32>, %weights: tensor<5x3x6x3xf // CHECK-LABEL: @conv2d_dynamic_input func.func @conv2d_dynamic_input(%input: tensor, %weights: tensor<5x3x6x3xf32>, %bias: tensor<5xf32>) -> () { // CHECK: -> tensor - %0 = tosa.conv2d %input, %weights, %bias {pad = array, stride = array, dilation = array} : (tensor, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor + %0 = tosa.conv2d %input, %weights, %bias {acc_type = f32, pad = array, stride = array, dilation = array} : (tensor, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor return } @@ -716,7 +716,7 @@ func.func @test_pool_padded(%arg0: tensor<3x5x6x7xf32>) { // CHECK-LABEL: @conv2d_dynamic_weight func.func @conv2d_dynamic_weight(%input: tensor<2x8x9x3xf32>, %weights: tensor, %bias: tensor<5xf32>) -> () { // CHECK: -> tensor<2x?x?x5xf32> - %0 = tosa.conv2d %input, %weights, %bias {pad = array, stride = array, dilation = array} : (tensor<2x8x9x3xf32>, tensor, tensor<5xf32>) -> tensor + %0 = tosa.conv2d %input, %weights, %bias {acc_type = f32, pad = array, stride = array, dilation = array} : (tensor<2x8x9x3xf32>, tensor, tensor<5xf32>) -> tensor return } @@ -725,7 +725,7 @@ func.func @conv2d_dynamic_weight(%input: tensor<2x8x9x3xf32>, %weights: tensor, %weights: tensor<5x3x6x3xf32>, %bias: tensor) -> () { // CHECK: -> tensor<2x6x4x5xf32> - %0 = tosa.conv2d %input, %weights, %bias {pad = array, stride = array, dilation = array} : (tensor<2x8x9x3xf32>, tensor<5x3x6x3xf32>, tensor) -> tensor + %0 = tosa.conv2d %input, %weights, %bias {acc_type = f32, pad = array, stride = array, dilation = array} : (tensor<2x8x9x3xf32>, tensor<5x3x6x3xf32>, tensor) -> tensor return } @@ -746,7 +746,7 @@ func.func @test_pool_stride(%arg0: tensor<3x11x12x7xf32>) { // CHECK-LABEL: @conv2d_padded func.func @conv2d_padded(%input: tensor<2x8x9x3xf32>, %weights: tensor<5x3x6x3xf32>, %bias: tensor<5xf32>) -> () { // CHECK: -> tensor<2x9x11x5xf32> - %0 = tosa.conv2d %input, %weights, %bias {pad = array, stride = array, dilation = array} : (tensor<2x8x9x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor + %0 = tosa.conv2d %input, %weights, %bias {acc_type = f32, pad = array, stride = array, dilation = array} : (tensor<2x8x9x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor return } @@ -755,7 +755,7 @@ func.func @conv2d_padded(%input: tensor<2x8x9x3xf32>, %weights: tensor<5x3x6x3xf // CHECK-LABEL: @conv2d_dilated func.func @conv2d_dilated(%input: tensor<2x12x14x3xf32>, %weights: tensor<5x3x6x3xf32>, %bias: tensor<5xf32>) -> () { // CHECK: -> tensor<2x6x4x5xf32> - %0 = tosa.conv2d %input, %weights, %bias {pad = array, stride = array, dilation = array} : (tensor<2x12x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor + %0 = tosa.conv2d %input, %weights, %bias {acc_type = f32, pad = array, stride = array, dilation = array} : (tensor<2x12x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor return } @@ -764,7 +764,7 @@ func.func @conv2d_dilated(%input: tensor<2x12x14x3xf32>, %weights: tensor<5x3x6x // CHECK-LABEL: @conv2d_strided func.func @conv2d_strided(%input: tensor<1x13x14x1xf32>, %weights: tensor<1x1x1x1xf32>, %bias: tensor<1xf32>) -> () { // CHECK: -> tensor<1x5x7x1xf32> - %0 = tosa.conv2d %input, %weights, %bias {pad = array, stride = array, dilation = array} : (tensor<1x13x14x1xf32>, tensor<1x1x1x1xf32>, tensor<1xf32>) -> tensor + %0 = tosa.conv2d %input, %weights, %bias {acc_type = f32, pad = array, stride = array, dilation = array} : (tensor<1x13x14x1xf32>, tensor<1x1x1x1xf32>, tensor<1xf32>) -> tensor return } @@ -773,7 +773,7 @@ func.func @conv2d_strided(%input: tensor<1x13x14x1xf32>, %weights: tensor<1x1x1x // CHECK-LABEL: @conv3d_static func.func @conv3d_static(%input: tensor<2x8x9x10x3xf32>, %weights: tensor<5x3x6x4x3xf32>, %bias: tensor<5xf32>) -> () { // CHECK: -> tensor<2x6x4x7x5xf32> - %0 = tosa.conv3d %input, %weights, %bias {dilation = array, pad = array, stride = array} : (tensor<2x8x9x10x3xf32>, tensor<5x3x6x4x3xf32>, tensor<5xf32>) -> tensor + %0 = tosa.conv3d %input, %weights, %bias {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<2x8x9x10x3xf32>, tensor<5x3x6x4x3xf32>, tensor<5xf32>) -> tensor return } @@ -782,7 +782,7 @@ func.func @conv3d_static(%input: tensor<2x8x9x10x3xf32>, %weights: tensor<5x3x6x // CHECK-LABEL: @conv3d_dynamic_input func.func @conv3d_dynamic_input(%arg0: tensor, %arg1: tensor<5x3x6x4x3xf32>, %arg2: tensor<5xf32>) { // CHECK: -> tensor - %0 = tosa.conv3d %arg0, %arg1, %arg2 {dilation = array, pad = array, stride = array} : (tensor, tensor<5x3x6x4x3xf32>, tensor<5xf32>) -> tensor + %0 = tosa.conv3d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor, tensor<5x3x6x4x3xf32>, tensor<5xf32>) -> tensor return } @@ -791,7 +791,7 @@ func.func @conv3d_dynamic_input(%arg0: tensor, %arg1: tensor<5x3x // CHECK-LABEL: @conv3d_dynamic_weight func.func @conv3d_dynamic_weight(%arg0: tensor<2x8x9x10x3xf32>, %arg1: tensor, %arg2: tensor<5xf32>) { // CHECK: -> tensor<2x?x?x?x5xf32> - %0 = tosa.conv3d %arg0, %arg1, %arg2 {dilation = array, pad = array, stride = array} : (tensor<2x8x9x10x3xf32>, tensor, tensor<5xf32>) -> tensor + %0 = tosa.conv3d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<2x8x9x10x3xf32>, tensor, tensor<5xf32>) -> tensor return } @@ -800,7 +800,7 @@ func.func @conv3d_dynamic_weight(%arg0: tensor<2x8x9x10x3xf32>, %arg1: tensor, %arg1: tensor<5x3x6x4x3xf32>, %arg2: tensor) { // CHECK: -> tensor<2x6x4x7x5xf32> - %0 = tosa.conv3d %arg0, %arg1, %arg2 {dilation = array, pad = array, stride = array} : (tensor<2x8x9x10x3xf32>, tensor<5x3x6x4x3xf32>, tensor) -> tensor + %0 = tosa.conv3d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<2x8x9x10x3xf32>, tensor<5x3x6x4x3xf32>, tensor) -> tensor return } @@ -809,7 +809,7 @@ func.func @conv3d_dynamic_bias(%arg0: tensor<2x8x9x10x3xf32>, %arg1: tensor<5x3x // CHECK-LABEL: @conv3d_padded func.func @conv3d_padded(%arg0: tensor<2x8x9x10x3xf32>, %arg1: tensor<5x3x6x4x3xf32>, %arg2: tensor<5xf32>) { // CHECK: -> tensor<2x9x11x18x5xf32> - %0 = tosa.conv3d %arg0, %arg1, %arg2 {dilation = array, pad = array, stride = array} : (tensor<2x8x9x10x3xf32>, tensor<5x3x6x4x3xf32>, tensor<5xf32>) -> tensor + %0 = tosa.conv3d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<2x8x9x10x3xf32>, tensor<5x3x6x4x3xf32>, tensor<5xf32>) -> tensor return } @@ -818,7 +818,7 @@ func.func @conv3d_padded(%arg0: tensor<2x8x9x10x3xf32>, %arg1: tensor<5x3x6x4x3x // CHECK-LABEL: @conv3d_dilated func.func @conv3d_dilated(%arg0: tensor<2x12x14x16x3xf32>, %arg1: tensor<5x3x6x2x3xf32>, %arg2: tensor<5xf32>) { // CHECK: -> tensor<2x6x4x12x5xf32> - %0 = tosa.conv3d %arg0, %arg1, %arg2 {dilation = array, pad = array, stride = array} : (tensor<2x12x14x16x3xf32>, tensor<5x3x6x2x3xf32>, tensor<5xf32>) -> tensor + %0 = tosa.conv3d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<2x12x14x16x3xf32>, tensor<5x3x6x2x3xf32>, tensor<5xf32>) -> tensor return } @@ -827,7 +827,7 @@ func.func @conv3d_dilated(%arg0: tensor<2x12x14x16x3xf32>, %arg1: tensor<5x3x6x2 // CHECK-LABEL: @conv3d_strided func.func @conv3d_strided(%arg0: tensor<1x13x14x15x1xf32>, %arg1: tensor<1x1x1x1x1xf32>, %arg2: tensor<1xf32>) { // CHECK: -> tensor<1x5x7x4x1xf32> - %0 = tosa.conv3d %arg0, %arg1, %arg2 {dilation = array, pad = array, stride = array} : (tensor<1x13x14x15x1xf32>, tensor<1x1x1x1x1xf32>, tensor<1xf32>) -> tensor + %0 = tosa.conv3d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<1x13x14x15x1xf32>, tensor<1x1x1x1x1xf32>, tensor<1xf32>) -> tensor return } @@ -836,7 +836,7 @@ func.func @conv3d_strided(%arg0: tensor<1x13x14x15x1xf32>, %arg1: tensor<1x1x1x1 // CHECK-LABEL: @depthwise_conv2d_static func.func @depthwise_conv2d_static(%arg0: tensor<2x8x9x3xf32>, %arg1: tensor<3x6x3x5xf32>, %arg2: tensor<15xf32>) { // CHECK: -> tensor<2x6x4x15xf32> - %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {dilation = array, pad = array, stride = array} : (tensor<2x8x9x3xf32>, tensor<3x6x3x5xf32>, tensor<15xf32>) -> tensor<2x6x4x15xf32> + %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<2x8x9x3xf32>, tensor<3x6x3x5xf32>, tensor<15xf32>) -> tensor<2x6x4x15xf32> return } @@ -845,7 +845,7 @@ func.func @depthwise_conv2d_static(%arg0: tensor<2x8x9x3xf32>, %arg1: tensor<3x6 // CHECK-LABEL: @depthwise_conv2d_dynamic_input func.func @depthwise_conv2d_dynamic_input(%arg0: tensor, %arg1: tensor<3x6x3x5xf32>, %arg2: tensor<15xf32>) { // CHECK: -> tensor - %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {dilation = array, pad = array, stride = array} : (tensor, tensor<3x6x3x5xf32>, tensor<15xf32>) -> tensor + %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor, tensor<3x6x3x5xf32>, tensor<15xf32>) -> tensor return } @@ -854,7 +854,7 @@ func.func @depthwise_conv2d_dynamic_input(%arg0: tensor, %arg1: ten // CHECK-LABEL: @depthwise_conv2d_dynamic_weight func.func @depthwise_conv2d_dynamic_weight(%arg0: tensor<2x8x9x3xf32>, %arg1: tensor, %arg2: tensor<15xf32>) { // CHECK: -> tensor<2x?x?x15xf32> - %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {dilation = array, pad = array, stride = array} : (tensor<2x8x9x3xf32>, tensor, tensor<15xf32>) -> tensor<2x?x?x15xf32> + %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<2x8x9x3xf32>, tensor, tensor<15xf32>) -> tensor<2x?x?x15xf32> return } @@ -863,7 +863,7 @@ func.func @depthwise_conv2d_dynamic_weight(%arg0: tensor<2x8x9x3xf32>, %arg1: te // CHECK-LABEL: @depthwise_conv2d_dynamic_bias func.func @depthwise_conv2d_dynamic_bias(%arg0: tensor<2x8x9x3xf32>, %arg1: tensor<3x6x3x5xf32>, %arg2: tensor) { // CHECK: -> tensor<2x6x4x15xf32> - %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {dilation = array, pad = array, stride = array} : (tensor<2x8x9x3xf32>, tensor<3x6x3x5xf32>, tensor) -> tensor<2x6x4x15xf32> + %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<2x8x9x3xf32>, tensor<3x6x3x5xf32>, tensor) -> tensor<2x6x4x15xf32> return } @@ -872,7 +872,7 @@ func.func @depthwise_conv2d_dynamic_bias(%arg0: tensor<2x8x9x3xf32>, %arg1: tens // CHECK-LABEL: @depthwise_conv2d_padded func.func @depthwise_conv2d_padded(%arg0: tensor<2x8x9x3xf32>, %arg1: tensor<3x6x3x5xf32>, %arg2: tensor<15xf32>) { // CHECK: -> tensor<2x9x11x15xf32> - %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {dilation = array, pad = array, stride = array} : (tensor<2x8x9x3xf32>, tensor<3x6x3x5xf32>, tensor<15xf32>) -> tensor<2x9x11x15xf32> + %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<2x8x9x3xf32>, tensor<3x6x3x5xf32>, tensor<15xf32>) -> tensor<2x9x11x15xf32> return } @@ -881,7 +881,7 @@ func.func @depthwise_conv2d_padded(%arg0: tensor<2x8x9x3xf32>, %arg1: tensor<3x6 // CHECK-LABEL: @depthwise_conv2d_dilated func.func @depthwise_conv2d_dilated(%arg0: tensor<2x12x14x3xf32>, %arg1: tensor<3x6x3x5xf32>, %arg2: tensor<15xf32>) { // CHECK: -> tensor<2x6x4x15xf32> - %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {dilation = array, pad = array, stride = array} : (tensor<2x12x14x3xf32>, tensor<3x6x3x5xf32>, tensor<15xf32>) -> tensor<2x6x4x15xf32> + %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<2x12x14x3xf32>, tensor<3x6x3x5xf32>, tensor<15xf32>) -> tensor<2x6x4x15xf32> return } @@ -890,7 +890,7 @@ func.func @depthwise_conv2d_dilated(%arg0: tensor<2x12x14x3xf32>, %arg1: tensor< // CHECK-LABEL: @depthwise_conv2d_strided func.func @depthwise_conv2d_strided(%arg0: tensor<1x13x14x1xf32>, %arg1: tensor<1x1x1x1xf32>, %arg2: tensor<1xf32>) { // CHECK: -> tensor<1x5x7x1xf32> - %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {dilation = array, pad = array, stride = array} : (tensor<1x13x14x1xf32>, tensor<1x1x1x1xf32>, tensor<1xf32>) -> tensor<1x5x7x1xf32> + %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<1x13x14x1xf32>, tensor<1x1x1x1xf32>, tensor<1xf32>) -> tensor<1x5x7x1xf32> return } @@ -899,7 +899,7 @@ func.func @depthwise_conv2d_strided(%arg0: tensor<1x13x14x1xf32>, %arg1: tensor< // CHECK-LABEL: @transpose_conv2d_out_shape func.func @transpose_conv2d_out_shape(%arg0: tensor<2x?x?x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>) { // CHECK: -> tensor<2x8x9x5xf32> - %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array, out_shape = array, stride = array} : (tensor<2x?x?x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x8x9x5xf32> + %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f32, out_pad = array, out_shape = array, stride = array} : (tensor<2x?x?x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x8x9x5xf32> return } @@ -908,7 +908,7 @@ func.func @transpose_conv2d_out_shape(%arg0: tensor<2x?x?x3xf32>, %arg1: tensor< // CHECK-LABEL: @transpose_conv2d_static func.func @transpose_conv2d_static(%arg0: tensor<2x16x14x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>) { // CHECK: -> tensor<2x18x19x5xf32> - %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array, out_shape = array, stride = array} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x?x?x5xf32> + %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f32, out_pad = array, out_shape = array, stride = array} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x?x?x5xf32> return } @@ -917,7 +917,7 @@ func.func @transpose_conv2d_static(%arg0: tensor<2x16x14x3xf32>, %arg1: tensor<5 // CHECK-LABEL: @transpose_conv2d_static_strided func.func @transpose_conv2d_static_strided(%arg0: tensor<2x16x14x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>) { // CHECK: -> tensor<2x33x45x5xf32> - %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array, out_shape = array, stride = array} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x?x?x5xf32> + %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f32, out_pad = array, out_shape = array, stride = array} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x?x?x5xf32> return } @@ -926,7 +926,7 @@ func.func @transpose_conv2d_static_strided(%arg0: tensor<2x16x14x3xf32>, %arg1: // CHECK-LABEL: @transpose_conv2d_dynamic_input func.func @transpose_conv2d_dynamic_input(%arg0: tensor, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>) { // CHECK: -> tensor - %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array, out_shape = array, stride = array} : (tensor, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor + %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f32, out_pad = array, out_shape = array, stride = array} : (tensor, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor return } @@ -935,7 +935,7 @@ func.func @transpose_conv2d_dynamic_input(%arg0: tensor, %arg1: ten // CHECK-LABEL: @transpose_conv2d_dynamic_weights func.func @transpose_conv2d_dynamic_weights(%arg0: tensor<2x6x4x3xf32>, %arg1: tensor, %arg2: tensor<5xf32>) { // CHECK: -> tensor<2x?x?x5xf32> - %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array, out_shape = array, stride = array} : (tensor<2x6x4x3xf32>, tensor, tensor<5xf32>) -> tensor<2x?x?x5xf32> + %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f32, out_pad = array, out_shape = array, stride = array} : (tensor<2x6x4x3xf32>, tensor, tensor<5xf32>) -> tensor<2x?x?x5xf32> return } @@ -944,7 +944,7 @@ func.func @transpose_conv2d_dynamic_weights(%arg0: tensor<2x6x4x3xf32>, %arg1: t // CHECK-LABEL: @transpose_conv2d_dynamic_bias func.func @transpose_conv2d_dynamic_bias(%arg0: tensor<2x6x4x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor) { // CHECK: -> tensor<2x8x9x5xf32> - %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array, out_shape = array, stride = array} : (tensor<2x6x4x3xf32>, tensor<5x3x6x3xf32>, tensor) -> tensor<2x8x9x5xf32> + %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f32, out_pad = array, out_shape = array, stride = array} : (tensor<2x6x4x3xf32>, tensor<5x3x6x3xf32>, tensor) -> tensor<2x8x9x5xf32> return } @@ -953,14 +953,14 @@ func.func @transpose_conv2d_dynamic_bias(%arg0: tensor<2x6x4x3xf32>, %arg1: tens // CHECK-LABEL: @transpose_conv2d_padded func.func @transpose_conv2d_padded(%arg0: tensor<2x9x11x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>) { // CHECK: -> tensor<2x10x13x5xf32> - %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array, out_shape = array, stride = array} : (tensor<2x9x11x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x10x13x5xf32> + %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f32, out_pad = array, out_shape = array, stride = array} : (tensor<2x9x11x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x10x13x5xf32> return } // CHECK-LABEL: @transpose_conv2d_strided func.func @transpose_conv2d_strided(%arg0: tensor<1x5x7x1xf32>, %arg1: tensor<1x1x1x1xf32>, %arg2: tensor<1xf32>) { // CHECK: -> tensor<1x13x13x1xf32> - %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array, out_shape = array, stride = array} : (tensor<1x5x7x1xf32>, tensor<1x1x1x1xf32>, tensor<1xf32>) -> tensor<1x13x13x1xf32> + %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f32, out_pad = array, out_shape = array, stride = array} : (tensor<1x5x7x1xf32>, tensor<1x1x1x1xf32>, tensor<1xf32>) -> tensor<1x13x13x1xf32> return } @@ -1368,7 +1368,7 @@ func.func @test_non_tosa_consumer_still_propagates(%arg0: tensor<1x1x8xf32>, %ar func.func @test_tosa_use_def_chain(%arg0: tensor<1x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>, %arg2: tensor<16xf32>) -> tensor { // CHECK: [[CONV:%.+]] = tosa.conv2d %arg0, %arg1, %arg2 // CHECK: (tensor<1x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32> - %0 = tosa.conv2d %arg0, %arg1, %arg2 {dilation = array, pad = array, stride = array} : (tensor<1x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor + %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<1x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor // CHECK: tosa.max_pool2d [[CONV]] // CHECK: (tensor<1x32x32x16xf32>) -> tensor<1x16x16x16xf32> %1 = tosa.max_pool2d %0 {kernel = array, pad = array, stride = array} : (tensor) -> tensor diff --git a/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp b/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp index ac904c3e01c93..83db1188861ab 100644 --- a/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp +++ b/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp @@ -149,7 +149,7 @@ ConvertTosaConv2DOp::matchAndRewrite(Operation *op, op->getLoc(), newTosaConv2DOpType, tosaConv2DOp.getInput(), tosaConv2DOp.getWeight(), tosaConv2DOp.getBias(), tosaConv2DOp.getPadAttr(), tosaConv2DOp.getStrideAttr(), - tosaConv2DOp.getDilationAttr()); + tosaConv2DOp.getDilationAttr(), tosaConv2DOp.getAccTypeAttr()); // Create rescale to quantized type double inputScale = inputQType.getScale();