Skip to content

Commit

Permalink
[clang][Interp] Handle mixed floating/integral compound assign operators
Browse files Browse the repository at this point in the history
Differential Revision: https://reviews.llvm.org/D157596
  • Loading branch information
tbaederr committed Sep 5, 2023
1 parent 168e23c commit adb1fb4
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 27 deletions.
82 changes: 56 additions & 26 deletions clang/lib/AST/Interp/ByteCodeExprGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -878,19 +878,22 @@ bool ByteCodeExprGen<Emitter>::VisitCharacterLiteral(
template <class Emitter>
bool ByteCodeExprGen<Emitter>::VisitFloatCompoundAssignOperator(
const CompoundAssignOperator *E) {
assert(E->getType()->isFloatingType());

const Expr *LHS = E->getLHS();
const Expr *RHS = E->getRHS();
llvm::RoundingMode RM = getRoundingMode(E);
QualType LHSType = LHS->getType();
QualType LHSComputationType = E->getComputationLHSType();
QualType ResultType = E->getComputationResultType();
std::optional<PrimType> LT = classify(LHSComputationType);
std::optional<PrimType> RT = classify(ResultType);

assert(ResultType->isFloatingType());

if (!LT || !RT)
return false;

PrimType LHST = classifyPrim(LHSType);

// C++17 onwards require that we evaluate the RHS first.
// Compute RHS and save it in a temporary variable so we can
// load it again later.
Expand All @@ -904,21 +907,19 @@ bool ByteCodeExprGen<Emitter>::VisitFloatCompoundAssignOperator(
// First, visit LHS.
if (!visit(LHS))
return false;
if (!this->emitLoad(*LT, E))
if (!this->emitLoad(LHST, E))
return false;

// If necessary, convert LHS to its computation type.
if (LHS->getType() != LHSComputationType) {
const auto *TargetSemantics = &Ctx.getFloatSemantics(LHSComputationType);

if (!this->emitCastFP(TargetSemantics, RM, E))
return false;
}
if (!this->emitPrimCast(LHST, classifyPrim(LHSComputationType),
LHSComputationType, E))
return false;

// Now load RHS.
if (!this->emitGetLocal(*RT, TempOffset, E))
return false;

llvm::RoundingMode RM = getRoundingMode(E);
switch (E->getOpcode()) {
case BO_AddAssign:
if (!this->emitAddf(RM, E))
Expand All @@ -940,17 +941,12 @@ bool ByteCodeExprGen<Emitter>::VisitFloatCompoundAssignOperator(
return false;
}

// If necessary, convert result to LHS's type.
if (LHS->getType() != ResultType) {
const auto *TargetSemantics = &Ctx.getFloatSemantics(LHS->getType());

if (!this->emitCastFP(TargetSemantics, RM, E))
return false;
}
if (!this->emitPrimCast(classifyPrim(ResultType), LHST, LHS->getType(), E))
return false;

if (DiscardResult)
return this->emitStorePop(*LT, E);
return this->emitStore(*LT, E);
return this->emitStorePop(LHST, E);
return this->emitStore(LHST, E);
}

template <class Emitter>
Expand Down Expand Up @@ -992,14 +988,6 @@ template <class Emitter>
bool ByteCodeExprGen<Emitter>::VisitCompoundAssignOperator(
const CompoundAssignOperator *E) {

// Handle floating point operations separately here, since they
// require special care.
if (E->getType()->isFloatingType())
return VisitFloatCompoundAssignOperator(E);

if (E->getType()->isPointerType())
return VisitPointerCompoundAssignOperator(E);

const Expr *LHS = E->getLHS();
const Expr *RHS = E->getRHS();
std::optional<PrimType> LHSComputationT =
Expand All @@ -1011,6 +999,15 @@ bool ByteCodeExprGen<Emitter>::VisitCompoundAssignOperator(
if (!LT || !RT || !ResultT || !LHSComputationT)
return false;

// Handle floating point operations separately here, since they
// require special care.

if (ResultT == PT_Float || RT == PT_Float)
return VisitFloatCompoundAssignOperator(E);

if (E->getType()->isPointerType())
return VisitPointerCompoundAssignOperator(E);

assert(!E->getType()->isPointerType() && "Handled above");
assert(!E->getType()->isFloatingType() && "Handled above");

Expand Down Expand Up @@ -2383,6 +2380,39 @@ ByteCodeExprGen<Emitter>::collectBaseOffset(const RecordType *BaseType,
return OffsetSum;
}

/// Emit casts from a PrimType to another PrimType.
template <class Emitter>
bool ByteCodeExprGen<Emitter>::emitPrimCast(PrimType FromT, PrimType ToT,
QualType ToQT, const Expr *E) {

if (FromT == PT_Float) {
// Floating to floating.
if (ToT == PT_Float) {
const llvm::fltSemantics *ToSem = &Ctx.getFloatSemantics(ToQT);
return this->emitCastFP(ToSem, getRoundingMode(E), E);
}

// Float to integral.
if (isIntegralType(ToT) || ToT == PT_Bool)
return this->emitCastFloatingIntegral(ToT, E);
}

if (isIntegralType(FromT) || FromT == PT_Bool) {
// Integral to integral.
if (isIntegralType(ToT) || ToT == PT_Bool)
return FromT != ToT ? this->emitCast(FromT, ToT, E) : true;

if (ToT == PT_Float) {
// Integral to floating.
const llvm::fltSemantics *ToSem = &Ctx.getFloatSemantics(ToQT);
return this->emitCastIntegralFloating(FromT, ToSem, getRoundingMode(E),
E);
}
}

return false;
}

/// When calling this, we have a pointer of the local-to-destroy
/// on the stack.
/// Emit destruction of record types (or arrays of record types).
Expand Down
1 change: 1 addition & 0 deletions clang/lib/AST/Interp/ByteCodeExprGen.h
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ class ByteCodeExprGen : public ConstStmtVisitor<ByteCodeExprGen<Emitter>, bool>,
return FPO.getRoundingMode();
}

bool emitPrimCast(PrimType FromT, PrimType ToT, QualType ToQT, const Expr *E);
bool emitRecordDestruction(const Descriptor *Desc);
bool emitDerivedToBaseCasts(const RecordType *DerivedType,
const RecordType *BaseType, const Expr *E);
Expand Down
2 changes: 1 addition & 1 deletion clang/lib/AST/Interp/PrimType.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
return OS;
}

constexpr bool isIntegralType(PrimType T) { return T <= PT_Uint64; }
constexpr bool isIntegralType(PrimType T) { return T <= PT_Bool; }

/// Mapping from primitive types to their representation.
template <PrimType T> struct PrimConv;
Expand Down
32 changes: 32 additions & 0 deletions clang/test/AST/Interp/floats.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,38 @@ namespace compound {
return a[1];
}
static_assert(ff() == 3, "");

constexpr float intPlusDouble() {
int a = 0;
a += 2.0;

return a;
}
static_assert(intPlusDouble() == 2, "");

constexpr double doublePlusInt() {
double a = 0.0;
a += 2;

return a;
}
static_assert(doublePlusInt() == 2, "");

constexpr float boolPlusDouble() {
bool a = 0;
a += 1.0;

return a;
}
static_assert(boolPlusDouble(), "");

constexpr bool doublePlusbool() {
double a = 0.0;
a += true;

return a;
}
static_assert(doublePlusbool() == 1.0, "");
}

namespace unary {
Expand Down

0 comments on commit adb1fb4

Please sign in to comment.