9
9
#include " mlir/Dialect/EmitC/Transforms/Transforms.h"
10
10
#include " mlir/Dialect/EmitC/IR/EmitC.h"
11
11
#include " mlir/IR/IRMapping.h"
12
+ #include " mlir/IR/Location.h"
12
13
#include " mlir/IR/PatternMatch.h"
14
+ #include " llvm/ADT/STLExtras.h"
13
15
14
16
namespace mlir {
15
17
namespace emitc {
@@ -24,20 +26,24 @@ ExpressionOp createExpression(Operation *op, OpBuilder &builder) {
24
26
Location loc = op->getLoc ();
25
27
26
28
builder.setInsertionPointAfter (op);
27
- auto expressionOp = emitc::ExpressionOp::create (builder, loc, resultType);
29
+ auto expressionOp =
30
+ emitc::ExpressionOp::create (builder, loc, resultType, op->getOperands ());
28
31
29
32
// Replace all op's uses with the new expression's result.
30
33
result.replaceAllUsesWith (expressionOp.getResult ());
31
34
32
- // Create an op to yield op's value.
33
- Region ®ion = expressionOp.getRegion ();
34
- Block &block = region.emplaceBlock ();
35
+ Block &block = expressionOp.createBody ();
36
+ IRMapping mapper;
37
+ for (auto [operand, arg] :
38
+ llvm::zip (expressionOp.getOperands (), block.getArguments ()))
39
+ mapper.map (operand, arg);
35
40
builder.setInsertionPointToEnd (&block);
36
- auto yieldOp = emitc::YieldOp::create (builder, loc, result);
37
41
38
- // Move op into the new expression.
39
- op->moveBefore (yieldOp );
42
+ Operation *rootOp = builder. clone (*op, mapper);
43
+ op->erase ( );
40
44
45
+ // Create an op to yield op's value.
46
+ emitc::YieldOp::create (builder, loc, rootOp->getResults ()[0 ]);
41
47
return expressionOp;
42
48
}
43
49
@@ -53,51 +59,93 @@ struct FoldExpressionOp : public OpRewritePattern<ExpressionOp> {
53
59
using OpRewritePattern<ExpressionOp>::OpRewritePattern;
54
60
LogicalResult matchAndRewrite (ExpressionOp expressionOp,
55
61
PatternRewriter &rewriter) const override {
56
- bool anythingFolded = false ;
57
- for (Operation &op : llvm::make_early_inc_range (
58
- expressionOp.getBody ()->without_terminator ())) {
59
- // Don't fold expressions whose result value has its address taken.
60
- auto applyOp = dyn_cast<emitc::ApplyOp>(op);
61
- if (applyOp && applyOp.getApplicableOperator () == " &" )
62
- continue ;
63
-
64
- for (Value operand : op.getOperands ()) {
65
- auto usedExpression = operand.getDefiningOp <ExpressionOp>();
66
- if (!usedExpression)
67
- continue ;
68
-
69
- // Don't fold expressions with multiple users: assume any
70
- // re-materialization was done separately.
71
- if (!usedExpression.getResult ().hasOneUse ())
72
- continue ;
73
-
74
- // Don't fold expressions with side effects.
75
- if (usedExpression.hasSideEffects ())
76
- continue ;
77
-
78
- // Fold the used expression into this expression by cloning all
79
- // instructions in the used expression just before the operation using
80
- // its value.
81
- rewriter.setInsertionPoint (&op);
82
- IRMapping mapper;
83
- for (Operation &opToClone :
84
- usedExpression.getBody ()->without_terminator ()) {
85
- Operation *clone = rewriter.clone (opToClone, mapper);
86
- mapper.map (&opToClone, clone);
87
- }
88
-
89
- Operation *expressionRoot = usedExpression.getRootOp ();
90
- Operation *clonedExpressionRootOp = mapper.lookup (expressionRoot);
91
- assert (clonedExpressionRootOp &&
92
- " Expected cloned expression root to be in mapper" );
93
- assert (clonedExpressionRootOp->getNumResults () == 1 &&
94
- " Expected cloned root to have a single result" );
95
-
96
- rewriter.replaceOp (usedExpression, clonedExpressionRootOp);
97
- anythingFolded = true ;
98
- }
62
+ Block *expressionBody = expressionOp.getBody ();
63
+ ExpressionOp usedExpression;
64
+ SetVector<Value> foldedOperands;
65
+
66
+ auto takesItsOperandsAddress = [](Operation *user) {
67
+ auto applyOp = dyn_cast<emitc::ApplyOp>(user);
68
+ return applyOp && applyOp.getApplicableOperator () == " &" ;
69
+ };
70
+
71
+ // Select as expression to fold the first operand expression that
72
+ // - doesn't have its result value's address taken,
73
+ // - has a single user: assume any re-materialization was done separately,
74
+ // - has no side effects,
75
+ // and save all other operands to be used later as operands in the folded
76
+ // expression.
77
+ for (auto [operand, arg] : llvm::zip (expressionOp.getOperands (),
78
+ expressionBody->getArguments ())) {
79
+ ExpressionOp operandExpression = operand.getDefiningOp <ExpressionOp>();
80
+ if (usedExpression || !operandExpression ||
81
+ llvm::any_of (arg.getUsers (), takesItsOperandsAddress) ||
82
+ !operandExpression.getResult ().hasOneUse () ||
83
+ operandExpression.hasSideEffects ())
84
+ foldedOperands.insert (operand);
85
+ else
86
+ usedExpression = operandExpression;
99
87
}
100
- return anythingFolded ? success () : failure ();
88
+
89
+ // If no operand expression was selected, bail out.
90
+ if (!usedExpression)
91
+ return failure ();
92
+
93
+ // Collect additional operands from the folded expression.
94
+ for (Value operand : usedExpression.getOperands ())
95
+ foldedOperands.insert (operand);
96
+
97
+ // Create a new expression to hold the folding result.
98
+ rewriter.setInsertionPointAfter (expressionOp);
99
+ auto foldedExpression = emitc::ExpressionOp::create (
100
+ rewriter, expressionOp.getLoc (), expressionOp.getResult ().getType (),
101
+ foldedOperands.getArrayRef (), expressionOp.getDoNotInline ());
102
+ Block &foldedExpressionBody = foldedExpression.createBody ();
103
+
104
+ // Map each operand of the new expression to its matching block argument.
105
+ IRMapping mapper;
106
+ for (auto [operand, arg] : llvm::zip (foldedExpression.getOperands (),
107
+ foldedExpressionBody.getArguments ()))
108
+ mapper.map (operand, arg);
109
+
110
+ // Prepare to fold the used expression and the matched expression into the
111
+ // newly created folded expression.
112
+ auto foldExpression = [&rewriter, &mapper](ExpressionOp expressionToFold,
113
+ bool withTerminator) {
114
+ Block *expressionToFoldBody = expressionToFold.getBody ();
115
+ for (auto [operand, arg] :
116
+ llvm::zip (expressionToFold.getOperands (),
117
+ expressionToFoldBody->getArguments ())) {
118
+ mapper.map (arg, mapper.lookup (operand));
119
+ }
120
+
121
+ for (Operation &opToClone : expressionToFoldBody->without_terminator ())
122
+ rewriter.clone (opToClone, mapper);
123
+
124
+ if (withTerminator)
125
+ rewriter.clone (*expressionToFoldBody->getTerminator (), mapper);
126
+ };
127
+ rewriter.setInsertionPointToStart (&foldedExpressionBody);
128
+
129
+ // First, fold the used expression into the new expression and map its
130
+ // result to the clone of its root operation within the new expression.
131
+ foldExpression (usedExpression, /* withTerminator=*/ false );
132
+ Operation *expressionRoot = usedExpression.getRootOp ();
133
+ Operation *clonedExpressionRootOp = mapper.lookup (expressionRoot);
134
+ assert (clonedExpressionRootOp &&
135
+ " Expected cloned expression root to be in mapper" );
136
+ assert (clonedExpressionRootOp->getNumResults () == 1 &&
137
+ " Expected cloned root to have a single result" );
138
+ mapper.map (usedExpression.getResult (),
139
+ clonedExpressionRootOp->getResults ()[0 ]);
140
+
141
+ // Now fold the matched expression into the new expression.
142
+ foldExpression (expressionOp, /* withTerminator=*/ true );
143
+
144
+ // Complete the rewrite.
145
+ rewriter.replaceOp (expressionOp, foldedExpression);
146
+ rewriter.eraseOp (usedExpression);
147
+
148
+ return success ();
101
149
}
102
150
};
103
151
0 commit comments