Skip to content

Commit

Permalink
[mlir][tosa] Add constant folding for tosa.slice
Browse files Browse the repository at this point in the history
If the input to a tosa.slice operation is a splat we can just replace with
another splat. If the result is a single element, replacing with a splat
is universally useful.

Reviewed By: NatashaKnk

Differential Revision: https://reviews.llvm.org/D132499
  • Loading branch information
rsuderman committed Aug 24, 2022
1 parent ecde303 commit 89d5551
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 2 deletions.
23 changes: 21 additions & 2 deletions mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -555,11 +555,30 @@ OpFoldResult SliceOp::fold(ArrayRef<Attribute> operands) {
auto inputTy = getInput().getType().dyn_cast<RankedTensorType>();
auto outputTy = getType().dyn_cast<RankedTensorType>();

if (!inputTy || !outputTy || inputTy != outputTy)
if (!inputTy || !outputTy)
return {};
if (inputTy.hasStaticShape())

if (inputTy == outputTy && inputTy.hasStaticShape())
return getInput();

if (!operands[0])
return {};

auto operand = operands[0].cast<ElementsAttr>();
if (operand.isSplat() && outputTy.hasStaticShape()) {
return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
}

if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
outputTy.getNumElements() == 1) {
llvm::SmallVector<uint64_t> indices;
for (auto val : getStart()) {
indices.push_back(val.cast<IntegerAttr>().getInt());
}
auto value = operand.getValues<Attribute>()[indices];
return SplatElementsAttr::get(outputTy, value);
}

return {};
}

Expand Down
22 changes: 22 additions & 0 deletions mlir/test/Dialect/Tosa/constant-op-fold.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,25 @@ func.func @fold_add_splat_f32() -> tensor<10xf32> {
// CHECK: return %[[THREE]]
return %add : tensor<10xf32>
}

// -----

// CHECK-LABEL: @slice_splat
func.func @slice_splat() -> tensor<1x1x1xi32> {
// CHECK: %[[SLICE:.+]] = "tosa.const"() {value = dense<42> : tensor<1x1x1xi32>}
%splat = "tosa.const"() {value = dense<42> : tensor<4x5x6xi32>} : () -> tensor<4x5x6xi32>
%slice = "tosa.slice"(%splat) { size = [1, 1, 1], start = [1, 2, 3] } : (tensor<4x5x6xi32>) -> tensor<1x1x1xi32>
// CHECK: return %[[SLICE]]
return %slice : tensor<1x1x1xi32>
}

// -----

// CHECK-LABEL: @slice_singleton
func.func @slice_singleton() -> tensor<1x1xi32> {
%splat = "tosa.const"() {value = dense<[[0, 1, 2], [3, 4, 5], [6, 7 ,8]]> : tensor<3x3xi32>} : () -> tensor<3x3xi32>
// CHECK: %[[SLICE:.+]] = "tosa.const"() {value = dense<4> : tensor<1x1xi32>}
%slice = "tosa.slice"(%splat) { size = [1, 1], start = [1, 1] } : (tensor<3x3xi32>) -> tensor<1x1xi32>
// CHECK: return %[[SLICE]]
return %slice : tensor<1x1xi32>
}

0 comments on commit 89d5551

Please sign in to comment.