|
19 | 19 | #include "mlir/IR/OpDefinition.h" |
20 | 20 | #include "mlir/IR/PatternMatch.h" |
21 | 21 | #include "mlir/IR/Value.h" |
| 22 | +#include "mlir/Interfaces/CallInterfaces.h" |
| 23 | +#include "mlir/Interfaces/FunctionImplementation.h" |
22 | 24 | #include "llvm/ADT/APInt.h" |
| 25 | +#include "llvm/ADT/STLExtras.h" |
23 | 26 | #include "llvm/ADT/SmallVector.h" |
24 | 27 | #include "llvm/Support/Casting.h" |
25 | 28 | #include "llvm/Support/LogicalResult.h" |
@@ -626,3 +629,86 @@ void GambleOp::emitCNFWithoutInversion( |
626 | 629 | // out = allSet | ~orSet |
627 | 630 | circt::addOrClauses(outVar, {allSet, -orSet}, addClause); |
628 | 631 | } |
| 632 | + |
| 633 | +//===----------------------------------------------------------------------===// |
| 634 | +// CutRewritePatternOp |
| 635 | +//===----------------------------------------------------------------------===// |
| 636 | + |
| 637 | +ParseResult CutRewritePatternOp::parse(OpAsmParser &parser, |
| 638 | + OperationState &result) { |
| 639 | + |
| 640 | + SmallVector<OpAsmParser::Argument> entryArgs; |
| 641 | + SmallVector<Type> resultTypes; |
| 642 | + SmallVector<DictionaryAttr> resultAttrs; |
| 643 | + bool isVariadic = false; |
| 644 | + |
| 645 | + if (function_interface_impl::parseFunctionSignatureWithArguments( |
| 646 | + parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes, |
| 647 | + resultAttrs)) |
| 648 | + return failure(); |
| 649 | + |
| 650 | + auto inputTypes = llvm::map_to_vector( |
| 651 | + entryArgs, [](auto &arg) -> Type { return arg.type; }); |
| 652 | + auto functionType = |
| 653 | + parser.getBuilder().getFunctionType(inputTypes, resultTypes); |
| 654 | + |
| 655 | + result.addAttribute(getFunctionTypeAttrName(result.name), |
| 656 | + TypeAttr::get(functionType)); |
| 657 | + if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) |
| 658 | + return failure(); |
| 659 | + |
| 660 | + return parser.parseRegion(*result.addRegion(), entryArgs, |
| 661 | + /*enableNameShadowing=*/false); |
| 662 | +} |
| 663 | + |
| 664 | +void CutRewritePatternOp::print(OpAsmPrinter &p) { |
| 665 | + auto functionType = getFunctionType(); |
| 666 | + call_interface_impl::printFunctionSignature( |
| 667 | + p, functionType.getInputs(), /*argAttrs=*/{}, /*isVariadic=*/false, |
| 668 | + functionType.getResults(), /*resultAttrs=*/{}, &getBody(), |
| 669 | + /*printEmptyResult=*/false); |
| 670 | + |
| 671 | + p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), |
| 672 | + {getFunctionTypeAttrName()}); |
| 673 | + |
| 674 | + p << ' '; |
| 675 | + p.printRegion(getBody(), /*printEntryBlockArgs=*/false, |
| 676 | + /*printBlockTerminators=*/true); |
| 677 | +} |
| 678 | + |
| 679 | +LogicalResult CutRewritePatternOp::verify() { |
| 680 | + auto functionType = getFunctionType(); |
| 681 | + |
| 682 | + if (functionType.getNumResults() != 1) |
| 683 | + return emitError() << "requires exactly one result"; |
| 684 | + |
| 685 | + for (auto type : functionType.getInputs()) |
| 686 | + if (!type.isInteger(1)) |
| 687 | + return emitError() << "argument type must be i1, but got " << type; |
| 688 | + |
| 689 | + for (auto type : functionType.getResults()) |
| 690 | + if (!type.isInteger(1)) |
| 691 | + return emitError() << "result type must be i1, but got " << type; |
| 692 | + |
| 693 | + // Check outputs. |
| 694 | + auto *terminator = this->getBody().front().getTerminator(); |
| 695 | + if (terminator->getOperands().size() != functionType.getNumResults()) |
| 696 | + return emitError() << "result type doesn't match with the terminator"; |
| 697 | + |
| 698 | + for (auto [lhs, rhs] : llvm::zip(terminator->getOperands().getTypes(), |
| 699 | + functionType.getResults())) |
| 700 | + if (rhs != lhs) |
| 701 | + return emitError() << rhs << " is expected but got " << lhs; |
| 702 | + |
| 703 | + auto blockArgs = this->getBody().front().getArguments(); |
| 704 | + if (blockArgs.size() != functionType.getNumInputs()) |
| 705 | + return emitError() << "operand type doesn't match with the block arg"; |
| 706 | + |
| 707 | + for (auto [blockArg, inputType] : |
| 708 | + llvm::zip(blockArgs, functionType.getInputs())) |
| 709 | + if (blockArg.getType() != inputType) |
| 710 | + return emitError() << inputType << " is expected but got " |
| 711 | + << blockArg.getType(); |
| 712 | + |
| 713 | + return success(); |
| 714 | +} |
0 commit comments