diff --git a/clang/lib/AST/Interp/ByteCodeExprGen.cpp b/clang/lib/AST/Interp/ByteCodeExprGen.cpp index 63c184db2be38..a2a70e8eb2044 100644 --- a/clang/lib/AST/Interp/ByteCodeExprGen.cpp +++ b/clang/lib/AST/Interp/ByteCodeExprGen.cpp @@ -394,18 +394,19 @@ bool ByteCodeExprGen::VisitLogicalBinOp(const BinaryOperator *E) { BinaryOperatorKind Op = E->getOpcode(); const Expr *LHS = E->getLHS(); const Expr *RHS = E->getRHS(); + std::optional T = classify(E->getType()); if (Op == BO_LOr) { // Logical OR. Visit LHS and only evaluate RHS if LHS was FALSE. LabelTy LabelTrue = this->getLabel(); LabelTy LabelEnd = this->getLabel(); - if (!this->visit(LHS)) + if (!this->visitBool(LHS)) return false; if (!this->jumpTrue(LabelTrue)) return false; - if (!this->visit(RHS)) + if (!this->visitBool(RHS)) return false; if (!this->jump(LabelEnd)) return false; @@ -415,35 +416,36 @@ bool ByteCodeExprGen::VisitLogicalBinOp(const BinaryOperator *E) { this->fallthrough(LabelEnd); this->emitLabel(LabelEnd); - if (DiscardResult) - return this->emitPopBool(E); - - return true; - } - - // Logical AND. - // Visit LHS. Only visit RHS if LHS was TRUE. - LabelTy LabelFalse = this->getLabel(); - LabelTy LabelEnd = this->getLabel(); + } else { + assert(Op == BO_LAnd); + // Logical AND. + // Visit LHS. Only visit RHS if LHS was TRUE. + LabelTy LabelFalse = this->getLabel(); + LabelTy LabelEnd = this->getLabel(); - if (!this->visit(LHS)) - return false; - if (!this->jumpFalse(LabelFalse)) - return false; + if (!this->visitBool(LHS)) + return false; + if (!this->jumpFalse(LabelFalse)) + return false; - if (!this->visit(RHS)) - return false; - if (!this->jump(LabelEnd)) - return false; + if (!this->visitBool(RHS)) + return false; + if (!this->jump(LabelEnd)) + return false; - this->emitLabel(LabelFalse); - this->emitConstBool(false, E); - this->fallthrough(LabelEnd); - this->emitLabel(LabelEnd); + this->emitLabel(LabelFalse); + this->emitConstBool(false, E); + this->fallthrough(LabelEnd); + this->emitLabel(LabelEnd); + } if (DiscardResult) return this->emitPopBool(E); + // For C, cast back to integer type. + assert(T); + if (T != PT_Bool) + return this->emitCast(PT_Bool, *T, E); return true; } @@ -815,17 +817,9 @@ bool ByteCodeExprGen::VisitAbstractConditionalOperator( LabelTy LabelEnd = this->getLabel(); // Label after the operator. LabelTy LabelFalse = this->getLabel(); // Label for the false expr. - if (!this->visit(Condition)) + if (!this->visitBool(Condition)) return false; - // C special case: Convert to bool because our jump ops need that. - // TODO: We probably want this to be done in visitBool(). - if (std::optional CondT = classify(Condition->getType()); - CondT && CondT != PT_Bool) { - if (!this->emitCast(*CondT, PT_Bool, E)) - return false; - } - if (!this->jumpFalse(LabelFalse)) return false; @@ -1551,9 +1545,29 @@ bool ByteCodeExprGen::visitInitializer(const Expr *E) { template bool ByteCodeExprGen::visitBool(const Expr *E) { - if (std::optional T = classify(E->getType())) - return visit(E); - return this->bail(E); + std::optional T = classify(E->getType()); + if (!T) + return false; + + if (!this->visit(E)) + return false; + + if (T == PT_Bool) + return true; + + // Convert pointers to bool. + if (T == PT_Ptr || T == PT_FnPtr) { + if (!this->emitNull(*T, E)) + return false; + return this->emitNE(*T, E); + } + + // Or Floats. + if (T == PT_Float) + return this->emitCastFloatingIntegralBool(E); + + // Or anything else we can. + return this->emitCast(*T, PT_Bool, E); } template diff --git a/clang/test/AST/Interp/c.c b/clang/test/AST/Interp/c.c index 3387ed49a0d62..f4adfd189d22e 100644 --- a/clang/test/AST/Interp/c.c +++ b/clang/test/AST/Interp/c.c @@ -9,6 +9,8 @@ _Static_assert(1, ""); _Static_assert(0 != 1, ""); _Static_assert(1.0 == 1.0, ""); // pedantic-ref-warning {{not an integer constant expression}} \ // pedantic-expected-warning {{not an integer constant expression}} +_Static_assert(1 && 1.0, ""); // pedantic-ref-warning {{not an integer constant expression}} \ + // pedantic-expected-warning {{not an integer constant expression}} _Static_assert( (5 > 4) + (3 > 2) == 2, ""); /// FIXME: Should also be rejected in the new interpreter