Skip to content

Commit

Permalink
[mlir][tosa] Fix start index in slice canonicalization
Browse files Browse the repository at this point in the history
The updated start indices weren't being used.

Differential Revision: https://reviews.llvm.org/D156567
  • Loading branch information
jpienaar committed Jul 28, 2023
1 parent bff6a42 commit 9a807b8
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 8 deletions.
13 changes: 6 additions & 7 deletions mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
Expand Up @@ -406,13 +406,12 @@ struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {

if (sliceStart[axis] >= 0 &&
(sliceStart[axis] + sliceSize[axis]) <= inputType.getDimSize(axis)) {
replaceWithSlice =
rewriter
.create<tosa::SliceOp>(
sliceOp.getLoc(), sliceOp.getType(), input,
rewriter.getDenseI64ArrayAttr(sliceOp.getStart()),
rewriter.getDenseI64ArrayAttr(sliceSize))
.getResult();
replaceWithSlice = rewriter
.create<tosa::SliceOp>(
sliceOp.getLoc(), sliceOp.getType(), input,
rewriter.getDenseI64ArrayAttr(sliceStart),
rewriter.getDenseI64ArrayAttr(sliceSize))
.getResult();
break;
}
sliceStart[axis] -= inputType.getDimSize(axis);
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/Tosa/canonicalize.mlir
Expand Up @@ -542,7 +542,7 @@ func.func @canonicalize_cross_concat_inputs(%arg0 : tensor<1x12x12xf32>, %arg1 :
// CHECK-LABEL: @canonicalize_concat_slice_on_non_concat_axis
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x12x12xf32>, %[[VAL_1:.*]]: tensor<1x12x12xf32>
// CHECK: %[[VAL_2:.*]] = "tosa.slice"(%[[VAL_0]]) <{size = array<i64: 1, 6, 12>, start = array<i64: 0, 0, 0>}> : (tensor<1x12x12xf32>) -> tensor<1x6x12xf32>
// CHECK: %[[VAL_3:.*]] = "tosa.slice"(%[[VAL_1]]) <{size = array<i64: 1, 3, 12>, start = array<i64: 1, 3, 12>}> : (tensor<1x12x12xf32>) -> tensor<1x3x12xf32>
// CHECK: %[[VAL_3:.*]] = "tosa.slice"(%[[VAL_1]]) <{size = array<i64: 1, 3, 12>, start = array<i64: 1, 3, 0>}> : (tensor<1x12x12xf32>) -> tensor<1x3x12xf32>
// CHECK: return %[[VAL_2]], %[[VAL_3]] : tensor<1x6x12xf32>, tensor<1x3x12xf32>
func.func @canonicalize_concat_slice_on_non_concat_axis(%arg0 : tensor<1x12x12xf32>, %arg1 : tensor<1x12x12xf32>) -> (tensor<1x6x12xf32>, tensor<1x3x12xf32>) {
%0 = "tosa.concat"(%arg0, %arg1) {axis = 2 : i64} : (tensor<1x12x12xf32>, tensor<1x12x12xf32>) -> tensor<1x12x24xf32>
Expand Down

0 comments on commit 9a807b8

Please sign in to comment.