diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index 6bd76bb1ffc4b..56f81b0bea9e2 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -173,8 +173,11 @@ struct CppEmitter { /// Emits the operands of the operation. All operands are emitted in order. LogicalResult emitOperands(Operation &op); - /// Emits value as an operands of an operation - LogicalResult emitOperand(Value value); + /// Emits value as an operand of some operation. Unless \p isInBrackets is + /// true, operands emitted as sub-expressions will be parenthesized if needed + /// in order to enforce correct evaluation based on precedence and + /// associativity. + LogicalResult emitOperand(Value value, bool isInBrackets = false); /// Emit an expression as a C expression. LogicalResult emitExpression(ExpressionOp expressionOp); @@ -1578,7 +1581,7 @@ LogicalResult CppEmitter::emitExpression(ExpressionOp expressionOp) { return success(); } -LogicalResult CppEmitter::emitOperand(Value value) { +LogicalResult CppEmitter::emitOperand(Value value, bool isInBrackets) { if (isPartOfCurrentExpression(value)) { Operation *def = value.getDefiningOp(); assert(def && "Expected operand to be defined by an operation"); @@ -1586,10 +1589,12 @@ LogicalResult CppEmitter::emitOperand(Value value) { if (failed(precedence)) return failure(); - // Sub-expressions with equal or lower precedence need to be parenthesized, - // as they might be evaluated in the wrong order depending on the shape of - // the expression tree. - bool encloseInParenthesis = precedence.value() <= getExpressionPrecedence(); + // Unless already in brackets, sub-expressions with equal or lower + // precedence need to be parenthesized as they might be evaluated in the + // wrong order depending on the shape of the expression tree. + bool encloseInParenthesis = + !isInBrackets && precedence.value() <= getExpressionPrecedence(); + if (encloseInParenthesis) os << "("; pushExpressionPrecedence(precedence.value()); @@ -1628,15 +1633,9 @@ LogicalResult CppEmitter::emitOperand(Value value) { LogicalResult CppEmitter::emitOperands(Operation &op) { return interleaveCommaWithError(op.getOperands(), os, [&](Value operand) { - // If an expression is being emitted, push lowest precedence as these - // operands are either wrapped by parenthesis. - if (getEmittedExpression()) - pushExpressionPrecedence(lowestPrecedence()); - if (failed(emitOperand(operand))) - return failure(); - if (getEmittedExpression()) - popExpressionPrecedence(); - return success(); + // Emit operand under guarantee that if it's part of an expression then it + // is being emitted within brackets. + return emitOperand(operand, /*isInBrackets=*/true); }); }