diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h index 9d9a934cdfd5e..e9ad7869422c4 100644 --- a/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h +++ b/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h @@ -88,6 +88,8 @@ TypeAttr buildQTypeAttrFromMinMax(OpBuilder builder, Type inputDType, IntegerAttr quantBits, int filterQuantDim, bool isSigned, BoolAttr narrowRange); +Type getStorageElementTypeFromQuantized(quant::QuantizedType quantizedType); + } // namespace tosa } // namespace mlir diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index a118ac9c4b111..c420a4c9596ff 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" +#include "mlir/Dialect/Tosa/Utils/QuantUtils.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Matchers.h" @@ -539,7 +540,7 @@ struct ClampClampOptimization : public OpRewritePattern { auto inputEType = llvm::cast(input.getType()).getElementType(); if (auto quantType = llvm::dyn_cast(inputEType)) { - inputEType = quantType.getStorageType(); + inputEType = getStorageElementTypeFromQuantized(quantType); } Attribute newMinValAttr, newMaxValAttr; diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 65e0a59d39168..1c175f9ab0207 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -563,7 +563,7 @@ static std::optional idivCheck(const int64_t lhs, const int64_t rhs) { static Type getStorageElementTypeOrSelf(Type type) { auto srcType = getElementTypeOrSelf(type); if (auto quantType = llvm::dyn_cast(srcType)) - srcType = quantType.getStorageType(); + srcType = getStorageElementTypeFromQuantized(quantType); return srcType; } @@ -631,16 +631,16 @@ static LogicalResult verifyConvOp(T op) { bool resultIsFloat = llvm::isa(resultEType); if (auto quantType = llvm::dyn_cast(inputEType)) - inputEType = quantType.getStorageType(); + inputEType = getStorageElementTypeFromQuantized(quantType); if (auto quantType = llvm::dyn_cast(weightEType)) - weightEType = quantType.getStorageType(); + weightEType = getStorageElementTypeFromQuantized(quantType); if (auto quantType = llvm::dyn_cast(biasEType)) - biasEType = quantType.getStorageType(); + biasEType = getStorageElementTypeFromQuantized(quantType); if (auto quantType = llvm::dyn_cast(resultEType)) - resultEType = quantType.getStorageType(); + resultEType = getStorageElementTypeFromQuantized(quantType); if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) { // for now, only enforce bias element type == result element type for @@ -709,7 +709,7 @@ LogicalResult tosa::ConstOp::verify() { if (auto result = llvm::dyn_cast( outputType.getElementType())) { - if (result.getStorageType() == attrType.getElementType()) + if (getStorageElementTypeFromQuantized(result) == attrType.getElementType()) return success(); } @@ -727,7 +727,7 @@ static LogicalResult verifyConvOpModes(T op) { llvm::cast(op.getInput().getType()).getElementType(); if (auto quantType = llvm::dyn_cast(inputEType)) - inputEType = quantType.getStorageType(); + inputEType = getStorageElementTypeFromQuantized(quantType); auto accType = op.getAccType(); if (inputEType.isInteger(8) && !accType.isInteger(32)) @@ -752,7 +752,7 @@ static LogicalResult verifyConvOpModes(T op) { llvm::cast(op.getResult().getType()).getElementType(); if (auto quantType = llvm::dyn_cast(resultEType)) - resultEType = quantType.getStorageType(); + resultEType = getStorageElementTypeFromQuantized(quantType); return success(); } @@ -1179,13 +1179,13 @@ LogicalResult tosa::ClampOp::verify() { llvm::cast(getInput().getType()).getElementType(); if (auto quantType = llvm::dyn_cast(inputETy)) { - inputETy = quantType.getStorageType(); + inputETy = getStorageElementTypeFromQuantized(quantType); } mlir::Type outputETy = llvm::cast(getOutput().getType()).getElementType(); if (auto quantType = llvm::dyn_cast(outputETy)) { - outputETy = quantType.getStorageType(); + outputETy = getStorageElementTypeFromQuantized(quantType); } if (inputETy != outputETy) return emitOpError("input/output element types are incompatible."); diff --git a/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp index 02c86a090e6d4..c55b13dc98cc5 100644 --- a/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp +++ b/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp @@ -395,3 +395,16 @@ mlir::tosa::buildQTypeAttrFromMinMax(OpBuilder builder, Type inputDtype, maxAttr, quantBits, filterQuantDim, isSigned, narrowRange)); } + +Type mlir::tosa::getStorageElementTypeFromQuantized( + quant::QuantizedType quantType) { + auto quantEty = quantType.getStorageType(); + // StorageType doesn't capture the sign information + // Explicitly create unsigned type if needed + if (!quantType.isSigned()) { + quantEty = IntegerType::get(quantEty.getContext(), + quantEty.getIntOrFloatBitWidth(), + IntegerType::Unsigned); + } + return quantEty; +} diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index fc5ea7710e2c4..84776c47b628d 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -360,6 +360,36 @@ func.func @clamp_twice_with_nan_propagate_ignore_is_not_single_clamp(%arg0: tens return %1 : tensor<4xi8> } +// ----- + +// CHECK-LABEL: @clamp_twice_with_unsigned_quantized_is_single_clamp +// CHECK: tosa.clamp %arg0 {max_val = 230 : ui8, min_val = 10 : ui8} +func.func @clamp_twice_with_unsigned_quantized_is_single_clamp(%arg0:tensor>) -> (tensor>) { + %0 = tosa.clamp %arg0 {max_val = 240 : ui8, min_val = 10 : ui8} : (tensor>) -> tensor> + %1 = tosa.clamp %0 {max_val = 230 : ui8, min_val = 5 : ui8} : (tensor>) -> tensor> + return %1 : tensor> +} + +// ----- + +// CHECK-LABEL: @clamp_twice_with_signed_quantized_is_single_clamp +// CHECK: tosa.clamp %arg0 {max_val = 110 : i8, min_val = -5 : i8} +func.func @clamp_twice_with_signed_quantized_is_single_clamp(%arg0:tensor>) -> (tensor>) { + %0 = tosa.clamp %arg0 {max_val = 110 : i8, min_val = -10 : i8} : (tensor>) -> tensor> + %1 = tosa.clamp %0 {max_val = 120 : i8, min_val = -5 : i8} : (tensor>) -> tensor> + return %1 : tensor> +} + +// CHECK-LABEL: @clamp_twice_with_signed_quantized_non_overlap_is_not_single_clamp +// CHECK: %[[CLAMP_1:.*]] = tosa.clamp %arg0 {max_val = 50 : i8, min_val = -10 : i8} +// CHECK-NEXT: tosa.clamp %[[CLAMP_1]] {max_val = 120 : i8, min_val = 60 : i8} +func.func @clamp_twice_with_signed_quantized_non_overlap_is_not_single_clamp(%arg0:tensor>) -> (tensor>) { + %0 = tosa.clamp %arg0 {max_val = 50 : i8, min_val = -10 : i8} : (tensor>) -> tensor> + %1 = tosa.clamp %0 {max_val = 120 : i8, min_val = 60 : i8} : (tensor>) -> tensor> + return %1 : tensor> +} + + // ----- // CHECK-LABEL: @concat_fold diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir index a4591f7ffd393..652447bd6056e 100644 --- a/mlir/test/Dialect/Tosa/ops.mlir +++ b/mlir/test/Dialect/Tosa/ops.mlir @@ -279,6 +279,13 @@ func.func @test_clamp_quantized(%arg0: tensor<13x21x3x!quant.uniform> } +// ----- +// CHECK-LABEL: clamp_quantized_unsigned +func.func @clamp_quantized_unsigned(%arg0:tensor>) -> (tensor>) { + %0 = tosa.clamp %arg0 {max_val = 255 : ui8, min_val = 0 : ui8} : (tensor>) -> tensor> + return %0 : tensor> +} + // ----- // CHECK-LABEL: sigmoid func.func @test_sigmoid(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { diff --git a/mlir/test/Dialect/Tosa/quant-test.mlir b/mlir/test/Dialect/Tosa/quant-test.mlir index f0ad4eb4fdb0b..88dffe7fdd2e8 100644 --- a/mlir/test/Dialect/Tosa/quant-test.mlir +++ b/mlir/test/Dialect/Tosa/quant-test.mlir @@ -1,13 +1,21 @@ // RUN: mlir-opt --tosa-test-quant-utils %s | FileCheck %s // ----- -// CHECK-LABEL: test_build_qtype -func.func @test_build_qtype(%arg0 : tensor<16x1x1x8x!quant.uniform:f32, 0.015680249780416489:128>>, %arg1: tensor<1xi8>, %arg2: tensor<1xi8>) -> tensor<16x1x1x8x!quant.uniform:f32, 0.015680249780416489:128>> { +// CHECK-LABEL: test_build_qtype_unsigned +func.func @test_build_qtype_unsigned(%arg0 : tensor<16x1x1x8x!quant.uniform:f32, 0.015680249780416489:128>>, %arg1: tensor<1xui8>, %arg2: tensor<1xui8>) -> tensor<16x1x1x8x!quant.uniform:f32, 0.015680249780416489:128>> { // CHECK: tosa.negate - %0 = "tosa.negate"(%arg0, %arg1, %arg2) : (tensor<16x1x1x8x!quant.uniform:f32, 0.015680249780416489:128>>, tensor<1xi8>, tensor<1xi8>) -> tensor<16x1x1x8x!quant.uniform:f32, 0.015680249780416489:128>> + %0 = "tosa.negate"(%arg0, %arg1, %arg2) : (tensor<16x1x1x8x!quant.uniform:f32, 0.015680249780416489:128>>, tensor<1xui8>, tensor<1xui8>) -> tensor<16x1x1x8x!quant.uniform:f32, 0.015680249780416489:128>> return %0 : tensor<16x1x1x8x!quant.uniform:f32, 0.015680249780416489:128>> } +// ----- +// CHECK-LABEL: test_build_qtype_signed +func.func @test_build_qtype_signed(%arg0 : tensor<16x1x1x8x!quant.uniform:f32, 0.015680249780416489:128>>, %arg1: tensor<1xi8>, %arg2: tensor<1xi8>) -> tensor<16x1x1x8x!quant.uniform:f32, 0.015680249780416489:128>> { + // CHECK: tosa.negate + %0 = "tosa.negate"(%arg0, %arg1, %arg2) : (tensor<16x1x1x8x!quant.uniform:f32, 0.015680249780416489:128>>, tensor<1xi8>, tensor<1xi8>) -> tensor<16x1x1x8x!quant.uniform:f32, 0.015680249780416489:128>> + return %0 : tensor<16x1x1x8x!quant.uniform:f32, 0.015680249780416489:128>> +} + // ----- // CHECK-LABEL: test_build_mult_and_shift func.func @test_build_mult_and_shift(%arg0: tensor<1x32x32x8x!quant.uniform>, %arg1 : tensor<16x1x1x8x!quant.uniform:f32, 0.015680249780416489>>, %arg2 : tensor<16xi32>) -> tensor<1x34x36x16x!quant.uniform> { diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir index 6cf76cdc7ad8e..ea64d468f151e 100644 --- a/mlir/test/Dialect/Tosa/verifier.mlir +++ b/mlir/test/Dialect/Tosa/verifier.mlir @@ -1222,3 +1222,11 @@ func.func @test_cast_to_block_scaled_data_scale_channel_mismatch(%arg0: tensor<4 %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size} : (tensor<4x32xf32>) -> (tensor<4x32xf4E2M1FN>, tensor<4x2xf8E8M0FNU>) return %0#0, %0#1 : tensor<4x32xf4E2M1FN>, tensor<4x2xf8E8M0FNU> } + +// ----- + +func.func @test_clamp_quantized(%arg0:tensor>) -> (tensor>) { + // expected-error@+1 {{'tosa.clamp' op min/max attributes types are incompatible with input/output element types.}} + %0 = tosa.clamp %arg0 {max_val = 127 : i8, min_val = -128 : i8} : (tensor>) -> tensor> + return %0 : tensor> +}