diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 6d2cbb5539e14..e3cba38871909 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -452,18 +452,14 @@ struct ClampIsNoOp : public OpRewritePattern { auto inputType = llvm::dyn_cast(op.getInput().getType()); auto inputElementType = inputType.getElementType(); - if (!inputType.hasStaticShape()) { - return failure(); - } - if (isa(inputElementType)) { // Unlike integer types, floating point types can represent infinity. - auto minClamp = + const auto minClamp = llvm::cast(op.getMinValAttr()).getValue(); - auto maxClamp = + const auto maxClamp = llvm::cast(op.getMaxValAttr()).getValue(); - bool isMin = minClamp.isNegInfinity(); - bool isMax = maxClamp.isInfinity(); + const bool isMin = minClamp.isNegInfinity(); + const bool isMax = maxClamp.isInfinity(); if (isMin && isMax) { rewriter.replaceOp(op, input); @@ -472,18 +468,19 @@ struct ClampIsNoOp : public OpRewritePattern { return failure(); } - if (inputElementType.isUnsignedInteger()) { - int64_t minClamp = - llvm::cast(op.getMinValAttr()).getUInt(); - int64_t maxClamp = - llvm::cast(op.getMaxValAttr()).getUInt(); + // i1 types are boolean in TOSA + const bool isBoolean = inputElementType.isInteger(1); + if (inputElementType.isUnsignedInteger() || isBoolean) { + const int64_t minClamp = llvm::cast(op.getMinValAttr()) + .getValue() + .getZExtValue(); + const int64_t maxClamp = llvm::cast(op.getMaxValAttr()) + .getValue() + .getZExtValue(); - int64_t intMin = - APInt::getMinValue(inputElementType.getIntOrFloatBitWidth()) - .getZExtValue(); - int64_t intMax = - APInt::getMaxValue(inputElementType.getIntOrFloatBitWidth()) - .getZExtValue(); + const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth(); + const int64_t intMin = APInt::getMinValue(bitWidth).getZExtValue(); + const int64_t intMax = APInt::getMaxValue(bitWidth).getZExtValue(); if (minClamp <= intMin && maxClamp >= intMax) { rewriter.replaceOp(op, input); @@ -493,17 +490,14 @@ struct ClampIsNoOp : public OpRewritePattern { } if (llvm::isa(inputElementType)) { - int64_t minClamp = + const int64_t minClamp = llvm::cast(op.getMinValAttr()).getInt(); - int64_t maxClamp = + const int64_t maxClamp = llvm::cast(op.getMaxValAttr()).getInt(); - int64_t intMin = - APInt::getSignedMinValue(inputElementType.getIntOrFloatBitWidth()) - .getSExtValue(); - int64_t intMax = - APInt::getSignedMaxValue(inputElementType.getIntOrFloatBitWidth()) - .getSExtValue(); + const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth(); + const int64_t intMin = APInt::getSignedMinValue(bitWidth).getSExtValue(); + const int64_t intMax = APInt::getSignedMaxValue(bitWidth).getSExtValue(); if (minClamp <= intMin && maxClamp >= intMax) { rewriter.replaceOp(op, input); diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index 6b55442a82a0a..5150ee36e9e5e 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -241,6 +241,26 @@ func.func @clamp_f32_is_noop(%arg0: tensor<4xf32>) -> tensor<4xf32> { // ----- +// CHECK-LABEL: @clamp_boolean_is_noop +func.func @clamp_boolean_is_noop(%arg0: tensor<4xi1>) -> tensor<4xi1> { + // CHECK: return %arg0 + // CHECK-NOT: tosa.clamp + %0 = tosa.clamp %arg0 {min_val = false, max_val = true} : (tensor<4xi1>) -> tensor<4xi1> + return %0 : tensor<4xi1> +} + +// ----- + +// CHECK-LABEL: @clamp_boolean_dynamic_is_noop +func.func @clamp_boolean_dynamic_is_noop(%arg0: tensor) -> tensor { + // CHECK: return %arg0 + // CHECK-NOT: tosa.clamp + %0 = tosa.clamp %arg0 {min_val = false, max_val = true} : (tensor) -> tensor + return %0 : tensor +} + +// ----- + // CHECK-LABEL: @clamp_int8_is_noop func.func @clamp_int8_is_noop(%arg0: tensor<4xi8>) -> tensor<4xi8> { // CHECK: return %arg0