Skip to content

Commit 030d0a0

Browse files
committed
[Synth] Add an operation for declarative Cut rewrite pattern
1 parent 7aa41cc commit 030d0a0

5 files changed

Lines changed: 162 additions & 0 deletions

File tree

include/circt/Dialect/Synth/SynthOps.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#ifndef CIRCT_DIALECT_SYNTH_SYNTHOPS_H
1414
#define CIRCT_DIALECT_SYNTH_SYNTHOPS_H
1515

16+
#include "circt/Dialect/Synth/SynthAttributes.h"
1617
#include "circt/Dialect/Synth/SynthDialect.h"
1718
#include "circt/Dialect/Synth/SynthOpInterfaces.h"
1819
#include "circt/Support/LLVM.h"

include/circt/Dialect/Synth/SynthOps.td

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,29 @@ def GambleOp : SymmetricThreeInputOp<"gamble", "evaluateGambleLogic"> {
305305
}];
306306
}
307307

308+
def CutRewritePatternOp : SynthOp<"cut_rewrite_pattern", [
309+
IsolatedFromAbove,
310+
SingleBlockImplicitTerminator<"YieldOp">
311+
]> {
312+
let summary = "Declarative cut rewrite pattern";
313+
314+
let arguments = (ins
315+
TypeAttrOf<FunctionType>:$function_type,
316+
MappingCostAttr:$cost
317+
);
318+
319+
let regions = (region SizedRegion<1>:$body);
320+
let hasVerifier = 1;
321+
let hasCustomAssemblyFormat = 1;
322+
}
323+
324+
def YieldOp : SynthOp<"yield",
325+
[Pure, Terminator]> {
326+
let summary = "Yield synth operations";
327+
328+
let arguments = (ins Variadic<AnyType>:$operands);
329+
let assemblyFormat = "$operands attr-dict `:` type($operands)";
330+
}
308331

309332

310333
#endif // CIRCT_DIALECT_SYNTH_SYNTHOPS_TD

lib/Dialect/Synth/SynthOps.cpp

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@
1919
#include "mlir/IR/OpDefinition.h"
2020
#include "mlir/IR/PatternMatch.h"
2121
#include "mlir/IR/Value.h"
22+
#include "mlir/Interfaces/CallInterfaces.h"
23+
#include "mlir/Interfaces/FunctionImplementation.h"
2224
#include "llvm/ADT/APInt.h"
25+
#include "llvm/ADT/STLExtras.h"
2326
#include "llvm/ADT/SmallVector.h"
2427
#include "llvm/Support/Casting.h"
2528
#include "llvm/Support/LogicalResult.h"
@@ -626,3 +629,86 @@ void GambleOp::emitCNFWithoutInversion(
626629
// out = allSet | ~orSet
627630
circt::addOrClauses(outVar, {allSet, -orSet}, addClause);
628631
}
632+
633+
//===----------------------------------------------------------------------===//
634+
// CutRewritePatternOp
635+
//===----------------------------------------------------------------------===//
636+
637+
ParseResult CutRewritePatternOp::parse(OpAsmParser &parser,
638+
OperationState &result) {
639+
640+
SmallVector<OpAsmParser::Argument> entryArgs;
641+
SmallVector<Type> resultTypes;
642+
SmallVector<DictionaryAttr> resultAttrs;
643+
bool isVariadic = false;
644+
645+
if (function_interface_impl::parseFunctionSignatureWithArguments(
646+
parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
647+
resultAttrs))
648+
return failure();
649+
650+
auto inputTypes = llvm::map_to_vector(
651+
entryArgs, [](auto &arg) -> Type { return arg.type; });
652+
auto functionType =
653+
parser.getBuilder().getFunctionType(inputTypes, resultTypes);
654+
655+
result.addAttribute(getFunctionTypeAttrName(result.name),
656+
TypeAttr::get(functionType));
657+
if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
658+
return failure();
659+
660+
return parser.parseRegion(*result.addRegion(), entryArgs,
661+
/*enableNameShadowing=*/false);
662+
}
663+
664+
void CutRewritePatternOp::print(OpAsmPrinter &p) {
665+
auto functionType = getFunctionType();
666+
call_interface_impl::printFunctionSignature(
667+
p, functionType.getInputs(), /*argAttrs=*/{}, /*isVariadic=*/false,
668+
functionType.getResults(), /*resultAttrs=*/{}, &getBody(),
669+
/*printEmptyResult=*/false);
670+
671+
p.printOptionalAttrDictWithKeyword((*this)->getAttrs(),
672+
{getFunctionTypeAttrName()});
673+
674+
p << ' ';
675+
p.printRegion(getBody(), /*printEntryBlockArgs=*/false,
676+
/*printBlockTerminators=*/true);
677+
}
678+
679+
LogicalResult CutRewritePatternOp::verify() {
680+
auto functionType = getFunctionType();
681+
682+
if (functionType.getNumResults() != 1)
683+
return emitError() << "requires exactly one result";
684+
685+
for (auto type : functionType.getInputs())
686+
if (!type.isInteger(1))
687+
return emitError() << "argument type must be i1, but got " << type;
688+
689+
for (auto type : functionType.getResults())
690+
if (!type.isInteger(1))
691+
return emitError() << "result type must be i1, but got " << type;
692+
693+
// Check outputs.
694+
auto *terminator = this->getBody().front().getTerminator();
695+
if (terminator->getOperands().size() != functionType.getNumResults())
696+
return emitError() << "result type doesn't match with the terminator";
697+
698+
for (auto [lhs, rhs] : llvm::zip(terminator->getOperands().getTypes(),
699+
functionType.getResults()))
700+
if (rhs != lhs)
701+
return emitError() << rhs << " is expected but got " << lhs;
702+
703+
auto blockArgs = this->getBody().front().getArguments();
704+
if (blockArgs.size() != functionType.getNumInputs())
705+
return emitError() << "operand type doesn't match with the block arg";
706+
707+
for (auto [blockArg, inputType] :
708+
llvm::zip(blockArgs, functionType.getInputs()))
709+
if (blockArg.getType() != inputType)
710+
return emitError() << inputType << " is expected but got "
711+
<< blockArg.getType();
712+
713+
return success();
714+
}

