Skip to content

Commit

Permalink
[mlir][tosa] Use arith::maxf/arith::minf in lowering from tosa
Browse files Browse the repository at this point in the history
now that `arith` dialect has maxf/minf use it instead of cmp/select.
Also refactor clamp helpers to make them simlper.

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D131426
  • Loading branch information
ThomasRaoux committed Aug 9, 2022
1 parent 474145c commit 2eb50ce
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 81 deletions.
20 changes: 9 additions & 11 deletions mlir/include/mlir/Dialect/Tosa/Utils/CoversionUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,15 @@ SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops);
// Takes a vector of values and condenses them to a vector with no gaps.
SmallVector<Value> condenseValues(const SmallVector<Value> &values);

// Takes the parameters for a clamp and turns it into a series of ops.
template <typename T, typename P>
arith::SelectOp clampHelper(Location loc, Value arg, arith::ConstantOp min,
arith::ConstantOp max, P pred,
OpBuilder &rewriter) {
auto smallerThanMin = rewriter.create<T>(loc, pred, arg, min);
auto minOrArg =
rewriter.create<arith::SelectOp>(loc, smallerThanMin, min, arg);
auto largerThanMax = rewriter.create<T>(loc, pred, max, arg);
return rewriter.create<arith::SelectOp>(loc, largerThanMax, max, minOrArg);
}
// Takes the parameters for a clamp and turns it into a series of ops for float
// inputs.
Value clampFloatHelper(Location loc, Value arg, arith::ConstantOp min,
arith::ConstantOp max, OpBuilder &rewriter);

// Takes the parameters for a clamp and turns it into a series of ops for
// integer inputs.
Value clampIntHelper(Location loc, Value arg, arith::ConstantOp min,
arith::ConstantOp max, OpBuilder &rewriter);

// Returns the values in an attribute as an array of values.
template <typename T>
Expand Down
60 changes: 19 additions & 41 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
auto max = rewriter.create<arith::ConstantIntOp>(
loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(),
intermediateType);
auto clamp = clampHelper<arith::CmpIOp>(
loc, sub, min, max, arith::CmpIPredicate::slt, rewriter);
auto clamp = clampIntHelper(loc, sub, min, max, rewriter);

// Truncate to the final value.
return rewriter.create<arith::TruncIOp>(loc, elementTy, clamp);
Expand Down Expand Up @@ -335,9 +334,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,

// tosa::MaximumOp
if (isa<tosa::MaximumOp>(op) && elementTy.isa<FloatType>()) {
auto predicate = rewriter.create<arith::CmpFOp>(
loc, arith::CmpFPredicate::OGT, args[0], args[1]);
return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
return rewriter.create<arith::MaxFOp>(loc, args[0], args[1]);
}

if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
Expand All @@ -348,9 +345,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,

// tosa::MinimumOp
if (isa<tosa::MinimumOp>(op) && elementTy.isa<FloatType>()) {
auto predicate = rewriter.create<arith::CmpFOp>(
loc, arith::CmpFPredicate::OLT, args[0], args[1]);
return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
return rewriter.create<arith::MinFOp>(loc, args[0], args[1]);
}

if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
Expand Down Expand Up @@ -380,8 +375,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
loc, elementTy, rewriter.getFloatAttr(elementTy, min_apf));
auto max = rewriter.create<arith::ConstantOp>(
loc, elementTy, rewriter.getFloatAttr(elementTy, max_apf));
return clampHelper<arith::CmpFOp>(loc, args[0], min, max,
arith::CmpFPredicate::OLT, rewriter);
return clampFloatHelper(loc, args[0], min, max, rewriter);
}

if (isa<tosa::ClampOp>(op) && elementTy.isa<IntegerType>()) {
Expand Down Expand Up @@ -409,8 +403,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
loc, min, intTy.getIntOrFloatBitWidth());
auto maxVal = rewriter.create<arith::ConstantIntOp>(
loc, max, intTy.getIntOrFloatBitWidth());
return clampHelper<arith::CmpIOp>(loc, args[0], minVal, maxVal,
arith::CmpIPredicate::slt, rewriter);
return clampIntHelper(loc, args[0], minVal, maxVal, rewriter);
}

// tosa::ReluNOp
Expand All @@ -423,17 +416,15 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
APFloat::rmNearestTiesToEven, &losesInfo);
auto n = rewriter.create<arith::ConstantOp>(
loc, elementTy, rewriter.getFloatAttr(elementTy, max_apf));
return clampHelper<arith::CmpFOp>(loc, args[0], zero, n,
arith::CmpFPredicate::OLT, rewriter);
return clampFloatHelper(loc, args[0], zero, n, rewriter);
}

