Skip to content

Commit

Permalink
[mlir][tosa] Fix tosa.cast semantics to perform rounding/clipping
Browse files Browse the repository at this point in the history
Rounding to integers requires rounding (for floating points) and clipping
to the min/max values of the destination range. Added this behavior and
updated tests appropriately.

Reviewed By: sjarus, silvas

Differential Revision: https://reviews.llvm.org/D102375
  • Loading branch information
rsuderman committed May 13, 2021
1 parent f690715 commit 3f8aafd
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 6 deletions.
51 changes: 45 additions & 6 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
Expand Up @@ -491,9 +491,34 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
args.front(), zero);
}

if (mlir::FPToSIOp::areCastCompatible(srcTy, dstTy))
return rewriter.create<mlir::FPToSIOp>(loc, resultTypes, args,
mlir::None);
if (mlir::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
auto zero =
rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
auto half =
rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.5f));

auto intMin = rewriter.create<ConstantOp>(
loc, rewriter.getF32FloatAttr(
APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
.getSExtValue()));

auto intMax = rewriter.create<ConstantOp>(
loc, rewriter.getF32FloatAttr(
APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
.getSExtValue()));

auto added = rewriter.create<AddFOp>(loc, args[0], half);
auto subbed = rewriter.create<SubFOp>(loc, args[0], half);
auto negative =
rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OLT, args[0], zero);
auto rounded =
rewriter.create<mlir::SelectOp>(loc, negative, subbed, added);

auto clamped = clampHelper<mlir::CmpFOp>(loc, rounded, intMin, intMax,
CmpFPredicate::OLT, rewriter);

return rewriter.create<mlir::FPToSIOp>(loc, dstTy, clamped);
}

// Casting to boolean, integers need to only be checked as not-equal to
// zero.
Expand All @@ -508,9 +533,23 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
return rewriter.create<mlir::SignExtendIOp>(loc, resultTypes, args,
mlir::None);

if (srcTy.isa<IntegerType>() && dstTy.isa<IntegerType>() && !bitExtend)
return rewriter.create<mlir::TruncateIOp>(loc, resultTypes, args,
mlir::None);
if (srcTy.isa<IntegerType>() && dstTy.isa<IntegerType>() && !bitExtend) {
auto intMin = rewriter.create<ConstantIntOp>(
loc,
APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
.getSExtValue(),
srcTy.getIntOrFloatBitWidth());

auto intMax = rewriter.create<ConstantIntOp>(
loc,
APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
.getSExtValue(),
srcTy.getIntOrFloatBitWidth());

auto clamped = clampHelper<mlir::CmpIOp>(loc, args[0], intMin, intMax,
CmpIPredicate::slt, rewriter);
return rewriter.create<mlir::TruncateIOp>(loc, dstTy, clamped);
}
}

(void)rewriter.notifyMatchFailure(
Expand Down
18 changes: 18 additions & 0 deletions mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
Expand Up @@ -213,6 +213,18 @@ func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
%20 = "tosa.sigmoid"(%0) : (tensor<1xf32>) -> tensor<1xf32>

// CHECK: linalg.generic
// CHECK: constant 0.000000e+00
// CHECK: constant 5.000000e-01
// CHECK: constant -2.14748365E+9
// CHECK: constant 2.14748365E+9
// CHECK: addf
// CHECK: subf
// CHECK: cmpf olt
// CHECK: select
// CHECK: cmpf olt
// CHECK: select
// CHECK: cmpf olt
// CHECK: select
// CHECK: fptosi
%21 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xi32>

Expand Down Expand Up @@ -358,6 +370,12 @@ func @test_simple_i32(%arg0: tensor<1xi32>) -> () {
%18 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>

// CHECK: linalg.generic
// CHECK: constant -32768
// CHECK: constant 32767
// CHECK: cmpi slt
// CHECK: select
// CHECK: cmpi slt
// CHECK: select
// CHECK: trunci
%19 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi16>

Expand Down

0 comments on commit 3f8aafd

Please sign in to comment.