diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td index 6e00e5852a526..b66ea38d03daf 100644 --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -460,7 +460,7 @@ def Shape_AnyOp : Shape_Op<"any", [NoSideEffect]> { let assemblyFormat = "$inputs attr-dict"; } -def Shape_AssumingAllOp : Shape_Op<"assuming_all", [NoSideEffect]> { +def Shape_AssumingAllOp : Shape_Op<"assuming_all", [Commutative, NoSideEffect]> { let summary = "Return a logical AND of all witnesses"; let description = [{ Used to simplify constraints as any single failing precondition is enough @@ -485,6 +485,8 @@ def Shape_AssumingAllOp : Shape_Op<"assuming_all", [NoSideEffect]> { let results = (outs Shape_WitnessType:$result); let assemblyFormat = "$inputs attr-dict"; + + let hasFolder = 1; } def Shape_AssumingOp : Shape_Op<"assuming", diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index 26928f272f2a8..a4a8b2de59fd9 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -145,6 +145,30 @@ static void print(OpAsmPrinter &p, AssumingOp op) { p.printOptionalAttrDict(op.getAttrs()); } +//===----------------------------------------------------------------------===// +// AssumingAllOp +//===----------------------------------------------------------------------===// +OpFoldResult AssumingAllOp::fold(ArrayRef operands) { + // Iterate in reverse to first handle all constant operands. They are + // guaranteed to be the tail of the inputs because this is commutative. + for (int idx = operands.size() - 1; idx >= 0; idx--) { + Attribute a = operands[idx]; + // Cannot fold if any inputs are not constant; + if (!a) + return nullptr; + + // We do not need to keep statically known values after handling them in + // this method. + getOperation()->eraseOperand(idx); + + // Always false if any input is statically known false + if (!a.cast().getValue()) + return a; + } + // If this is reached, all inputs were statically known passing. + return BoolAttr::get(true, getContext()); +} + //===----------------------------------------------------------------------===// // BroadcastOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir index 69c312e6dad78..646700f8d6bf4 100644 --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -split-input-file -canonicalize <%s | FileCheck %s --dump-input=fail +// RUN: mlir-opt -split-input-file -allow-unregistered-dialect -canonicalize <%s | FileCheck %s --dump-input=fail // ----- // CHECK-LABEL: func @f @@ -212,3 +212,36 @@ func @not_const(%arg0: !shape.shape) -> !shape.size { %0 = shape.get_extent %arg0, 3 return %0 : !shape.size } + +// ----- +// assuming_all with known passing witnesses can be folded +// CHECK-LABEL: func @f +func @f() { + // CHECK-NEXT: shape.const_witness true + // CHECK-NEXT: consume.witness + // CHECK-NEXT: return + %0 = shape.const_witness true + %1 = shape.const_witness true + %2 = shape.const_witness true + %3 = shape.assuming_all %0, %1, %2 + "consume.witness"(%3) : (!shape.witness) -> () + return +} + +// ----- +// assuming_all should not be removed if not all witnesses are statically passing. +// +// Additionally check that the attribute is moved to the end as this op is +// commutative. +// CHECK-LABEL: func @f +func @f() { + // CHECK-NEXT: %[[UNKNOWN:.*]] = "test.source" + // CHECK-NEXT: shape.assuming_all %[[UNKNOWN]] + // CHECK-NEXT: consume.witness + // CHECK-NEXT: return + %0 = shape.const_witness true + %1 = "test.source"() : () -> !shape.witness + %2 = shape.assuming_all %0, %1 + "consume.witness"(%2) : (!shape.witness) -> () + return +}