if (isa<tosa::ReluNOp>(op) && elementTy.isa<IntegerType>()) {
auto zero =
rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
auto n = createConstFromIntAttribute<int32_t>(op, "max_int", elementTy,
rewriter);
return clampHelper<arith::CmpIOp>(loc, args[0], zero, n,
arith::CmpIPredicate::slt, rewriter);
return clampIntHelper(loc, args[0], zero, n, rewriter);
}

// tosa::SigmoidOp
Expand Down Expand Up @@ -521,8 +512,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
auto rounded =
rewriter.create<arith::SelectOp>(loc, negative, subbed, added);

auto clamped = clampHelper<arith::CmpFOp>(
loc, rounded, intMin, intMax, arith::CmpFPredicate::OLT, rewriter);
auto clamped = clampFloatHelper(loc, rounded, intMin, intMax, rewriter);

return rewriter.create<arith::FPToSIOp>(loc, dstTy, clamped);
}
Expand Down Expand Up @@ -553,8 +543,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
.getSExtValue(),
srcTy.getIntOrFloatBitWidth());

auto clamped = clampHelper<arith::CmpIOp>(
loc, args[0], intMin, intMax, arith::CmpIPredicate::slt, rewriter);
auto clamped = clampIntHelper(loc, args[0], intMin, intMax, rewriter);
return rewriter.create<arith::TruncIOp>(loc, dstTy, clamped);
}
}
Expand Down Expand Up @@ -751,9 +740,7 @@ static Value createLinalgBodyCalculationForReduceOp(Operation *op,
}

if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<FloatType>()) {
auto predicate = rewriter.create<arith::CmpFOp>(
loc, arith::CmpFPredicate::OLT, args[0], args[1]);
return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
return rewriter.create<arith::MinFOp>(loc, args[0], args[1]);
}

if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<IntegerType>()) {
Expand All @@ -763,9 +750,7 @@ static Value createLinalgBodyCalculationForReduceOp(Operation *op,
}

if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<FloatType>()) {
auto predicate = rewriter.create<arith::CmpFOp>(
loc, arith::CmpFPredicate::OGT, args[0], args[1]);
return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
return rewriter.create<arith::MaxFOp>(loc, args[0], args[1]);
}

