diff --git a/clang/lib/AST/Interp/ByteCodeExprGen.cpp b/clang/lib/AST/Interp/ByteCodeExprGen.cpp index a2f0a77142ad8..aaab980ac81bc 100644 --- a/clang/lib/AST/Interp/ByteCodeExprGen.cpp +++ b/clang/lib/AST/Interp/ByteCodeExprGen.cpp @@ -187,6 +187,10 @@ bool ByteCodeExprGen::VisitParenExpr(const ParenExpr *PE) { template bool ByteCodeExprGen::VisitBinaryOperator(const BinaryOperator *BO) { + // Need short-circuiting for these. + if (BO->isLogicalOp()) + return this->VisitLogicalBinOp(BO); + const Expr *LHS = BO->getLHS(); const Expr *RHS = BO->getRHS(); @@ -270,8 +274,9 @@ bool ByteCodeExprGen::VisitBinaryOperator(const BinaryOperator *BO) { return Discard(this->emitShr(*LT, *RT, BO)); case BO_Xor: return Discard(this->emitBitXor(*T, BO)); - case BO_LAnd: case BO_LOr: + case BO_LAnd: + llvm_unreachable("Already handled earlier"); default: return this->bail(BO); } @@ -329,6 +334,65 @@ bool ByteCodeExprGen::VisitPointerArithBinOp(const BinaryOperator *E) { return this->bail(E); } +template +bool ByteCodeExprGen::VisitLogicalBinOp(const BinaryOperator *E) { + assert(E->isLogicalOp()); + BinaryOperatorKind Op = E->getOpcode(); + const Expr *LHS = E->getLHS(); + const Expr *RHS = E->getRHS(); + + if (Op == BO_LOr) { + // Logical OR. Visit LHS and only evaluate RHS if LHS was FALSE. + LabelTy LabelTrue = this->getLabel(); + LabelTy LabelEnd = this->getLabel(); + + if (!this->visit(LHS)) + return false; + if (!this->jumpTrue(LabelTrue)) + return false; + + if (!this->visit(RHS)) + return false; + if (!this->jump(LabelEnd)) + return false; + + this->emitLabel(LabelTrue); + this->emitConstBool(true, E); + this->fallthrough(LabelEnd); + this->emitLabel(LabelEnd); + + if (DiscardResult) + return this->emitPopBool(E); + + return true; + } + + // Logical AND. + // Visit LHS. Only visit RHS if LHS was TRUE. + LabelTy LabelFalse = this->getLabel(); + LabelTy LabelEnd = this->getLabel(); + + if (!this->visit(LHS)) + return false; + if (!this->jumpFalse(LabelFalse)) + return false; + + if (!this->visit(RHS)) + return false; + if (!this->jump(LabelEnd)) + return false; + + this->emitLabel(LabelFalse); + this->emitConstBool(false, E); + this->fallthrough(LabelEnd); + this->emitLabel(LabelEnd); + + if (DiscardResult) + return this->emitPopBool(E); + + return true; +} + template bool ByteCodeExprGen::VisitImplicitValueInitExpr(const ImplicitValueInitExpr *E) { std::optional T = classify(E); diff --git a/clang/lib/AST/Interp/ByteCodeExprGen.h b/clang/lib/AST/Interp/ByteCodeExprGen.h index ed33e0285a8f1..0a64ff3513dd8 100644 --- a/clang/lib/AST/Interp/ByteCodeExprGen.h +++ b/clang/lib/AST/Interp/ByteCodeExprGen.h @@ -61,6 +61,7 @@ class ByteCodeExprGen : public ConstStmtVisitor, bool>, bool VisitFloatingLiteral(const FloatingLiteral *E); bool VisitParenExpr(const ParenExpr *E); bool VisitBinaryOperator(const BinaryOperator *E); + bool VisitLogicalBinOp(const BinaryOperator *E); bool VisitPointerArithBinOp(const BinaryOperator *E); bool VisitCXXDefaultArgExpr(const CXXDefaultArgExpr *E); bool VisitCallExpr(const CallExpr *E); diff --git a/clang/test/AST/Interp/cond.cpp b/clang/test/AST/Interp/cond.cpp index 1fc69ed333e15..8679c116f57bb 100644 --- a/clang/test/AST/Interp/cond.cpp +++ b/clang/test/AST/Interp/cond.cpp @@ -9,3 +9,29 @@ constexpr int cond_then_else(int a, int b) { return a - b; } } + +constexpr int dontCallMe(unsigned m) { + if (m == 0) return 0; + return dontCallMe(m - 2); +} + +// Can't call this because it will run into infinite recursion. +constexpr int assertNotReached() { + return dontCallMe(3); +} + +static_assert(true || true, ""); +static_assert(true || false, ""); +static_assert(false || true, ""); +static_assert(!(false || false), ""); + +static_assert(true || assertNotReached(), ""); +static_assert(true || true || true || false, ""); + +static_assert(true && true, ""); +static_assert(!(true && false), ""); +static_assert(!(false && true), ""); +static_assert(!(false && false), ""); + +static_assert(!(false && assertNotReached()), ""); +static_assert(!(true && true && true && false), "");