Skip to content

Commit

Permalink
[mlir][tosa] Fix tosa.mul to use tosa.apply_scale
Browse files Browse the repository at this point in the history
Multiply-shift requires wider compute types or CPU specific code to avoid
premature truncation, apply_shift fixes this issue

Also, Tosa's mul op supports different input / output types. Added path that
sign-extends input values to int-32 values before multiplying.

Differential Revision: https://reviews.llvm.org/D99011
  • Loading branch information
rsuderman committed Mar 22, 2021
1 parent 5727df2 commit d7c44a5
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 24 deletions.
39 changes: 33 additions & 6 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,39 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
}

if (isa<tosa::MulOp>(op) && elementTy.isa<IntegerType>()) {
auto mul =
rewriter.create<mlir::MulIOp>(loc, resultTypes, args[0], args[1]);
auto constant =
rewriter.create<mlir::ConstantOp>(loc, elementTy, op->getAttr("shift"));
return rewriter.create<mlir::SignedShiftRightOp>(loc, resultTypes, mul,
constant);
Value a = args[0];
Value b = args[1];
auto shift =
op->getAttr("shift").cast<IntegerAttr>().getValue().getSExtValue();
if (shift > 0) {
auto shiftConst =
rewriter.create<ConstantIntOp>(loc, shift, /*bitwidth=*/8);
if (!a.getType().isInteger(32))
a = rewriter.create<SignExtendIOp>(loc, rewriter.getI32Type(), a);

if (!b.getType().isInteger(32))
b = rewriter.create<SignExtendIOp>(loc, rewriter.getI32Type(), b);

auto result = rewriter.create<tosa::ApplyScaleOp>(
loc, rewriter.getI32Type(), a, b, shiftConst,
rewriter.getBoolAttr(false));

if (elementTy.isInteger(32))
return result;

return rewriter.create<TruncateIOp>(loc, elementTy, result);
}

int aWidth = a.getType().getIntOrFloatBitWidth();
int bWidth = b.getType().getIntOrFloatBitWidth();
int cWidth = resultTypes[0].getIntOrFloatBitWidth();

if (aWidth < cWidth)
a = rewriter.create<SignExtendIOp>(loc, resultTypes[0], a);
if (bWidth < cWidth)
b = rewriter.create<SignExtendIOp>(loc, resultTypes[0], b);

return rewriter.create<mlir::MulIOp>(loc, resultTypes, a, b);
}

// tosa::NegateOp
Expand Down
54 changes: 36 additions & 18 deletions mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,19 @@ func @test_simple_f16(%arg0: tensor<1xf16>) -> () {

// -----

// CHECK-LABEL: @test_simple_i16
func @test_simple_i16(%arg0: tensor<1xi16>) -> () {
// CHECK: linalg.generic
// CHECK: sext
// CHECK: sext
// CHECK: muli
%0 = "tosa.mul"(%arg0, %arg0) {shift = 0 : i32} : (tensor<1xi16>, tensor<1xi16>) -> tensor<1xi32>

return
}

// -----

// CHECK-LABEL: @test_simple_i32
func @test_simple_i32(%arg0: tensor<1xi32>) -> () {
// CHECK: linalg.generic
Expand All @@ -228,82 +241,87 @@ func @test_simple_i32(%arg0: tensor<1xi32>) -> () {
// CHECK: muli
%2 = "tosa.mul"(%arg0, %arg0) {shift = 0 : i32} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>

// CHECK: linalg.generic
// CHECK: constant 2
// CHECK: apply_scale
%3 = "tosa.mul"(%arg0, %arg0) {shift = 2 : i32} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>

// CHECK: linalg.generic
// CHECK: muli
%3 = "tosa.negate"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
%4 = "tosa.negate"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>

// CHECK: linalg.generic
// CHECK: and
%4 = "tosa.bitwise_and"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
%5 = "tosa.bitwise_and"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>

// CHECK: linalg.generic
// CHECK: or
%5 = "tosa.bitwise_or"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
%6 = "tosa.bitwise_or"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>

// CHECK: linalg.generic
// CHECK: xor
%6 = "tosa.bitwise_xor"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
%7 = "tosa.bitwise_xor"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>

// CHECK: linalg.generic
// CHECK: shift_left
%7 = "tosa.logical_left_shift"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
%8 = "tosa.logical_left_shift"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>

// CHECK: linalg.generic
// CHECK: shift_right_unsigned
%8 = "tosa.logical_right_shift"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
%9 = "tosa.logical_right_shift"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>

// CHECK: linalg.generic
// CHECK: cmpi
%9 = "tosa.greater"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
%10 = "tosa.greater"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>

// CHECK: linalg.generic
// CHECK: cmpi
%10 = "tosa.greater_equal"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
%11 = "tosa.greater_equal"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>

// CHECK: linalg.generic
// CHECK: select
%11 = "tosa.select"(%9, %0, %1) : (tensor<1xi1>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
%12 = "tosa.select"(%10, %0, %1) : (tensor<1xi1>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>

// CHECK: linalg.generic
// CHECK: cmpi
// CHECK: select
%12 = "tosa.maximum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
%13 = "tosa.maximum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>

// CHECK: linalg.generic
// CHECK: cmpi
// CHECK: select
%13 = "tosa.minimum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
%14 = "tosa.minimum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>

// CHECK: linalg.generic
// CHECK: cmpi
// CHECK: select
%14 = "tosa.clamp"(%0) {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
%15 = "tosa.clamp"(%0) {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>

// CHECK: linalg.generic
// CHECK: cmpi
// CHECK: select
%15 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
%16 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>

// CHECK: linalg.generic
// CHECK: trunci
%16 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi16>
%17 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi16>

// CHECK: linalg.generic
// CHECK: yield
%17 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi32>
%18 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi32>

// CHECK: linalg.generic
// CHECK: sexti
%18 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi64>
%19 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi64>

// CHECK: linalg.generic
// CHECK: constant 0
// CHECK: cmpi
%19 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi1>
%20 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi1>

// CHECK: linalg.generic
// CHECK: sitofp
%20 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xf32>
%21 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xf32>

return
}
Expand Down

0 comments on commit d7c44a5

Please sign in to comment.