Skip to content

Commit

Permalink
[mlir][tosa] Fix clamp to restrict only within valid bitwidth range
Browse files Browse the repository at this point in the history
Its possible for the clamp to have invalid min/max values on its range. To fix
this we validate the range of the min/max and clamp to a valid range.

Reviewed By: NatashaKnk

Differential Revision: https://reviews.llvm.org/D108256
  • Loading branch information
rsuderman committed Aug 18, 2021
1 parent 58e4e71 commit 76c9712
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 6 deletions.
32 changes: 26 additions & 6 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
Expand Up @@ -428,12 +428,32 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
}

if (isa<tosa::ClampOp>(op) && elementTy.isa<IntegerType>()) {
auto min = createConstFromIntAttribute<int32_t>(op, "min_int", elementTy,
rewriter);
auto max = createConstFromIntAttribute<int32_t>(op, "max_int", elementTy,
rewriter);
return clampHelper<mlir::CmpIOp>(loc, args[0], min, max, CmpIPredicate::slt,
rewriter);
auto intTy = elementTy.cast<IntegerType>();
int32_t min = static_cast<int32_t>(
op->getAttr("min_int").cast<IntegerAttr>().getValue().getSExtValue());
int32_t max = static_cast<int32_t>(
op->getAttr("max_int").cast<IntegerAttr>().getValue().getSExtValue());

if (intTy.isUnsignedInteger()) {
min = std::max<int32_t>(min, 0);
max = std::min<int32_t>(
max,
APInt::getMaxValue(intTy.getIntOrFloatBitWidth()).getSExtValue());
} else {
min = std::max<int32_t>(
min, APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
.getSExtValue());
max = std::min<int32_t>(
max, APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
.getSExtValue());
}

auto minVal =
rewriter.create<ConstantIntOp>(loc, min, intTy.getIntOrFloatBitWidth());
auto maxVal =
rewriter.create<ConstantIntOp>(loc, max, intTy.getIntOrFloatBitWidth());
return clampHelper<mlir::CmpIOp>(loc, args[0], minVal, maxVal,
CmpIPredicate::slt, rewriter);
}

// tosa::ReluNOp
Expand Down
25 changes: 25 additions & 0 deletions mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
Expand Up @@ -404,6 +404,31 @@ func @test_simple_i32(%arg0: tensor<1xi32>) -> () {

// -----

// CHECK-LABEL: @test_i8
func @test_i8(%arg0: tensor<1xi8>) -> () {
// CHECK: linalg.generic
// CHECK-DAG: %[[C127:.+]] = constant -127
// CHECK-DAG: %[[C126:.+]] = constant 126
// CHECK-DAG: %[[CMP1:.+]] = cmpi slt, %arg1, %[[C127]]
// CHECK-DAG: %[[SEL1:.+]] = select %[[CMP1]], %[[C127]]
// CHECK-DAG: %[[CMP2:.+]] = cmpi slt, %[[C126]], %arg1
// CHECK: %[[SEL2:.+]] = select %[[CMP2]], %[[C126]], %[[SEL1]]
%0 = "tosa.clamp"(%arg0) {min_int = -127 : i64, max_int = 126 : i64, min_fp = 0.0 : f32, max_fp = 0.0 : f32} : (tensor<1xi8>) -> tensor<1xi8>

// CHECK: linalg.generic
// CHECK-DAG: %[[C128:.+]] = constant -128
// CHECK-DAG: %[[C127:.+]] = constant 127
// CHECK-DAG: %[[CMP1:.+]] = cmpi slt, %arg1, %[[C128]]
// CHECK-DAG: %[[SEL1:.+]] = select %[[CMP1]], %[[C128]]
// CHECK-DAG: %[[CMP2:.+]] = cmpi slt, %[[C127]], %arg1
// CHECK: %[[SEL2:.+]] = select %[[CMP2]], %[[C127]], %[[SEL1]]
%1 = "tosa.clamp"(%arg0) {min_int = -130 : i64, max_int = 130 : i64, min_fp = 0.0 : f32, max_fp = 0.0 : f32} : (tensor<1xi8>) -> tensor<1xi8>

return
}

// -----

// CHECK-LABEL: @test_bool
func @test_bool(%arg0: tensor<1xi1>, %arg1: tensor<1xi1>) -> () {
// CHECK: linalg.generic
Expand Down

0 comments on commit 76c9712

Please sign in to comment.