if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<IntegerType>()) {
Expand Down Expand Up @@ -1314,9 +1299,8 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
auto intMaxVal = nestedBuilder.create<arith::ConstantOp>(
loc, nestedBuilder.getI32IntegerAttr(intMax));

value = clampHelper<arith::CmpIOp>(
nestedLoc, value, intMinVal, intMaxVal, arith::CmpIPredicate::slt,
nestedBuilder);
value = clampIntHelper(nestedLoc, value, intMinVal, intMaxVal,
nestedBuilder);

if (outIntType.getWidth() < 32) {
value = nestedBuilder.create<arith::TruncIOp>(
Expand Down Expand Up @@ -1497,10 +1481,8 @@ class ResizeConverter : public OpRewritePattern<tosa::ResizeOp> {

// Clamp the to be within the bounds of the input image.

iy = clampHelper<arith::CmpIOp>(loc, iy, hwMin, hMax,
arith::CmpIPredicate::slt, rewriter);
ix = clampHelper<arith::CmpIOp>(loc, ix, hwMin, wMax,
arith::CmpIPredicate::slt, rewriter);
iy = clampIntHelper(loc, iy, hwMin, hMax, rewriter);
ix = clampIntHelper(loc, ix, hwMin, wMax, rewriter);

// Read the value from the input array.
iy =
Expand All @@ -1525,15 +1507,11 @@ class ResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
Value y1 = rewriter.create<arith::AddIOp>(loc, y0, oneVal);
Value x1 = rewriter.create<arith::AddIOp>(loc, x0, oneVal);

y0 = clampHelper<arith::CmpIOp>(loc, y0, hwMin, hMax,
arith::CmpIPredicate::slt, rewriter);
y1 = clampHelper<arith::CmpIOp>(loc, y1, hwMin, hMax,
arith::CmpIPredicate::slt, rewriter);
y0 = clampIntHelper(loc, y0, hwMin, hMax, rewriter);
y1 = clampIntHelper(loc, y1, hwMin, hMax, rewriter);

x0 = clampHelper<arith::CmpIOp>(loc, x0, hwMin, wMax,
arith::CmpIPredicate::slt, rewriter);
x1 = clampHelper<arith::CmpIOp>(loc, x1, hwMin, wMax,
arith::CmpIPredicate::slt, rewriter);
x0 = clampIntHelper(loc, x0, hwMin, wMax, rewriter);
x1 = clampIntHelper(loc, x1, hwMin, wMax, rewriter);

y0 =
rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), y0);
Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -943,8 +943,7 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
auto max = rewriter.create<arith::ConstantIntOp>(
loc, APInt::getSignedMaxValue(outBitwidth).getSExtValue(),
accETy);
auto clamp = clampHelper<arith::CmpIOp>(
loc, scaled, min, max, arith::CmpIPredicate::slt, rewriter);
auto clamp = clampIntHelper(loc, scaled, min, max, rewriter);

poolVal = clamp;
// Convert type.
Expand Down
18 changes: 18 additions & 0 deletions mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,21 @@ mlir::tosa::condenseValues(const SmallVector<Value> &values) {
condensedValues.push_back(value);
return condensedValues;
}

Value mlir::tosa::clampFloatHelper(Location loc, Value arg,
arith::ConstantOp min, arith::ConstantOp max,
OpBuilder &rewriter) {
Value minValue = rewriter.create<arith::MinFOp>(loc, arg, min);
return rewriter.create<arith::MaxFOp>(loc, minValue, max);
}

Value mlir::tosa::clampIntHelper(Location loc, Value arg, arith::ConstantOp min,
arith::ConstantOp max, OpBuilder &rewriter) {
auto smallerThanMin =
rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, arg, min);
auto minOrArg =
rewriter.create<arith::SelectOp>(loc, smallerThanMin, min, arg);
auto largerThanMax =
rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, max, arg);
return rewriter.create<arith::SelectOp>(loc, largerThanMax, max, minOrArg);
}
48 changes: 21 additions & 27 deletions mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -198,13 +198,11 @@ func.func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
%13 = "tosa.select"(%10, %0, %1) : (tensor<1xi1>, tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>

// CHECK: linalg.generic
// CHECK: arith.cmpf
// CHECK: select
// CHECK: arith.maxf
%14 = "tosa.maximum"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>

// CHECK: linalg.generic
// CHECK: arith.cmpf
// CHECK: select
// CHECK: arith.minf
%15 = "tosa.minimum"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>

// CHECK: linalg.generic
Expand All @@ -216,13 +214,13 @@ func.func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
%17 = "tosa.floor"(%0) : (tensor<1xf32>) -> tensor<1xf32>

// CHECK: linalg.generic
// CHECK: arith.cmpf
// CHECK: select
// CHECK: arith.minf
// CHECK: arith.maxf
%18 = "tosa.clamp"(%0) {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xf32>) -> tensor<1xf32>

// CHECK: linalg.generic
// CHECK: arith.cmpf
// CHECK: select
// CHECK: arith.minf
// CHECK: arith.maxf
%19 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xf32>) -> tensor<1xf32>

// CHECK: linalg.generic
Expand All @@ -241,10 +239,8 @@ func.func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
// CHECK: arith.subf
// CHECK: arith.cmpf olt
// CHECK: select
// CHECK: arith.cmpf olt
// CHECK: select
// CHECK: arith.cmpf olt
// CHECK: select
// CHECK: arith.minf
// CHECK: arith.maxf
// CHECK: arith.fptosi
%21 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xi32>

