-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[HLSL] add support for HLSLAggregateSplatCast and HLSLElementwiseCast to constant expression evaluator #164700
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Constant expression evaluator. Add tests. Fix/Add support for other minor necessary things.
|
@llvm/pr-subscribers-hlsl Author: Sarah Spall (spall) ChangesAdd support to handle these casts in the constant expression evaluator.
Add tests Patch is 28.32 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/164700.diff 4 Files Affected:
diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index 00aaaab957591..5dfb2b3e3491f 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -3828,6 +3828,333 @@ static bool CheckArraySize(EvalInfo &Info, const ConstantArrayType *CAT,
/*Diag=*/true);
}
+static bool handleScalarCast(EvalInfo &Info, const FPOptions FPO, const Expr *E,
+ QualType SourceTy, QualType DestTy,
+ APValue const &Original, APValue &Result) {
+ // boolean must be checked before integer
+ // since IsIntegerType() is true for bool
+ if (SourceTy->isBooleanType()) {
+ if (DestTy->isBooleanType()) {
+ Result = Original;
+ return true;
+ }
+ if (DestTy->isIntegerType() || DestTy->isRealFloatingType()) {
+ bool BoolResult;
+ if (!HandleConversionToBool(Original, BoolResult))
+ return false;
+ uint64_t IntResult = BoolResult;
+ Result = APValue(Info.Ctx.MakeIntValue(IntResult, DestTy));
+ // TODO destty is wrong here if destty is float....
+ // can we use sourcety here?
+ }
+ if (DestTy->isFloatingType()) {
+ APValue Result2 = APValue(APFloat(0.0));
+ if (!HandleIntToFloatCast(Info, E, FPO,
+ Info.Ctx.getIntTypeForBitwidth(64, true),
+ Result.getInt(), DestTy, Result2.getFloat()))
+ return false;
+ Result = Result2;
+ }
+ return true;
+ }
+ if (SourceTy->isIntegerType()) {
+ if (DestTy->isRealFloatingType()) {
+ Result = APValue(APFloat(0.0));
+ return HandleIntToFloatCast(Info, E, FPO, SourceTy, Original.getInt(),
+ DestTy, Result.getFloat());
+ }
+ if (DestTy->isBooleanType()) {
+ bool BoolResult;
+ if (!HandleConversionToBool(Original, BoolResult))
+ return false;
+ uint64_t IntResult = BoolResult;
+ Result = APValue(Info.Ctx.MakeIntValue(IntResult, DestTy));
+ return true;
+ }
+ if (DestTy->isIntegerType()) {
+ Result = APValue(
+ HandleIntToIntCast(Info, E, DestTy, SourceTy, Original.getInt()));
+ return true;
+ }
+ } else if (SourceTy->isRealFloatingType()) {
+ if (DestTy->isRealFloatingType()) {
+ Result = Original;
+ return HandleFloatToFloatCast(Info, E, SourceTy, DestTy,
+ Result.getFloat());
+ }
+ if (DestTy->isBooleanType()) {
+ bool BoolResult;
+ if (!HandleConversionToBool(Original, BoolResult))
+ return false;
+ uint64_t IntResult = BoolResult;
+ Result = APValue(Info.Ctx.MakeIntValue(IntResult, DestTy));
+ return true;
+ }
+ if (DestTy->isIntegerType()) {
+ Result = APValue(APSInt());
+ return HandleFloatToIntCast(Info, E, SourceTy, Original.getFloat(),
+ DestTy, Result.getInt());
+ }
+ }
+
+ // Info.FFDiag(E, diag::err_convertvector_constexpr_unsupported_vector_cast)
+ // << SourceTy << DestTy;
+ return false;
+}
+
+// do the heavy lifting for casting to aggregate types
+// because we have to deal with bitfields specially
+static bool constructAggregate(EvalInfo &Info, const FPOptions FPO,
+ const Expr *E, APValue &Result,
+ QualType ResultType,
+ SmallVectorImpl<APValue> &Elements,
+ SmallVectorImpl<QualType> &ElTypes) {
+
+ SmallVector<std::tuple<APValue *, QualType, unsigned>> WorkList = {
+ {&Result, ResultType, 0}};
+
+ unsigned ElI = 0;
+ while (!WorkList.empty() && ElI < Elements.size()) {
+ auto [Res, Type, BitWidth] = WorkList.pop_back_val();
+
+ if (Type->isRealFloatingType() || Type->isBooleanType()) {
+ if (!handleScalarCast(Info, FPO, E, ElTypes[ElI], Type, Elements[ElI],
+ *Res))
+ return false;
+ ElI++;
+ continue;
+ }
+ if (Type->isIntegerType()) {
+ if (!handleScalarCast(Info, FPO, E, ElTypes[ElI], Type, Elements[ElI],
+ *Res))
+ return false;
+ if (BitWidth > 0) {
+ if (!Res->isInt())
+ return false;
+ APSInt &Int = Res->getInt();
+ unsigned OldBitWidth = Int.getBitWidth();
+ unsigned NewBitWidth = BitWidth;
+ if (NewBitWidth < OldBitWidth)
+ Int = Int.trunc(NewBitWidth).extend(OldBitWidth);
+ }
+ ElI++;
+ continue;
+ }
+ if (Type->isVectorType()) {
+ QualType ElTy = Type->castAs<VectorType>()->getElementType();
+ unsigned NumEl = Type->castAs<VectorType>()->getNumElements();
+ SmallVector<APValue> Vals(NumEl);
+ for (unsigned I = 0; I < NumEl; ++I) {
+ if (!handleScalarCast(Info, FPO, E, ElTypes[ElI], ElTy, Elements[ElI],
+ Vals[I]))
+ return false;
+ ElI++;
+ }
+ *Res = APValue(Vals.data(), NumEl);
+ continue;
+ }
+ if (Type->isConstantArrayType()) {
+ QualType ElTy = cast<ConstantArrayType>(Info.Ctx.getAsArrayType(Type))
+ ->getElementType();
+ uint64_t Size =
+ cast<ConstantArrayType>(Info.Ctx.getAsArrayType(Type))->getZExtSize();
+ *Res = APValue(APValue::UninitArray(), Size, Size);
+ for (int64_t I = Size - 1; I > -1; --I) {
+ WorkList.emplace_back(&Res->getArrayInitializedElt(I), ElTy, 0u);
+ }
+ continue;
+ }
+ if (Type->isRecordType()) {
+ const RecordDecl *RD = Type->getAsRecordDecl();
+
+ unsigned NumBases = 0;
+ if (auto *CXXRD = dyn_cast<CXXRecordDecl>(RD))
+ NumBases = CXXRD->getNumBases();
+
+ *Res = APValue(APValue::UninitStruct(), NumBases,
+ std::distance(RD->field_begin(), RD->field_end()));
+
+ SmallVector<std::tuple<APValue *, QualType, unsigned>> ReverseList;
+ // we need to traverse backwards
+ // Visit the base classes.
+ if (auto *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
+ // todo assert there is only 1 base at most
+ for (size_t I = 0, E = CXXRD->getNumBases(); I != E; ++I) {
+ const CXXBaseSpecifier &BS = CXXRD->bases_begin()[I];
+ ReverseList.emplace_back(&Res->getStructBase(I), BS.getType(), 0u);
+ }
+ }
+
+ // Visit the fields.
+ for (FieldDecl *FD : RD->fields()) {
+ unsigned FDBW = 0;
+ if (FD->isUnnamedBitField())
+ continue;
+ if (FD->isBitField()) {
+ FDBW = FD->getBitWidthValue();
+ }
+
+ ReverseList.emplace_back(&Res->getStructField(FD->getFieldIndex()),
+ FD->getType(), FDBW);
+ }
+
+ std::reverse(ReverseList.begin(), ReverseList.end());
+ llvm::append_range(WorkList, ReverseList);
+ continue;
+ }
+ return false;
+ }
+ return true;
+}
+
+static bool handleElementwiseCast(EvalInfo &Info, const Expr *E,
+ const FPOptions FPO,
+ SmallVectorImpl<APValue> &Elements,
+ SmallVectorImpl<QualType> &SrcTypes,
+ SmallVectorImpl<QualType> &DestTypes,
+ SmallVectorImpl<APValue> &Results) {
+
+ assert((Elements.size() == SrcTypes.size()) &&
+ (Elements.size() == DestTypes.size()));
+
+ for (unsigned I = 0, ESz = Elements.size(); I < ESz; ++I) {
+ APValue Original = Elements[I];
+ QualType SourceTy = SrcTypes[I];
+ QualType DestTy = DestTypes[I];
+
+ if (!handleScalarCast(Info, FPO, E, SourceTy, DestTy, Original, Results[I]))
+ return false;
+ }
+ return true;
+}
+
+static unsigned elementwiseSize(EvalInfo &Info, QualType BaseTy) {
+
+ SmallVector<QualType> WorkList = {BaseTy};
+
+ unsigned Size = 0;
+ while (!WorkList.empty()) {
+ QualType Type = WorkList.pop_back_val();
+ if (Type->isRealFloatingType() || Type->isIntegerType() ||
+ Type->isBooleanType()) {
+ ++Size;
+ continue;
+ }
+ if (Type->isVectorType()) {
+ unsigned NumEl = Type->castAs<VectorType>()->getNumElements();
+ Size += NumEl;
+ continue;
+ }
+ if (Type->isConstantArrayType()) {
+ QualType ElTy = cast<ConstantArrayType>(Info.Ctx.getAsArrayType(Type))
+ ->getElementType();
+ uint64_t Size =
+ cast<ConstantArrayType>(Info.Ctx.getAsArrayType(Type))->getZExtSize();
+ for (uint64_t I = 0; I < Size; ++I) {
+ WorkList.push_back(ElTy);
+ }
+ continue;
+ }
+ if (Type->isRecordType()) {
+ const RecordDecl *RD = Type->getAsRecordDecl();
+ // const ASTRecordLayout &Layout = Info.Ctx.getASTRecordLayout(RD);
+
+ // Visit the base classes.
+ if (auto *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
+ // todo assert there is only 1 base at most
+ for (size_t I = 0, E = CXXRD->getNumBases(); I != E; ++I) {
+ const CXXBaseSpecifier &BS = CXXRD->bases_begin()[I];
+ WorkList.push_back(BS.getType());
+ }
+ }
+
+ // visit the fields.
+ for (FieldDecl *FD : RD->fields()) {
+ if (FD->isUnnamedBitField())
+ continue;
+ WorkList.push_back(FD->getType());
+ }
+ continue;
+ }
+ }
+ return Size;
+}
+
+static bool flattenAPValue(const ASTContext &Ctx, APValue Value,
+ QualType BaseTy, SmallVectorImpl<APValue> &Elements,
+ SmallVectorImpl<QualType> &Types, unsigned Size) {
+
+ SmallVector<std::pair<APValue, QualType>> WorkList = {{Value, BaseTy}};
+ unsigned Populated = 0;
+ while (!WorkList.empty() && Populated < Size) {
+ auto [Work, Type] = WorkList.pop_back_val();
+
+ if (Work.isFloat() || Work.isInt()) { // todo what does this do with bool
+ Elements.push_back(Work);
+ Types.push_back(Type);
+ Populated++;
+ continue;
+ }
+ if (Work.isVector()) {
+ assert(Type->isVectorType() && "Type mismatch.");
+ QualType ElTy = Type->castAs<VectorType>()->getElementType();
+ for (unsigned I = 0; I < Work.getVectorLength() && Populated < Size;
+ I++) {
+ Elements.push_back(Work.getVectorElt(I));
+ Types.push_back(ElTy);
+ Populated++;
+ }
+ continue;
+ }
+ if (Work.isArray()) {
+ assert(Type->isConstantArrayType() && "Type mismatch.");
+ QualType ElTy =
+ cast<ConstantArrayType>(Ctx.getAsArrayType(Type))->getElementType();
+ for (int64_t I = Work.getArraySize() - 1; I > -1; --I) {
+ WorkList.emplace_back(Work.getArrayInitializedElt(I), ElTy);
+ }
+ continue;
+ }
+
+ if (Work.isStruct()) {
+ assert(Type->isRecordType() && "Type mismatch.");
+
+ const RecordDecl *RD = Type->getAsRecordDecl();
+
+ SmallVector<std::pair<APValue, QualType>> ReverseList;
+ // Visit the fields.
+ for (FieldDecl *FD : RD->fields()) {
+ if (FD->isUnnamedBitField())
+ continue;
+ // if (FD->isBitField()) {
+ ReverseList.emplace_back(Work.getStructField(FD->getFieldIndex()),
+ FD->getType());
+ }
+
+ std::reverse(ReverseList.begin(), ReverseList.end());
+ llvm::append_range(WorkList, ReverseList);
+
+ // Visit the base classes.
+ if (auto *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
+ if (CXXRD->getNumBases() > 0) {
+ assert(CXXRD->getNumBases() == 1);
+ const CXXBaseSpecifier &BS = CXXRD->bases_begin()[0];
+ const APValue &Base = Work.getStructBase(0);
+
+ // Can happen in error cases.
+ if (!Base.isStruct())
+ return false;
+
+ WorkList.emplace_back(Base, BS.getType());
+ }
+ }
+ continue;
+ }
+ return false;
+ }
+ return true;
+}
+
namespace {
/// A handle to a complete object (an object that is not a subobject of
/// another object).
@@ -8666,6 +8993,25 @@ class ExprEvaluatorBase
case CK_UserDefinedConversion:
return StmtVisitorTy::Visit(E->getSubExpr());
+ case CK_HLSLArrayRValue: {
+ const Expr *SubExpr = E->getSubExpr();
+ if (!SubExpr->isGLValue()) {
+ APValue Val;
+ if (!Evaluate(Val, Info, SubExpr))
+ return false;
+ return DerivedSuccess(Val, E);
+ }
+
+ LValue LVal;
+ if (!EvaluateLValue(SubExpr, LVal, Info))
+ return false;
+ APValue RVal;
+ // Note, we use the subexpression's type in order to retain cv-qualifiers.
+ if (!handleLValueToRValueConversion(Info, E, SubExpr->getType(), LVal,
+ RVal))
+ return false;
+ return DerivedSuccess(RVal, E);
+ }
case CK_LValueToRValue: {
LValue LVal;
if (!EvaluateLValue(E->getSubExpr(), LVal, Info))
@@ -10850,6 +11196,67 @@ bool RecordExprEvaluator::VisitCastExpr(const CastExpr *E) {
Result = *Value;
return true;
}
+ case CK_HLSLAggregateSplatCast: {
+ APValue Val;
+ const Expr *SE = E->getSubExpr();
+
+ if (!Evaluate(Val, Info, SE))
+ return Error(E);
+
+ unsigned NEls = elementwiseSize(Info, E->getType());
+ // flatten the source
+ SmallVector<APValue, 1> SrcEls;
+ SmallVector<QualType, 1> SrcTypes;
+ if (!flattenAPValue(Info.Ctx, Val, SE->getType(), SrcEls, SrcTypes, NEls))
+ return Error(E);
+
+ // check there is only one and splat it
+ assert(SrcEls.size() == 1);
+ SmallVector<APValue> SplatEls(NEls, SrcEls[0]);
+ SmallVector<QualType> SplatType(NEls, SrcTypes[0]);
+
+ APValue Tmp;
+ handleDefaultInitValue(E->getType(), Tmp);
+
+ // cast the elements and construct our struct result
+ const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+ if (!constructAggregate(Info, FPO, E, Result, E->getType(), SplatEls,
+ SplatType))
+ return Error(E);
+
+ return true;
+ }
+ case CK_HLSLElementwiseCast: {
+ APValue Val;
+ const Expr *SE = E->getSubExpr();
+
+ if (!Evaluate(Val, Info, SE))
+ return Error(E);
+
+ // must be dealing with a record;
+ if (Val.isLValue()) {
+ LValue LVal;
+ LVal.setFrom(Info.Ctx, Val);
+ if (!handleLValueToRValueConversion(Info, SE, SE->getType(), LVal, Val))
+ return false;
+ }
+
+ // flatten the source
+ SmallVector<APValue> SrcEls;
+ SmallVector<QualType> SrcTypes;
+ if (!flattenAPValue(Info.Ctx, Val, SE->getType(), SrcEls, SrcTypes,
+ UINT_MAX))
+ return Error(E);
+
+ // cast the elements and construct our struct result
+ const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+
+ if (!constructAggregate(Info, FPO, E, Result, E->getType(), SrcEls,
+ SrcTypes))
+ return Error(E);
+
+ return true;
+ }
}
}
@@ -11345,6 +11752,58 @@ bool VectorExprEvaluator::VisitCastExpr(const CastExpr *E) {
Elements.push_back(Val.getVectorElt(I));
return Success(Elements, E);
}
+ case CK_HLSLAggregateSplatCast: {
+ APValue Val;
+
+ if (!Evaluate(Val, Info, SE))
+ return Error(E);
+
+ // this cast doesn't handle splatting from scalars when result is a vector
+ SmallVector<APValue, 1> Elements;
+ SmallVector<QualType, 1> DestTypes = {VTy->getElementType()};
+ SmallVector<QualType, 1> SrcTypes;
+ if (!flattenAPValue(Info.Ctx, Val, SETy, Elements, SrcTypes, NElts))
+ return Error(E);
+
+ // check there is only one element and cast and splat it
+ assert(Elements.size() == 1 &&
+ "HLSLAggregateSplatCast RHS must contain one element");
+ const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+ SmallVector<APValue, 1> ResultEls(1);
+ if (!handleElementwiseCast(Info, E, FPO, Elements, SrcTypes, DestTypes,
+ ResultEls))
+ return Error(E);
+
+ SmallVector<APValue, 4> SplatEls(NElts, ResultEls[0]);
+ return Success(SplatEls, E);
+ }
+ case CK_HLSLElementwiseCast: {
+ APValue Val;
+
+ if (!Evaluate(Val, Info, SE))
+ return Error(E);
+
+ // must be dealing with a record;
+ if (Val.isLValue()) {
+ LValue LVal;
+ LVal.setFrom(Info.Ctx, Val);
+ if (!handleLValueToRValueConversion(Info, SE, SE->getType(), LVal, Val))
+ return false;
+ }
+
+ SmallVector<APValue, 4> Elements;
+ SmallVector<QualType, 4> DestTypes(NElts, VTy->getElementType());
+ SmallVector<QualType, 4> SrcTypes;
+ if (!flattenAPValue(Info.Ctx, Val, SETy, Elements, SrcTypes, NElts))
+ return Error(E);
+ // cast elements
+ const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+ SmallVector<APValue, 4> ResultEls(NElts);
+ if (!handleElementwiseCast(Info, E, FPO, Elements, SrcTypes, DestTypes,
+ ResultEls))
+ return Error(E);
+ return Success(ResultEls, E);
+ }
default:
return ExprEvaluatorBaseTy::VisitCastExpr(E);
}
@@ -13029,6 +13488,7 @@ namespace {
bool VisitCallExpr(const CallExpr *E) {
return handleCallExpr(E, Result, &This);
}
+ bool VisitCastExpr(const CastExpr *E);
bool VisitInitListExpr(const InitListExpr *E,
QualType AllocType = QualType());
bool VisitArrayInitLoopExpr(const ArrayInitLoopExpr *E);
@@ -13099,6 +13559,70 @@ static bool MaybeElementDependentArrayFiller(const Expr *FillerExpr) {
return true;
}
+bool ArrayExprEvaluator::VisitCastExpr(const CastExpr *E) {
+ const Expr *SE = E->getSubExpr();
+
+ switch (E->getCastKind()) {
+ default:
+ return ExprEvaluatorBaseTy::VisitCastExpr(E);
+ case CK_HLSLAggregateSplatCast: {
+ APValue Val;
+
+ if (!Evaluate(Val, Info, SE))
+ return Error(E);
+
+ unsigned NEls = elementwiseSize(Info, E->getType());
+ // flatten the source
+ SmallVector<APValue, 1> SrcEls;
+ SmallVector<QualType, 1> SrcTypes;
+ if (!flattenAPValue(Info.Ctx, Val, SE->getType(), SrcEls, SrcTypes, NEls))
+ return Error(E);
+
+ // check there is only one and splat it
+ assert(SrcEls.size() == 1);
+ SmallVector<APValue> SplatEls(NEls, SrcEls[0]);
+ SmallVector<QualType> SplatType(NEls, SrcTypes[0]);
+
+ // cast the elements
+ const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+ if (!constructAggregate(Info, FPO, E, Result, E->getType(), SplatEls,
+ SplatType))
+ return Error(E);
+
+ return true;
+ }
+ case CK_HLSLElementwiseCast: {
+ APValue Val;
+
+ if (!Evaluate(Val, Info, SE))
+ return Error(E);
+
+ // must be dealing with a record;
+ if (Val.isLValue()) {
+ LValue LVal;
+ LVal.setFrom(Info.Ctx, Val);
+ if (!handleLValueToRValueConversion(Info, SE, SE->getType(), LVal, Val))
+ return false;
+ }
+
+ // flatten the source
+ SmallVector<APValue> SrcEls;
+ SmallVector<QualType> SrcTypes;
+ if (!flattenAPValue(Info.Ctx, Val, SE->getType(), SrcEls, SrcTypes,
+ UINT_MAX))
+ return Error(E);
+
+ // cast the elements
+ const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+ if (!constructAggregate(Info, FPO, E, Result, E->getType(), SrcEls,
+ SrcTypes))
+ return Error(E);
+
+ return true;
+ }
+ }
+}
+
bool ArrayExprEvaluator::VisitInitListExpr(const InitListExpr *E,
QualType AllocType) {
const ConstantArrayType *CAT = Info.Ctx.getAsConstantArrayType(
@@ -16801,7 +17325,6 @@ bool IntExprEvaluator::VisitCastExpr(const CastExpr *E) {
case CK_NoOp:
case CK_LValueToRValueBitCast:
case CK_HLSLArrayRValue:
- case CK_HLSLElementwiseCast:
return ExprEvaluatorBaseTy::VisitCastExpr(E);
case CK_MemberPointerToBoolean:
@@ -16948,6 +17471,35 @@ bool IntExprEvaluator::VisitCastExpr(const CastExpr *E) {
return Error(E);
return Success(Val.getVectorElt(0), E);
}
+ case CK_HLSLElementwiseCast: {
+ APValue Val;
+
+ if (!Evaluate(Val, Info, SubExpr))
+ return Error(E);
+
+ // must be dealing with a record;
+ if (Val.isLValue()) {
+ LValue LVal;
+ LVal.setFrom(Info.Ctx, Val);
+ if (!handleLValueToRValueConversion(Info, SubExpr, SubExpr->getType(),
+ LVal, Val...
[truncated]
|
|
@llvm/pr-subscribers-clang Author: Sarah Spall (spall) ChangesAdd support to handle these casts in the constant expression evaluator.
Add tests Patch is 28.32 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/164700.diff 4 Files Affected:
diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index 00aaaab957591..5dfb2b3e3491f 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -3828,6 +3828,333 @@ static bool CheckArraySize(EvalInfo &Info, const ConstantArrayType *CAT,
/*Diag=*/true);
}
+static bool handleScalarCast(EvalInfo &Info, const FPOptions FPO, const Expr *E,
+ QualType SourceTy, QualType DestTy,
+ APValue const &Original, APValue &Result) {
+ // boolean must be checked before integer
+ // since IsIntegerType() is true for bool
+ if (SourceTy->isBooleanType()) {
+ if (DestTy->isBooleanType()) {
+ Result = Original;
+ return true;
+ }
+ if (DestTy->isIntegerType() || DestTy->isRealFloatingType()) {
+ bool BoolResult;
+ if (!HandleConversionToBool(Original, BoolResult))
+ return false;
+ uint64_t IntResult = BoolResult;
+ Result = APValue(Info.Ctx.MakeIntValue(IntResult, DestTy));
+ // TODO destty is wrong here if destty is float....
+ // can we use sourcety here?
+ }
+ if (DestTy->isFloatingType()) {
+ APValue Result2 = APValue(APFloat(0.0));
+ if (!HandleIntToFloatCast(Info, E, FPO,
+ Info.Ctx.getIntTypeForBitwidth(64, true),
+ Result.getInt(), DestTy, Result2.getFloat()))
+ return false;
+ Result = Result2;
+ }
+ return true;
+ }
+ if (SourceTy->isIntegerType()) {
+ if (DestTy->isRealFloatingType()) {
+ Result = APValue(APFloat(0.0));
+ return HandleIntToFloatCast(Info, E, FPO, SourceTy, Original.getInt(),
+ DestTy, Result.getFloat());
+ }
+ if (DestTy->isBooleanType()) {
+ bool BoolResult;
+ if (!HandleConversionToBool(Original, BoolResult))
+ return false;
+ uint64_t IntResult = BoolResult;
+ Result = APValue(Info.Ctx.MakeIntValue(IntResult, DestTy));
+ return true;
+ }
+ if (DestTy->isIntegerType()) {
+ Result = APValue(
+ HandleIntToIntCast(Info, E, DestTy, SourceTy, Original.getInt()));
+ return true;
+ }
+ } else if (SourceTy->isRealFloatingType()) {
+ if (DestTy->isRealFloatingType()) {
+ Result = Original;
+ return HandleFloatToFloatCast(Info, E, SourceTy, DestTy,
+ Result.getFloat());
+ }
+ if (DestTy->isBooleanType()) {
+ bool BoolResult;
+ if (!HandleConversionToBool(Original, BoolResult))
+ return false;
+ uint64_t IntResult = BoolResult;
+ Result = APValue(Info.Ctx.MakeIntValue(IntResult, DestTy));
+ return true;
+ }
+ if (DestTy->isIntegerType()) {
+ Result = APValue(APSInt());
+ return HandleFloatToIntCast(Info, E, SourceTy, Original.getFloat(),
+ DestTy, Result.getInt());
+ }
+ }
+
+ // Info.FFDiag(E, diag::err_convertvector_constexpr_unsupported_vector_cast)
+ // << SourceTy << DestTy;
+ return false;
+}
+
+// do the heavy lifting for casting to aggregate types
+// because we have to deal with bitfields specially
+static bool constructAggregate(EvalInfo &Info, const FPOptions FPO,
+ const Expr *E, APValue &Result,
+ QualType ResultType,
+ SmallVectorImpl<APValue> &Elements,
+ SmallVectorImpl<QualType> &ElTypes) {
+
+ SmallVector<std::tuple<APValue *, QualType, unsigned>> WorkList = {
+ {&Result, ResultType, 0}};
+
+ unsigned ElI = 0;
+ while (!WorkList.empty() && ElI < Elements.size()) {
+ auto [Res, Type, BitWidth] = WorkList.pop_back_val();
+
+ if (Type->isRealFloatingType() || Type->isBooleanType()) {
+ if (!handleScalarCast(Info, FPO, E, ElTypes[ElI], Type, Elements[ElI],
+ *Res))
+ return false;
+ ElI++;
+ continue;
+ }
+ if (Type->isIntegerType()) {
+ if (!handleScalarCast(Info, FPO, E, ElTypes[ElI], Type, Elements[ElI],
+ *Res))
+ return false;
+ if (BitWidth > 0) {
+ if (!Res->isInt())
+ return false;
+ APSInt &Int = Res->getInt();
+ unsigned OldBitWidth = Int.getBitWidth();
+ unsigned NewBitWidth = BitWidth;
+ if (NewBitWidth < OldBitWidth)
+ Int = Int.trunc(NewBitWidth).extend(OldBitWidth);
+ }
+ ElI++;
+ continue;
+ }
+ if (Type->isVectorType()) {
+ QualType ElTy = Type->castAs<VectorType>()->getElementType();
+ unsigned NumEl = Type->castAs<VectorType>()->getNumElements();
+ SmallVector<APValue> Vals(NumEl);
+ for (unsigned I = 0; I < NumEl; ++I) {
+ if (!handleScalarCast(Info, FPO, E, ElTypes[ElI], ElTy, Elements[ElI],
+ Vals[I]))
+ return false;
+ ElI++;
+ }
+ *Res = APValue(Vals.data(), NumEl);
+ continue;
+ }
+ if (Type->isConstantArrayType()) {
+ QualType ElTy = cast<ConstantArrayType>(Info.Ctx.getAsArrayType(Type))
+ ->getElementType();
+ uint64_t Size =
+ cast<ConstantArrayType>(Info.Ctx.getAsArrayType(Type))->getZExtSize();
+ *Res = APValue(APValue::UninitArray(), Size, Size);
+ for (int64_t I = Size - 1; I > -1; --I) {
+ WorkList.emplace_back(&Res->getArrayInitializedElt(I), ElTy, 0u);
+ }
+ continue;
+ }
+ if (Type->isRecordType()) {
+ const RecordDecl *RD = Type->getAsRecordDecl();
+
+ unsigned NumBases = 0;
+ if (auto *CXXRD = dyn_cast<CXXRecordDecl>(RD))
+ NumBases = CXXRD->getNumBases();
+
+ *Res = APValue(APValue::UninitStruct(), NumBases,
+ std::distance(RD->field_begin(), RD->field_end()));
+
+ SmallVector<std::tuple<APValue *, QualType, unsigned>> ReverseList;
+ // we need to traverse backwards
+ // Visit the base classes.
+ if (auto *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
+ // todo assert there is only 1 base at most
+ for (size_t I = 0, E = CXXRD->getNumBases(); I != E; ++I) {
+ const CXXBaseSpecifier &BS = CXXRD->bases_begin()[I];
+ ReverseList.emplace_back(&Res->getStructBase(I), BS.getType(), 0u);
+ }
+ }
+
+ // Visit the fields.
+ for (FieldDecl *FD : RD->fields()) {
+ unsigned FDBW = 0;
+ if (FD->isUnnamedBitField())
+ continue;
+ if (FD->isBitField()) {
+ FDBW = FD->getBitWidthValue();
+ }
+
+ ReverseList.emplace_back(&Res->getStructField(FD->getFieldIndex()),
+ FD->getType(), FDBW);
+ }
+
+ std::reverse(ReverseList.begin(), ReverseList.end());
+ llvm::append_range(WorkList, ReverseList);
+ continue;
+ }
+ return false;
+ }
+ return true;
+}
+
+static bool handleElementwiseCast(EvalInfo &Info, const Expr *E,
+ const FPOptions FPO,
+ SmallVectorImpl<APValue> &Elements,
+ SmallVectorImpl<QualType> &SrcTypes,
+ SmallVectorImpl<QualType> &DestTypes,
+ SmallVectorImpl<APValue> &Results) {
+
+ assert((Elements.size() == SrcTypes.size()) &&
+ (Elements.size() == DestTypes.size()));
+
+ for (unsigned I = 0, ESz = Elements.size(); I < ESz; ++I) {
+ APValue Original = Elements[I];
+ QualType SourceTy = SrcTypes[I];
+ QualType DestTy = DestTypes[I];
+
+ if (!handleScalarCast(Info, FPO, E, SourceTy, DestTy, Original, Results[I]))
+ return false;
+ }
+ return true;
+}
+
+static unsigned elementwiseSize(EvalInfo &Info, QualType BaseTy) {
+
+ SmallVector<QualType> WorkList = {BaseTy};
+
+ unsigned Size = 0;
+ while (!WorkList.empty()) {
+ QualType Type = WorkList.pop_back_val();
+ if (Type->isRealFloatingType() || Type->isIntegerType() ||
+ Type->isBooleanType()) {
+ ++Size;
+ continue;
+ }
+ if (Type->isVectorType()) {
+ unsigned NumEl = Type->castAs<VectorType>()->getNumElements();
+ Size += NumEl;
+ continue;
+ }
+ if (Type->isConstantArrayType()) {
+ QualType ElTy = cast<ConstantArrayType>(Info.Ctx.getAsArrayType(Type))
+ ->getElementType();
+ uint64_t Size =
+ cast<ConstantArrayType>(Info.Ctx.getAsArrayType(Type))->getZExtSize();
+ for (uint64_t I = 0; I < Size; ++I) {
+ WorkList.push_back(ElTy);
+ }
+ continue;
+ }
+ if (Type->isRecordType()) {
+ const RecordDecl *RD = Type->getAsRecordDecl();
+ // const ASTRecordLayout &Layout = Info.Ctx.getASTRecordLayout(RD);
+
+ // Visit the base classes.
+ if (auto *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
+ // todo assert there is only 1 base at most
+ for (size_t I = 0, E = CXXRD->getNumBases(); I != E; ++I) {
+ const CXXBaseSpecifier &BS = CXXRD->bases_begin()[I];
+ WorkList.push_back(BS.getType());
+ }
+ }
+
+ // visit the fields.
+ for (FieldDecl *FD : RD->fields()) {
+ if (FD->isUnnamedBitField())
+ continue;
+ WorkList.push_back(FD->getType());
+ }
+ continue;
+ }
+ }
+ return Size;
+}
+
+static bool flattenAPValue(const ASTContext &Ctx, APValue Value,
+ QualType BaseTy, SmallVectorImpl<APValue> &Elements,
+ SmallVectorImpl<QualType> &Types, unsigned Size) {
+
+ SmallVector<std::pair<APValue, QualType>> WorkList = {{Value, BaseTy}};
+ unsigned Populated = 0;
+ while (!WorkList.empty() && Populated < Size) {
+ auto [Work, Type] = WorkList.pop_back_val();
+
+ if (Work.isFloat() || Work.isInt()) { // todo what does this do with bool
+ Elements.push_back(Work);
+ Types.push_back(Type);
+ Populated++;
+ continue;
+ }
+ if (Work.isVector()) {
+ assert(Type->isVectorType() && "Type mismatch.");
+ QualType ElTy = Type->castAs<VectorType>()->getElementType();
+ for (unsigned I = 0; I < Work.getVectorLength() && Populated < Size;
+ I++) {
+ Elements.push_back(Work.getVectorElt(I));
+ Types.push_back(ElTy);
+ Populated++;
+ }
+ continue;
+ }
+ if (Work.isArray()) {
+ assert(Type->isConstantArrayType() && "Type mismatch.");
+ QualType ElTy =
+ cast<ConstantArrayType>(Ctx.getAsArrayType(Type))->getElementType();
+ for (int64_t I = Work.getArraySize() - 1; I > -1; --I) {
+ WorkList.emplace_back(Work.getArrayInitializedElt(I), ElTy);
+ }
+ continue;
+ }
+
+ if (Work.isStruct()) {
+ assert(Type->isRecordType() && "Type mismatch.");
+
+ const RecordDecl *RD = Type->getAsRecordDecl();
+
+ SmallVector<std::pair<APValue, QualType>> ReverseList;
+ // Visit the fields.
+ for (FieldDecl *FD : RD->fields()) {
+ if (FD->isUnnamedBitField())
+ continue;
+ // if (FD->isBitField()) {
+ ReverseList.emplace_back(Work.getStructField(FD->getFieldIndex()),
+ FD->getType());
+ }
+
+ std::reverse(ReverseList.begin(), ReverseList.end());
+ llvm::append_range(WorkList, ReverseList);
+
+ // Visit the base classes.
+ if (auto *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
+ if (CXXRD->getNumBases() > 0) {
+ assert(CXXRD->getNumBases() == 1);
+ const CXXBaseSpecifier &BS = CXXRD->bases_begin()[0];
+ const APValue &Base = Work.getStructBase(0);
+
+ // Can happen in error cases.
+ if (!Base.isStruct())
+ return false;
+
+ WorkList.emplace_back(Base, BS.getType());
+ }
+ }
+ continue;
+ }
+ return false;
+ }
+ return true;
+}
+
namespace {
/// A handle to a complete object (an object that is not a subobject of
/// another object).
@@ -8666,6 +8993,25 @@ class ExprEvaluatorBase
case CK_UserDefinedConversion:
return StmtVisitorTy::Visit(E->getSubExpr());
+ case CK_HLSLArrayRValue: {
+ const Expr *SubExpr = E->getSubExpr();
+ if (!SubExpr->isGLValue()) {
+ APValue Val;
+ if (!Evaluate(Val, Info, SubExpr))
+ return false;
+ return DerivedSuccess(Val, E);
+ }
+
+ LValue LVal;
+ if (!EvaluateLValue(SubExpr, LVal, Info))
+ return false;
+ APValue RVal;
+ // Note, we use the subexpression's type in order to retain cv-qualifiers.
+ if (!handleLValueToRValueConversion(Info, E, SubExpr->getType(), LVal,
+ RVal))
+ return false;
+ return DerivedSuccess(RVal, E);
+ }
case CK_LValueToRValue: {
LValue LVal;
if (!EvaluateLValue(E->getSubExpr(), LVal, Info))
@@ -10850,6 +11196,67 @@ bool RecordExprEvaluator::VisitCastExpr(const CastExpr *E) {
Result = *Value;
return true;
}
+ case CK_HLSLAggregateSplatCast: {
+ APValue Val;
+ const Expr *SE = E->getSubExpr();
+
+ if (!Evaluate(Val, Info, SE))
+ return Error(E);
+
+ unsigned NEls = elementwiseSize(Info, E->getType());
+ // flatten the source
+ SmallVector<APValue, 1> SrcEls;
+ SmallVector<QualType, 1> SrcTypes;
+ if (!flattenAPValue(Info.Ctx, Val, SE->getType(), SrcEls, SrcTypes, NEls))
+ return Error(E);
+
+ // check there is only one and splat it
+ assert(SrcEls.size() == 1);
+ SmallVector<APValue> SplatEls(NEls, SrcEls[0]);
+ SmallVector<QualType> SplatType(NEls, SrcTypes[0]);
+
+ APValue Tmp;
+ handleDefaultInitValue(E->getType(), Tmp);
+
+ // cast the elements and construct our struct result
+ const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+ if (!constructAggregate(Info, FPO, E, Result, E->getType(), SplatEls,
+ SplatType))
+ return Error(E);
+
+ return true;
+ }
+ case CK_HLSLElementwiseCast: {
+ APValue Val;
+ const Expr *SE = E->getSubExpr();
+
+ if (!Evaluate(Val, Info, SE))
+ return Error(E);
+
+ // must be dealing with a record;
+ if (Val.isLValue()) {
+ LValue LVal;
+ LVal.setFrom(Info.Ctx, Val);
+ if (!handleLValueToRValueConversion(Info, SE, SE->getType(), LVal, Val))
+ return false;
+ }
+
+ // flatten the source
+ SmallVector<APValue> SrcEls;
+ SmallVector<QualType> SrcTypes;
+ if (!flattenAPValue(Info.Ctx, Val, SE->getType(), SrcEls, SrcTypes,
+ UINT_MAX))
+ return Error(E);
+
+ // cast the elements and construct our struct result
+ const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+
+ if (!constructAggregate(Info, FPO, E, Result, E->getType(), SrcEls,
+ SrcTypes))
+ return Error(E);
+
+ return true;
+ }
}
}
@@ -11345,6 +11752,58 @@ bool VectorExprEvaluator::VisitCastExpr(const CastExpr *E) {
Elements.push_back(Val.getVectorElt(I));
return Success(Elements, E);
}
+ case CK_HLSLAggregateSplatCast: {
+ APValue Val;
+
+ if (!Evaluate(Val, Info, SE))
+ return Error(E);
+
+ // this cast doesn't handle splatting from scalars when result is a vector
+ SmallVector<APValue, 1> Elements;
+ SmallVector<QualType, 1> DestTypes = {VTy->getElementType()};
+ SmallVector<QualType, 1> SrcTypes;
+ if (!flattenAPValue(Info.Ctx, Val, SETy, Elements, SrcTypes, NElts))
+ return Error(E);
+
+ // check there is only one element and cast and splat it
+ assert(Elements.size() == 1 &&
+ "HLSLAggregateSplatCast RHS must contain one element");
+ const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+ SmallVector<APValue, 1> ResultEls(1);
+ if (!handleElementwiseCast(Info, E, FPO, Elements, SrcTypes, DestTypes,
+ ResultEls))
+ return Error(E);
+
+ SmallVector<APValue, 4> SplatEls(NElts, ResultEls[0]);
+ return Success(SplatEls, E);
+ }
+ case CK_HLSLElementwiseCast: {
+ APValue Val;
+
+ if (!Evaluate(Val, Info, SE))
+ return Error(E);
+
+ // must be dealing with a record;
+ if (Val.isLValue()) {
+ LValue LVal;
+ LVal.setFrom(Info.Ctx, Val);
+ if (!handleLValueToRValueConversion(Info, SE, SE->getType(), LVal, Val))
+ return false;
+ }
+
+ SmallVector<APValue, 4> Elements;
+ SmallVector<QualType, 4> DestTypes(NElts, VTy->getElementType());
+ SmallVector<QualType, 4> SrcTypes;
+ if (!flattenAPValue(Info.Ctx, Val, SETy, Elements, SrcTypes, NElts))
+ return Error(E);
+ // cast elements
+ const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+ SmallVector<APValue, 4> ResultEls(NElts);
+ if (!handleElementwiseCast(Info, E, FPO, Elements, SrcTypes, DestTypes,
+ ResultEls))
+ return Error(E);
+ return Success(ResultEls, E);
+ }
default:
return ExprEvaluatorBaseTy::VisitCastExpr(E);
}
@@ -13029,6 +13488,7 @@ namespace {
bool VisitCallExpr(const CallExpr *E) {
return handleCallExpr(E, Result, &This);
}
+ bool VisitCastExpr(const CastExpr *E);
bool VisitInitListExpr(const InitListExpr *E,
QualType AllocType = QualType());
bool VisitArrayInitLoopExpr(const ArrayInitLoopExpr *E);
@@ -13099,6 +13559,70 @@ static bool MaybeElementDependentArrayFiller(const Expr *FillerExpr) {
return true;
}
+bool ArrayExprEvaluator::VisitCastExpr(const CastExpr *E) {
+ const Expr *SE = E->getSubExpr();
+
+ switch (E->getCastKind()) {
+ default:
+ return ExprEvaluatorBaseTy::VisitCastExpr(E);
+ case CK_HLSLAggregateSplatCast: {
+ APValue Val;
+
+ if (!Evaluate(Val, Info, SE))
+ return Error(E);
+
+ unsigned NEls = elementwiseSize(Info, E->getType());
+ // flatten the source
+ SmallVector<APValue, 1> SrcEls;
+ SmallVector<QualType, 1> SrcTypes;
+ if (!flattenAPValue(Info.Ctx, Val, SE->getType(), SrcEls, SrcTypes, NEls))
+ return Error(E);
+
+ // check there is only one and splat it
+ assert(SrcEls.size() == 1);
+ SmallVector<APValue> SplatEls(NEls, SrcEls[0]);
+ SmallVector<QualType> SplatType(NEls, SrcTypes[0]);
+
+ // cast the elements
+ const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+ if (!constructAggregate(Info, FPO, E, Result, E->getType(), SplatEls,
+ SplatType))
+ return Error(E);
+
+ return true;
+ }
+ case CK_HLSLElementwiseCast: {
+ APValue Val;
+
+ if (!Evaluate(Val, Info, SE))
+ return Error(E);
+
+ // must be dealing with a record;
+ if (Val.isLValue()) {
+ LValue LVal;
+ LVal.setFrom(Info.Ctx, Val);
+ if (!handleLValueToRValueConversion(Info, SE, SE->getType(), LVal, Val))
+ return false;
+ }
+
+ // flatten the source
+ SmallVector<APValue> SrcEls;
+ SmallVector<QualType> SrcTypes;
+ if (!flattenAPValue(Info.Ctx, Val, SE->getType(), SrcEls, SrcTypes,
+ UINT_MAX))
+ return Error(E);
+
+ // cast the elements
+ const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+ if (!constructAggregate(Info, FPO, E, Result, E->getType(), SrcEls,
+ SrcTypes))
+ return Error(E);
+
+ return true;
+ }
+ }
+}
+
bool ArrayExprEvaluator::VisitInitListExpr(const InitListExpr *E,
QualType AllocType) {
const ConstantArrayType *CAT = Info.Ctx.getAsConstantArrayType(
@@ -16801,7 +17325,6 @@ bool IntExprEvaluator::VisitCastExpr(const CastExpr *E) {
case CK_NoOp:
case CK_LValueToRValueBitCast:
case CK_HLSLArrayRValue:
- case CK_HLSLElementwiseCast:
return ExprEvaluatorBaseTy::VisitCastExpr(E);
case CK_MemberPointerToBoolean:
@@ -16948,6 +17471,35 @@ bool IntExprEvaluator::VisitCastExpr(const CastExpr *E) {
return Error(E);
return Success(Val.getVectorElt(0), E);
}
+ case CK_HLSLElementwiseCast: {
+ APValue Val;
+
+ if (!Evaluate(Val, Info, SubExpr))
+ return Error(E);
+
+ // must be dealing with a record;
+ if (Val.isLValue()) {
+ LValue LVal;
+ LVal.setFrom(Info.Ctx, Val);
+ if (!handleLValueToRValueConversion(Info, SubExpr, SubExpr->getType(),
+ LVal, Val...
[truncated]
|
Add support to handle these casts in the constant expression evaluator.
Add tests
Closes #125766
Closes #125321