diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td index 721f9f6b320ad..ebd4850161894 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td @@ -507,7 +507,7 @@ def EmitC_ExpressionOp let arguments = (ins Variadic>:$defs, UnitAttr:$do_not_inline); - let results = (outs EmitCType:$result); + let results = (outs AnyTypeOf<[EmitCType, EmitC_LValueType]>:$result); let regions = (region SizedRegion<1>:$region); let hasVerifier = 1; @@ -873,7 +873,7 @@ def EmitC_IncludeOp let hasCustomAssemblyFormat = 1; } -def EmitC_LiteralOp : EmitC_Op<"literal", [Pure]> { +def EmitC_LiteralOp : EmitC_Op<"literal", [Pure, CExpressionInterface]> { let summary = "Literal operation"; let description = [{ The `emitc.literal` operation produces an SSA value equal to some constant @@ -896,6 +896,15 @@ def EmitC_LiteralOp : EmitC_Op<"literal", [Pure]> { let hasVerifier = 1; let assemblyFormat = "$value attr-dict `:` type($result)"; + + let extraClassDeclaration = [{ + bool hasSideEffects() { + return false; + } + bool alwaysInline() { + return true; // C doesn't support variable references. + } + }]; } def EmitC_LogicalAndOp : EmitC_BinaryOp<"logical_and", []> { @@ -1062,7 +1071,7 @@ def EmitC_SubOp : EmitC_BinaryOp<"sub", []> { let hasVerifier = 1; } -def EmitC_MemberOp : EmitC_Op<"member"> { +def EmitC_MemberOp : EmitC_Op<"member", [CExpressionInterface]> { let summary = "Member operation"; let description = [{ With the `emitc.member` operation the member access operator `.` can be @@ -1083,9 +1092,18 @@ def EmitC_MemberOp : EmitC_Op<"member"> { EmitC_LValueOf<[EmitC_OpaqueType]>:$operand ); let results = (outs AnyTypeOf<[EmitC_ArrayType, EmitC_LValueType]>); + + let extraClassDeclaration = [{ + bool hasSideEffects() { + return false; + } + bool alwaysInline() { + return true; // C doesn't support variable references. + } + }]; } -def EmitC_MemberOfPtrOp : EmitC_Op<"member_of_ptr"> { +def EmitC_MemberOfPtrOp : EmitC_Op<"member_of_ptr", [CExpressionInterface]> { let summary = "Member of pointer operation"; let description = [{ With the `emitc.member_of_ptr` operation the member access operator `->` @@ -1108,6 +1126,15 @@ def EmitC_MemberOfPtrOp : EmitC_Op<"member_of_ptr"> { EmitC_LValueOf<[EmitC_OpaqueType,EmitC_PointerType]>:$operand ); let results = (outs AnyTypeOf<[EmitC_ArrayType, EmitC_LValueType]>); + + let extraClassDeclaration = [{ + bool hasSideEffects() { + return false; + } + bool alwaysInline() { + return true; // C doesn't support variable references. + } + }]; } def EmitC_ConditionalOp : EmitC_Op<"conditional", @@ -1277,8 +1304,10 @@ def EmitC_GlobalOp : EmitC_Op<"global", [Symbol]> { let hasVerifier = 1; } -def EmitC_GetGlobalOp : EmitC_Op<"get_global", - [Pure, DeclareOpInterfaceMethods]> { +def EmitC_GetGlobalOp + : EmitC_Op<"get_global", [Pure, + DeclareOpInterfaceMethods, + CExpressionInterface]> { let summary = "Obtain access to a global variable"; let description = [{ The `emitc.get_global` operation retrieves the lvalue of a @@ -1296,6 +1325,15 @@ def EmitC_GetGlobalOp : EmitC_Op<"get_global", let arguments = (ins FlatSymbolRefAttr:$name); let results = (outs AnyTypeOf<[EmitC_ArrayType, EmitC_LValueType]>:$result); let assemblyFormat = "$name `:` type($result) attr-dict"; + + let extraClassDeclaration = [{ + bool hasSideEffects() { + return false; + } + bool alwaysInline() { + return true; // C doesn't support variable references. + } + }]; } def EmitC_VerbatimOp : EmitC_Op<"verbatim"> { @@ -1406,7 +1444,8 @@ def EmitC_YieldOp : EmitC_Op<"yield", value is yielded. }]; - let arguments = (ins Optional:$result); + let arguments = + (ins Optional>:$result); let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>]; let hasVerifier = 1; @@ -1477,7 +1516,7 @@ def EmitC_IfOp : EmitC_Op<"if", let hasCustomAssemblyFormat = 1; } -def EmitC_SubscriptOp : EmitC_Op<"subscript", []> { +def EmitC_SubscriptOp : EmitC_Op<"subscript", [CExpressionInterface]> { let summary = "Subscript operation"; let description = [{ With the `emitc.subscript` operation the subscript operator `[]` can be applied @@ -1525,6 +1564,15 @@ def EmitC_SubscriptOp : EmitC_Op<"subscript", []> { let hasVerifier = 1; let assemblyFormat = "$value `[` $indices `]` attr-dict `:` functional-type(operands, results)"; + + let extraClassDeclaration = [{ + bool hasSideEffects() { + return false; + } + bool alwaysInline() { + return true; // C doesn't support variable references. + } + }]; } def EmitC_SwitchOp : EmitC_Op<"switch", [RecursiveMemoryEffects, @@ -1707,8 +1755,9 @@ def EmitC_FieldOp : EmitC_Op<"field", [Symbol]> { } def EmitC_GetFieldOp - : EmitC_Op<"get_field", [Pure, DeclareOpInterfaceMethods< - SymbolUserOpInterface>]> { + : EmitC_Op<"get_field", [Pure, + DeclareOpInterfaceMethods, + CExpressionInterface]> { let summary = "Obtain access to a field within a class instance"; let description = [{ The `emitc.get_field` operation retrieves the lvalue of a @@ -1725,6 +1774,15 @@ def EmitC_GetFieldOp let results = (outs EmitCType:$result); let assemblyFormat = "$field_name `:` type($result) attr-dict"; let hasVerifier = 1; + + let extraClassDeclaration = [{ + bool hasSideEffects() { + return false; + } + bool alwaysInline() { + return true; // C doesn't support variable references. + } + }]; } #endif // MLIR_DIALECT_EMITC_IR_EMITC diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitCInterfaces.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitCInterfaces.td index 777784e56202a..c11e017e40d0f 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitCInterfaces.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitCInterfaces.td @@ -21,8 +21,8 @@ def CExpressionInterface : OpInterface<"CExpressionInterface"> { }]; let cppNamespace = "::mlir::emitc"; - let methods = [ - InterfaceMethod<[{ + let methods = + [InterfaceMethod<[{ Check whether operation has side effects that may affect the expression evaluation. @@ -38,9 +38,28 @@ def CExpressionInterface : OpInterface<"CExpressionInterface"> { }; ``` }], - "bool", "hasSideEffects", (ins), /*methodBody=*/[{}], - /*defaultImplementation=*/[{ + "bool", "hasSideEffects", (ins), /*methodBody=*/[{}], + /*defaultImplementation=*/[{ return true; + }]>, + InterfaceMethod<[{ + Check whether operation must be inlined into all its users. + + By default operation is not marked as always inlined. + + ```c++ + class ConcreteOp ... { + public: + bool alwaysInline() { + // That way we can override the default implementation. + return true; + } + }; + ``` + }], + "bool", "alwaysInline", (ins), /*methodBody=*/[{}], + /*defaultImplementation=*/[{ + return false; }]>, ]; } diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index 5c8564bca6f86..4468ac686aa8b 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -483,7 +483,8 @@ LogicalResult ExpressionOp::verify() { Operation *op = worklist.back(); worklist.pop_back(); if (visited.contains(op)) { - if (cast(op).hasSideEffects()) + auto cExpr = cast(op); + if (!cExpr.alwaysInline() && cExpr.hasSideEffects()) return emitOpError( "requires exactly one use for operations with side effects"); } @@ -494,6 +495,12 @@ LogicalResult ExpressionOp::verify() { } } + if (getDoNotInline() && + cast(rootOp).alwaysInline()) { + return emitOpError("root operation must be inlined but expression is marked" + " do-not-inline"); + } + return success(); } @@ -980,6 +987,10 @@ LogicalResult emitc::YieldOp::verify() { if (!result && containingOp->getNumResults() != 0) return emitOpError() << "does not yield a value to be returned by parent"; + if (result && isa(result.getType()) && + !isa(containingOp)) + return emitOpError() << "yielding lvalues is not supported for this op"; + return success(); } diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index a5bd80e9d6b8b..7541845ff6242 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/Dialect/EmitC/IR/EmitCInterfaces.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" @@ -99,13 +100,18 @@ static FailureOr getOperatorPrecedence(Operation *operation) { .Case([&](auto op) { return 2; }) .Case([&](auto op) { return 17; }) .Case([&](auto op) { return 13; }) + .Case([&](auto op) { return 18; }) + .Case([&](auto op) { return 18; }) .Case([&](auto op) { return 16; }) .Case([&](auto op) { return 4; }) .Case([&](auto op) { return 15; }) .Case([&](auto op) { return 3; }) + .Case([&](auto op) { return 17; }) + .Case([&](auto op) { return 17; }) .Case([&](auto op) { return 13; }) .Case([&](auto op) { return 13; }) .Case([&](auto op) { return 12; }) + .Case([&](auto op) { return 17; }) .Case([&](auto op) { return 15; }) .Case([&](auto op) { return 15; }) .Default([](auto op) { return op->emitError("unsupported operation"); }); @@ -173,31 +179,27 @@ struct CppEmitter { /// Emits the operands of the operation. All operands are emitted in order. LogicalResult emitOperands(Operation &op); - /// Emits value as an operands of an operation - LogicalResult emitOperand(Value value); + /// Emits value as an operand of an operation. If \p isInBrackets is true, + /// this operand is already being emitted between some kind of brackets, so + /// there is no need to wrap it in parentheses for correct precedence. + LogicalResult emitOperand(Value value, bool isInBrackets = false); - /// Emit an expression as a C expression. - LogicalResult emitExpression(ExpressionOp expressionOp); + /// Collect all operations to emit as an expression starting at \p op, + /// recursively adding operands that shouldBeInlined. + void buildExpression(Operation *op); - /// Insert the expression representing the operation into the value cache. - void cacheDeferredOpResult(Value value, StringRef str); + /// Emit an expression as a C expression. + LogicalResult emitExpression(Operation *op); /// Return the existing or a new name for a Value. StringRef getOrCreateName(Value val); + void setName(Value val, StringRef name); + /// Return the existing or a new name for a loop induction variable of an /// emitc::ForOp. StringRef getOrCreateInductionVarName(Value val); - // Returns the textual representation of a subscript operation. - std::string getSubscriptName(emitc::SubscriptOp op); - - // Returns the textual representation of a member (of object) operation. - std::string createMemberAccess(emitc::MemberOp op); - - // Returns the textual representation of a member of pointer operation. - std::string createMemberAccess(emitc::MemberOfPtrOp op); - /// Return the existing or a new label of a Block. StringRef getOrCreateName(Block &block); @@ -257,25 +259,24 @@ struct CppEmitter { return !fileId.empty() && file.getId() == fileId; } - /// Get expression currently being emitted. - ExpressionOp getEmittedExpression() { return emittedExpression; } + /// Is expression currently being emitted. + bool isEmittingExpression() { return !emittedExpression.empty(); } /// Determine whether given value is part of the expression potentially being /// emitted. - bool isPartOfCurrentExpression(Value value) { - if (!emittedExpression) - return false; + Operation *isPartOfCurrentExpression(Value value) { Operation *def = value.getDefiningOp(); - if (!def) - return false; - return isPartOfCurrentExpression(def); + if (def) + return isPartOfCurrentExpression(def) ? def : nullptr; + return nullptr; } /// Determine whether given operation is part of the expression potentially /// being emitted. bool isPartOfCurrentExpression(Operation *def) { - auto operandExpression = dyn_cast(def->getParentOp()); - return operandExpression && operandExpression == emittedExpression; + if (auto parentExpression = dyn_cast(def->getParentOp())) + def = parentExpression; + return emittedExpression.contains(def); }; // Resets the value counter to 0. @@ -322,7 +323,7 @@ struct CppEmitter { unsigned int valueCount{0}; /// State of the current expression being emitted. - ExpressionOp emittedExpression; + SmallPtrSet emittedExpression; SmallVector emittedExpressionPrecedence; void pushExpressionPrecedence(int precedence) { @@ -338,19 +339,35 @@ struct CppEmitter { }; } // namespace -/// Determine whether expression \p op should be emitted in a deferred way. -static bool hasDeferredEmission(Operation *op) { - return isa_and_nonnull(op); -} - -/// Determine whether expression \p expressionOp should be emitted inline, i.e. -/// as part of its user. This function recommends inlining of any expressions -/// that can be inlined unless it is used by another expression, under the -/// assumption that any expression fusion/re-materialization was taken care of +/// Determine whether operation \p operation should be emitted inline, i.e. +/// as part of its user. +/// The operation can force inlining using its +/// CExpressionInterface::alwaysInline() method when it's not included in any +/// ExpressionOp or is some ExpressionOp's root operation. +/// Otherwise, for any ExpressionOp that can be inlined, this function +/// recommends inlining unless it is used by another expression, under the +/// assumption that any expression fusion/re-materialization was taken care of /// by transformations run by the backend. -static bool shouldBeInlined(ExpressionOp expressionOp) { +static bool shouldBeInlined(Operation *operation) { + auto expressionInterface = dyn_cast(operation); + + // Inline if this is an always-inline CExpression not part of any + // expression. + if (expressionInterface && expressionInterface.alwaysInline()) { + assert(!isa(operation->getParentOp()) && + "Unexpectedly called on operation included in expression"); + return true; + } + + ExpressionOp expressionOp = dyn_cast(operation); + + if (!expressionOp) + return false; + + // Inline if the root operation is an always-inline CExpression. + if (cast(expressionOp.getRootOp()).alwaysInline()) + return true; + // Do not inline if expression is marked as such. if (expressionOp.getDoNotInline()) return false; @@ -367,17 +384,83 @@ static bool shouldBeInlined(ExpressionOp expressionOp) { Operation *user = *result.getUsers().begin(); - // Do not inline expressions used by operations with deferred emission, since - // their translation requires the materialization of variables. - if (hasDeferredEmission(user)) - return false; - // Do not inline expressions used by other expressions or by ops with the // CExpressionInterface. If this was intended, the user could have been merged // into the expression op. return !isa(*user); } +static LogicalResult printOperation(CppEmitter &emitter, + emitc::GetFieldOp getFieldOp) { + if (!emitter.isPartOfCurrentExpression(getFieldOp.getOperation())) { + emitter.setName(getFieldOp.getResult(), getFieldOp.getFieldName()); + return success(); + } + + emitter.ostream() << getFieldOp.getFieldName(); + return success(); +} + +static LogicalResult printOperation(CppEmitter &emitter, + emitc::GetGlobalOp getGlobalOp) { + if (!emitter.isPartOfCurrentExpression(getGlobalOp.getOperation())) { + emitter.setName(getGlobalOp.getResult(), getGlobalOp.getName()); + return success(); + } + + emitter.ostream() << getGlobalOp.getName(); + return success(); +} + +static LogicalResult printOperation(CppEmitter &emitter, + emitc::LiteralOp literalOp) { + if (!emitter.isPartOfCurrentExpression(literalOp.getOperation())) + return success(); + + emitter.ostream() << literalOp.getValue(); + return success(); +} + +static LogicalResult printOperation(CppEmitter &emitter, + emitc::MemberOp memberOp) { + if (!emitter.isPartOfCurrentExpression(memberOp.getOperation())) + return success(); + + if (failed(emitter.emitOperand(memberOp.getOperand()))) + return failure(); + emitter.ostream() << "." << memberOp.getMember(); + return success(); +} + +static LogicalResult printOperation(CppEmitter &emitter, + emitc::MemberOfPtrOp memberOfPtrOp) { + if (!emitter.isPartOfCurrentExpression(memberOfPtrOp.getOperation())) + return success(); + + if (failed(emitter.emitOperand(memberOfPtrOp.getOperand()))) + return failure(); + emitter.ostream() << "->" << memberOfPtrOp.getMember(); + return success(); +} + +static LogicalResult printOperation(CppEmitter &emitter, + emitc::SubscriptOp subscriptOp) { + if (!emitter.isPartOfCurrentExpression(subscriptOp.getOperation())) { + return success(); + } + + raw_ostream &os = emitter.ostream(); + if (failed(emitter.emitOperand(subscriptOp.getValue()))) + return failure(); + for (auto index : subscriptOp.getIndices()) { + os << "["; + if (failed(emitter.emitOperand(index, /*isInBrackets=*/true))) + return failure(); + os << "]"; + } + return success(); +} + static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation, Attribute value) { OpResult result = operation->getResult(0); @@ -437,11 +520,11 @@ static LogicalResult printOperation(CppEmitter &emitter, static LogicalResult printOperation(CppEmitter &emitter, emitc::AssignOp assignOp) { - OpResult result = assignOp.getVar().getDefiningOp()->getResult(0); - - if (failed(emitter.emitVariableAssignment(result))) + if (failed(emitter.emitOperand(assignOp.getVar()))) return failure(); + emitter.ostream() << " = "; + return emitter.emitOperand(assignOp.getValue()); } @@ -1288,44 +1371,14 @@ CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop, labelInScopeCount.push(0); } -std::string CppEmitter::getSubscriptName(emitc::SubscriptOp op) { - std::string out; - llvm::raw_string_ostream ss(out); - ss << getOrCreateName(op.getValue()); - for (auto index : op.getIndices()) { - ss << "[" << getOrCreateName(index) << "]"; - } - return out; -} - -std::string CppEmitter::createMemberAccess(emitc::MemberOp op) { - std::string out; - llvm::raw_string_ostream ss(out); - ss << getOrCreateName(op.getOperand()); - ss << "." << op.getMember(); - return out; -} - -std::string CppEmitter::createMemberAccess(emitc::MemberOfPtrOp op) { - std::string out; - llvm::raw_string_ostream ss(out); - ss << getOrCreateName(op.getOperand()); - ss << "->" << op.getMember(); - return out; -} - -void CppEmitter::cacheDeferredOpResult(Value value, StringRef str) { - if (!valueMapper.count(value)) - valueMapper.insert(value, str.str()); +void CppEmitter::setName(Value val, StringRef name) { + assert(!valueMapper.count(val) && "Expected value not to be in mapper"); + valueMapper.insert(val, name.str()); } /// Return the existing or a new name for a Value. StringRef CppEmitter::getOrCreateName(Value val) { if (!valueMapper.count(val)) { - assert(!hasDeferredEmission(val.getDefiningOp()) && - "cacheDeferredOpResult should have been called on this value, " - "update the emitOperation function."); - valueMapper.insert(val, formatv("v{0}", ++valueCount)); } return *valueMapper.begin(val); @@ -1492,12 +1545,37 @@ LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) { return emitError(loc, "cannot emit attribute: ") << attr; } -LogicalResult CppEmitter::emitExpression(ExpressionOp expressionOp) { +void CppEmitter::buildExpression(Operation *op) { + emittedExpression.insert(op); + + for (Value operand : op->getOperands()) { + Operation *defOp = operand.getDefiningOp(); + if (defOp && shouldBeInlined(defOp)) + buildExpression(defOp); + } +} + +LogicalResult CppEmitter::emitExpression(Operation *op) { assert(emittedExpressionPrecedence.empty() && "Expected precedence stack to be empty"); - Operation *rootOp = expressionOp.getRootOp(); + assert(emittedExpression.empty() && + "Expected sub-expressions set to be empty"); + + Operation *rootOp = nullptr; + + if (auto expressionOp = dyn_cast(op)) { + rootOp = expressionOp.getRootOp(); + } else { + auto expressionInterface = cast(op); + assert(expressionInterface.alwaysInline() && + "Expected an always-inline operation"); + assert(!isa(op->getParentOp()) && + "Expected operation to have no containing expression"); + rootOp = op; + } + + buildExpression(op); - emittedExpression = expressionOp; FailureOr precedence = getOperatorPrecedence(rootOp); if (failed(precedence)) return failure(); @@ -1509,14 +1587,13 @@ LogicalResult CppEmitter::emitExpression(ExpressionOp expressionOp) { popExpressionPrecedence(); assert(emittedExpressionPrecedence.empty() && "Expected precedence stack to be empty"); - emittedExpression = nullptr; + emittedExpression.clear(); return success(); } -LogicalResult CppEmitter::emitOperand(Value value) { - if (isPartOfCurrentExpression(value)) { - Operation *def = value.getDefiningOp(); +LogicalResult CppEmitter::emitOperand(Value value, bool isInBrackets) { + if (Operation *def = isPartOfCurrentExpression(value)) { assert(def && "Expected operand to be defined by an operation"); FailureOr precedence = getOperatorPrecedence(def); if (failed(precedence)) @@ -1525,7 +1602,8 @@ LogicalResult CppEmitter::emitOperand(Value value) { // Sub-expressions with equal or lower precedence need to be parenthesized, // as they might be evaluated in the wrong order depending on the shape of // the expression tree. - bool encloseInParenthesis = precedence.value() <= getExpressionPrecedence(); + bool encloseInParenthesis = + !isInBrackets && precedence.value() <= getExpressionPrecedence(); if (encloseInParenthesis) os << "("; pushExpressionPrecedence(precedence.value()); @@ -1540,22 +1618,23 @@ LogicalResult CppEmitter::emitOperand(Value value) { return success(); } - auto expressionOp = value.getDefiningOp(); - if (expressionOp && shouldBeInlined(expressionOp)) - return emitExpression(expressionOp); + if (Operation *defOp = value.getDefiningOp()) { + auto expressionOp = dyn_cast(defOp); + auto expressionInterface = dyn_cast(defOp); + + if (expressionInterface) { + if (expressionInterface.alwaysInline()) + return emitExpression(defOp); + } else if (expressionOp && shouldBeInlined(expressionOp)) + return emitExpression(expressionOp); + } if (BlockArgument arg = dyn_cast(value)) { // If this operand is a block argument of an expression, emit instead the // matching expression parameter. Operation *argOp = arg.getParentBlock()->getParentOp(); - if (auto expressionOp = dyn_cast(argOp)) { - // This scenario is only expected when one of the operations within the - // expression being emitted references one of the expression's block - // arguments. - assert(expressionOp == emittedExpression && - "Expected expression being emitted"); - value = expressionOp->getOperand(arg.getArgNumber()); - } + if (auto expressionOp = dyn_cast(argOp)) + return emitOperand(expressionOp->getOperand(arg.getArgNumber())); } os << getOrCreateName(value); @@ -1564,14 +1643,8 @@ LogicalResult CppEmitter::emitOperand(Value value) { LogicalResult CppEmitter::emitOperands(Operation &op) { return interleaveCommaWithError(op.getOperands(), os, [&](Value operand) { - // If an expression is being emitted, push lowest precedence as these - // operands are either wrapped by parenthesis. - if (getEmittedExpression()) - pushExpressionPrecedence(lowestPrecedence()); - if (failed(emitOperand(operand))) + if (failed(emitOperand(operand, /*isInBrackets=*/true))) return failure(); - if (getEmittedExpression()) - popExpressionPrecedence(); return success(); }); } @@ -1613,7 +1686,9 @@ LogicalResult CppEmitter::emitVariableAssignment(OpResult result) { LogicalResult CppEmitter::emitVariableDeclaration(OpResult result, bool trailingSemicolon) { - if (hasDeferredEmission(result.getDefiningOp())) + auto expressionInterface = + dyn_cast(result.getDefiningOp()); + if (expressionInterface && expressionInterface.alwaysInline()) return success(); if (hasValueInScope(result)) { return result.getDefiningOp()->emitError( @@ -1654,7 +1729,7 @@ LogicalResult CppEmitter::emitGlobalVariable(GlobalOp op) { LogicalResult CppEmitter::emitAssignPrefix(Operation &op) { // If op is being emitted as part of an expression, bail out. - if (getEmittedExpression()) + if (isEmittingExpression()) return success(); switch (op.getNumResults()) { @@ -1713,9 +1788,11 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { emitc::CmpOp, emitc::ConditionalOp, emitc::ConstantOp, emitc::DeclareFuncOp, emitc::DivOp, emitc::ExpressionOp, emitc::FieldOp, emitc::FileOp, emitc::ForOp, emitc::FuncOp, - emitc::GlobalOp, emitc::IfOp, emitc::IncludeOp, emitc::LoadOp, + emitc::GetFieldOp, emitc::GetGlobalOp, emitc::GlobalOp, + emitc::IfOp, emitc::IncludeOp, emitc::LiteralOp, emitc::LoadOp, emitc::LogicalAndOp, emitc::LogicalNotOp, emitc::LogicalOrOp, - emitc::MulOp, emitc::RemOp, emitc::ReturnOp, emitc::SubOp, + emitc::MemberOfPtrOp, emitc::MemberOp, emitc::MulOp, + emitc::RemOp, emitc::ReturnOp, emitc::SubscriptOp, emitc::SubOp, emitc::SwitchOp, emitc::UnaryMinusOp, emitc::UnaryPlusOp, emitc::VariableOp, emitc::VerbatimOp>( @@ -1723,30 +1800,6 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { // Func ops. .Case( [&](auto op) { return printOperation(*this, op); }) - .Case([&](auto op) { - cacheDeferredOpResult(op.getResult(), op.getName()); - return success(); - }) - .Case([&](auto op) { - cacheDeferredOpResult(op.getResult(), op.getFieldName()); - return success(); - }) - .Case([&](auto op) { - cacheDeferredOpResult(op.getResult(), op.getValue()); - return success(); - }) - .Case([&](auto op) { - cacheDeferredOpResult(op.getResult(), createMemberAccess(op)); - return success(); - }) - .Case([&](auto op) { - cacheDeferredOpResult(op.getResult(), createMemberAccess(op)); - return success(); - }) - .Case([&](auto op) { - cacheDeferredOpResult(op.getResult(), getSubscriptName(op)); - return success(); - }) .Default([&](Operation *) { return op.emitOpError("unable to find printer for op"); }); @@ -1754,10 +1807,11 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { if (failed(status)) return failure(); - if (hasDeferredEmission(&op)) + auto expressionInterface = dyn_cast(op); + if (expressionInterface && expressionInterface.alwaysInline()) return success(); - if (getEmittedExpression() || + if (isEmittingExpression() || (isa(op) && shouldBeInlined(cast(op)))) return success(); diff --git a/mlir/test/Dialect/EmitC/form-expressions.mlir b/mlir/test/Dialect/EmitC/form-expressions.mlir index 7b6723989e260..291f8e24693b7 100644 --- a/mlir/test/Dialect/EmitC/form-expressions.mlir +++ b/mlir/test/Dialect/EmitC/form-expressions.mlir @@ -131,17 +131,14 @@ func.func @single_result_requirement() -> (i32, i32) { // CHECK-LABEL: func.func @expression_with_load( // CHECK-SAME: %[[VAL_0:.*]]: i32, // CHECK-SAME: %[[VAL_1:.*]]: !emitc.ptr) -> i1 { -// CHECK: %[[VAL_2:.*]] = emitc.expression : () -> i64 { -// CHECK: %[[VAL_C:.*]] = "emitc.constant"() <{value = 0 : i64}> : () -> i64 -// CHECK: yield %[[VAL_C]] : i64 -// CHECK: } // CHECK: %[[VAL_3:.*]] = "emitc.variable"() <{value = #emitc.opaque<"42">}> : () -> !emitc.lvalue // CHECK: %[[VAL_4:.*]] = emitc.expression %[[VAL_3]] : (!emitc.lvalue) -> i32 { // CHECK: %[[VAL_5:.*]] = load %[[VAL_3]] : // CHECK: yield %[[VAL_5]] : i32 // CHECK: } -// CHECK: %[[VAL_6:.*]] = emitc.subscript %[[VAL_1]]{{\[}}%[[VAL_2]]] : (!emitc.ptr, i64) -> !emitc.lvalue -// CHECK: %[[VAL_7:.*]] = emitc.expression %[[VAL_6]] : (!emitc.lvalue) -> i32 { +// CHECK: %[[VAL_7:.*]] = emitc.expression %[[VAL_1]] : (!emitc.ptr) -> i32 { +// CHECK: %[[VAL_C:.*]] = "emitc.constant"() <{value = 0 : i64}> : () -> i64 +// CHECK: %[[VAL_6:.*]] = subscript %[[VAL_1]]{{\[}}%[[VAL_C]]] : (!emitc.ptr, i64) -> !emitc.lvalue // CHECK: %[[VAL_8:.*]] = load %[[VAL_6]] : // CHECK: yield %[[VAL_8]] : i32 // CHECK: } @@ -208,3 +205,157 @@ func.func @expression_with_constant(%arg0: i32) -> i32 { %a = emitc.mul %arg0, %c42 : (i32, i32) -> i32 return %a : i32 } + +// CHECK-LABEL: func.func @expression_with_subscript( +// CHECK-SAME: %[[ARG0:.*]]: !emitc.array<4x8xi32>, +// CHECK-SAME: %[[ARG1:.*]]: i32, +// CHECK-SAME: %[[ARG2:.*]]: i32) -> i32 { +// CHECK: %[[VAL_0:.*]] = emitc.expression %[[ARG0]], %[[ARG2]], %[[ARG1]] : (!emitc.array<4x8xi32>, i32, i32) -> i32 { +// CHECK: %[[VAL_1:.*]] = add %[[ARG1]], %[[ARG2]] : (i32, i32) -> i32 +// CHECK: %[[VAL_2:.*]] = mul %[[VAL_1]], %[[ARG2]] : (i32, i32) -> i32 +// CHECK: %[[VAL_3:.*]] = subscript %[[ARG0]]{{\[}}%[[VAL_1]], %[[VAL_2]]] : (!emitc.array<4x8xi32>, i32, i32) -> !emitc.lvalue +// CHECK: %[[VAL_4:.*]] = load %[[VAL_3]] : +// CHECK: yield %[[VAL_4]] : i32 +// CHECK: } +// CHECK: return %[[VAL_0]] : i32 +// CHECK: } + +func.func @expression_with_subscript(%arg0: !emitc.array<4x8xi32>, %arg1: i32, %arg2: i32) -> i32 { + %0 = emitc.add %arg1, %arg2 : (i32, i32) -> i32 + %1 = emitc.mul %0, %arg2 : (i32, i32) -> i32 + %2 = emitc.subscript %arg0[%0, %1] : (!emitc.array<4x8xi32>, i32, i32) -> !emitc.lvalue + %3 = emitc.load %2 : !emitc.lvalue + return %3 : i32 +} + +// CHECK-LABEL: func.func @member( +// CHECK-SAME: %[[ARG0:.*]]: !emitc.opaque<"mystruct">, +// CHECK-SAME: %[[ARG1:.*]]: i32, +// CHECK-SAME: %[[ARG2:.*]]: index) { +// CHECK: %[[VAL_0:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.lvalue> +// CHECK: emitc.assign %[[ARG0]] : !emitc.opaque<"mystruct"> to %[[VAL_0]] : > +// CHECK: %[[VAL_1:.*]] = emitc.expression %[[VAL_0]] : (!emitc.lvalue>) -> !emitc.lvalue { +// CHECK: %[[VAL_2:.*]] = "emitc.member"(%[[VAL_0]]) <{member = "a"}> : (!emitc.lvalue>) -> !emitc.lvalue +// CHECK: yield %[[VAL_2]] : !emitc.lvalue +// CHECK: } +// CHECK: emitc.assign %[[ARG1]] : i32 to %[[VAL_1]] : +// CHECK: %[[VAL_3:.*]] = emitc.expression %[[VAL_0]] : (!emitc.lvalue>) -> i32 { +// CHECK: %[[VAL_4:.*]] = "emitc.member"(%[[VAL_0]]) <{member = "b"}> : (!emitc.lvalue>) -> !emitc.lvalue +// CHECK: %[[VAL_5:.*]] = load %[[VAL_4]] : +// CHECK: yield %[[VAL_5]] : i32 +// CHECK: } +// CHECK: %[[VAL_6:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.lvalue +// CHECK: emitc.assign %[[VAL_3]] : i32 to %[[VAL_6]] : +// CHECK: %[[VAL_7:.*]] = emitc.expression %[[ARG2]], %[[VAL_0]] : (index, !emitc.lvalue>) -> i32 { +// CHECK: %[[VAL_8:.*]] = "emitc.member"(%[[VAL_0]]) <{member = "c"}> : (!emitc.lvalue>) -> !emitc.array<2xi32> +// CHECK: %[[VAL_9:.*]] = subscript %[[VAL_8]]{{\[}}%[[ARG2]]] : (!emitc.array<2xi32>, index) -> !emitc.lvalue +// CHECK: %[[VAL_10:.*]] = load %[[VAL_9]] : +// CHECK: yield %[[VAL_10]] : i32 +// CHECK: } +// CHECK: emitc.assign %[[VAL_7]] : i32 to %[[VAL_6]] : +// CHECK: %[[VAL_11:.*]] = emitc.expression %[[ARG2]], %[[VAL_0]] : (index, !emitc.lvalue>) -> !emitc.lvalue { +// CHECK: %[[VAL_12:.*]] = "emitc.member"(%[[VAL_0]]) <{member = "d"}> : (!emitc.lvalue>) -> !emitc.array<2xi32> +// CHECK: %[[VAL_13:.*]] = subscript %[[VAL_12]]{{\[}}%[[ARG2]]] : (!emitc.array<2xi32>, index) -> !emitc.lvalue +// CHECK: yield %[[VAL_13]] : !emitc.lvalue +// CHECK: } +// CHECK: emitc.assign %[[ARG1]] : i32 to %[[VAL_11]] : +// CHECK: return +// CHECK: } + +func.func @member(%arg0: !emitc.opaque<"mystruct">, %arg1: i32, %arg2: index) { + %var0 = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.lvalue> + emitc.assign %arg0 : !emitc.opaque<"mystruct"> to %var0 : !emitc.lvalue> + + %0 = "emitc.member" (%var0) {member = "a"} : (!emitc.lvalue>) -> !emitc.lvalue + emitc.assign %arg1 : i32 to %0 : !emitc.lvalue + + %1 = "emitc.member" (%var0) {member = "b"} : (!emitc.lvalue>) -> !emitc.lvalue + %2 = emitc.load %1 : !emitc.lvalue + %3 = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.lvalue + emitc.assign %2 : i32 to %3 : !emitc.lvalue + + %4 = "emitc.member" (%var0) {member = "c"} : (!emitc.lvalue>) -> !emitc.array<2xi32> + %5 = emitc.subscript %4[%arg2] : (!emitc.array<2xi32>, index) -> !emitc.lvalue + %6 = emitc.load %5 : + emitc.assign %6 : i32 to %3 : !emitc.lvalue + + %7 = "emitc.member" (%var0) {member = "d"} : (!emitc.lvalue>) -> !emitc.array<2xi32> + %8 = emitc.subscript %7[%arg2] : (!emitc.array<2xi32>, index) -> !emitc.lvalue + emitc.assign %arg1 : i32 to %8 : !emitc.lvalue + + return +} + +// CHECK-LABEL: func.func @member_of_pointer( +// CHECK-SAME: %[[ARG0:.*]]: !emitc.ptr>, +// CHECK-SAME: %[[ARG1:.*]]: i32, +// CHECK-SAME: %[[ARG2:.*]]: index) { +// CHECK: %[[VAL_0:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.lvalue>> +// CHECK: emitc.assign %[[ARG0]] : !emitc.ptr> to %[[VAL_0]] : >> +// CHECK: %[[VAL_1:.*]] = emitc.expression %[[VAL_0]] : (!emitc.lvalue>>) -> !emitc.lvalue { +// CHECK: %[[VAL_2:.*]] = "emitc.member_of_ptr"(%[[VAL_0]]) <{member = "a"}> : (!emitc.lvalue>>) -> !emitc.lvalue +// CHECK: yield %[[VAL_2]] : !emitc.lvalue +// CHECK: } +// CHECK: emitc.assign %[[ARG1]] : i32 to %[[VAL_1]] : +// CHECK: %[[VAL_3:.*]] = emitc.expression %[[VAL_0]] : (!emitc.lvalue>>) -> i32 { +// CHECK: %[[VAL_4:.*]] = "emitc.member_of_ptr"(%[[VAL_0]]) <{member = "b"}> : (!emitc.lvalue>>) -> !emitc.lvalue +// CHECK: %[[VAL_5:.*]] = load %[[VAL_4]] : +// CHECK: yield %[[VAL_5]] : i32 +// CHECK: } +// CHECK: %[[VAL_6:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.lvalue +// CHECK: emitc.assign %[[VAL_3]] : i32 to %[[VAL_6]] : +// CHECK: %[[VAL_7:.*]] = emitc.expression %[[ARG2]], %[[VAL_0]] : (index, !emitc.lvalue>>) -> i32 { +// CHECK: %[[VAL_8:.*]] = "emitc.member_of_ptr"(%[[VAL_0]]) <{member = "c"}> : (!emitc.lvalue>>) -> !emitc.array<2xi32> +// CHECK: %[[VAL_9:.*]] = subscript %[[VAL_8]]{{\[}}%[[ARG2]]] : (!emitc.array<2xi32>, index) -> !emitc.lvalue +// CHECK: %[[VAL_10:.*]] = load %[[VAL_9]] : +// CHECK: yield %[[VAL_10]] : i32 +// CHECK: } +// CHECK: emitc.assign %[[VAL_7]] : i32 to %[[VAL_6]] : +// CHECK: %[[VAL_11:.*]] = emitc.expression %[[ARG2]], %[[VAL_0]] : (index, !emitc.lvalue>>) -> !emitc.lvalue { +// CHECK: %[[VAL_12:.*]] = "emitc.member_of_ptr"(%[[VAL_0]]) <{member = "d"}> : (!emitc.lvalue>>) -> !emitc.array<2xi32> +// CHECK: %[[VAL_13:.*]] = subscript %[[VAL_12]]{{\[}}%[[ARG2]]] : (!emitc.array<2xi32>, index) -> !emitc.lvalue +// CHECK: yield %[[VAL_13]] : !emitc.lvalue +// CHECK: } +// CHECK: emitc.assign %[[ARG1]] : i32 to %[[VAL_11]] : +// CHECK: return +// CHECK: } + +func.func @member_of_pointer(%arg0: !emitc.ptr>, %arg1: i32, %arg2: index) { + %var0 = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.lvalue>> + emitc.assign %arg0 : !emitc.ptr> to %var0 : !emitc.lvalue>> + + %0 = "emitc.member_of_ptr" (%var0) {member = "a"} : (!emitc.lvalue>>) -> !emitc.lvalue + emitc.assign %arg1 : i32 to %0 : !emitc.lvalue + + %1 = "emitc.member_of_ptr" (%var0) {member = "b"} : (!emitc.lvalue>>) -> !emitc.lvalue + %2 = emitc.load %1 : !emitc.lvalue + %3 = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.lvalue + emitc.assign %2 : i32 to %3 : !emitc.lvalue + + %4 = "emitc.member_of_ptr" (%var0) {member = "c"} : (!emitc.lvalue>>) -> !emitc.array<2xi32> + %5 = emitc.subscript %4[%arg2] : (!emitc.array<2xi32>, index) -> !emitc.lvalue + %6 = emitc.load %5 : + emitc.assign %6 : i32 to %3 : !emitc.lvalue + + %7 = "emitc.member_of_ptr" (%var0) {member = "d"} : (!emitc.lvalue>>) -> !emitc.array<2xi32> + %8 = emitc.subscript %7[%arg2] : (!emitc.array<2xi32>, index) -> !emitc.lvalue + emitc.assign %arg1 : i32 to %8 : !emitc.lvalue + + return +} + +// CHECK-LABEL: func.func @expression_with_literal( +// CHECK-SAME: %[[ARG0:.*]]: f32) -> f32 { +// CHECK: %[[VAL_0:.*]] = emitc.expression %[[ARG0]] : (f32) -> f32 { +// CHECK: %[[VAL_1:.*]] = literal "M_PI" : f32 +// CHECK: %[[VAL_2:.*]] = add %[[ARG0]], %[[VAL_1]] : (f32, f32) -> f32 +// CHECK: yield %[[VAL_2]] : f32 +// CHECK: } +// CHECK: return %[[VAL_0]] : f32 +// CHECK: } + +func.func @expression_with_literal(%arg0: f32) -> f32 { + %p0 = emitc.literal "M_PI" : f32 + %1 = "emitc.add" (%arg0, %p0) : (f32, f32) -> f32 + return %1 : f32 +} diff --git a/mlir/test/Target/Cpp/expressions.mlir b/mlir/test/Target/Cpp/expressions.mlir index 4281f41d0b3fb..e19945085a87b 100644 --- a/mlir/test/Target/Cpp/expressions.mlir +++ b/mlir/test/Target/Cpp/expressions.mlir @@ -355,6 +355,26 @@ func.func @expression_with_address_taken(%arg0: i32, %arg1: i32, %arg2: !emitc.p return %c : i1 } +// CPP-DEFAULT: int32_t expression_with_subscript(int32_t [[VAL_1:v.+]][4][8], int32_t [[VAL_2:v.+]], int32_t [[VAL_3:v.+]]) +// CPP-DEFAULT-NEXT: int32_t [[VAL_4:v.+]] = [[VAL_1]][[[VAL_2]] + [[VAL_3]]][([[VAL_2]] + [[VAL_3]]) * [[VAL_3]]]; +// CPP-DEFAULT-NEXT: return [[VAL_4]]; + +// CPP-DECLTOP: int32_t expression_with_subscript(int32_t [[VAL_1:v.+]][4][8], int32_t [[VAL_2:v.+]], int32_t [[VAL_3:v.+]]) +// CPP-DECLTOP-NEXT: int32_t [[VAL_4:v.+]]; +// CPP-DECLTOP-NEXT: [[VAL_4]] = [[VAL_1]][[[VAL_2]] + [[VAL_3]]][([[VAL_2]] + [[VAL_3]]) * [[VAL_3]]]; +// CPP-DECLTOP-NEXT: return [[VAL_4]]; + +func.func @expression_with_subscript(%arg0: !emitc.array<4x8xi32>, %arg1: i32, %arg2: i32) -> i32 { + %res = emitc.expression %arg0, %arg1, %arg2 : (!emitc.array<4x8xi32>, i32, i32) -> i32 { + %0 = add %arg1, %arg2 : (i32, i32) -> i32 + %1 = mul %0, %arg2 : (i32, i32) -> i32 + %2 = subscript %arg0[%0, %1] : (!emitc.array<4x8xi32>, i32, i32) -> !emitc.lvalue + %3 = emitc.load %2 : !emitc.lvalue + yield %3 : i32 + } + return %res : i32 +} + // CPP-DEFAULT: int32_t expression_with_subscript_user(void* [[VAL_1:v.+]]) // CPP-DEFAULT-NEXT: int64_t [[VAL_2:v.+]] = 0; // CPP-DEFAULT-NEXT: int32_t* [[VAL_3:v.+]] = (int32_t*) [[VAL_1]]; @@ -458,3 +478,142 @@ emitc.func @expression_with_call_opaque_with_args_array(%0 : i32, %1 : i32) { } return } + +// CPP-DEFAULT: void member(mystruct [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], size_t [[VAL_3:v[0-9]+]]) { +// CPP-DEFAULT-NEXT: mystruct [[VAL_4:v[0-9]+]]; +// CPP-DEFAULT-NEXT: [[VAL_4]] = [[VAL_1]]; +// CPP-DEFAULT-NEXT: [[VAL_4]].a = [[VAL_2]]; +// CPP-DEFAULT-NEXT: int32_t [[VAL_5:v[0-9]+]] = [[VAL_4]].b; +// CPP-DEFAULT-NEXT: int32_t [[VAL_6:v[0-9]+]]; +// CPP-DEFAULT-NEXT: [[VAL_6]] = [[VAL_5]]; +// CPP-DEFAULT-NEXT: int32_t [[VAL_7:v[0-9]+]] = ([[VAL_4]].c)[[[VAL_3]]]; +// CPP-DEFAULT-NEXT: [[VAL_6]] = [[VAL_7]]; +// CPP-DEFAULT-NEXT: ([[VAL_4]].d)[[[VAL_3]]] = [[VAL_2]]; +// CPP-DEFAULT-NEXT: return; +// CPP-DEFAULT-NEXT: } + +// CPP-DECLTOP: void member(mystruct [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], size_t [[VAL_3:v[0-9]+]]) { +// CPP-DECLTOP-NEXT: mystruct [[VAL_4:v[0-9]+]]; +// CPP-DECLTOP-NEXT: int32_t [[VAL_5:v[0-9]+]]; +// CPP-DECLTOP-NEXT: int32_t [[VAL_6:v[0-9]+]]; +// CPP-DECLTOP-NEXT: int32_t [[VAL_7:v[0-9]+]]; +// CPP-DECLTOP-NEXT: ; +// CPP-DECLTOP-NEXT: [[VAL_4]] = [[VAL_1]]; +// CPP-DECLTOP-NEXT: [[VAL_4]].a = [[VAL_2]]; +// CPP-DECLTOP-NEXT: [[VAL_5]] = [[VAL_4]].b; +// CPP-DECLTOP-NEXT: ; +// CPP-DECLTOP-NEXT: [[VAL_6]] = [[VAL_5]]; +// CPP-DECLTOP-NEXT: [[VAL_7]] = ([[VAL_4]].c)[[[VAL_3]]]; +// CPP-DECLTOP-NEXT: [[VAL_6]] = [[VAL_7]]; +// CPP-DECLTOP-NEXT: ([[VAL_4]].d)[[[VAL_3]]] = [[VAL_2]]; +// CPP-DECLTOP-NEXT: return; +// CPP-DECLTOP-NEXT: } + +func.func @member(%arg0: !emitc.opaque<"mystruct">, %arg1: i32, %arg2: index) { + %0 = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.lvalue> + emitc.assign %arg0 : !emitc.opaque<"mystruct"> to %0 : > + %1 = emitc.expression %0 : (!emitc.lvalue>) -> !emitc.lvalue { + %6 = "emitc.member"(%0) <{member = "a"}> : (!emitc.lvalue>) -> !emitc.lvalue + yield %6 : !emitc.lvalue + } + emitc.assign %arg1 : i32 to %1 : + %2 = emitc.expression %0 : (!emitc.lvalue>) -> i32 { + %6 = "emitc.member"(%0) <{member = "b"}> : (!emitc.lvalue>) -> !emitc.lvalue + %7 = load %6 : + yield %7 : i32 + } + %3 = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.lvalue + emitc.assign %2 : i32 to %3 : + %4 = emitc.expression %arg2, %0 : (index, !emitc.lvalue>) -> i32 { + %6 = "emitc.member"(%0) <{member = "c"}> : (!emitc.lvalue>) -> !emitc.array<2xi32> + %7 = subscript %6[%arg2] : (!emitc.array<2xi32>, index) -> !emitc.lvalue + %8 = load %7 : + yield %8 : i32 + } + emitc.assign %4 : i32 to %3 : + %5 = emitc.expression %arg2, %0 : (index, !emitc.lvalue>) -> !emitc.lvalue { + %6 = "emitc.member"(%0) <{member = "d"}> : (!emitc.lvalue>) -> !emitc.array<2xi32> + %7 = subscript %6[%arg2] : (!emitc.array<2xi32>, index) -> !emitc.lvalue + yield %7 : !emitc.lvalue + } + emitc.assign %arg1 : i32 to %5 : + return +} + +// CPP-DEFAULT: void member_of_pointer(mystruct* [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], size_t [[VAL_3:v[0-9]+]]) { +// CPP-DEFAULT-NEXT: mystruct* [[VAL_4:v[0-9]+]]; +// CPP-DEFAULT-NEXT: [[VAL_4]] = [[VAL_1]]; +// CPP-DEFAULT-NEXT: [[VAL_4]]->a = [[VAL_2]]; +// CPP-DEFAULT-NEXT: int32_t [[VAL_5:v[0-9]+]] = [[VAL_4]]->b; +// CPP-DEFAULT-NEXT: int32_t [[VAL_6:v[0-9]+]]; +// CPP-DEFAULT-NEXT: [[VAL_6]] = [[VAL_5]]; +// CPP-DEFAULT-NEXT: int32_t [[VAL_7:v[0-9]+]] = ([[VAL_4]]->c)[[[VAL_3]]]; +// CPP-DEFAULT-NEXT: [[VAL_6]] = [[VAL_7]]; +// CPP-DEFAULT-NEXT: ([[VAL_4]]->d)[[[VAL_3]]] = [[VAL_2]]; +// CPP-DEFAULT-NEXT: return; +// CPP-DEFAULT-NEXT: } + +// CPP-DECLTOP: void member_of_pointer(mystruct* [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], size_t [[VAL_3:v[0-9]+]]) { +// CPP-DECLTOP-NEXT: mystruct* [[VAL_4:v[0-9]+]]; +// CPP-DECLTOP-NEXT: int32_t [[VAL_5:v[0-9]+]]; +// CPP-DECLTOP-NEXT: int32_t [[VAL_6:v[0-9]+]]; +// CPP-DECLTOP-NEXT: int32_t [[VAL_7:v[0-9]+]]; +// CPP-DECLTOP-NEXT: ; +// CPP-DECLTOP-NEXT: [[VAL_4]] = [[VAL_1]]; +// CPP-DECLTOP-NEXT: [[VAL_4]]->a = [[VAL_2]]; +// CPP-DECLTOP-NEXT: [[VAL_5]] = [[VAL_4]]->b; +// CPP-DECLTOP-NEXT: ; +// CPP-DECLTOP-NEXT: [[VAL_6]] = [[VAL_5]]; +// CPP-DECLTOP-NEXT: [[VAL_7]] = ([[VAL_4]]->c)[[[VAL_3]]]; +// CPP-DECLTOP-NEXT: [[VAL_6]] = [[VAL_7]]; +// CPP-DECLTOP-NEXT: ([[VAL_4]]->d)[[[VAL_3]]] = [[VAL_2]]; +// CPP-DECLTOP-NEXT: return; +// CPP-DECLTOP-NEXT: } + +func.func @member_of_pointer(%arg0: !emitc.ptr>, %arg1: i32, %arg2: index) { + %0 = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.lvalue>> + emitc.assign %arg0 : !emitc.ptr> to %0 : >> + %1 = emitc.expression %0 : (!emitc.lvalue>>) -> !emitc.lvalue { + %6 = "emitc.member_of_ptr"(%0) <{member = "a"}> : (!emitc.lvalue>>) -> !emitc.lvalue + yield %6 : !emitc.lvalue + } + emitc.assign %arg1 : i32 to %1 : + %2 = emitc.expression %0 : (!emitc.lvalue>>) -> i32 { + %6 = "emitc.member_of_ptr"(%0) <{member = "b"}> : (!emitc.lvalue>>) -> !emitc.lvalue + %7 = load %6 : + yield %7 : i32 + } + %3 = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.lvalue + emitc.assign %2 : i32 to %3 : + %4 = emitc.expression %arg2, %0 : (index, !emitc.lvalue>>) -> i32 { + %6 = "emitc.member_of_ptr"(%0) <{member = "c"}> : (!emitc.lvalue>>) -> !emitc.array<2xi32> + %7 = subscript %6[%arg2] : (!emitc.array<2xi32>, index) -> !emitc.lvalue + %8 = load %7 : + yield %8 : i32 + } + emitc.assign %4 : i32 to %3 : + %5 = emitc.expression %arg2, %0 : (index, !emitc.lvalue>>) -> !emitc.lvalue { + %6 = "emitc.member_of_ptr"(%0) <{member = "d"}> : (!emitc.lvalue>>) -> !emitc.array<2xi32> + %7 = subscript %6[%arg2] : (!emitc.array<2xi32>, index) -> !emitc.lvalue + yield %7 : !emitc.lvalue + } + emitc.assign %arg1 : i32 to %5 : + return +} + +// CPP-DEFAULT: float expression_with_literal(float [[VAL_1:v[0-9]+]]) { +// CPP-DEFAULT-NEXT: return [[VAL_1]] + M_PI; +// CPP-DEFAULT-NEXT: } + +// CPP-DECLTOP: float expression_with_literal(float [[VAL_1:v[0-9]+]]) { +// CPP-DECLTOP-NEXT: return [[VAL_1]] + M_PI; +// CPP-DECLTOP-NEXT: } + +func.func @expression_with_literal(%arg0: f32) -> f32 { + %0 = emitc.expression %arg0 : (f32) -> f32 { + %1 = literal "M_PI" : f32 + %2 = add %arg0, %1 : (f32, f32) -> f32 + yield %2 : f32 + } + return %0 : f32 +} diff --git a/mlir/test/Target/Cpp/member.mlir b/mlir/test/Target/Cpp/member.mlir index 6e0395250afbd..45b6336f63ce0 100644 --- a/mlir/test/Target/Cpp/member.mlir +++ b/mlir/test/Target/Cpp/member.mlir @@ -31,9 +31,9 @@ func.func @member(%arg0: !emitc.opaque<"mystruct">, %arg1: i32, %arg2: index) { // CPP-DEFAULT-NEXT: int32_t [[V3:[^ ]*]] = [[V2]].b; // CPP-DEFAULT-NEXT: int32_t [[V4:[^ ]*]]; // CPP-DEFAULT-NEXT: [[V4]] = [[V3]]; -// CPP-DEFAULT-NEXT: int32_t [[V5:[^ ]*]] = [[V2]].c[[[Index]]]; +// CPP-DEFAULT-NEXT: int32_t [[V5:[^ ]*]] = ([[V2]].c)[[[Index]]]; // CPP-DEFAULT-NEXT: [[V4]] = [[V5]]; -// CPP-DEFAULT-NEXT: [[V2]].d[[[Index]]] = [[V1]]; +// CPP-DEFAULT-NEXT: ([[V2]].d)[[[Index]]] = [[V1]]; func.func @member_of_pointer(%arg0: !emitc.ptr>, %arg1: i32, %arg2: index) { @@ -67,6 +67,6 @@ func.func @member_of_pointer(%arg0: !emitc.ptr>, %arg1 // CPP-DEFAULT-NEXT: int32_t [[V3:[^ ]*]] = [[V2]]->b; // CPP-DEFAULT-NEXT: int32_t [[V4:[^ ]*]]; // CPP-DEFAULT-NEXT: [[V4]] = [[V3]]; -// CPP-DEFAULT-NEXT: int32_t [[V5:[^ ]*]] = [[V2]]->c[[[Index]]]; +// CPP-DEFAULT-NEXT: int32_t [[V5:[^ ]*]] = ([[V2]]->c)[[[Index]]]; // CPP-DEFAULT-NEXT: [[V4]] = [[V5]]; -// CPP-DEFAULT-NEXT: [[V2]]->d[[[Index]]] = [[V1]]; +// CPP-DEFAULT-NEXT: ([[V2]]->d)[[[Index]]] = [[V1]];