diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index bcce0a5145221..3936087ec3293 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -111,6 +111,8 @@ static FailureOr getOperatorPrecedence(Operation *operation) { .Default([](auto op) { return op->emitError("unsupported operation"); }); } +static bool shouldBeInlined(Operation *op); + namespace { /// Emitter that uses dialect specific emitters to emit C++ code. struct CppEmitter { @@ -254,24 +256,19 @@ struct CppEmitter { } /// Is expression currently being emitted. - bool isEmittingExpression() { return emittedExpression; } + bool isEmittingExpression() { return !emittedExpressionPrecedence.empty(); } /// Determine whether given value is part of the expression potentially being /// emitted. bool isPartOfCurrentExpression(Value value) { - if (!emittedExpression) - return false; Operation *def = value.getDefiningOp(); - if (!def) - return false; - return isPartOfCurrentExpression(def); + return def ? isPartOfCurrentExpression(def) : false; } /// Determine whether given operation is part of the expression potentially /// being emitted. bool isPartOfCurrentExpression(Operation *def) { - auto operandExpression = dyn_cast(def->getParentOp()); - return operandExpression && operandExpression == emittedExpression; + return isEmittingExpression() && shouldBeInlined(def); }; // Resets the value counter to 0. @@ -318,7 +315,6 @@ struct CppEmitter { unsigned int valueCount{0}; /// State of the current expression being emitted. - ExpressionOp emittedExpression; SmallVector emittedExpressionPrecedence; void pushExpressionPrecedence(int precedence) { @@ -341,12 +337,22 @@ static bool hasDeferredEmission(Operation *op) { emitc::GetFieldOp>(op); } -/// Determine whether expression \p expressionOp should be emitted inline, i.e. +/// Determine whether operation \p op should be emitted inline, i.e. /// as part of its user. This function recommends inlining of any expressions /// that can be inlined unless it is used by another expression, under the /// assumption that any expression fusion/re-materialization was taken care of /// by transformations run by the backend. -static bool shouldBeInlined(ExpressionOp expressionOp) { +static bool shouldBeInlined(Operation *op) { + // CExpression operations are inlined if and only if they reside within an + // ExpressionOp. + if (isa(op)) + return isa(op->getParentOp()); + + // Only other inlinable operation is ExpressionOp itself. + ExpressionOp expressionOp = dyn_cast(op); + if (!expressionOp) + return false; + // Do not inline if expression is marked as such. if (expressionOp.getDoNotInline()) return false; @@ -1564,7 +1570,6 @@ LogicalResult CppEmitter::emitExpression(ExpressionOp expressionOp) { "Expected precedence stack to be empty"); Operation *rootOp = expressionOp.getRootOp(); - emittedExpression = expressionOp; FailureOr precedence = getOperatorPrecedence(rootOp); if (failed(precedence)) return failure(); @@ -1576,7 +1581,6 @@ LogicalResult CppEmitter::emitExpression(ExpressionOp expressionOp) { popExpressionPrecedence(); assert(emittedExpressionPrecedence.empty() && "Expected precedence stack to be empty"); - emittedExpression = nullptr; return success(); } @@ -1617,14 +1621,8 @@ LogicalResult CppEmitter::emitOperand(Value value, bool isInBrackets) { // If this operand is a block argument of an expression, emit instead the // matching expression parameter. Operation *argOp = arg.getParentBlock()->getParentOp(); - if (auto expressionOp = dyn_cast(argOp)) { - // This scenario is only expected when one of the operations within the - // expression being emitted references one of the expression's block - // arguments. - assert(expressionOp == emittedExpression && - "Expected expression being emitted"); - value = expressionOp->getOperand(arg.getArgNumber()); - } + if (auto expressionOp = dyn_cast(argOp)) + return emitOperand(expressionOp->getOperand(arg.getArgNumber())); } os << getOrCreateName(value);