Skip to content

Commit

Permalink
[mlir][tosa] Fix tosa.concat by inserting linalg.fill after linalg.init
Browse files Browse the repository at this point in the history
All linalg.init operations must be fed into a linalg operation before
subtensor. The inserted linalg.fill guarantees it executes correctly.

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D101848
  • Loading branch information
rsuderman committed May 4, 2021
1 parent a018bd5 commit 1f7adf8
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 3 deletions.
7 changes: 6 additions & 1 deletion mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1591,9 +1591,14 @@ struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
}
sizes[axis] = resultDimSize;

Value result = rewriter.create<linalg::InitTensorOp>(
Value init = rewriter.create<linalg::InitTensorOp>(
loc, resultType.getShape(), resultType.getElementType());

Value zeroVal = rewriter.create<ConstantOp>(
loc, rewriter.getZeroAttr(resultType.getElementType()));
Value result =
rewriter.create<linalg::FillOp>(loc, init, zeroVal).getResult(0);

for (auto arg : args) {
sizes[axis] = rewriter.create<memref::DimOp>(loc, arg, axisValue);
result = rewriter.create<SubTensorInsertOp>(loc, arg, result, offsets,
Expand Down
8 changes: 6 additions & 2 deletions mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -637,8 +637,10 @@ func @concat(%arg0: tensor<5x1xf32>, %arg1: tensor<6x1xf32>) -> () {
// CHECK: [[ARG1_AXIS:%.+]] = memref.dim %arg1, [[AXIS]]
// CHECK: [[RESULT_AXIS:%.+]] = addi [[ARG0_DIM0]], [[ARG1_AXIS]]
// CHECK: [[INIT:%.+]] = linalg.init_tensor [11, 1]
// CHECK: [[CST:%.+]] = constant 0.0
// CHECK: [[FILL:%.+]] = linalg.fill([[INIT]], [[CST]])
// CHECK: [[ARG0_DIM0:%.+]] = memref.dim %arg0, [[AXIS]]
// CHECK: [[INSERT0:%.+]] = subtensor_insert %arg0 into [[INIT]]{{\[}}[[OFFSET]], [[OFFSET]]] {{\[}}[[ARG0_DIM0]], [[ARG0_DIM1]]] {{\[}}[[STRIDE]], [[STRIDE]]]
// CHECK: [[INSERT0:%.+]] = subtensor_insert %arg0 into [[FILL]]{{\[}}[[OFFSET]], [[OFFSET]]] {{\[}}[[ARG0_DIM0]], [[ARG0_DIM1]]] {{\[}}[[STRIDE]], [[STRIDE]]]
// CHECK: [[NEW_OFFSET:%.+]] = addi [[OFFSET]], [[ARG0_DIM0]]
// CHECK: [[ARG1_DIM0:%.+]] = memref.dim %arg1, [[AXIS]]
// CHECK: [[INSERT1:%.+]] = subtensor_insert %arg1 into [[INSERT0]]{{\[}}[[NEW_OFFSET]], [[OFFSET]]] {{\[}}[[ARG1_DIM0]], [[ARG0_DIM1]]] {{\[}}[[STRIDE]], [[STRIDE]]]
Expand All @@ -654,8 +656,10 @@ func @concat(%arg0: tensor<5x1xf32>, %arg1: tensor<6x1xf32>) -> () {
// CHECK: [[ARG1_AXIS:%.+]] = memref.dim %arg0, [[AXIS]]
// CHECK: [[RESULT_AXIS:%.+]] = addi [[ARG0_DIM1]], [[ARG1_AXIS]]
// CHECK: [[INIT:%.+]] = linalg.init_tensor [5, 2]
// CHECK: [[CST:%.+]] = constant 0.0
// CHECK: [[FILL:%.+]] = linalg.fill([[INIT]], [[CST]])
// CHECK: [[ARG0_DIM1:%.+]] = memref.dim %arg0, [[AXIS]]
// CHECK: [[INSERT0:%.+]] = subtensor_insert %arg0 into [[INIT]]{{\[}}[[OFFSET]], [[OFFSET]]] {{\[}}[[ARG0_DIM0]], [[ARG0_DIM1]]] {{\[}}[[STRIDE]], [[STRIDE]]]
// CHECK: [[INSERT0:%.+]] = subtensor_insert %arg0 into [[FILL]]{{\[}}[[OFFSET]], [[OFFSET]]] {{\[}}[[ARG0_DIM0]], [[ARG0_DIM1]]] {{\[}}[[STRIDE]], [[STRIDE]]]
// CHECK: [[NEW_OFFSET:%.+]] = addi [[OFFSET]], [[ARG0_DIM1]]
// CHECK: [[ARG1_DIM1:%.+]] = memref.dim %arg0, [[AXIS]]
// CHECK: [[INSERT1:%.+]] = subtensor_insert %arg0 into [[INSERT0]]{{\[}}[[OFFSET]], [[NEW_OFFSET]]] {{\[}}[[ARG0_DIM0]], [[ARG1_DIM1]]] {{\[}}[[STRIDE]], [[STRIDE]]]
Expand Down

0 comments on commit 1f7adf8

Please sign in to comment.