diff --git a/clang/lib/AST/Interp/ByteCodeExprGen.cpp b/clang/lib/AST/Interp/ByteCodeExprGen.cpp index 131525c98ca59..d13a805f0714f 100644 --- a/clang/lib/AST/Interp/ByteCodeExprGen.cpp +++ b/clang/lib/AST/Interp/ByteCodeExprGen.cpp @@ -878,19 +878,22 @@ bool ByteCodeExprGen::VisitCharacterLiteral( template bool ByteCodeExprGen::VisitFloatCompoundAssignOperator( const CompoundAssignOperator *E) { - assert(E->getType()->isFloatingType()); const Expr *LHS = E->getLHS(); const Expr *RHS = E->getRHS(); - llvm::RoundingMode RM = getRoundingMode(E); + QualType LHSType = LHS->getType(); QualType LHSComputationType = E->getComputationLHSType(); QualType ResultType = E->getComputationResultType(); std::optional LT = classify(LHSComputationType); std::optional RT = classify(ResultType); + assert(ResultType->isFloatingType()); + if (!LT || !RT) return false; + PrimType LHST = classifyPrim(LHSType); + // C++17 onwards require that we evaluate the RHS first. // Compute RHS and save it in a temporary variable so we can // load it again later. @@ -904,21 +907,19 @@ bool ByteCodeExprGen::VisitFloatCompoundAssignOperator( // First, visit LHS. if (!visit(LHS)) return false; - if (!this->emitLoad(*LT, E)) + if (!this->emitLoad(LHST, E)) return false; // If necessary, convert LHS to its computation type. - if (LHS->getType() != LHSComputationType) { - const auto *TargetSemantics = &Ctx.getFloatSemantics(LHSComputationType); - - if (!this->emitCastFP(TargetSemantics, RM, E)) - return false; - } + if (!this->emitPrimCast(LHST, classifyPrim(LHSComputationType), + LHSComputationType, E)) + return false; // Now load RHS. if (!this->emitGetLocal(*RT, TempOffset, E)) return false; + llvm::RoundingMode RM = getRoundingMode(E); switch (E->getOpcode()) { case BO_AddAssign: if (!this->emitAddf(RM, E)) @@ -940,17 +941,12 @@ bool ByteCodeExprGen::VisitFloatCompoundAssignOperator( return false; } - // If necessary, convert result to LHS's type. - if (LHS->getType() != ResultType) { - const auto *TargetSemantics = &Ctx.getFloatSemantics(LHS->getType()); - - if (!this->emitCastFP(TargetSemantics, RM, E)) - return false; - } + if (!this->emitPrimCast(classifyPrim(ResultType), LHST, LHS->getType(), E)) + return false; if (DiscardResult) - return this->emitStorePop(*LT, E); - return this->emitStore(*LT, E); + return this->emitStorePop(LHST, E); + return this->emitStore(LHST, E); } template @@ -992,14 +988,6 @@ template bool ByteCodeExprGen::VisitCompoundAssignOperator( const CompoundAssignOperator *E) { - // Handle floating point operations separately here, since they - // require special care. - if (E->getType()->isFloatingType()) - return VisitFloatCompoundAssignOperator(E); - - if (E->getType()->isPointerType()) - return VisitPointerCompoundAssignOperator(E); - const Expr *LHS = E->getLHS(); const Expr *RHS = E->getRHS(); std::optional LHSComputationT = @@ -1011,6 +999,15 @@ bool ByteCodeExprGen::VisitCompoundAssignOperator( if (!LT || !RT || !ResultT || !LHSComputationT) return false; + // Handle floating point operations separately here, since they + // require special care. + + if (ResultT == PT_Float || RT == PT_Float) + return VisitFloatCompoundAssignOperator(E); + + if (E->getType()->isPointerType()) + return VisitPointerCompoundAssignOperator(E); + assert(!E->getType()->isPointerType() && "Handled above"); assert(!E->getType()->isFloatingType() && "Handled above"); @@ -2383,6 +2380,39 @@ ByteCodeExprGen::collectBaseOffset(const RecordType *BaseType, return OffsetSum; } +/// Emit casts from a PrimType to another PrimType. +template +bool ByteCodeExprGen::emitPrimCast(PrimType FromT, PrimType ToT, + QualType ToQT, const Expr *E) { + + if (FromT == PT_Float) { + // Floating to floating. + if (ToT == PT_Float) { + const llvm::fltSemantics *ToSem = &Ctx.getFloatSemantics(ToQT); + return this->emitCastFP(ToSem, getRoundingMode(E), E); + } + + // Float to integral. + if (isIntegralType(ToT) || ToT == PT_Bool) + return this->emitCastFloatingIntegral(ToT, E); + } + + if (isIntegralType(FromT) || FromT == PT_Bool) { + // Integral to integral. + if (isIntegralType(ToT) || ToT == PT_Bool) + return FromT != ToT ? this->emitCast(FromT, ToT, E) : true; + + if (ToT == PT_Float) { + // Integral to floating. + const llvm::fltSemantics *ToSem = &Ctx.getFloatSemantics(ToQT); + return this->emitCastIntegralFloating(FromT, ToSem, getRoundingMode(E), + E); + } + } + + return false; +} + /// When calling this, we have a pointer of the local-to-destroy /// on the stack. /// Emit destruction of record types (or arrays of record types). diff --git a/clang/lib/AST/Interp/ByteCodeExprGen.h b/clang/lib/AST/Interp/ByteCodeExprGen.h index dda954320cd24..9b7593ce54f9e 100644 --- a/clang/lib/AST/Interp/ByteCodeExprGen.h +++ b/clang/lib/AST/Interp/ByteCodeExprGen.h @@ -265,6 +265,7 @@ class ByteCodeExprGen : public ConstStmtVisitor, bool>, return FPO.getRoundingMode(); } + bool emitPrimCast(PrimType FromT, PrimType ToT, QualType ToQT, const Expr *E); bool emitRecordDestruction(const Descriptor *Desc); bool emitDerivedToBaseCasts(const RecordType *DerivedType, const RecordType *BaseType, const Expr *E); diff --git a/clang/lib/AST/Interp/PrimType.h b/clang/lib/AST/Interp/PrimType.h index a4e2ae4355b54..7c7ee6120b89a 100644 --- a/clang/lib/AST/Interp/PrimType.h +++ b/clang/lib/AST/Interp/PrimType.h @@ -56,7 +56,7 @@ inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, return OS; } -constexpr bool isIntegralType(PrimType T) { return T <= PT_Uint64; } +constexpr bool isIntegralType(PrimType T) { return T <= PT_Bool; } /// Mapping from primitive types to their representation. template struct PrimConv; diff --git a/clang/test/AST/Interp/floats.cpp b/clang/test/AST/Interp/floats.cpp index 79e501b19a0ab..a3b058e1eafb3 100644 --- a/clang/test/AST/Interp/floats.cpp +++ b/clang/test/AST/Interp/floats.cpp @@ -102,6 +102,38 @@ namespace compound { return a[1]; } static_assert(ff() == 3, ""); + + constexpr float intPlusDouble() { + int a = 0; + a += 2.0; + + return a; + } + static_assert(intPlusDouble() == 2, ""); + + constexpr double doublePlusInt() { + double a = 0.0; + a += 2; + + return a; + } + static_assert(doublePlusInt() == 2, ""); + + constexpr float boolPlusDouble() { + bool a = 0; + a += 1.0; + + return a; + } + static_assert(boolPlusDouble(), ""); + + constexpr bool doublePlusbool() { + double a = 0.0; + a += true; + + return a; + } + static_assert(doublePlusbool() == 1.0, ""); } namespace unary {