Expand Down Expand Up @@ -451,20 +447,22 @@ func.func @test_simple_ui8(%arg0: tensor<1xi8>) -> () {
// CHECK-LABEL: @test_i8
func.func @test_i8(%arg0: tensor<1xi8>) -> () {
// CHECK: linalg.generic
// CHECK: ^bb0(%[[ARG1:.+]]: i8,
// CHECK-DAG: %[[C127:.+]] = arith.constant -127
// CHECK-DAG: %[[C126:.+]] = arith.constant 126
// CHECK-DAG: %[[CMP1:.+]] = arith.cmpi slt, %arg1, %[[C127]]
// CHECK-DAG: %[[CMP1:.+]] = arith.cmpi slt, %[[ARG1]], %[[C127]]
// CHECK-DAG: %[[SEL1:.+]] = arith.select %[[CMP1]], %[[C127]]
// CHECK-DAG: %[[CMP2:.+]] = arith.cmpi slt, %[[C126]], %arg1
// CHECK-DAG: %[[CMP2:.+]] = arith.cmpi slt, %[[C126]], %[[ARG1]]
// CHECK: %[[SEL2:.+]] = arith.select %[[CMP2]], %[[C126]], %[[SEL1]]
%0 = "tosa.clamp"(%arg0) {min_int = -127 : i64, max_int = 126 : i64, min_fp = 0.0 : f32, max_fp = 0.0 : f32} : (tensor<1xi8>) -> tensor<1xi8>

// CHECK: linalg.generic
// CHECK: ^bb0(%[[ARG1:.+]]: i8,
// CHECK-DAG: %[[C128:.+]] = arith.constant -128
// CHECK-DAG: %[[C127:.+]] = arith.constant 127
// CHECK-DAG: %[[CMP1:.+]] = arith.cmpi slt, %arg1, %[[C128]]
// CHECK-DAG: %[[CMP1:.+]] = arith.cmpi slt, %[[ARG1]], %[[C128]]
// CHECK-DAG: %[[SEL1:.+]] = arith.select %[[CMP1]], %[[C128]]
// CHECK-DAG: %[[CMP2:.+]] = arith.cmpi slt, %[[C127]], %arg1
// CHECK-DAG: %[[CMP2:.+]] = arith.cmpi slt, %[[C127]], %[[ARG1]]
// CHECK: %[[SEL2:.+]] = arith.select %[[CMP2]], %[[C127]], %[[SEL1]]
%1 = "tosa.clamp"(%arg0) {min_int = -130 : i64, max_int = 130 : i64, min_fp = 0.0 : f32, max_fp = 0.0 : f32} : (tensor<1xi8>) -> tensor<1xi8>

Expand All @@ -476,12 +474,11 @@ func.func @test_i8(%arg0: tensor<1xi8>) -> () {
// CHECK-LABEL: @test_clamp_f16
func.func @test_clamp_f16(%arg0: tensor<1xf16>) -> () {
// CHECK: linalg.generic
// CHECK: ^bb0(%[[ARG1:.+]]: f16,
// CHECK-DAG: %[[C0:.+]] = arith.constant 0.0
// CHECK-DAG: %[[C6:.+]] = arith.constant 6.0
// CHECK-DAG: %[[CMP1:.+]] = arith.cmpf olt, %arg1, %[[C0]]
// CHECK-DAG: %[[SEL1:.+]] = arith.select %[[CMP1]], %[[C0]]
// CHECK-DAG: %[[CMP2:.+]] = arith.cmpf olt, %[[C6]], %arg1
// CHECK: %[[SEL2:.+]] = arith.select %[[CMP2]], %[[C6]], %[[SEL1]]
// CHECK-DAG: %[[MIN:.+]] = arith.minf %[[ARG1]], %[[C0]]
// CHECK-DAG: %[[MAX:.+]] = arith.maxf %[[MIN]], %[[C6]]
%0 = "tosa.clamp"(%arg0) {min_int = 0 : i64, max_int = 0 : i64, min_fp = 0.0 : f32, max_fp = 6.0 : f32} : (tensor<1xf16>) -> tensor<1xf16>

return
Expand Down Expand Up @@ -732,15 +729,13 @@ func.func @reduce_float(%arg0: tensor<5x4xf32>) -> () {
// CHECK: arith.constant 3.40282347E+38 : f32
// CHECK: linalg.fill
// CHECK: linalg.generic
// CHECK: arith.cmpf olt
// CHECK: select
// CHECK: arith.minf
%3 = "tosa.reduce_min"(%arg0) {axis = 0 : i64} : (tensor<5x4xf32>) -> tensor<1x4xf32>

// CHECK: arith.constant -3.40282347E+38 : f32
// CHECK: linalg.fill
// CHECK: linalg.generic
// CHECK: arith.cmpf ogt
// CHECK: select
// CHECK: arith.maxf
%4 = "tosa.reduce_max"(%arg0) {axis = 0 : i64} : (tensor<5x4xf32>) -> tensor<1x4xf32>
return
}
Expand Down Expand Up @@ -803,9 +798,8 @@ func.func @reduce_float_dyn_multiple(%arg0: tensor<?x?xf32>) -> () {
// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CMIN]]{{.*}}outs(%[[INIT]]
// CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<?x?xf32>) outs(%[[FILL]] : tensor<?xf32>)
// CHECK: ^bb0(%arg1: f32, %arg2: f32)
// CHECK: %[[CMP:.+]] = arith.cmpf ogt, %arg1, %arg2 : f32
// CHECK: %[[RES:.+]] = arith.select %[[CMP]], %arg1, %arg2 : f32
// CHECK: linalg.yield %[[RES]] : f32
// CHECK: %[[MAX:.+]] = arith.maxf %arg1, %arg2 : f32
// CHECK: linalg.yield %[[MAX]] : f32
// CHECK: tensor.expand_shape %[[GENERIC]] {{\[}}[0, 1]] : tensor<?xf32> into tensor<?x1xf32>
%0 = "tosa.reduce_max"(%arg0) {axis = 1 : i64} : (tensor<?x?xf32>) -> tensor<?x1xf32>
return
Expand Down

0 comments on commit 2eb50ce

Please sign in to comment.