test/Dialect/Synth/errors.mlir

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,41 @@ hw.module @test(out result : i1) {
55
%0 = synth.choice : i1
66
hw.output %0 : i1
77
}
8+
9+
// -----
10+
11+
// expected-error @below {{argument type must be i1, but got 'i2'}}
12+
synth.cut_rewrite_pattern (%a: i2) -> i1 attributes {cost = #synth.mapping_cost<area = 1.0 : f64, arcs = [], input_caps = {}>} {
13+
%0 = comb.extract %a from 0 : (i2) -> i1
14+
synth.yield %0 : i1
15+
}
16+
17+
// -----
18+
19+
// expected-error @below {{result type must be i1, but got 'i2'}}
20+
synth.cut_rewrite_pattern (%a: i1) -> i2 attributes {cost = #synth.mapping_cost<area = 1.0 : f64, arcs = [], input_caps = {}>} {
21+
%0 = hw.constant 0 : i2
22+
synth.yield %0 : i2
23+
}
24+
25+
// -----
26+
27+
// expected-error @below {{requires exactly one result}}
28+
synth.cut_rewrite_pattern (%a: i1) -> (i1, i1) attributes {cost = #synth.mapping_cost<area = 1.0 : f64, arcs = [], input_caps = {}>} {
29+
synth.yield %a, %a : i1, i1
30+
}
31+
32+
// -----
33+
34+
// expected-error @below {{result type doesn't match with the terminator}}
35+
synth.cut_rewrite_pattern (%a: i1) -> i1 attributes {cost = #synth.mapping_cost<area = 1.0 : f64, arcs = [], input_caps = {}>} {
36+
"synth.yield"() : () -> ()
37+
}
38+
39+
// -----
40+
41+
// expected-error @below {{'i1' is expected but got 'i2'}}
42+
synth.cut_rewrite_pattern (%a: i1) -> i1 attributes {cost = #synth.mapping_cost<area = 1.0 : f64, arcs = [], input_caps = {}>} {
43+
%0 = hw.constant 0 : i2
44+
synth.yield %0 : i2
45+
}

test/Dialect/Synth/round-trip.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,17 @@ hw.module @mux_inv(in %c: i4, in %a: i4, in %b: i4) {
5454
hw.module @gamble(in %x: i1, in %y: i1, in %z: i1) {
5555
%0 = synth.gamble %x, not %y, %z : i1
5656
}
57+
58+
// CHECK-LABEL: synth.cut_rewrite_pattern
59+
// CHECK-SAME: (%{{.*}}: i1, %{{.*}}: i1, %{{.*}}: i1) -> i1
60+
synth.cut_rewrite_pattern (%a: i1, %b: i1, %c: i1) -> i1 attributes {cost = #synth.mapping_cost<area = 1.0 : f64, arcs = [], input_caps = {}>} {
61+
%0 = synth.aig.and_inv %a, not %b, %c : i1
62+
synth.yield %0 : i1
63+
}
64+
65+
// CHECK-LABEL: synth.cut_rewrite_pattern
66+
// CHECK-SAME: (%{{.*}}: i1, %{{.*}}: i1) -> i1 attributes {cost = #synth.mapping_cost<area =
67+
synth.cut_rewrite_pattern (%a: i1, %b: i1) -> i1 attributes {cost = #synth.mapping_cost<area = 1.0 : f64, arcs = [], input_caps = {}>} {
68+
%0 = synth.aig.and_inv %a, %b : i1
69+
synth.yield %0 : i1
70+
}

0 commit comments

Comments
 (0)