Skip to content

Commit

Permalink
[mlir][tosa] Fold tosa.reshape with splat values
Browse files Browse the repository at this point in the history
Folding reshapes of splats is trivial and should be canonicalized
away.

Reviewed By: NatashaKnk

Differential Revision: https://reviews.llvm.org/D132760
  • Loading branch information
rsuderman committed Aug 30, 2022
1 parent 221f785 commit 43e1fc5
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 2 deletions.
13 changes: 11 additions & 2 deletions mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -717,9 +717,18 @@ OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
auto inputTy = getInput1().getType().dyn_cast<RankedTensorType>();
auto outputTy = getType().dyn_cast<RankedTensorType>();

if (!inputTy || !outputTy || inputTy != outputTy)
if (!inputTy || !outputTy)
return {};
return getInput1();

if (inputTy == outputTy)
return getInput1();

auto operand = operands[0].dyn_cast_or_null<DenseElementsAttr>();
if (operand && outputTy.hasStaticShape() && operand.isSplat()) {
return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
}

return {};
}

OpFoldResult PadOp::fold(ArrayRef<Attribute> operands) {
Expand Down
10 changes: 10 additions & 0 deletions mlir/test/Dialect/Tosa/constant-op-fold.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,16 @@ func.func @fold_greater_splat_i32_true() -> tensor<10xi1> {

// -----

func.func @reshape_splat() -> tensor<6x5x4xi32> {
// CHECK: %[[SPLAT:.+]] = "tosa.const"() {value = dense<42> : tensor<6x5x4xi32>}
%splat = "tosa.const"() {value = dense<42> : tensor<4x5x6xi32>} : () -> tensor<4x5x6xi32>
%reshape = "tosa.reshape"(%splat) { new_shape = [6, 5, 4] } : (tensor<4x5x6xi32>) -> tensor<6x5x4xi32>
// CHECK: return %[[SPLAT]]
return %reshape : tensor<6x5x4xi32>
}

// -----

// CHECK-LABEL: @slice_splat
func.func @slice_splat() -> tensor<1x1x1xi32> {
// CHECK: %[[SLICE:.+]] = "tosa.const"() {value = dense<42> : tensor<1x1x1xi32>}
Expand Down

0 comments on commit 43e1fc5

Please sign in to comment.