Skip to content

Commit 20cbb69

Browse files
committed
[mlir][emitc] Turn constant into CExpression
The `emitc.constant` op was so far left out of `emit.expression`'s as its ConstantLike trait could cause CSE to invalidate `emitc.expression` ops in two ways: - Remove the root of a constant-only expression, leaving the expression empty. - Simplify within the expression, violating the single-use requirement. The first issue was recently resolved by making `emitc.expression` isolated-from-above. The second is resolved here by limiting the single-use requirement to CExpressions with side effects, as ops with no side effects can safely be cloned as needed.
1 parent 74b7e73 commit 20cbb69

File tree

7 files changed

+123
-27
lines changed

7 files changed

+123
-27
lines changed

mlir/include/mlir/Dialect/EmitC/IR/EmitC.td

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,8 @@ def EmitC_CmpOp : EmitC_BinaryOp<"cmp", []> {
402402
let assemblyFormat = "$predicate `,` operands attr-dict `:` functional-type(operands, results)";
403403
}
404404

405-
def EmitC_ConstantOp : EmitC_Op<"constant", [ConstantLike]> {
405+
def EmitC_ConstantOp
406+
: EmitC_Op<"constant", [ConstantLike, CExpressionInterface]> {
406407
let summary = "Constant operation";
407408
let description = [{
408409
The `emitc.constant` operation produces an SSA value equal to some constant
@@ -429,6 +430,13 @@ def EmitC_ConstantOp : EmitC_Op<"constant", [ConstantLike]> {
429430

430431
let hasFolder = 1;
431432
let hasVerifier = 1;
433+
434+
let extraClassDeclaration = [{
435+
bool hasSideEffects() {
436+
// If operand is fundamental type, the operation is pure.
437+
return !isFundamentalType(getResult().getType());
438+
}
439+
}];
432440
}
433441

434442
def EmitC_DivOp : EmitC_BinaryOp<"div", []> {

mlir/lib/Dialect/EmitC/IR/EmitC.cpp

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
#include "mlir/IR/DialectImplementation.h"
1515
#include "mlir/IR/Types.h"
1616
#include "mlir/Interfaces/FunctionImplementation.h"
17+
#include "mlir/Support/LLVM.h"
1718
#include "llvm/ADT/STLExtras.h"
19+
#include "llvm/ADT/SmallVector.h"
1820
#include "llvm/ADT/TypeSwitch.h"
1921
#include "llvm/Support/Casting.h"
2022

@@ -462,12 +464,34 @@ LogicalResult ExpressionOp::verify() {
462464
return emitOpError("requires yielded type to match return type");
463465

464466
for (Operation &op : region.front().without_terminator()) {
465-
if (!isa<emitc::CExpressionInterface>(op))
467+
auto expressionInterface = dyn_cast<emitc::CExpressionInterface>(op);
468+
if (!expressionInterface)
466469
return emitOpError("contains an unsupported operation");
467470
if (op.getNumResults() != 1)
468471
return emitOpError("requires exactly one result for each operation");
469-
if (!op.getResult(0).hasOneUse())
470-
return emitOpError("requires exactly one use for each operation");
472+
Value result = op.getResult(0);
473+
if (result.use_empty())
474+
return emitOpError("contains an unused operation");
475+
}
476+
477+
// Make sure any operation with side effect is only reachable once from
478+
// the root op, otherwise emission will be replicating side effects.
479+
SmallPtrSet<Operation *, 16> visited;
480+
SmallVector<Operation *> worklist;
481+
worklist.push_back(rootOp);
482+
while (!worklist.empty()) {
483+
Operation *op = worklist.back();
484+
worklist.pop_back();
485+
if (visited.contains(op)) {
486+
if (cast<CExpressionInterface>(op).hasSideEffects())
487+
return emitOpError(
488+
"requires exactly one use for operations with side effects");
489+
}
490+
visited.insert(op);
491+
for (Value operand : op->getOperands())
492+
if (Operation *def = operand.getDefiningOp()) {
493+
worklist.push_back(def);
494+
}
471495
}
472496

473497
return success();

mlir/lib/Target/Cpp/TranslateToCpp.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ static FailureOr<int> getOperatorPrecedence(Operation *operation) {
9797
return op->emitError("unsupported cmp predicate");
9898
})
9999
.Case<emitc::ConditionalOp>([&](auto op) { return 2; })
100+
.Case<emitc::ConstantOp>([&](auto op) { return 17; })
100101
.Case<emitc::DivOp>([&](auto op) { return 13; })
101102
.Case<emitc::LoadOp>([&](auto op) { return 16; })
102103
.Case<emitc::LogicalAndOp>([&](auto op) { return 4; })
@@ -267,8 +268,14 @@ struct CppEmitter {
267268
Operation *def = value.getDefiningOp();
268269
if (!def)
269270
return false;
271+
return isPartOfCurrentExpression(def);
272+
}
273+
274+
/// Determine whether given operation is part of the expression potentially
275+
/// being emitted.
276+
bool isPartOfCurrentExpression(Operation *def) {
270277
auto operandExpression = dyn_cast<ExpressionOp>(def->getParentOp());
271-
return operandExpression == emittedExpression;
278+
return operandExpression && operandExpression == emittedExpression;
272279
};
273280

274281
// Resets the value counter to 0.
@@ -408,6 +415,9 @@ static LogicalResult printOperation(CppEmitter &emitter,
408415
Operation *operation = constantOp.getOperation();
409416
Attribute value = constantOp.getValue();
410417

418+
if (emitter.isPartOfCurrentExpression(operation))
419+
return emitter.emitAttribute(operation->getLoc(), value);
420+
411421
return printConstantOp(emitter, operation, value);
412422
}
413423

mlir/test/Dialect/EmitC/form-expressions.mlir

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
// CHECK-LABEL: func.func @single_expression(
44
// CHECK-SAME: %[[VAL_0:.*]]: i32, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32) -> i1 {
5-
// CHECK: %[[VAL_4:.*]] = "emitc.constant"() <{value = 42 : i32}> : () -> i32
6-
// CHECK: %[[VAL_5:.*]] = emitc.expression %[[VAL_3]], %[[VAL_2]], %[[VAL_0]], %[[VAL_4]] : (i32, i32, i32, i32) -> i1 {
5+
// CHECK: %[[VAL_5:.*]] = emitc.expression %[[VAL_3]], %[[VAL_2]], %[[VAL_0]] : (i32, i32, i32) -> i1 {
6+
// CHECK: %[[VAL_4:.*]] = "emitc.constant"() <{value = 42 : i32}> : () -> i32
77
// CHECK: %[[VAL_6:.*]] = mul %[[VAL_0]], %[[VAL_4]] : (i32, i32) -> i32
88
// CHECK: %[[VAL_7:.*]] = sub %[[VAL_6]], %[[VAL_2]] : (i32, i32) -> i32
99
// CHECK: %[[VAL_8:.*]] = cmp lt, %[[VAL_7]], %[[VAL_3]] : (i32, i32) -> i1
@@ -131,7 +131,10 @@ func.func @single_result_requirement() -> (i32, i32) {
131131
// CHECK-LABEL: func.func @expression_with_load(
132132
// CHECK-SAME: %[[VAL_0:.*]]: i32,
133133
// CHECK-SAME: %[[VAL_1:.*]]: !emitc.ptr<i32>) -> i1 {
134-
// CHECK: %[[VAL_2:.*]] = "emitc.constant"() <{value = 0 : i64}> : () -> i64
134+
// CHECK: %[[VAL_2:.*]] = emitc.expression : () -> i64 {
135+
// CHECK: %[[VAL_C:.*]] = "emitc.constant"() <{value = 0 : i64}> : () -> i64
136+
// CHECK: yield %[[VAL_C]] : i64
137+
// CHECK: }
135138
// CHECK: %[[VAL_3:.*]] = "emitc.variable"() <{value = #emitc.opaque<"42">}> : () -> !emitc.lvalue<i32>
136139
// CHECK: %[[VAL_4:.*]] = emitc.expression %[[VAL_3]] : (!emitc.lvalue<i32>) -> i32 {
137140
// CHECK: %[[VAL_5:.*]] = load %[[VAL_3]] : <i32>
@@ -162,24 +165,46 @@ func.func @expression_with_load(%arg0: i32, %arg1: !emitc.ptr<i32>) -> i1 {
162165
}
163166

164167
// CHECK-LABEL: func.func @opaque_type_expression(%arg0: i32, %arg1: !emitc.opaque<"T0">, %arg2: i32) -> i1 {
165-
// CHECK: %0 = "emitc.constant"() <{value = 42 : i32}> : () -> i32
166-
// CHECK: %1 = emitc.expression %arg1, %arg0, %0 : (!emitc.opaque<"T0">, i32, i32) -> i32 {
167-
// CHECK: %3 = mul %arg0, %0 : (i32, i32) -> i32
168-
// CHECK: %4 = sub %3, %arg1 : (i32, !emitc.opaque<"T0">) -> i32
168+
// CHECK: %0 = emitc.expression : () -> !emitc.opaque<"T1"> {
169+
// CHECK: %4 = "emitc.constant"() <{value = #emitc.opaque<"V">}> : () -> !emitc.opaque<"T1">
170+
// CHECK: yield %4 : !emitc.opaque<"T1">
171+
// CHECK: }
172+
// CHECK: %1 = emitc.expression %arg0, %0 : (i32, !emitc.opaque<"T1">) -> i32 {
173+
// CHECK: %4 = mul %arg0, %0 : (i32, !emitc.opaque<"T1">) -> i32
174+
// CHECK: yield %4 : i32
175+
// CHECK: }
176+
// CHECK: %2 = emitc.expression %1, %arg1 : (i32, !emitc.opaque<"T0">) -> i32 {
177+
// CHECK: %4 = sub %1, %arg1 : (i32, !emitc.opaque<"T0">) -> i32
169178
// CHECK: yield %4 : i32
170179
// CHECK: }
171-
// CHECK: %2 = emitc.expression %1, %arg2 : (i32, i32) -> i1 {
172-
// CHECK: %3 = cmp lt, %1, %arg2 : (i32, i32) -> i1
173-
// CHECK: yield %3 : i1
180+
// CHECK: %3 = emitc.expression %2, %arg2 : (i32, i32) -> i1 {
181+
// CHECK: %4 = cmp lt, %2, %arg2 : (i32, i32) -> i1
182+
// CHECK: yield %4 : i1
174183
// CHECK: }
175-
// CHECK: return %2 : i1
184+
// CHECK: return %3 : i1
176185
// CHECK: }
177186

178187

179188
func.func @opaque_type_expression(%arg0: i32, %arg1: !emitc.opaque<"T0">, %arg2: i32) -> i1 {
180-
%c42 = "emitc.constant"(){value = 42 : i32} : () -> i32
181-
%a = emitc.mul %arg0, %c42 : (i32, i32) -> i32
189+
%c42 = "emitc.constant"(){value = #emitc.opaque<"V">} : () -> !emitc.opaque<"T1">
190+
%a = emitc.mul %arg0, %c42 : (i32, !emitc.opaque<"T1">) -> i32
182191
%b = emitc.sub %a, %arg1 : (i32, !emitc.opaque<"T0">) -> i32
183192
%c = emitc.cmp lt, %b, %arg2 :(i32, i32) -> i1
184193
return %c : i1
185194
}
195+
196+
// CHECK-LABEL: func.func @expression_with_constant(
197+
// CHECK-SAME: %[[VAL_0:.*]]: i32) -> i32 {
198+
// CHECK: %[[VAL_1:.*]] = emitc.expression %[[VAL_0]] : (i32) -> i32 {
199+
// CHECK: %[[VAL_2:.*]] = "emitc.constant"() <{value = 42 : i32}> : () -> i32
200+
// CHECK: %[[VAL_3:.*]] = mul %[[VAL_0]], %[[VAL_2]] : (i32, i32) -> i32
201+
// CHECK: yield %[[VAL_3]] : i32
202+
// CHECK: }
203+
// CHECK: return %[[VAL_1]] : i32
204+
// CHECK: }
205+
206+
func.func @expression_with_constant(%arg0: i32) -> i32 {
207+
%c42 = "emitc.constant"(){value = 42 : i32} : () -> i32
208+
%a = emitc.mul %arg0, %c42 : (i32, i32) -> i32
209+
return %a : i32
210+
}

mlir/test/Dialect/EmitC/invalid_ops.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ func.func @test_expression_illegal_op(%arg0 : i1) -> i32 {
311311
// -----
312312

313313
func.func @test_expression_no_use(%arg0: i32, %arg1: i32) -> i32 {
314-
// expected-error @+1 {{'emitc.expression' op requires exactly one use for each operation}}
314+
// expected-error @+1 {{'emitc.expression' op contains an unused operation}}
315315
%r = emitc.expression %arg0, %arg1 : (i32, i32) -> i32 {
316316
%a = emitc.add %arg0, %arg1 : (i32, i32) -> i32
317317
%b = emitc.rem %arg0, %arg1 : (i32, i32) -> i32
@@ -323,12 +323,12 @@ func.func @test_expression_no_use(%arg0: i32, %arg1: i32) -> i32 {
323323
// -----
324324

325325
func.func @test_expression_multiple_uses(%arg0: i32, %arg1: i32) -> i32 {
326-
// expected-error @+1 {{'emitc.expression' op requires exactly one use for each operation}}
326+
// expected-error @+1 {{'emitc.expression' op requires exactly one use for operations with side effects}}
327327
%r = emitc.expression %arg0, %arg1 : (i32, i32) -> i32 {
328-
%a = emitc.rem %arg0, %arg1 : (i32, i32) -> i32
328+
%a = emitc.call_opaque "foo"(%arg0, %arg1) : (i32, i32) -> i32
329329
%b = emitc.add %a, %arg0 : (i32, i32) -> i32
330-
%c = emitc.mul %arg1, %a : (i32, i32) -> i32
331-
emitc.yield %a : i32
330+
%c = emitc.mul %b, %a : (i32, i32) -> i32
331+
emitc.yield %c : i32
332332
}
333333
return %r : i32
334334
}

mlir/test/Dialect/EmitC/ops.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,16 @@ func.func @test_expression(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: f32, %arg4
203203
return %r : i32
204204
}
205205

206+
func.func @test_expression_multiple_uses(%arg0: i32, %arg1: i32) -> i32 {
207+
%r = emitc.expression %arg0, %arg1 : (i32, i32) -> i32 {
208+
%a = emitc.rem %arg0, %arg1 : (i32, i32) -> i32
209+
%b = emitc.add %a, %arg0 : (i32, i32) -> i32
210+
%c = emitc.mul %b, %a : (i32, i32) -> i32
211+
emitc.yield %c : i32
212+
}
213+
return %r : i32
214+
}
215+
206216
func.func @test_for(%arg0 : index, %arg1 : index, %arg2 : index) {
207217
emitc.for %i0 = %arg0 to %arg1 step %arg2 {
208218
%0 = emitc.call_opaque "func_const"(%i0) : (index) -> i32

mlir/test/Target/Cpp/expressions.mlir

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,25 @@
11
// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s -check-prefix=CPP-DEFAULT
22
// RUN: mlir-translate -mlir-to-cpp -declare-variables-at-top %s | FileCheck %s -check-prefix=CPP-DECLTOP
33

4+
// CPP-DEFAULT: bool single_expression(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) {
5+
// CPP-DEFAULT-NEXT: return [[VAL_1]] * 42 - [[VAL_2]] < [[VAL_3]];
6+
// CPP-DEFAULT-NEXT: }
7+
8+
// CPP-DECLTOP: bool single_expression(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) {
9+
// CPP-DECLTOP-NEXT: return [[VAL_1]] * 42 - [[VAL_2]] < [[VAL_3]];
10+
// CPP-DECLTOP-NEXT: }
11+
12+
func.func @single_expression(%arg0: i32, %arg1: i32, %arg2: i32) -> i1 {
13+
%e = emitc.expression %arg0, %arg1, %arg2 : (i32, i32, i32) -> i1 {
14+
%c42 = "emitc.constant"(){value = 42 : i32} : () -> i32
15+
%a = emitc.mul %arg0, %c42 : (i32, i32) -> i32
16+
%b = emitc.sub %a, %arg1 : (i32, i32) -> i32
17+
%c = emitc.cmp lt, %b, %arg2 :(i32, i32) -> i1
18+
emitc.yield %c : i1
19+
}
20+
return %e : i1
21+
}
22+
423
// CPP-DEFAULT: int32_t single_use(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]], int32_t [[VAL_4:v[0-9]+]]) {
524
// CPP-DEFAULT-NEXT: bool [[VAL_5:v[0-9]+]] = bar([[VAL_1]] * M_PI, [[VAL_3]]) - [[VAL_4]] < [[VAL_2]];
625
// CPP-DEFAULT-NEXT: int32_t [[VAL_6:v[0-9]+]];
@@ -185,7 +204,7 @@ func.func @user_with_expression_trait(%arg0: i32, %arg1: i32, %arg2: i32) -> i32
185204
}
186205

187206
// CPP-DEFAULT: int32_t multiple_uses(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]], int32_t [[VAL_4:v[0-9]+]]) {
188-
// CPP-DEFAULT-NEXT: bool [[VAL_5:v[0-9]+]] = bar([[VAL_1]] * [[VAL_2]], [[VAL_3]]) - [[VAL_4]] < [[VAL_2]];
207+
// CPP-DEFAULT-NEXT: bool [[VAL_5:v[0-9]+]] = bar([[VAL_1]] * [[VAL_2]], [[VAL_3]]) - [[VAL_1]] * [[VAL_2]] < [[VAL_4]];
189208
// CPP-DEFAULT-NEXT: int32_t [[VAL_6:v[0-9]+]];
190209
// CPP-DEFAULT-NEXT: if ([[VAL_5]]) {
191210
// CPP-DEFAULT-NEXT: [[VAL_6]] = [[VAL_1]];
@@ -203,7 +222,7 @@ func.func @user_with_expression_trait(%arg0: i32, %arg1: i32, %arg2: i32) -> i32
203222
// CPP-DECLTOP-NEXT: int32_t [[VAL_6:v[0-9]+]];
204223
// CPP-DECLTOP-NEXT: bool [[VAL_7:v[0-9]+]];
205224
// CPP-DECLTOP-NEXT: int32_t [[VAL_8:v[0-9]+]];
206-
// CPP-DECLTOP-NEXT: [[VAL_5]] = bar([[VAL_1]] * [[VAL_2]], [[VAL_3]]) - [[VAL_4]] < [[VAL_2]];
225+
// CPP-DECLTOP-NEXT: [[VAL_5]] = bar([[VAL_1]] * [[VAL_2]], [[VAL_3]]) - [[VAL_1]] * [[VAL_2]] < [[VAL_4]];
207226
// CPP-DECLTOP-NEXT: ;
208227
// CPP-DECLTOP-NEXT: if ([[VAL_5]]) {
209228
// CPP-DECLTOP-NEXT: [[VAL_6]] = [[VAL_1]];
@@ -220,8 +239,8 @@ func.func @multiple_uses(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) -> i32
220239
%e = emitc.expression %arg0, %arg1, %arg2, %arg3 : (i32, i32, i32, i32) -> i1 {
221240
%a = emitc.mul %arg0, %arg1 : (i32, i32) -> i32
222241
%b = emitc.call_opaque "bar" (%a, %arg2) : (i32, i32) -> (i32)
223-
%c = emitc.sub %b, %arg3 : (i32, i32) -> i32
224-
%d = emitc.cmp lt, %c, %arg1 :(i32, i32) -> i1
242+
%c = emitc.sub %b, %a : (i32, i32) -> i32
243+
%d = emitc.cmp lt, %c, %arg3 :(i32, i32) -> i1
225244
emitc.yield %d : i1
226245
}
227246
%v = "emitc.variable"(){value = #emitc.opaque<"">} : () -> !emitc.lvalue<i32>

0 commit comments

Comments
 (0)