diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td index fb7a108b39fc0..721f9f6b320ad 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td @@ -402,7 +402,8 @@ def EmitC_CmpOp : EmitC_BinaryOp<"cmp", []> { let assemblyFormat = "$predicate `,` operands attr-dict `:` functional-type(operands, results)"; } -def EmitC_ConstantOp : EmitC_Op<"constant", [ConstantLike]> { +def EmitC_ConstantOp + : EmitC_Op<"constant", [ConstantLike, CExpressionInterface]> { let summary = "Constant operation"; let description = [{ The `emitc.constant` operation produces an SSA value equal to some constant @@ -429,6 +430,13 @@ def EmitC_ConstantOp : EmitC_Op<"constant", [ConstantLike]> { let hasFolder = 1; let hasVerifier = 1; + + let extraClassDeclaration = [{ + bool hasSideEffects() { + // If operand is fundamental type, the operation is pure. + return !isFundamentalType(getResult().getType()); + } + }]; } def EmitC_DivOp : EmitC_BinaryOp<"div", []> { @@ -466,8 +474,9 @@ def EmitC_ExpressionOp its single-basic-block region. The operation doesn't take any arguments. As the operation is to be emitted as a C expression, the operations within - its body must form a single Def-Use tree of emitc ops whose result is - yielded by a terminating `emitc.yield`. + its body must form a single Def-Use tree, or a DAG trivially expandable to + one, i.e. a DAG where each operation with side effects is only reachable + once from the expression root. Example: diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index 00ce3b59bf870..a73470cdf76c5 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -14,7 +14,9 @@ #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/Types.h" #include "mlir/Interfaces/FunctionImplementation.h" +#include "mlir/Support/LLVM.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Casting.h" @@ -462,12 +464,34 @@ LogicalResult ExpressionOp::verify() { return emitOpError("requires yielded type to match return type"); for (Operation &op : region.front().without_terminator()) { - if (!isa(op)) + auto expressionInterface = dyn_cast(op); + if (!expressionInterface) return emitOpError("contains an unsupported operation"); if (op.getNumResults() != 1) return emitOpError("requires exactly one result for each operation"); - if (!op.getResult(0).hasOneUse()) - return emitOpError("requires exactly one use for each operation"); + Value result = op.getResult(0); + if (result.use_empty()) + return emitOpError("contains an unused operation"); + } + + // Make sure any operation with side effect is only reachable once from + // the root op, otherwise emission will be replicating side effects. + SmallPtrSet visited; + SmallVector worklist; + worklist.push_back(rootOp); + while (!worklist.empty()) { + Operation *op = worklist.back(); + worklist.pop_back(); + if (visited.contains(op)) { + if (cast(op).hasSideEffects()) + return emitOpError( + "requires exactly one use for operations with side effects"); + } + visited.insert(op); + for (Value operand : op->getOperands()) + if (Operation *def = operand.getDefiningOp()) { + worklist.push_back(def); + } } return success(); diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index 570f38c60020b..a5bd80e9d6b8b 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -97,6 +97,7 @@ static FailureOr getOperatorPrecedence(Operation *operation) { return op->emitError("unsupported cmp predicate"); }) .Case([&](auto op) { return 2; }) + .Case([&](auto op) { return 17; }) .Case([&](auto op) { return 13; }) .Case([&](auto op) { return 16; }) .Case([&](auto op) { return 4; }) @@ -267,8 +268,14 @@ struct CppEmitter { Operation *def = value.getDefiningOp(); if (!def) return false; + return isPartOfCurrentExpression(def); + } + + /// Determine whether given operation is part of the expression potentially + /// being emitted. + bool isPartOfCurrentExpression(Operation *def) { auto operandExpression = dyn_cast(def->getParentOp()); - return operandExpression == emittedExpression; + return operandExpression && operandExpression == emittedExpression; }; // Resets the value counter to 0. @@ -408,6 +415,9 @@ static LogicalResult printOperation(CppEmitter &emitter, Operation *operation = constantOp.getOperation(); Attribute value = constantOp.getValue(); + if (emitter.isPartOfCurrentExpression(operation)) + return emitter.emitAttribute(operation->getLoc(), value); + return printConstantOp(emitter, operation, value); } diff --git a/mlir/test/Dialect/EmitC/form-expressions.mlir b/mlir/test/Dialect/EmitC/form-expressions.mlir index 67cd6fddba638..7b6723989e260 100644 --- a/mlir/test/Dialect/EmitC/form-expressions.mlir +++ b/mlir/test/Dialect/EmitC/form-expressions.mlir @@ -2,8 +2,8 @@ // CHECK-LABEL: func.func @single_expression( // CHECK-SAME: %[[VAL_0:.*]]: i32, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32) -> i1 { -// CHECK: %[[VAL_4:.*]] = "emitc.constant"() <{value = 42 : i32}> : () -> i32 -// CHECK: %[[VAL_5:.*]] = emitc.expression %[[VAL_3]], %[[VAL_2]], %[[VAL_0]], %[[VAL_4]] : (i32, i32, i32, i32) -> i1 { +// CHECK: %[[VAL_5:.*]] = emitc.expression %[[VAL_3]], %[[VAL_2]], %[[VAL_0]] : (i32, i32, i32) -> i1 { +// CHECK: %[[VAL_4:.*]] = "emitc.constant"() <{value = 42 : i32}> : () -> i32 // CHECK: %[[VAL_6:.*]] = mul %[[VAL_0]], %[[VAL_4]] : (i32, i32) -> i32 // CHECK: %[[VAL_7:.*]] = sub %[[VAL_6]], %[[VAL_2]] : (i32, i32) -> i32 // CHECK: %[[VAL_8:.*]] = cmp lt, %[[VAL_7]], %[[VAL_3]] : (i32, i32) -> i1 @@ -131,7 +131,10 @@ func.func @single_result_requirement() -> (i32, i32) { // CHECK-LABEL: func.func @expression_with_load( // CHECK-SAME: %[[VAL_0:.*]]: i32, // CHECK-SAME: %[[VAL_1:.*]]: !emitc.ptr) -> i1 { -// CHECK: %[[VAL_2:.*]] = "emitc.constant"() <{value = 0 : i64}> : () -> i64 +// CHECK: %[[VAL_2:.*]] = emitc.expression : () -> i64 { +// CHECK: %[[VAL_C:.*]] = "emitc.constant"() <{value = 0 : i64}> : () -> i64 +// CHECK: yield %[[VAL_C]] : i64 +// CHECK: } // CHECK: %[[VAL_3:.*]] = "emitc.variable"() <{value = #emitc.opaque<"42">}> : () -> !emitc.lvalue // CHECK: %[[VAL_4:.*]] = emitc.expression %[[VAL_3]] : (!emitc.lvalue) -> i32 { // CHECK: %[[VAL_5:.*]] = load %[[VAL_3]] : @@ -162,24 +165,46 @@ func.func @expression_with_load(%arg0: i32, %arg1: !emitc.ptr) -> i1 { } // CHECK-LABEL: func.func @opaque_type_expression(%arg0: i32, %arg1: !emitc.opaque<"T0">, %arg2: i32) -> i1 { -// CHECK: %0 = "emitc.constant"() <{value = 42 : i32}> : () -> i32 -// CHECK: %1 = emitc.expression %arg1, %arg0, %0 : (!emitc.opaque<"T0">, i32, i32) -> i32 { -// CHECK: %3 = mul %arg0, %0 : (i32, i32) -> i32 -// CHECK: %4 = sub %3, %arg1 : (i32, !emitc.opaque<"T0">) -> i32 +// CHECK: %0 = emitc.expression : () -> !emitc.opaque<"T1"> { +// CHECK: %4 = "emitc.constant"() <{value = #emitc.opaque<"V">}> : () -> !emitc.opaque<"T1"> +// CHECK: yield %4 : !emitc.opaque<"T1"> +// CHECK: } +// CHECK: %1 = emitc.expression %arg0, %0 : (i32, !emitc.opaque<"T1">) -> i32 { +// CHECK: %4 = mul %arg0, %0 : (i32, !emitc.opaque<"T1">) -> i32 +// CHECK: yield %4 : i32 +// CHECK: } +// CHECK: %2 = emitc.expression %1, %arg1 : (i32, !emitc.opaque<"T0">) -> i32 { +// CHECK: %4 = sub %1, %arg1 : (i32, !emitc.opaque<"T0">) -> i32 // CHECK: yield %4 : i32 // CHECK: } -// CHECK: %2 = emitc.expression %1, %arg2 : (i32, i32) -> i1 { -// CHECK: %3 = cmp lt, %1, %arg2 : (i32, i32) -> i1 -// CHECK: yield %3 : i1 +// CHECK: %3 = emitc.expression %2, %arg2 : (i32, i32) -> i1 { +// CHECK: %4 = cmp lt, %2, %arg2 : (i32, i32) -> i1 +// CHECK: yield %4 : i1 // CHECK: } -// CHECK: return %2 : i1 +// CHECK: return %3 : i1 // CHECK: } func.func @opaque_type_expression(%arg0: i32, %arg1: !emitc.opaque<"T0">, %arg2: i32) -> i1 { - %c42 = "emitc.constant"(){value = 42 : i32} : () -> i32 - %a = emitc.mul %arg0, %c42 : (i32, i32) -> i32 + %c42 = "emitc.constant"(){value = #emitc.opaque<"V">} : () -> !emitc.opaque<"T1"> + %a = emitc.mul %arg0, %c42 : (i32, !emitc.opaque<"T1">) -> i32 %b = emitc.sub %a, %arg1 : (i32, !emitc.opaque<"T0">) -> i32 %c = emitc.cmp lt, %b, %arg2 :(i32, i32) -> i1 return %c : i1 } + +// CHECK-LABEL: func.func @expression_with_constant( +// CHECK-SAME: %[[VAL_0:.*]]: i32) -> i32 { +// CHECK: %[[VAL_1:.*]] = emitc.expression %[[VAL_0]] : (i32) -> i32 { +// CHECK: %[[VAL_2:.*]] = "emitc.constant"() <{value = 42 : i32}> : () -> i32 +// CHECK: %[[VAL_3:.*]] = mul %[[VAL_0]], %[[VAL_2]] : (i32, i32) -> i32 +// CHECK: yield %[[VAL_3]] : i32 +// CHECK: } +// CHECK: return %[[VAL_1]] : i32 +// CHECK: } + +func.func @expression_with_constant(%arg0: i32) -> i32 { + %c42 = "emitc.constant"(){value = 42 : i32} : () -> i32 + %a = emitc.mul %arg0, %c42 : (i32, i32) -> i32 + return %a : i32 +} diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir index fdfb0eb46f7c5..a97474401645c 100644 --- a/mlir/test/Dialect/EmitC/invalid_ops.mlir +++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir @@ -311,7 +311,7 @@ func.func @test_expression_illegal_op(%arg0 : i1) -> i32 { // ----- func.func @test_expression_no_use(%arg0: i32, %arg1: i32) -> i32 { - // expected-error @+1 {{'emitc.expression' op requires exactly one use for each operation}} + // expected-error @+1 {{'emitc.expression' op contains an unused operation}} %r = emitc.expression %arg0, %arg1 : (i32, i32) -> i32 { %a = emitc.add %arg0, %arg1 : (i32, i32) -> i32 %b = emitc.rem %arg0, %arg1 : (i32, i32) -> i32 @@ -323,12 +323,12 @@ func.func @test_expression_no_use(%arg0: i32, %arg1: i32) -> i32 { // ----- func.func @test_expression_multiple_uses(%arg0: i32, %arg1: i32) -> i32 { - // expected-error @+1 {{'emitc.expression' op requires exactly one use for each operation}} + // expected-error @+1 {{'emitc.expression' op requires exactly one use for operations with side effects}} %r = emitc.expression %arg0, %arg1 : (i32, i32) -> i32 { - %a = emitc.rem %arg0, %arg1 : (i32, i32) -> i32 + %a = emitc.call_opaque "foo"(%arg0, %arg1) : (i32, i32) -> i32 %b = emitc.add %a, %arg0 : (i32, i32) -> i32 - %c = emitc.mul %arg1, %a : (i32, i32) -> i32 - emitc.yield %a : i32 + %c = emitc.mul %b, %a : (i32, i32) -> i32 + emitc.yield %c : i32 } return %r : i32 } diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir index e890f77173de7..84c9b65d775d2 100644 --- a/mlir/test/Dialect/EmitC/ops.mlir +++ b/mlir/test/Dialect/EmitC/ops.mlir @@ -203,6 +203,16 @@ func.func @test_expression(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: f32, %arg4 return %r : i32 } +func.func @test_expression_multiple_uses(%arg0: i32, %arg1: i32) -> i32 { + %r = emitc.expression %arg0, %arg1 : (i32, i32) -> i32 { + %a = emitc.rem %arg0, %arg1 : (i32, i32) -> i32 + %b = emitc.add %a, %arg0 : (i32, i32) -> i32 + %c = emitc.mul %b, %a : (i32, i32) -> i32 + emitc.yield %c : i32 + } + return %r : i32 +} + func.func @test_for(%arg0 : index, %arg1 : index, %arg2 : index) { emitc.for %i0 = %arg0 to %arg1 step %arg2 { %0 = emitc.call_opaque "func_const"(%i0) : (index) -> i32 diff --git a/mlir/test/Target/Cpp/expressions.mlir b/mlir/test/Target/Cpp/expressions.mlir index 433a67ccb3f39..4281f41d0b3fb 100644 --- a/mlir/test/Target/Cpp/expressions.mlir +++ b/mlir/test/Target/Cpp/expressions.mlir @@ -1,6 +1,25 @@ // RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s -check-prefix=CPP-DEFAULT // RUN: mlir-translate -mlir-to-cpp -declare-variables-at-top %s | FileCheck %s -check-prefix=CPP-DECLTOP +// CPP-DEFAULT: bool single_expression(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) { +// CPP-DEFAULT-NEXT: return [[VAL_1]] * 42 - [[VAL_2]] < [[VAL_3]]; +// CPP-DEFAULT-NEXT: } + +// CPP-DECLTOP: bool single_expression(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) { +// CPP-DECLTOP-NEXT: return [[VAL_1]] * 42 - [[VAL_2]] < [[VAL_3]]; +// CPP-DECLTOP-NEXT: } + +func.func @single_expression(%arg0: i32, %arg1: i32, %arg2: i32) -> i1 { + %e = emitc.expression %arg0, %arg1, %arg2 : (i32, i32, i32) -> i1 { + %c42 = "emitc.constant"(){value = 42 : i32} : () -> i32 + %a = emitc.mul %arg0, %c42 : (i32, i32) -> i32 + %b = emitc.sub %a, %arg1 : (i32, i32) -> i32 + %c = emitc.cmp lt, %b, %arg2 :(i32, i32) -> i1 + emitc.yield %c : i1 + } + return %e : i1 +} + // CPP-DEFAULT: int32_t single_use(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]], int32_t [[VAL_4:v[0-9]+]]) { // CPP-DEFAULT-NEXT: bool [[VAL_5:v[0-9]+]] = bar([[VAL_1]] * M_PI, [[VAL_3]]) - [[VAL_4]] < [[VAL_2]]; // CPP-DEFAULT-NEXT: int32_t [[VAL_6:v[0-9]+]]; @@ -185,7 +204,7 @@ func.func @user_with_expression_trait(%arg0: i32, %arg1: i32, %arg2: i32) -> i32 } // CPP-DEFAULT: int32_t multiple_uses(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]], int32_t [[VAL_4:v[0-9]+]]) { -// CPP-DEFAULT-NEXT: bool [[VAL_5:v[0-9]+]] = bar([[VAL_1]] * [[VAL_2]], [[VAL_3]]) - [[VAL_4]] < [[VAL_2]]; +// CPP-DEFAULT-NEXT: bool [[VAL_5:v[0-9]+]] = bar([[VAL_1]] * [[VAL_2]], [[VAL_3]]) - [[VAL_1]] * [[VAL_2]] < [[VAL_4]]; // CPP-DEFAULT-NEXT: int32_t [[VAL_6:v[0-9]+]]; // CPP-DEFAULT-NEXT: if ([[VAL_5]]) { // CPP-DEFAULT-NEXT: [[VAL_6]] = [[VAL_1]]; @@ -203,7 +222,7 @@ func.func @user_with_expression_trait(%arg0: i32, %arg1: i32, %arg2: i32) -> i32 // CPP-DECLTOP-NEXT: int32_t [[VAL_6:v[0-9]+]]; // CPP-DECLTOP-NEXT: bool [[VAL_7:v[0-9]+]]; // CPP-DECLTOP-NEXT: int32_t [[VAL_8:v[0-9]+]]; -// CPP-DECLTOP-NEXT: [[VAL_5]] = bar([[VAL_1]] * [[VAL_2]], [[VAL_3]]) - [[VAL_4]] < [[VAL_2]]; +// CPP-DECLTOP-NEXT: [[VAL_5]] = bar([[VAL_1]] * [[VAL_2]], [[VAL_3]]) - [[VAL_1]] * [[VAL_2]] < [[VAL_4]]; // CPP-DECLTOP-NEXT: ; // CPP-DECLTOP-NEXT: if ([[VAL_5]]) { // CPP-DECLTOP-NEXT: [[VAL_6]] = [[VAL_1]]; @@ -220,8 +239,8 @@ func.func @multiple_uses(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) -> i32 %e = emitc.expression %arg0, %arg1, %arg2, %arg3 : (i32, i32, i32, i32) -> i1 { %a = emitc.mul %arg0, %arg1 : (i32, i32) -> i32 %b = emitc.call_opaque "bar" (%a, %arg2) : (i32, i32) -> (i32) - %c = emitc.sub %b, %arg3 : (i32, i32) -> i32 - %d = emitc.cmp lt, %c, %arg1 :(i32, i32) -> i1 + %c = emitc.sub %b, %a : (i32, i32) -> i32 + %d = emitc.cmp lt, %c, %arg3 :(i32, i32) -> i1 emitc.yield %d : i1 } %v = "emitc.variable"(){value = #emitc.opaque<"">} : () -> !emitc.lvalue