From 9268794a17c8a16aba72ab669dec468d1489082b Mon Sep 17 00:00:00 2001 From: Gil Rapaport Date: Tue, 22 Jul 2025 10:52:37 +0300 Subject: [PATCH] [mlir][emitc] Deferred emission as expressions The translator currently implements deferred emission for certain ops. Like expressions, these ops are emitted as part of their users but unlike expressions, this is mandatory. Besides complicating the code with a second inlining mechanism, deferred emission's inlining is limited as it's not recursive. This patch extends EmitC's expressions to deferred emission ops by (a) marking them as CExpressions, (b) extending expression interface to mark ops as always-inline and (c) support inlining of always-inline CExpressions even when not packed of an `emitc.expression` op, retaining current behavior. --- mlir/include/mlir/Dialect/EmitC/IR/EmitC.td | 78 ++++- .../mlir/Dialect/EmitC/IR/EmitCInterfaces.td | 27 +- mlir/lib/Dialect/EmitC/IR/EmitC.cpp | 13 +- mlir/lib/Target/Cpp/TranslateToCpp.cpp | 324 ++++++++++-------- mlir/test/Dialect/EmitC/form-expressions.mlir | 163 ++++++++- mlir/test/Target/Cpp/expressions.mlir | 159 +++++++++ mlir/test/Target/Cpp/member.mlir | 8 +- 7 files changed, 612 insertions(+), 160 deletions(-) diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td index 721f9f6b320ad..ebd4850161894 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td @@ -507,7 +507,7 @@ def EmitC_ExpressionOp let arguments = (ins Variadic>:$defs, UnitAttr:$do_not_inline); - let results = (outs EmitCType:$result); + let results = (outs AnyTypeOf<[EmitCType, EmitC_LValueType]>:$result); let regions = (region SizedRegion<1>:$region); let hasVerifier = 1; @@ -873,7 +873,7 @@ def EmitC_IncludeOp let hasCustomAssemblyFormat = 1; } -def EmitC_LiteralOp : EmitC_Op<"literal", [Pure]> { +def EmitC_LiteralOp : EmitC_Op<"literal", [Pure, CExpressionInterface]> { let summary = "Literal operation"; let description = [{ The `emitc.literal` operation produces an SSA value equal to some constant @@ -896,6 +896,15 @@ def EmitC_LiteralOp : EmitC_Op<"literal", [Pure]> { let hasVerifier = 1; let assemblyFormat = "$value attr-dict `:` type($result)"; + + let extraClassDeclaration = [{ + bool hasSideEffects() { + return false; + } + bool alwaysInline() { + return true; // C doesn't support variable references. + } + }]; } def EmitC_LogicalAndOp : EmitC_BinaryOp<"logical_and", []> { @@ -1062,7 +1071,7 @@ def EmitC_SubOp : EmitC_BinaryOp<"sub", []> { let hasVerifier = 1; } -def EmitC_MemberOp : EmitC_Op<"member"> { +def EmitC_MemberOp : EmitC_Op<"member", [CExpressionInterface]> { let summary = "Member operation"; let description = [{ With the `emitc.member` operation the member access operator `.` can be @@ -1083,9 +1092,18 @@ def EmitC_MemberOp : EmitC_Op<"member"> { EmitC_LValueOf<[EmitC_OpaqueType]>:$operand ); let results = (outs AnyTypeOf<[EmitC_ArrayType, EmitC_LValueType]>); + + let extraClassDeclaration = [{ + bool hasSideEffects() { + return false; + } + bool alwaysInline() { + return true; // C doesn't support variable references. + } + }]; } -def EmitC_MemberOfPtrOp : EmitC_Op<"member_of_ptr"> { +def EmitC_MemberOfPtrOp : EmitC_Op<"member_of_ptr", [CExpressionInterface]> { let summary = "Member of pointer operation"; let description = [{ With the `emitc.member_of_ptr` operation the member access operator `->` @@ -1108,6 +1126,15 @@ def EmitC_MemberOfPtrOp : EmitC_Op<"member_of_ptr"> { EmitC_LValueOf<[EmitC_OpaqueType,EmitC_PointerType]>:$operand ); let results = (outs AnyTypeOf<[EmitC_ArrayType, EmitC_LValueType]>); + + let extraClassDeclaration = [{ + bool hasSideEffects() { + return false; + } + bool alwaysInline() { + return true; // C doesn't support variable references. + } + }]; } def EmitC_ConditionalOp : EmitC_Op<"conditional", @@ -1277,8 +1304,10 @@ def EmitC_GlobalOp : EmitC_Op<"global", [Symbol]> { let hasVerifier = 1; } -def EmitC_GetGlobalOp : EmitC_Op<"get_global", - [Pure, DeclareOpInterfaceMethods]> { +def EmitC_GetGlobalOp + : EmitC_Op<"get_global", [Pure, + DeclareOpInterfaceMethods, + CExpressionInterface]> { let summary = "Obtain access to a global variable"; let description = [{ The `emitc.get_global` operation retrieves the lvalue of a @@ -1296,6 +1325,15 @@ def EmitC_GetGlobalOp : EmitC_Op<"get_global", let arguments = (ins FlatSymbolRefAttr:$name); let results = (outs AnyTypeOf<[EmitC_ArrayType, EmitC_LValueType]>:$result); let assemblyFormat = "$name `:` type($result) attr-dict"; + + let extraClassDeclaration = [{ + bool hasSideEffects() { + return false; + } + bool alwaysInline() { + return true; // C doesn't support variable references. + } + }]; } def EmitC_VerbatimOp : EmitC_Op<"verbatim"> { @@ -1406,7 +1444,8 @@ def EmitC_YieldOp : EmitC_Op<"yield", value is yielded. }]; - let arguments = (ins Optional:$result); + let arguments = + (ins Optional>:$result); let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>]; let hasVerifier = 1; @@ -1477,7 +1516,7 @@ def EmitC_IfOp : EmitC_Op<"if", let hasCustomAssemblyFormat = 1; } -def EmitC_SubscriptOp : EmitC_Op<"subscript", []> { +def EmitC_SubscriptOp : EmitC_Op<"subscript", [CExpressionInterface]> { let summary = "Subscript operation"; let description = [{ With the `emitc.subscript` operation the subscript operator `[]` can be applied @@ -1525,6 +1564,15 @@ def EmitC_SubscriptOp : EmitC_Op<"subscript", []> { let hasVerifier = 1; let assemblyFormat = "$value `[` $indices `]` attr-dict `:` functional-type(operands, results)"; + + let extraClassDeclaration = [{ + bool hasSideEffects() { + return false; + } + bool alwaysInline() { + return true; // C doesn't support variable references. + } + }]; } def EmitC_SwitchOp : EmitC_Op<"switch", [RecursiveMemoryEffects, @@ -1707,8 +1755,9 @@ def EmitC_FieldOp : EmitC_Op<"field", [Symbol]> { } def EmitC_GetFieldOp - : EmitC_Op<"get_field", [Pure, DeclareOpInterfaceMethods< - SymbolUserOpInterface>]> { + : EmitC_Op<"get_field", [Pure, + DeclareOpInterfaceMethods, + CExpressionInterface]> { let summary = "Obtain access to a field within a class instance"; let description = [{ The `emitc.get_field` operation retrieves the lvalue of a @@ -1725,6 +1774,15 @@ def EmitC_GetFieldOp let results = (outs EmitCType:$result); let assemblyFormat = "$field_name `:` type($result) attr-dict"; let hasVerifier = 1; + + let extraClassDeclaration = [{ + bool hasSideEffects() { + return false; + } + bool alwaysInline() { + return true; // C doesn't support variable references. + } + }]; } #endif // MLIR_DIALECT_EMITC_IR_EMITC diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitCInterfaces.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitCInterfaces.td index 777784e56202a..c11e017e40d0f 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitCInterfaces.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitCInterfaces.td @@ -21,8 +21,8 @@ def CExpressionInterface : OpInterface<"CExpressionInterface"> { }]; let cppNamespace = "::mlir::emitc"; - let methods = [ - InterfaceMethod<[{ + let methods = + [InterfaceMethod<[{ Check whether operation has side effects that may affect the expression evaluation. @@ -38,9 +38,28 @@ def CExpressionInterface : OpInterface<"CExpressionInterface"> { }; ``` }], - "bool", "hasSideEffects", (ins), /*methodBody=*/[{}], - /*defaultImplementation=*/[{ + "bool", "hasSideEffects", (ins), /*methodBody=*/[{}], + /*defaultImplementation=*/[{ return true; + }]>, + InterfaceMethod<[{ + Check whether operation must be inlined into all its users. + + By default operation is not marked as always inlined. + + ```c++ + class ConcreteOp ... { + public: + bool alwaysInline() { + // That way we can override the default implementation. + return true; + } + }; + ``` + }], + "bool", "alwaysInline", (ins), /*methodBody=*/[{}], + /*defaultImplementation=*/[{ + return false; }]>, ]; } diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index 5c8564bca6f86..4468ac686aa8b 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -483,7 +483,8 @@ LogicalResult ExpressionOp::verify() { Operation *op = worklist.back(); worklist.pop_back(); if (visited.contains(op)) { - if (cast(op).hasSideEffects()) + auto cExpr = cast(op); + if (!cExpr.alwaysInline() && cExpr.hasSideEffects()) return emitOpError( "requires exactly one use for operations with side effects"); } @@ -494,6 +495,12 @@ LogicalResult ExpressionOp::verify() { } } + if (getDoNotInline() && + cast(rootOp).alwaysInline()) { + return emitOpError("root operation must be inlined but expression is marked" + " do-not-inline"); + } + return success(); } @@ -980,6 +987,10 @@ LogicalResult emitc::YieldOp::verify() { if (!result && containingOp->getNumResults() != 0) return emitOpError() << "does not yield a value to be returned by parent"; + if (result && isa(result.getType()) && + !isa(containingOp)) + return emitOpError() << "yielding lvalues is not supported for this op"; + return success(); } diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index a5bd80e9d6b8b..7541845ff6242 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/Dialect/EmitC/IR/EmitCInterfaces.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" @@ -99,13 +100,18 @@ static FailureOr getOperatorPrecedence(Operation *operation) { .Case([&](auto op) { return 2; }) .Case([&](auto op) { return 17; }) .Case([&](auto op) { return 13; }) + .Case([&](auto op) { return 18; }) + .Case([&](auto op) { return 18; }) .Case([&](auto op) { return 16; }) .Case([&](auto op) { return 4; }) .Case([&](auto op) { return 15; }) .Case([&](auto op) { return 3; }) + .Case([&](auto op) { return 17; }) + .Case([&](auto op) { return 17; }) .Case([&](auto op) { return 13; }) .Case([&](auto op) { return 13; }) .Case([&](auto op) { return 12; }) + .Case([&](auto op) { return 17; }) .Case([&](auto op) { return 15; }) .Case([&](auto op) { return 15; }) .Default([](auto op) { return op->emitError("unsupported operation"); }); @@ -173,31 +179,27 @@ struct CppEmitter { /// Emits the operands of the operation. All operands are emitted in order. LogicalResult emitOperands(Operation &op); - /// Emits value as an operands of an operation - LogicalResult emitOperand(Value value); + /// Emits value as an operand of an operation. If \p isInBrackets is true, + /// this operand is already being emitted between some kind of brackets, so + /// there is no need to wrap it in parentheses for correct precedence. + LogicalResult emitOperand(Value value, bool isInBrackets = false); - /// Emit an expression as a C expression. - LogicalResult emitExpression(ExpressionOp expressionOp); + /// Collect all operations to emit as an expression starting at \p op, + /// recursively adding operands that shouldBeInlined. + void buildExpression(Operation *op); - /// Insert the expression representing the operation into the value cache. - void cacheDeferredOpResult(Value value, StringRef str); + /// Emit an expression as a C expression. + LogicalResult emitExpression(Operation *op); /// Return the existing or a new name for a Value. StringRef getOrCreateName(Value val); + void setName(Value val, StringRef name); + /// Return the existing or a new name for a loop induction variable of an /// emitc::ForOp. StringRef getOrCreateInductionVarName(Value val); - // Returns the textual representation of a subscript operation. - std::string getSubscriptName(emitc::SubscriptOp op); - - // Returns the textual representation of a member (of object) operation. - std::string createMemberAccess(emitc::MemberOp op); - - // Returns the textual representation of a member of pointer operation. - std::string createMemberAccess(emitc::MemberOfPtrOp op); - /// Return the existing or a new label of a Block. StringRef getOrCreateName(Block &block); @@ -257,25 +259,24 @@ struct CppEmitter { return !fileId.empty() && file.getId() == fileId; } - /// Get expression currently being emitted. - ExpressionOp getEmittedExpression() { return emittedExpression; } + /// Is expression currently being emitted. + bool isEmittingExpression() { return !emittedExpression.empty(); } /// Determine whether given value is part of the expression potentially being /// emitted. - bool isPartOfCurrentExpression(Value value) { - if (!emittedExpression) - return false; + Operation *isPartOfCurrentExpression(Value value) { Operation *def = value.getDefiningOp(); - if (!def) - return false; - return isPartOfCurrentExpression(def); + if (def) + return isPartOfCurrentExpression(def) ? def : nullptr; + return nullptr; } /// Determine whether given operation is part of the expression potentially /// being emitted. bool isPartOfCurrentExpression(Operation *def) { - auto operandExpression = dyn_cast(def->getParentOp()); - return operandExpression && operandExpression == emittedExpression; + if (auto parentExpression = dyn_cast(def->getParentOp())) + def = parentExpression; + return emittedExpression.contains(def); }; // Resets the value counter to 0. @@ -322,7 +323,7 @@ struct CppEmitter { unsigned int valueCount{0}; /// State of the current expression being emitted. - ExpressionOp emittedExpression; + SmallPtrSet emittedExpression; SmallVector emittedExpressionPrecedence; void pushExpressionPrecedence(int precedence) { @@ -338,19 +339,35 @@ struct CppEmitter { }; } // namespace -/// Determine whether expression \p op should be emitted in a deferred way. -static bool hasDeferredEmission(Operation *op) { - return isa_and_nonnull(op); -} - -/// Determine whether expression \p expressionOp should be emitted inline, i.e. -/// as part of its user. This function recommends inlining of any expressions -/// that can be inlined unless it is used by another expression, under the -/// assumption that any expression fusion/re-materialization was taken care of +/// Determine whether operation \p operation should be emitted inline, i.e. +/// as part of its user. +/// The operation can force inlining using its +/// CExpressionInterface::alwaysInline() method when it's not included in any +/// ExpressionOp or is some ExpressionOp's root operation. +/// Otherwise, for any ExpressionOp that can be inlined, this function +/// recommends inlining unless it is used by another expression, under the +/// assumption that any expression fusion/re-materialization was taken care of /// by transformations run by the backend. -static bool shouldBeInlined(ExpressionOp expressionOp) { +static bool shouldBeInlined(Operation *operation) { + auto expressionInterface = dyn_cast(operation); + + // Inline if this is an always-inline CExpression not part of any + // expression. + if (expressionInterface && expressionInterface.alwaysInline()) { + assert(!isa(operation->getParentOp()) && + "Unexpectedly called on operation included in expression"); + return true; + } + + ExpressionOp expressionOp = dyn_cast(operation); + + if (!expressionOp) + return false; + + // Inline if the root operation is an always-inline CExpression. + if (cast(expressionOp.getRootOp()).alwaysInline()) + return true; + // Do not inline if expression is marked as such. if (expressionOp.getDoNotInline()) return false; @@ -367,17 +384,83 @@ static bool shouldBeInlined(ExpressionOp expressionOp) { Operation *user = *result.getUsers().begin(); - // Do not inline expressions used by operations with deferred emission, since - // their translation requires the materialization of variables. - if (hasDeferredEmission(user)) - return false; - // Do not inline expressions used by other expressions or by ops with the // CExpressionInterface. If this was intended, the user could have been merged // into the expression op. return !isa(*user); } +static LogicalResult printOperation(CppEmitter &emitter, + emitc::GetFieldOp getFieldOp) { + if (!emitter.isPartOfCurrentExpression(getFieldOp.getOperation())) { + emitter.setName(getFieldOp.getResult(), getFieldOp.getFieldName()); + return success(); + } + + emitter.ostream() << getFieldOp.getFieldName(); + return success(); +} + +static LogicalResult printOperation(CppEmitter &emitter, + emitc::GetGlobalOp getGlobalOp) { + if (!emitter.isPartOfCurrentExpression(getGlobalOp.getOperation())) { + emitter.setName(getGlobalOp.getResult(), getGlobalOp.getName()); + return success(); + } + + emitter.ostream() << getGlobalOp.getName(); + return success(); +} + +static LogicalResult printOperation(CppEmitter &emitter, + emitc::LiteralOp literalOp) { + if (!emitter.isPartOfCurrentExpression(literalOp.getOperation())) + return success(); + + emitter.ostream() << literalOp.getValue(); + return success(); +} + +static LogicalResult printOperation(CppEmitter &emitter, + emitc::MemberOp memberOp) { + if (!emitter.isPartOfCurrentExpression(memberOp.getOperation())) + return success(); + + if (failed(emitter.emitOperand(memberOp.getOperand()))) + return failure(); + emitter.ostream() << "." << memberOp.getMember(); + return success(); +} + +static LogicalResult printOperation(CppEmitter &emitter, + emitc::MemberOfPtrOp memberOfPtrOp) { + if (!emitter.isPartOfCurrentExpression(memberOfPtrOp.getOperation())) + return success(); + + if (failed(emitter.emitOperand(memberOfPtrOp.getOperand()))) + return failure(); + emitter.ostream() << "->" << memberOfPtrOp.getMember(); + return success(); +} + +static LogicalResult printOperation(CppEmitter &emitter, + emitc::SubscriptOp subscriptOp) { + if (!emitter.isPartOfCurrentExpression(subscriptOp.getOperation())) { + return success(); + } + + raw_ostream &os = emitter.ostream(); + if (failed(emitter.emitOperand(subscriptOp.getValue()))) + return failure(); + for (auto index : subscriptOp.getIndices()) { + os << "["; + if (failed(emitter.emitOperand(index, /*isInBrackets=*/true))) + return failure(); + os << "]"; + } + return success(); +} + static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation, Attribute value) { OpResult result = operation->getResult(0); @@ -437,11 +520,11 @@ static LogicalResult printOperation(CppEmitter &emitter, static LogicalResult printOperation(CppEmitter &emitter, emitc::AssignOp assignOp) { - OpResult result = assignOp.getVar().getDefiningOp()->getResult(0); - - if (failed(emitter.emitVariableAssignment(result))) + if (failed(emitter.emitOperand(assignOp.getVar()))) return failure(); + emitter.ostream() << " = "; + return emitter.emitOperand(assignOp.getValue()); } @@ -1288,44 +1371,14 @@ CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop, labelInScopeCount.push(0); } -std::string CppEmitter::getSubscriptName(emitc::SubscriptOp op) { - std::string out; - llvm::raw_string_ostream ss(out); - ss << getOrCreateName(op.getValue()); - for (auto index : op.getIndices()) { - ss << "[" << getOrCreateName(index) << "]"; - } - return out; -} - -std::string CppEmitter::createMemberAccess(emitc::MemberOp op) { - std::string out; - llvm::raw_string_ostream ss(out); - ss << getOrCreateName(op.getOperand()); - ss << "." << op.getMember(); - return out; -} - -std::string CppEmitter::createMemberAccess(emitc::MemberOfPtrOp op) { - std::string out; - llvm::raw_string_ostream ss(out); - ss << getOrCreateName(op.getOperand()); - ss << "->" << op.getMember(); - return out; -} - -void CppEmitter::cacheDeferredOpResult(Value value, StringRef str) { - if (!valueMapper.count(value)) - valueMapper.insert(value, str.str()); +void CppEmitter::setName(Value val, StringRef name) { + assert(!valueMapper.count(val) && "Expected value not to be in mapper"); + valueMapper.insert(val, name.str()); } /// Return the existing or a new name for a Value. StringRef CppEmitter::getOrCreateName(Value val) { if (!valueMapper.count(val)) { - assert(!hasDeferredEmission(val.getDefiningOp()) && - "cacheDeferredOpResult should have been called on this value, " - "update the emitOperation function."); - valueMapper.insert(val, formatv("v{0}", ++valueCount)); } return *valueMapper.begin(val); @@ -1492,12 +1545,37 @@ LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) { return emitError(loc, "cannot emit attribute: ") << attr; } -LogicalResult CppEmitter::emitExpression(ExpressionOp expressionOp) { +void CppEmitter::buildExpression(Operation *op) { + emittedExpression.insert(op); + + for (Value operand : op->getOperands()) { + Operation *defOp = operand.getDefiningOp(); + if (defOp && shouldBeInlined(defOp)) + buildExpression(defOp); + } +} + +LogicalResult CppEmitter::emitExpression(Operation *op) { assert(emittedExpressionPrecedence.empty() && "Expected precedence stack to be empty"); - Operation *rootOp = expressionOp.getRootOp(); + assert(emittedExpression.empty() && + "Expected sub-expressions set to be empty"); + + Operation *rootOp = nullptr; + + if (auto expressionOp = dyn_cast(op)) { + rootOp = expressionOp.getRootOp(); + } else { + auto expressionInterface = cast(op); + assert(expressionInterface.alwaysInline() && + "Expected an always-inline operation"); + assert(!isa(op->getParentOp()) && + "Expected operation to have no containing expression"); + rootOp = op; + } + + buildExpression(op); - emittedExpression = expressionOp; FailureOr precedence = getOperatorPrecedence(rootOp); if (failed(precedence)) return failure(); @@ -1509,14 +1587,13 @@ LogicalResult CppEmitter::emitExpression(ExpressionOp expressionOp) { popExpressionPrecedence(); assert(emittedExpressionPrecedence.empty() && "Expected precedence stack to be empty"); - emittedExpression = nullptr; + emittedExpression.clear(); return success(); } -LogicalResult CppEmitter::emitOperand(Value value) { - if (isPartOfCurrentExpression(value)) { - Operation *def = value.getDefiningOp(); +LogicalResult CppEmitter::emitOperand(Value value, bool isInBrackets) { + if (Operation *def = isPartOfCurrentExpression(value)) { assert(def && "Expected operand to be defined by an operation"); FailureOr precedence = getOperatorPrecedence(def); if (failed(precedence)) @@ -1525,7 +1602,8 @@ LogicalResult CppEmitter::emitOperand(Value value) { // Sub-expressions with equal or lower precedence need to be parenthesized, // as they might be evaluated in the wrong order depending on the shape of // the expression tree. - bool encloseInParenthesis = precedence.value() <= getExpressionPrecedence(); + bool encloseInParenthesis = + !isInBrackets && precedence.value() <= getExpressionPrecedence(); if (encloseInParenthesis) os << "("; pushExpressionPrecedence(precedence.value()); @@ -1540,22 +1618,23 @@ LogicalResult CppEmitter::emitOperand(Value value) { return success(); } - auto expressionOp = value.getDefiningOp(); - if (expressionOp && shouldBeInlined(expressionOp)) - return emitExpression(expressionOp); + if (Operation *defOp = value.getDefiningOp()) { + auto expressionOp = dyn_cast(defOp); + auto expressionInterface = dyn_cast(defOp); + + if (expressionInterface) { + if (expressionInterface.alwaysInline()) + return emitExpression(defOp); + } else if (expressionOp && shouldBeInlined(expressionOp)) + return emitExpression(expressionOp); + } if (BlockArgument arg = dyn_cast(value)) { // If this operand is a block argument of an expression, emit instead the // matching expression parameter. Operation *argOp = arg.getParentBlock()->getParentOp(); - if (auto expressionOp = dyn_cast(argOp)) { - // This scenario is only expected when one of the operations within the - // expression being emitted references one of the expression's block - // arguments. - assert(expressionOp == emittedExpression && - "Expected expression being emitted"); - value = expressionOp->getOperand(arg.getArgNumber()); - } + if (auto expressionOp = dyn_cast(argOp)) + return emitOperand(expressionOp->getOperand(arg.getArgNumber())); } os << getOrCreateName(value); @@ -1564,14 +1643,8 @@ LogicalResult CppEmitter::emitOperand(Value value) { LogicalResult CppEmitter::emitOperands(Operation &op) { return interleaveCommaWithError(op.getOperands(), os, [&](Value operand) { - // If an expression is being emitted, push lowest precedence as these - // operands are either wrapped by parenthesis. - if (getEmittedExpression()) - pushExpressionPrecedence(lowestPrecedence()); - if (failed(emitOperand(operand))) + if (failed(emitOperand(operand, /*isInBrackets=*/true))) return failure(); - if (getEmittedExpression()) - popExpressionPrecedence(); return success(); }); } @@ -1613,7 +1686,9 @@ LogicalResult CppEmitter::emitVariableAssignment(OpResult result) { LogicalResult CppEmitter::emitVariableDeclaration(OpResult result, bool trailingSemicolon) { - if (hasDeferredEmission(result.getDefiningOp())) + auto expressionInterface = + dyn_cast(result.getDefiningOp()); + if (expressionInterface && expressionInterface.alwaysInline()) return success(); if (hasValueInScope(result)) { return result.getDefiningOp()->emitError( @@ -1654,7 +1729,7 @@ LogicalResult CppEmitter::emitGlobalVariable(GlobalOp op) { LogicalResult CppEmitter::emitAssignPrefix(Operation &op) { // If op is being emitted as part of an expression, bail out. - if (getEmittedExpression()) + if (isEmittingExpression()) return success(); switch (op.getNumResults()) { @@ -1713,9 +1788,11 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { emitc::CmpOp, emitc::ConditionalOp, emitc::ConstantOp, emitc::DeclareFuncOp, emitc::DivOp, emitc::ExpressionOp, emitc::FieldOp, emitc::FileOp, emitc::ForOp, emitc::FuncOp, - emitc::GlobalOp, emitc::IfOp, emitc::IncludeOp, emitc::LoadOp, + emitc::GetFieldOp, emitc::GetGlobalOp, emitc::GlobalOp, + emitc::IfOp, emitc::IncludeOp, emitc::LiteralOp, emitc::LoadOp, emitc::LogicalAndOp, emitc::LogicalNotOp, emitc::LogicalOrOp, - emitc::MulOp, emitc::RemOp, emitc::ReturnOp, emitc::SubOp, + emitc::MemberOfPtrOp, emitc::MemberOp, emitc::MulOp, + emitc::RemOp, emitc::ReturnOp, emitc::SubscriptOp, emitc::SubOp, emitc::SwitchOp, emitc::UnaryMinusOp, emitc::UnaryPlusOp, emitc::VariableOp, emitc::VerbatimOp>( @@ -1723,30 +1800,6 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { // Func ops. .Case( [&](auto op) { return printOperation(*this, op); }) - .Case([&](auto op) { - cacheDeferredOpResult(op.getResult(), op.getName()); - return success(); - }) - .Case([&](auto op) { - cacheDeferredOpResult(op.getResult(), op.getFieldName()); - return success(); - }) - .Case([&](auto op) { - cacheDeferredOpResult(op.getResult(), op.getValue()); - return success(); - }) - .Case([&](auto op) { - cacheDeferredOpResult(op.getResult(), createMemberAccess(op)); - return success(); - }) - .Case([&](auto op) { - cacheDeferredOpResult(op.getResult(), createMemberAccess(op)); - return success(); - }) - .Case([&](auto op) { - cacheDeferredOpResult(op.getResult(), getSubscriptName(op)); - return success(); - }) .Default([&](Operation *) { return op.emitOpError("unable to find printer for op"); }); @@ -1754,10 +1807,11 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { if (failed(status)) return failure(); - if (hasDeferredEmission(&op)) + auto expressionInterface = dyn_cast(op); + if (expressionInterface && expressionInterface.alwaysInline()) return success(); - if (getEmittedExpression() || + if (isEmittingExpression() || (isa(op) && shouldBeInlined(cast(op)))) return success(); diff --git a/mlir/test/Dialect/EmitC/form-expressions.mlir b/mlir/test/Dialect/EmitC/form-expressions.mlir index 7b6723989e260..291f8e24693b7 100644 --- a/mlir/test/Dialect/EmitC/form-expressions.mlir +++ b/mlir/test/Dialect/EmitC/form-expressions.mlir @@ -131,17 +131,14 @@ func.func @single_result_requirement() -> (i32, i32) { // CHECK-LABEL: func.func @expression_with_load( // CHECK-SAME: %[[VAL_0:.*]]: i32, // CHECK-SAME: %[[VAL_1:.*]]: !emitc.ptr) -> i1 { -// CHECK: %[[VAL_2:.*]] = emitc.expression : () -> i64 { -// CHECK: %[[VAL_C:.*]] = "emitc.constant"() <{value = 0 : i64}> : () -> i64 -// CHECK: yield %[[VAL_C]] : i64 -// CHECK: } // CHECK: %[[VAL_3:.*]] = "emitc.variable"() <{value = #emitc.opaque<"42">}> : () -> !emitc.lvalue // CHECK: %[[VAL_4:.*]] = emitc.expression %[[VAL_3]] : (!emitc.lvalue) -> i32 { // CHECK: %[[VAL_5:.*]] = load %[[VAL_3]] : // CHECK: yield %[[VAL_5]] : i32 // CHECK: } -// CHECK: %[[VAL_6:.*]] = emitc.subscript %[[VAL_1]]{{\[}}%[[VAL_2]]] : (!emitc.ptr, i64) -> !emitc.lvalue -// CHECK: %[[VAL_7:.*]] = emitc.expression %[[VAL_6]] : (!emitc.lvalue) -> i32 { +// CHECK: %[[VAL_7:.*]] = emitc.expression %[[VAL_1]] : (!emitc.ptr) -> i32 { +// CHECK: %[[VAL_C:.*]] = "emitc.constant"() <{value = 0 : i64}> : () -> i64 +// CHECK: %[[VAL_6:.*]] = subscript %[[VAL_1]]{{\[}}%[[VAL_C]]] : (!emitc.ptr, i64) -> !emitc.lvalue // CHECK: %[[VAL_8:.*]] = load %[[VAL_6]] : // CHECK: yield %[[VAL_8]] : i32 // CHECK: } @@ -208,3 +205,157 @@ func.func @expression_with_constant(%arg0: i32) -> i32 { %a = emitc.mul %arg0, %c42 : (i32, i32) -> i32 return %a : i32 } + +// CHECK-LABEL: func.func @expression_with_subscript( +// CHECK-SAME: %[[ARG0:.*]]: !emitc.array<4x8xi32>, +// CHECK-SAME: %[[ARG1:.*]]: i32, +// CHECK-SAME: %[[ARG2:.*]]: i32) -> i32 { +// CHECK: %[[VAL_0:.*]] = emitc.expression %[[ARG0]], %[[ARG2]], %[[ARG1]] : (!emitc.array<4x8xi32>, i32, i32) -> i32 { +// CHECK: %[[VAL_1:.*]] = add %[[ARG1]], %[[ARG2]] : (i32, i32) -> i32 +// CHECK: %[[VAL_2:.*]] = mul %[[VAL_1]], %[[ARG2]] : (i32, i32) -> i32 +// CHECK: %[[VAL_3:.*]] = subscript %[[ARG0]]{{\[}}%[[VAL_1]], %[[VAL_2]]] : (!emitc.array<4x8xi32>, i32, i32) -> !emitc.lvalue +// CHECK: %[[VAL_4:.*]] = load %[[VAL_3]] : +// CHECK: yield %[[VAL_4]] : i32 +// CHECK: } +// CHECK: return %[[VAL_0]] : i32 +// CHECK: } + +func.func @expression_with_subscript(%arg0: !emitc.array<4x8xi32>, %arg1: i32, %arg2: i32) -> i32 { + %0 = emitc.add %arg1, %arg2 : (i32, i32) -> i32 + %1 = emitc.mul %0, %arg2 : (i32, i32) -> i32 + %2 = emitc.subscript %arg0[%0, %1] : (!emitc.array<4x8xi32>, i32, i32) -> !emitc.lvalue + %3 = emitc.load %2 : !emitc.lvalue + return %3 : i32 +} + +// CHECK-LABEL: func.func @member( +// CHECK-SAME: %[[ARG0:.*]]: !emitc.opaque<"mystruct">, +// CHECK-SAME: %[[ARG1:.*]]: i32, +// CHECK-SAME: %[[ARG2:.*]]: index) { +// CHECK: %[[VAL_0:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.lvalue> +// CHECK: emitc.assign %[[ARG0]] : !emitc.opaque<"mystruct"> to %[[VAL_0]] : > +// CHECK: %[[VAL_1:.*]] = emitc.expression %[[VAL_0]] : (!emitc.lvalue>) -> !emitc.lvalue { +// CHECK: %[[VAL_2:.*]] = "emitc.member"(%[[VAL_0]]) <{member = "a"}> : (!emitc.lvalue>) -> !emitc.lvalue +// CHECK: yield %[[VAL_2]] : !emitc.lvalue +// CHECK: } +// CHECK: emitc.assign %[[ARG1]] : i32 to %[[VAL_1]] : +// CHECK: %[[VAL_3:.*]] = emitc.expression %[[VAL_0]] : (!emitc.lvalue>) -> i32 { +// CHECK: %[[VAL_4:.*]] = "emitc.member"(%[[VAL_0]]) <{member = "b"}> : (!emitc.lvalue>) -> !emitc.lvalue +// CHECK: %[[VAL_5:.*]] = load %[[VAL_4]] : +// CHECK: yield %[[VAL_5]] : i32 +// CHECK: } +// CHECK: %[[VAL_6:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.lvalue +// CHECK: emitc.assign %[[VAL_3]] : i32 to %[[VAL_6]] : +// CHECK: %[[VAL_7:.*]] = emitc.expression %[[ARG2]], %[[VAL_0]] : (index, !emitc.lvalue>) -> i32 { +// CHECK: %[[VAL_8:.*]] = "emitc.member"(%[[VAL_0]]) <{member = "c"}> : (!emitc.lvalue>) -> !emitc.array<2xi32> +// CHECK: %[[VAL_9:.*]] = subscript %[[VAL_8]]{{\[}}%[[ARG2]]] : (!emitc.array<2xi32>, index) -> !emitc.lvalue +// CHECK: %[[VAL_10:.*]] = load %[[VAL_9]] : +// CHECK: yield %[[VAL_10]] : i32 +// CHECK: } +// CHECK: emitc.assign %[[VAL_7]] : i32 to %[[VAL_6]] : +// CHECK: %[[VAL_11:.*]] = emitc.expression %[[ARG2]], %[[VAL_0]] : (index, !emitc.lvalue>) -> !emitc.lvalue { +// CHECK: %[[VAL_12:.*]] = "emitc.member"(%[[VAL_0]]) <{member = "d"}> : (!emitc.lvalue>) -> !emitc.array<2xi32> +// CHECK: %[[VAL_13:.*]] = subscript %[[VAL_12]]{{\[}}%[[ARG2]]] : (!emitc.array<2xi32>, index) -> !emitc.lvalue +// CHECK: yield %[[VAL_13]] : !emitc.lvalue +// CHECK: } +// CHECK: emitc.assign %[[ARG1]] : i32 to %[[VAL_11]] : +// CHECK: return +// CHECK: } + +func.func @member(%arg0: !emitc.opaque<"mystruct">, %arg1: i32, %arg2: index) { + %var0 = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.lvalue> + emitc.assign %arg0 : !emitc.opaque<"mystruct"> to %var0 : !emitc.lvalue> + + %0 = "emitc.member" (%var0) {member = "a"} : (!emitc.lvalue>) -> !emitc.lvalue + emitc.assign %arg1 : i32 to %0 : !emitc.lvalue + + %1 = "emitc.member" (%var0) {member = "b"} : (!emitc.lvalue>) -> !emitc.lvalue + %2 = emitc.load %1 : !emitc.lvalue + %3 = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.lvalue + emitc.assign %2 : i32 to %3 : !emitc.lvalue + + %4 = "emitc.member" (%var0) {member = "c"} : (!emitc.lvalue>) -> !emitc.array<2xi32> + %5 = emitc.subscript %4[%arg2] : (!emitc.array<2xi32>, index) -> !emitc.lvalue + %6 = emitc.load %5 : + emitc.assign %6 : i32 to %3 : !emitc.lvalue + + %7 = "emitc.member" (%var0) {member = "d"} : (!emitc.lvalue>) -> !emitc.array<2xi32> + %8 = emitc.subscript %7[%arg2] : (!emitc.array<2xi32>, index) -> !emitc.lvalue + emitc.assign %arg1 : i32 to %8 : !emitc.lvalue + + return +} + +// CHECK-LABEL: func.func @member_of_pointer( +// CHECK-SAME: %[[ARG0:.*]]: !emitc.ptr>, +// CHECK-SAME: %[[ARG1:.*]]: i32, +// CHECK-SAME: %[[ARG2:.*]]: index) { +// CHECK: %[[VAL_0:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.lvalue>> +// CHECK: emitc.assign %[[ARG0]] : !emitc.ptr> to %[[VAL_0]] : >> +// CHECK: %[[VAL_1:.*]] = emitc.expression %[[VAL_0]] : (!emitc.lvalue>>) -> !emitc.lvalue { +// CHECK: %[[VAL_2:.*]] = "emitc.member_of_ptr"(%[[VAL_0]]) <{member = "a"}> : (!emitc.lvalue>>) -> !emitc.lvalue +// CHECK: yield %[[VAL_2]] : !emitc.lvalue +// CHECK: } +// CHECK: emitc.assign %[[ARG1]] : i32 to %[[VAL_1]] : +// CHECK: %[[VAL_3:.*]] = emitc.expression %[[VAL_0]] : (!emitc.lvalue>>) -> i32 { +// CHECK: %[[VAL_4:.*]] = "emitc.member_of_ptr"(%[[VAL_0]]) <{member = "b"}> : (!emitc.lvalue>>) -> !emitc.lvalue +// CHECK: %[[VAL_5:.*]] = load %[[VAL_4]] : +// CHECK: yield %[[VAL_5]] : i32 +// CHECK: } +// CHECK: %[[VAL_6:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.lvalue +// CHECK: emitc.assign %[[VAL_3]] : i32 to %[[VAL_6]] : +// CHECK: %[[VAL_7:.*]] = emitc.expression %[[ARG2]], %[[VAL_0]] : (index, !emitc.lvalue>>) -> i32 { +// CHECK: %[[VAL_8:.*]] = "emitc.member_of_ptr"(%[[VAL_0]]) <{member = "c"}> : (!emitc.lvalue>>) -> !emitc.array<2xi32> +// CHECK: %[[VAL_9:.*]] = subscript %[[VAL_8]]{{\[}}%[[ARG2]]] : (!emitc.array<2xi32>, index) -> !emitc.lvalue +// CHECK: %[[VAL_10:.*]] = load %[[VAL_9]] : +// CHECK: yield %[[VAL_10]] : i32 +// CHECK: } +// CHECK: emitc.assign %[[VAL_7]] : i32 to %[[VAL_6]] : +// CHECK: %[[VAL_11:.*]] = emitc.expression %[[ARG2]], %[[VAL_0]] : (index, !emitc.lvalue>>) -> !emitc.lvalue { +// CHECK: %[[VAL_12:.*]] = "emitc.member_of_ptr"(%[[VAL_0]]) <{member = "d"}> : (!emitc.lvalue>>) -> !emitc.array<2xi32> +// CHECK: %[[VAL_13:.*]] = subscript %[[VAL_12]]{{\[}}%[[ARG2]]] : (!emitc.array<2xi32>, index) -> !emitc.lvalue +// CHECK: yield %[[VAL_13]] : !emitc.lvalue +// CHECK: } +// CHECK: emitc.assign %[[ARG1]] : i32 to %[[VAL_11]] : +// CHECK: return +// CHECK: } + +func.func @member_of_pointer(%arg0: !emitc.ptr>, %arg1: i32, %arg2: index) { + %var0 = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.lvalue>> + emitc.assign %arg0 : !emitc.ptr> to %var0 : !emitc.lvalue>> + + %0 = "emitc.member_of_ptr" (%var0) {member = "a"} : (!emitc.lvalue>>) -> !emitc.lvalue + emitc.assign %arg1 : i32 to %0 : !emitc.lvalue + + %1 = "emitc.member_of_ptr" (%var0) {member = "b"} : (!emitc.lvalue>>) -> !emitc.lvalue + %2 = emitc.load %1 : !emitc.lvalue + %3 = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.lvalue + emitc.assign %2 : i32 to %3 : !emitc.lvalue + + %4 = "emitc.member_of_ptr" (%var0) {member = "c"} : (!emitc.lvalue>>) -> !emitc.array<2xi32> + %5 = emitc.subscript %4[%arg2] : (!emitc.array<2xi32>, index) -> !emitc.lvalue + %6 = emitc.load %5 : + emitc.assign %6 : i32 to %3 : !emitc.lvalue + + %7 = "emitc.member_of_ptr" (%var0) {member = "d"} : (!emitc.lvalue>>) -> !emitc.array<2xi32> + %8 = emitc.subscript %7[%arg2] : (!emitc.array<2xi32>, index) -> !emitc.lvalue + emitc.assign %arg1 : i32 to %8 : !emitc.lvalue + + return +} + +// CHECK-LABEL: func.func @expression_with_literal( +// CHECK-SAME: %[[ARG0:.*]]: f32) -> f32 { +// CHECK: %[[VAL_0:.*]] = emitc.expression %[[ARG0]] : (f32) -> f32 { +// CHECK: %[[VAL_1:.*]] = literal "M_PI" : f32 +// CHECK: %[[VAL_2:.*]] = add %[[ARG0]], %[[VAL_1]] : (f32, f32) -> f32 +// CHECK: yield %[[VAL_2]] : f32 +// CHECK: } +// CHECK: return %[[VAL_0]] : f32 +// CHECK: } + +func.func @expression_with_literal(%arg0: f32) -> f32 { + %p0 = emitc.literal "M_PI" : f32 + %1 = "emitc.add" (%arg0, %p0) : (f32, f32) -> f32 + return %1 : f32 +} diff --git a/mlir/test/Target/Cpp/expressions.mlir b/mlir/test/Target/Cpp/expressions.mlir index 4281f41d0b3fb..e19945085a87b 100644 --- a/mlir/test/Target/Cpp/expressions.mlir +++ b/mlir/test/Target/Cpp/expressions.mlir @@ -355,6 +355,26 @@ func.func @expression_with_address_taken(%arg0: i32, %arg1: i32, %arg2: !emitc.p return %c : i1 } +// CPP-DEFAULT: int32_t expression_with_subscript(int32_t [[VAL_1:v.+]][4][8], int32_t [[VAL_2:v.+]], int32_t [[VAL_3:v.+]]) +// CPP-DEFAULT-NEXT: int32_t [[VAL_4:v.+]] = [[VAL_1]][[[VAL_2]] + [[VAL_3]]][([[VAL_2]] + [[VAL_3]]) * [[VAL_3]]]; +// CPP-DEFAULT-NEXT: return [[VAL_4]]; + +// CPP-DECLTOP: int32_t expression_with_subscript(int32_t [[VAL_1:v.+]][4][8], int32_t [[VAL_2:v.+]], int32_t [[VAL_3:v.+]]) +// CPP-DECLTOP-NEXT: int32_t [[VAL_4:v.+]]; +// CPP-DECLTOP-NEXT: [[VAL_4]] = [[VAL_1]][[[VAL_2]] + [[VAL_3]]][([[VAL_2]] + [[VAL_3]]) * [[VAL_3]]]; +// CPP-DECLTOP-NEXT: return [[VAL_4]]; + +func.func @expression_with_subscript(%arg0: !emitc.array<4x8xi32>, %arg1: i32, %arg2: i32) -> i32 { + %res = emitc.expression %arg0, %arg1, %arg2 : (!emitc.array<4x8xi32>, i32, i32) -> i32 { + %0 = add %arg1, %arg2 : (i32, i32) -> i32 + %1 = mul %0, %arg2 : (i32, i32) -> i32 + %2 = subscript %arg0[%0, %1] : (!emitc.array<4x8xi32>, i32, i32) -> !emitc.lvalue + %3 = emitc.load %2 : !emitc.lvalue + yield %3 : i32 + } + return %res : i32 +} + // CPP-DEFAULT: int32_t expression_with_subscript_user(void* [[VAL_1:v.+]]) // CPP-DEFAULT-NEXT: int64_t [[VAL_2:v.+]] = 0; // CPP-DEFAULT-NEXT: int32_t* [[VAL_3:v.+]] = (int32_t*) [[VAL_1]]; @@ -458,3 +478,142 @@ emitc.func @expression_with_call_opaque_with_args_array(%0 : i32, %1 : i32) { } return } + +// CPP-DEFAULT: void member(mystruct [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], size_t [[VAL_3:v[0-9]+]]) { +// CPP-DEFAULT-NEXT: mystruct [[VAL_4:v[0-9]+]]; +// CPP-DEFAULT-NEXT: [[VAL_4]] = [[VAL_1]]; +// CPP-DEFAULT-NEXT: [[VAL_4]].a = [[VAL_2]]; +// CPP-DEFAULT-NEXT: int32_t [[VAL_5:v[0-9]+]] = [[VAL_4]].b; +// CPP-DEFAULT-NEXT: int32_t [[VAL_6:v[0-9]+]]; +// CPP-DEFAULT-NEXT: [[VAL_6]] = [[VAL_5]]; +// CPP-DEFAULT-NEXT: int32_t [[VAL_7:v[0-9]+]] = ([[VAL_4]].c)[[[VAL_3]]]; +// CPP-DEFAULT-NEXT: [[VAL_6]] = [[VAL_7]]; +// CPP-DEFAULT-NEXT: ([[VAL_4]].d)[[[VAL_3]]] = [[VAL_2]]; +// CPP-DEFAULT-NEXT: return; +// CPP-DEFAULT-NEXT: } + +// CPP-DECLTOP: void member(mystruct [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], size_t [[VAL_3:v[0-9]+]]) { +// CPP-DECLTOP-NEXT: mystruct [[VAL_4:v[0-9]+]]; +// CPP-DECLTOP-NEXT: int32_t [[VAL_5:v[0-9]+]]; +// CPP-DECLTOP-NEXT: int32_t [[VAL_6:v[0-9]+]]; +// CPP-DECLTOP-NEXT: int32_t [[VAL_7:v[0-9]+]]; +// CPP-DECLTOP-NEXT: ; +// CPP-DECLTOP-NEXT: [[VAL_4]] = [[VAL_1]]; +// CPP-DECLTOP-NEXT: [[VAL_4]].a = [[VAL_2]]; +// CPP-DECLTOP-NEXT: [[VAL_5]] = [[VAL_4]].b; +// CPP-DECLTOP-NEXT: ; +// CPP-DECLTOP-NEXT: [[VAL_6]] = [[VAL_5]]; +// CPP-DECLTOP-NEXT: [[VAL_7]] = ([[VAL_4]].c)[[[VAL_3]]]; +// CPP-DECLTOP-NEXT: [[VAL_6]] = [[VAL_7]]; +// CPP-DECLTOP-NEXT: ([[VAL_4]].d)[[[VAL_3]]] = [[VAL_2]]; +// CPP-DECLTOP-NEXT: return; +// CPP-DECLTOP-NEXT: } + +func.func @member(%arg0: !emitc.opaque<"mystruct">, %arg1: i32, %arg2: index) { + %0 = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.lvalue> + emitc.assign %arg0 : !emitc.opaque<"mystruct"> to %0 : > + %1 = emitc.expression %0 : (!emitc.lvalue>) -> !emitc.lvalue { + %6 = "emitc.member"(%0) <{member = "a"}> : (!emitc.lvalue>) -> !emitc.lvalue + yield %6 : !emitc.lvalue + } + emitc.assign %arg1 : i32 to %1 : + %2 = emitc.expression %0 : (!emitc.lvalue>) -> i32 { + %6 = "emitc.member"(%0) <{member = "b"}> : (!emitc.lvalue>) -> !emitc.lvalue + %7 = load %6 : + yield %7 : i32 + } + %3 = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.lvalue + emitc.assign %2 : i32 to %3 : + %4 = emitc.expression %arg2, %0 : (index, !emitc.lvalue>) -> i32 { + %6 = "emitc.member"(%0) <{member = "c"}> : (!emitc.lvalue>) -> !emitc.array<2xi32> + %7 = subscript %6[%arg2] : (!emitc.array<2xi32>, index) -> !emitc.lvalue + %8 = load %7 : + yield %8 : i32 + } + emitc.assign %4 : i32 to %3 : + %5 = emitc.expression %arg2, %0 : (index, !emitc.lvalue>) -> !emitc.lvalue { + %6 = "emitc.member"(%0) <{member = "d"}> : (!emitc.lvalue>) -> !emitc.array<2xi32> + %7 = subscript %6[%arg2] : (!emitc.array<2xi32>, index) -> !emitc.lvalue + yield %7 : !emitc.lvalue + } + emitc.assign %arg1 : i32 to %5 : + return +} + +// CPP-DEFAULT: void member_of_pointer(mystruct* [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], size_t [[VAL_3:v[0-9]+]]) { +// CPP-DEFAULT-NEXT: mystruct* [[VAL_4:v[0-9]+]]; +// CPP-DEFAULT-NEXT: [[VAL_4]] = [[VAL_1]]; +// CPP-DEFAULT-NEXT: [[VAL_4]]->a = [[VAL_2]]; +// CPP-DEFAULT-NEXT: int32_t [[VAL_5:v[0-9]+]] = [[VAL_4]]->b; +// CPP-DEFAULT-NEXT: int32_t [[VAL_6:v[0-9]+]]; +// CPP-DEFAULT-NEXT: [[VAL_6]] = [[VAL_5]]; +// CPP-DEFAULT-NEXT: int32_t [[VAL_7:v[0-9]+]] = ([[VAL_4]]->c)[[[VAL_3]]]; +// CPP-DEFAULT-NEXT: [[VAL_6]] = [[VAL_7]]; +// CPP-DEFAULT-NEXT: ([[VAL_4]]->d)[[[VAL_3]]] = [[VAL_2]]; +// CPP-DEFAULT-NEXT: return; +// CPP-DEFAULT-NEXT: } + +// CPP-DECLTOP: void member_of_pointer(mystruct* [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], size_t [[VAL_3:v[0-9]+]]) { +// CPP-DECLTOP-NEXT: mystruct* [[VAL_4:v[0-9]+]]; +// CPP-DECLTOP-NEXT: int32_t [[VAL_5:v[0-9]+]]; +// CPP-DECLTOP-NEXT: int32_t [[VAL_6:v[0-9]+]]; +// CPP-DECLTOP-NEXT: int32_t [[VAL_7:v[0-9]+]]; +// CPP-DECLTOP-NEXT: ; +// CPP-DECLTOP-NEXT: [[VAL_4]] = [[VAL_1]]; +// CPP-DECLTOP-NEXT: [[VAL_4]]->a = [[VAL_2]]; +// CPP-DECLTOP-NEXT: [[VAL_5]] = [[VAL_4]]->b; +// CPP-DECLTOP-NEXT: ; +// CPP-DECLTOP-NEXT: [[VAL_6]] = [[VAL_5]]; +// CPP-DECLTOP-NEXT: [[VAL_7]] = ([[VAL_4]]->c)[[[VAL_3]]]; +// CPP-DECLTOP-NEXT: [[VAL_6]] = [[VAL_7]]; +// CPP-DECLTOP-NEXT: ([[VAL_4]]->d)[[[VAL_3]]] = [[VAL_2]]; +// CPP-DECLTOP-NEXT: return; +// CPP-DECLTOP-NEXT: } + +func.func @member_of_pointer(%arg0: !emitc.ptr>, %arg1: i32, %arg2: index) { + %0 = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.lvalue>> + emitc.assign %arg0 : !emitc.ptr> to %0 : >> + %1 = emitc.expression %0 : (!emitc.lvalue>>) -> !emitc.lvalue { + %6 = "emitc.member_of_ptr"(%0) <{member = "a"}> : (!emitc.lvalue>>) -> !emitc.lvalue + yield %6 : !emitc.lvalue + } + emitc.assign %arg1 : i32 to %1 : + %2 = emitc.expression %0 : (!emitc.lvalue>>) -> i32 { + %6 = "emitc.member_of_ptr"(%0) <{member = "b"}> : (!emitc.lvalue>>) -> !emitc.lvalue + %7 = load %6 : + yield %7 : i32 + } + %3 = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.lvalue + emitc.assign %2 : i32 to %3 : + %4 = emitc.expression %arg2, %0 : (index, !emitc.lvalue>>) -> i32 { + %6 = "emitc.member_of_ptr"(%0) <{member = "c"}> : (!emitc.lvalue>>) -> !emitc.array<2xi32> + %7 = subscript %6[%arg2] : (!emitc.array<2xi32>, index) -> !emitc.lvalue + %8 = load %7 : + yield %8 : i32 + } + emitc.assign %4 : i32 to %3 : + %5 = emitc.expression %arg2, %0 : (index, !emitc.lvalue>>) -> !emitc.lvalue { + %6 = "emitc.member_of_ptr"(%0) <{member = "d"}> : (!emitc.lvalue>>) -> !emitc.array<2xi32> + %7 = subscript %6[%arg2] : (!emitc.array<2xi32>, index) -> !emitc.lvalue + yield %7 : !emitc.lvalue + } + emitc.assign %arg1 : i32 to %5 : + return +} + +// CPP-DEFAULT: float expression_with_literal(float [[VAL_1:v[0-9]+]]) { +// CPP-DEFAULT-NEXT: return [[VAL_1]] + M_PI; +// CPP-DEFAULT-NEXT: } + +// CPP-DECLTOP: float expression_with_literal(float [[VAL_1:v[0-9]+]]) { +// CPP-DECLTOP-NEXT: return [[VAL_1]] + M_PI; +// CPP-DECLTOP-NEXT: } + +func.func @expression_with_literal(%arg0: f32) -> f32 { + %0 = emitc.expression %arg0 : (f32) -> f32 { + %1 = literal "M_PI" : f32 + %2 = add %arg0, %1 : (f32, f32) -> f32 + yield %2 : f32 + } + return %0 : f32 +} diff --git a/mlir/test/Target/Cpp/member.mlir b/mlir/test/Target/Cpp/member.mlir index 6e0395250afbd..45b6336f63ce0 100644 --- a/mlir/test/Target/Cpp/member.mlir +++ b/mlir/test/Target/Cpp/member.mlir @@ -31,9 +31,9 @@ func.func @member(%arg0: !emitc.opaque<"mystruct">, %arg1: i32, %arg2: index) { // CPP-DEFAULT-NEXT: int32_t [[V3:[^ ]*]] = [[V2]].b; // CPP-DEFAULT-NEXT: int32_t [[V4:[^ ]*]]; // CPP-DEFAULT-NEXT: [[V4]] = [[V3]]; -// CPP-DEFAULT-NEXT: int32_t [[V5:[^ ]*]] = [[V2]].c[[[Index]]]; +// CPP-DEFAULT-NEXT: int32_t [[V5:[^ ]*]] = ([[V2]].c)[[[Index]]]; // CPP-DEFAULT-NEXT: [[V4]] = [[V5]]; -// CPP-DEFAULT-NEXT: [[V2]].d[[[Index]]] = [[V1]]; +// CPP-DEFAULT-NEXT: ([[V2]].d)[[[Index]]] = [[V1]]; func.func @member_of_pointer(%arg0: !emitc.ptr>, %arg1: i32, %arg2: index) { @@ -67,6 +67,6 @@ func.func @member_of_pointer(%arg0: !emitc.ptr>, %arg1 // CPP-DEFAULT-NEXT: int32_t [[V3:[^ ]*]] = [[V2]]->b; // CPP-DEFAULT-NEXT: int32_t [[V4:[^ ]*]]; // CPP-DEFAULT-NEXT: [[V4]] = [[V3]]; -// CPP-DEFAULT-NEXT: int32_t [[V5:[^ ]*]] = [[V2]]->c[[[Index]]]; +// CPP-DEFAULT-NEXT: int32_t [[V5:[^ ]*]] = ([[V2]]->c)[[[Index]]]; // CPP-DEFAULT-NEXT: [[V4]] = [[V5]]; -// CPP-DEFAULT-NEXT: [[V2]]->d[[[Index]]] = [[V1]]; +// CPP-DEFAULT-NEXT: ([[V2]]->d)[[[Index]]] = [[V1]];