diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp index 7db7a01998a4a..98a3719ea7cfc 100644 --- a/clang/lib/AST/ExprConstant.cpp +++ b/clang/lib/AST/ExprConstant.cpp @@ -3828,6 +3828,330 @@ static bool CheckArraySize(EvalInfo &Info, const ConstantArrayType *CAT, /*Diag=*/true); } +static bool handleScalarCast(EvalInfo &Info, const FPOptions FPO, const Expr *E, + QualType SourceTy, QualType DestTy, + APValue const &Original, APValue &Result) { + // boolean must be checked before integer + // since IsIntegerType() is true for bool + if (SourceTy->isBooleanType()) { + if (DestTy->isBooleanType()) { + Result = Original; + return true; + } + if (DestTy->isIntegerType() || DestTy->isRealFloatingType()) { + bool BoolResult; + if (!HandleConversionToBool(Original, BoolResult)) + return false; + uint64_t IntResult = BoolResult; + QualType IntType = DestTy->isIntegerType() + ? DestTy + : Info.Ctx.getIntTypeForBitwidth(64, false); + Result = APValue(Info.Ctx.MakeIntValue(IntResult, IntType)); + } + if (DestTy->isFloatingType()) { + APValue Result2 = APValue(APFloat(0.0)); + if (!HandleIntToFloatCast(Info, E, FPO, + Info.Ctx.getIntTypeForBitwidth(64, false), + Result.getInt(), DestTy, Result2.getFloat())) + return false; + Result = Result2; + } + return true; + } + if (SourceTy->isIntegerType()) { + if (DestTy->isRealFloatingType()) { + Result = APValue(APFloat(0.0)); + return HandleIntToFloatCast(Info, E, FPO, SourceTy, Original.getInt(), + DestTy, Result.getFloat()); + } + if (DestTy->isBooleanType()) { + bool BoolResult; + if (!HandleConversionToBool(Original, BoolResult)) + return false; + uint64_t IntResult = BoolResult; + Result = APValue(Info.Ctx.MakeIntValue(IntResult, DestTy)); + return true; + } + if (DestTy->isIntegerType()) { + Result = APValue( + HandleIntToIntCast(Info, E, DestTy, SourceTy, Original.getInt())); + return true; + } + } else if (SourceTy->isRealFloatingType()) { + if (DestTy->isRealFloatingType()) { + Result = Original; + return HandleFloatToFloatCast(Info, E, SourceTy, DestTy, + Result.getFloat()); + } + if (DestTy->isBooleanType()) { + bool BoolResult; + if (!HandleConversionToBool(Original, BoolResult)) + return false; + uint64_t IntResult = BoolResult; + Result = APValue(Info.Ctx.MakeIntValue(IntResult, DestTy)); + return true; + } + if (DestTy->isIntegerType()) { + Result = APValue(APSInt()); + return HandleFloatToIntCast(Info, E, SourceTy, Original.getFloat(), + DestTy, Result.getInt()); + } + } + + return false; +} + +// do the heavy lifting for casting to aggregate types +// because we have to deal with bitfields specially +static bool constructAggregate(EvalInfo &Info, const FPOptions FPO, + const Expr *E, APValue &Result, + QualType ResultType, + SmallVectorImpl &Elements, + SmallVectorImpl &ElTypes) { + + SmallVector> WorkList = { + {&Result, ResultType, 0}}; + + unsigned ElI = 0; + while (!WorkList.empty() && ElI < Elements.size()) { + auto [Res, Type, BitWidth] = WorkList.pop_back_val(); + + if (Type->isRealFloatingType()) { + if (!handleScalarCast(Info, FPO, E, ElTypes[ElI], Type, Elements[ElI], + *Res)) + return false; + ElI++; + continue; + } + if (Type->isIntegerType()) { + if (!handleScalarCast(Info, FPO, E, ElTypes[ElI], Type, Elements[ElI], + *Res)) + return false; + if (BitWidth > 0) { + if (!Res->isInt()) + return false; + APSInt &Int = Res->getInt(); + unsigned OldBitWidth = Int.getBitWidth(); + unsigned NewBitWidth = BitWidth; + if (NewBitWidth < OldBitWidth) + Int = Int.trunc(NewBitWidth).extend(OldBitWidth); + } + ElI++; + continue; + } + if (Type->isVectorType()) { + QualType ElTy = Type->castAs()->getElementType(); + unsigned NumEl = Type->castAs()->getNumElements(); + SmallVector Vals(NumEl); + for (unsigned I = 0; I < NumEl; ++I) { + if (!handleScalarCast(Info, FPO, E, ElTypes[ElI], ElTy, Elements[ElI], + Vals[I])) + return false; + ElI++; + } + *Res = APValue(Vals.data(), NumEl); + continue; + } + if (Type->isConstantArrayType()) { + QualType ElTy = cast(Info.Ctx.getAsArrayType(Type)) + ->getElementType(); + uint64_t Size = + cast(Info.Ctx.getAsArrayType(Type))->getZExtSize(); + *Res = APValue(APValue::UninitArray(), Size, Size); + for (int64_t I = Size - 1; I > -1; --I) { + WorkList.emplace_back(&Res->getArrayInitializedElt(I), ElTy, 0u); + } + continue; + } + if (Type->isRecordType()) { + const RecordDecl *RD = Type->getAsRecordDecl(); + + unsigned NumBases = 0; + if (auto *CXXRD = dyn_cast(RD)) + NumBases = CXXRD->getNumBases(); + + *Res = APValue(APValue::UninitStruct(), NumBases, + std::distance(RD->field_begin(), RD->field_end())); + + SmallVector> ReverseList; + // we need to traverse backwards + // Visit the base classes. + if (auto *CXXRD = dyn_cast(RD)) { + if (CXXRD->getNumBases() > 0) { + assert(CXXRD->getNumBases() == 1); + const CXXBaseSpecifier &BS = CXXRD->bases_begin()[0]; + ReverseList.emplace_back(&Res->getStructBase(0), BS.getType(), 0u); + } + } + + // Visit the fields. + for (FieldDecl *FD : RD->fields()) { + unsigned FDBW = 0; + if (FD->isUnnamedBitField()) + continue; + if (FD->isBitField()) { + FDBW = FD->getBitWidthValue(); + } + + ReverseList.emplace_back(&Res->getStructField(FD->getFieldIndex()), + FD->getType(), FDBW); + } + + std::reverse(ReverseList.begin(), ReverseList.end()); + llvm::append_range(WorkList, ReverseList); + continue; + } + return false; + } + return true; +} + +static bool handleElementwiseCast(EvalInfo &Info, const Expr *E, + const FPOptions FPO, + SmallVectorImpl &Elements, + SmallVectorImpl &SrcTypes, + SmallVectorImpl &DestTypes, + SmallVectorImpl &Results) { + + assert((Elements.size() == SrcTypes.size()) && + (Elements.size() == DestTypes.size())); + + for (unsigned I = 0, ESz = Elements.size(); I < ESz; ++I) { + APValue Original = Elements[I]; + QualType SourceTy = SrcTypes[I]; + QualType DestTy = DestTypes[I]; + + if (!handleScalarCast(Info, FPO, E, SourceTy, DestTy, Original, Results[I])) + return false; + } + return true; +} + +static unsigned elementwiseSize(EvalInfo &Info, QualType BaseTy) { + + SmallVector WorkList = {BaseTy}; + + unsigned Size = 0; + while (!WorkList.empty()) { + QualType Type = WorkList.pop_back_val(); + if (Type->isRealFloatingType() || Type->isIntegerType() || + Type->isBooleanType()) { + ++Size; + continue; + } + if (Type->isVectorType()) { + unsigned NumEl = Type->castAs()->getNumElements(); + Size += NumEl; + continue; + } + if (Type->isConstantArrayType()) { + QualType ElTy = cast(Info.Ctx.getAsArrayType(Type)) + ->getElementType(); + uint64_t Size = + cast(Info.Ctx.getAsArrayType(Type))->getZExtSize(); + for (uint64_t I = 0; I < Size; ++I) { + WorkList.push_back(ElTy); + } + continue; + } + if (Type->isRecordType()) { + const RecordDecl *RD = Type->getAsRecordDecl(); + + // Visit the base classes. + if (auto *CXXRD = dyn_cast(RD)) { + if (CXXRD->getNumBases() > 0) { + assert(CXXRD->getNumBases() == 1); + const CXXBaseSpecifier &BS = CXXRD->bases_begin()[0]; + WorkList.push_back(BS.getType()); + } + } + + // visit the fields. + for (FieldDecl *FD : RD->fields()) { + if (FD->isUnnamedBitField()) + continue; + WorkList.push_back(FD->getType()); + } + continue; + } + } + return Size; +} + +static bool flattenAPValue(const ASTContext &Ctx, APValue Value, + QualType BaseTy, SmallVectorImpl &Elements, + SmallVectorImpl &Types, unsigned Size) { + + SmallVector> WorkList = {{Value, BaseTy}}; + unsigned Populated = 0; + while (!WorkList.empty() && Populated < Size) { + auto [Work, Type] = WorkList.pop_back_val(); + + if (Work.isFloat() || Work.isInt()) { // todo what does this do with bool + Elements.push_back(Work); + Types.push_back(Type); + Populated++; + continue; + } + if (Work.isVector()) { + assert(Type->isVectorType() && "Type mismatch."); + QualType ElTy = Type->castAs()->getElementType(); + for (unsigned I = 0; I < Work.getVectorLength() && Populated < Size; + I++) { + Elements.push_back(Work.getVectorElt(I)); + Types.push_back(ElTy); + Populated++; + } + continue; + } + if (Work.isArray()) { + assert(Type->isConstantArrayType() && "Type mismatch."); + QualType ElTy = + cast(Ctx.getAsArrayType(Type))->getElementType(); + for (int64_t I = Work.getArraySize() - 1; I > -1; --I) { + WorkList.emplace_back(Work.getArrayInitializedElt(I), ElTy); + } + continue; + } + + if (Work.isStruct()) { + assert(Type->isRecordType() && "Type mismatch."); + + const RecordDecl *RD = Type->getAsRecordDecl(); + + SmallVector> ReverseList; + // Visit the fields. + for (FieldDecl *FD : RD->fields()) { + if (FD->isUnnamedBitField()) + continue; + ReverseList.emplace_back(Work.getStructField(FD->getFieldIndex()), + FD->getType()); + } + + std::reverse(ReverseList.begin(), ReverseList.end()); + llvm::append_range(WorkList, ReverseList); + + // Visit the base classes. + if (auto *CXXRD = dyn_cast(RD)) { + if (CXXRD->getNumBases() > 0) { + assert(CXXRD->getNumBases() == 1); + const CXXBaseSpecifier &BS = CXXRD->bases_begin()[0]; + const APValue &Base = Work.getStructBase(0); + + // Can happen in error cases. + if (!Base.isStruct()) + return false; + + WorkList.emplace_back(Base, BS.getType()); + } + } + continue; + } + return false; + } + return true; +} + namespace { /// A handle to a complete object (an object that is not a subobject of /// another object). @@ -8666,6 +8990,25 @@ class ExprEvaluatorBase case CK_UserDefinedConversion: return StmtVisitorTy::Visit(E->getSubExpr()); + case CK_HLSLArrayRValue: { + const Expr *SubExpr = E->getSubExpr(); + if (!SubExpr->isGLValue()) { + APValue Val; + if (!Evaluate(Val, Info, SubExpr)) + return false; + return DerivedSuccess(Val, E); + } + + LValue LVal; + if (!EvaluateLValue(SubExpr, LVal, Info)) + return false; + APValue RVal; + // Note, we use the subexpression's type in order to retain cv-qualifiers. + if (!handleLValueToRValueConversion(Info, E, SubExpr->getType(), LVal, + RVal)) + return false; + return DerivedSuccess(RVal, E); + } case CK_LValueToRValue: { LValue LVal; if (!EvaluateLValue(E->getSubExpr(), LVal, Info)) @@ -10850,6 +11193,67 @@ bool RecordExprEvaluator::VisitCastExpr(const CastExpr *E) { Result = *Value; return true; } + case CK_HLSLAggregateSplatCast: { + APValue Val; + const Expr *SE = E->getSubExpr(); + + if (!Evaluate(Val, Info, SE)) + return Error(E); + + unsigned NEls = elementwiseSize(Info, E->getType()); + // flatten the source + SmallVector SrcEls; + SmallVector SrcTypes; + if (!flattenAPValue(Info.Ctx, Val, SE->getType(), SrcEls, SrcTypes, NEls)) + return Error(E); + + // check there is only one and splat it + assert(SrcEls.size() == 1); + SmallVector SplatEls(NEls, SrcEls[0]); + SmallVector SplatType(NEls, SrcTypes[0]); + + APValue Tmp; + handleDefaultInitValue(E->getType(), Tmp); + + // cast the elements and construct our struct result + const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts()); + if (!constructAggregate(Info, FPO, E, Result, E->getType(), SplatEls, + SplatType)) + return Error(E); + + return true; + } + case CK_HLSLElementwiseCast: { + APValue Val; + const Expr *SE = E->getSubExpr(); + + if (!Evaluate(Val, Info, SE)) + return Error(E); + + // must be dealing with a record; + if (Val.isLValue()) { + LValue LVal; + LVal.setFrom(Info.Ctx, Val); + if (!handleLValueToRValueConversion(Info, SE, SE->getType(), LVal, Val)) + return false; + } + + // flatten the source + SmallVector SrcEls; + SmallVector SrcTypes; + if (!flattenAPValue(Info.Ctx, Val, SE->getType(), SrcEls, SrcTypes, + UINT_MAX)) + return Error(E); + + // cast the elements and construct our struct result + const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts()); + + if (!constructAggregate(Info, FPO, E, Result, E->getType(), SrcEls, + SrcTypes)) + return Error(E); + + return true; + } } } @@ -11345,6 +11749,58 @@ bool VectorExprEvaluator::VisitCastExpr(const CastExpr *E) { Elements.push_back(Val.getVectorElt(I)); return Success(Elements, E); } + case CK_HLSLAggregateSplatCast: { + APValue Val; + + if (!Evaluate(Val, Info, SE)) + return Error(E); + + // this cast doesn't handle splatting from scalars when result is a vector + SmallVector Elements; + SmallVector DestTypes = {VTy->getElementType()}; + SmallVector SrcTypes; + if (!flattenAPValue(Info.Ctx, Val, SETy, Elements, SrcTypes, NElts)) + return Error(E); + + // check there is only one element and cast and splat it + assert(Elements.size() == 1 && + "HLSLAggregateSplatCast RHS must contain one element"); + const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts()); + SmallVector ResultEls(1); + if (!handleElementwiseCast(Info, E, FPO, Elements, SrcTypes, DestTypes, + ResultEls)) + return Error(E); + + SmallVector SplatEls(NElts, ResultEls[0]); + return Success(SplatEls, E); + } + case CK_HLSLElementwiseCast: { + APValue Val; + + if (!Evaluate(Val, Info, SE)) + return Error(E); + + // must be dealing with a record; + if (Val.isLValue()) { + LValue LVal; + LVal.setFrom(Info.Ctx, Val); + if (!handleLValueToRValueConversion(Info, SE, SE->getType(), LVal, Val)) + return false; + } + + SmallVector Elements; + SmallVector DestTypes(NElts, VTy->getElementType()); + SmallVector SrcTypes; + if (!flattenAPValue(Info.Ctx, Val, SETy, Elements, SrcTypes, NElts)) + return Error(E); + // cast elements + const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts()); + SmallVector ResultEls(NElts); + if (!handleElementwiseCast(Info, E, FPO, Elements, SrcTypes, DestTypes, + ResultEls)) + return Error(E); + return Success(ResultEls, E); + } default: return ExprEvaluatorBaseTy::VisitCastExpr(E); } @@ -13043,6 +13499,7 @@ namespace { bool VisitCallExpr(const CallExpr *E) { return handleCallExpr(E, Result, &This); } + bool VisitCastExpr(const CastExpr *E); bool VisitInitListExpr(const InitListExpr *E, QualType AllocType = QualType()); bool VisitArrayInitLoopExpr(const ArrayInitLoopExpr *E); @@ -13113,6 +13570,70 @@ static bool MaybeElementDependentArrayFiller(const Expr *FillerExpr) { return true; } +bool ArrayExprEvaluator::VisitCastExpr(const CastExpr *E) { + const Expr *SE = E->getSubExpr(); + + switch (E->getCastKind()) { + default: + return ExprEvaluatorBaseTy::VisitCastExpr(E); + case CK_HLSLAggregateSplatCast: { + APValue Val; + + if (!Evaluate(Val, Info, SE)) + return Error(E); + + unsigned NEls = elementwiseSize(Info, E->getType()); + // flatten the source + SmallVector SrcEls; + SmallVector SrcTypes; + if (!flattenAPValue(Info.Ctx, Val, SE->getType(), SrcEls, SrcTypes, NEls)) + return Error(E); + + // check there is only one and splat it + assert(SrcEls.size() == 1); + SmallVector SplatEls(NEls, SrcEls[0]); + SmallVector SplatType(NEls, SrcTypes[0]); + + // cast the elements + const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts()); + if (!constructAggregate(Info, FPO, E, Result, E->getType(), SplatEls, + SplatType)) + return Error(E); + + return true; + } + case CK_HLSLElementwiseCast: { + APValue Val; + + if (!Evaluate(Val, Info, SE)) + return Error(E); + + // must be dealing with a record; + if (Val.isLValue()) { + LValue LVal; + LVal.setFrom(Info.Ctx, Val); + if (!handleLValueToRValueConversion(Info, SE, SE->getType(), LVal, Val)) + return false; + } + + // flatten the source + SmallVector SrcEls; + SmallVector SrcTypes; + if (!flattenAPValue(Info.Ctx, Val, SE->getType(), SrcEls, SrcTypes, + UINT_MAX)) + return Error(E); + + // cast the elements + const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts()); + if (!constructAggregate(Info, FPO, E, Result, E->getType(), SrcEls, + SrcTypes)) + return Error(E); + + return true; + } + } +} + bool ArrayExprEvaluator::VisitInitListExpr(const InitListExpr *E, QualType AllocType) { const ConstantArrayType *CAT = Info.Ctx.getAsConstantArrayType( @@ -16815,7 +17336,6 @@ bool IntExprEvaluator::VisitCastExpr(const CastExpr *E) { case CK_NoOp: case CK_LValueToRValueBitCast: case CK_HLSLArrayRValue: - case CK_HLSLElementwiseCast: return ExprEvaluatorBaseTy::VisitCastExpr(E); case CK_MemberPointerToBoolean: @@ -16962,6 +17482,35 @@ bool IntExprEvaluator::VisitCastExpr(const CastExpr *E) { return Error(E); return Success(Val.getVectorElt(0), E); } + case CK_HLSLElementwiseCast: { + APValue Val; + + if (!Evaluate(Val, Info, SubExpr)) + return Error(E); + + // must be dealing with a record; + if (Val.isLValue()) { + LValue LVal; + LVal.setFrom(Info.Ctx, Val); + if (!handleLValueToRValueConversion(Info, SubExpr, SubExpr->getType(), + LVal, Val)) + return false; + } + + SmallVector Elements; + SmallVector DestTypes(1, DestType); + SmallVector SrcTypes; + if (!flattenAPValue(Info.Ctx, Val, SrcType, Elements, SrcTypes, 1)) + return Error(E); + + // cast our single element + const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts()); + APValue ResultVal; + if (!handleScalarCast(Info, FPO, E, SrcTypes[0], DestTypes[0], Elements[0], + ResultVal)) + return Error(E); + return Success(ResultVal, E); + } } llvm_unreachable("unknown cast resulting in integral value"); @@ -17499,6 +18048,9 @@ bool FloatExprEvaluator::VisitCastExpr(const CastExpr *E) { default: return ExprEvaluatorBaseTy::VisitCastExpr(E); + case CK_HLSLAggregateSplatCast: + llvm_unreachable("invalid cast kind for floating value"); + case CK_IntegralToFloating: { APSInt IntResult; const FPOptions FPO = E->getFPFeaturesInEffect( @@ -17537,6 +18089,36 @@ bool FloatExprEvaluator::VisitCastExpr(const CastExpr *E) { return Error(E); return Success(Val.getVectorElt(0), E); } + case CK_HLSLElementwiseCast: { + APValue Val; + + if (!Evaluate(Val, Info, SubExpr)) + return Error(E); + + // must be dealing with a record; + if (Val.isLValue()) { + LValue LVal; + LVal.setFrom(Info.Ctx, Val); + if (!handleLValueToRValueConversion(Info, SubExpr, SubExpr->getType(), + LVal, Val)) + return false; + } + + SmallVector Elements; + SmallVector DestTypes(1, E->getType()); + SmallVector SrcTypes; + if (!flattenAPValue(Info.Ctx, Val, SubExpr->getType(), Elements, SrcTypes, + 1)) + return Error(E); + + // cast our single element + const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts()); + APValue ResultVal; + if (!handleScalarCast(Info, FPO, E, SrcTypes[0], DestTypes[0], Elements[0], + ResultVal)) + return Error(E); + return Success(ResultVal, E); + } } } diff --git a/clang/test/SemaHLSL/Types/AggregateSplatConstantExpr.hlsl b/clang/test/SemaHLSL/Types/AggregateSplatConstantExpr.hlsl new file mode 100644 index 0000000000000..7df41f24ee0d9 --- /dev/null +++ b/clang/test/SemaHLSL/Types/AggregateSplatConstantExpr.hlsl @@ -0,0 +1,89 @@ +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.6-library -finclude-default-header -fnative-half-type -std=hlsl202x -verify %s + +// expected-no-diagnostics + +struct Base { + double D; + uint64_t2 U; + int16_t I : 5; + uint16_t I2: 5; +}; + +struct R : Base { + int G : 10; + int : 30; + float F; +}; + +struct B1 { + float A; + float B; +}; + +struct B2 : B1 { + int C; + int D; + bool BB; +}; + +// tests for HLSLAggregateSplatCast +export void fn() { + // result type vector + // splat from a vector of size 1 + + constexpr float1 Y = {1.0}; + constexpr int1 A1 = {1}; + constexpr float4 F4 = (float4)Y; + _Static_assert(F4[0] == 1.0, "Woo!"); + _Static_assert(F4[1] == 1.0, "Woo!"); + _Static_assert(F4[2] == 1.0, "Woo!"); + _Static_assert(F4[3] == 1.0, "Woo!"); + + // result type array + // splat from a scalar + constexpr float F = 3.33; + constexpr int B6[6] = (int[6])F; + _Static_assert(B6[0] == 3, "Woo!"); + _Static_assert(B6[1] == 3, "Woo!"); + _Static_assert(B6[2] == 3, "Woo!"); + _Static_assert(B6[3] == 3, "Woo!"); + _Static_assert(B6[4] == 3, "Woo!"); + _Static_assert(B6[5] == 3, "Woo!"); + + // splat from a vector of size 1 + constexpr uint64_t2 A7[2] = (uint64_t2[2])A1; + _Static_assert(A7[0][0] == 1, "Woo!"); + _Static_assert(A7[0][1] == 1, "Woo!"); + _Static_assert(A7[1][0] == 1, "Woo!"); + _Static_assert(A7[1][1] == 1, "Woo!"); + + // result type struct + // splat from a scalar + constexpr double D = 100.6789; + constexpr R SR = (R)D; + _Static_assert(SR.D == 100.6789, "Woo!"); + _Static_assert(SR.U[0] == 100, "Woo!"); + _Static_assert(SR.U[1] == 100, "Woo!"); + _Static_assert(SR.I == 4, "Woo!"); + _Static_assert(SR.I2 == 4, "Woo!"); + _Static_assert(SR.G == 100, "Woo!"); + _Static_assert(SR.F == 100.6789, "Woo!"); + + // splat from a vector of size 1 + constexpr float1 A100 = {1000.1111}; + constexpr B2 SB2 = (B2)A100; + _Static_assert(SB2.A == 1000.1111, "Woo!"); + _Static_assert(SB2.B == 1000.1111, "Woo!"); + _Static_assert(SB2.C == 1000, "Woo!"); + _Static_assert(SB2.D == 1000, "Woo!"); + _Static_assert(SB2.BB == true, "Woo!"); + + // splat from a bool to an int and float etc + constexpr bool B = true; + constexpr B2 SB3 = (B2)B; + _Static_assert(SB3.A == 1.0, "Woo!"); + _Static_assert(SB3.B == 1.0, "Woo!"); + _Static_assert(SB3.C == 1, "Woo!"); + _Static_assert(SB3.D == 1, "Woo!"); + _Static_assert(SB3.BB == true, "Woo!"); +} diff --git a/clang/test/SemaHLSL/Types/ElementwiseCastConstantExpr.hlsl b/clang/test/SemaHLSL/Types/ElementwiseCastConstantExpr.hlsl new file mode 100644 index 0000000000000..1689fb091b624 --- /dev/null +++ b/clang/test/SemaHLSL/Types/ElementwiseCastConstantExpr.hlsl @@ -0,0 +1,91 @@ +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.6-library -finclude-default-header -fnative-half-type -std=hlsl202x -verify %s + +// expected-no-diagnostics + +struct Base { + double D; + uint64_t2 U; + int16_t I : 5; + uint16_t I2: 5; +}; + +struct R : Base { + int G : 10; + int : 30; + float F; +}; + +struct B1 { + float A; + float B; +}; + +struct B2 : B1 { + int C; + int D; + bool BB; +}; + +export void fn() { + _Static_assert(((float4)(int[6]){1,2,3,4,5,6}).x == 1.0, "Woo!"); + + // This compiling successfully verifies that the array constant expression + // gets truncated to a float at compile time for instantiation via the + // flat cast + _Static_assert(((int)(int[2]){1,2}) == 1, "Woo!"); + + // truncation tests + // result type int + // truncate from struct + constexpr B1 SB1 = {1.0, 3.0}; + constexpr float Blah = SB1.A; + constexpr int X = (int)SB1; + _Static_assert(X == 1, "Woo!"); + + // result type float + // truncate from array + constexpr B1 Arr[2] = {4.0, 3.0, 2.0, 1.0}; + constexpr float F = (float)Arr; + _Static_assert(F == 4.0, "Woo!"); + + // result type vector + // truncate from array of vector + constexpr int2 Arr2[2] = {5,6,7,8}; + constexpr int2 I2 = (int2)Arr2; + _Static_assert(I2[0] == 5, "Woo!"); + _Static_assert(I2[1] == 6, "Woo!"); + + // lhs and rhs are same "size" tests + + // result type vector from array + constexpr int4 I4 = (int4)Arr; + _Static_assert(I4[0] == 4, "Woo!"); + _Static_assert(I4[1] == 3, "Woo!"); + _Static_assert(I4[2] == 2, "Woo!"); + _Static_assert(I4[3] == 1, "Woo!"); + + // result type array from vector + constexpr double3 D3 = {100.11, 200.11, 300.11}; + constexpr float FArr[3] = (float[3])D3; + _Static_assert(FArr[0] == 100.11, "Woo!"); + _Static_assert(FArr[1] == 200.11, "Woo!"); + _Static_assert(FArr[2] == 300.11, "Woo!"); + + // result type struct from struct + constexpr B2 SB2 = {5.5, 6.5, 1000, 5000, false}; + constexpr Base SB = (Base)SB2; + _Static_assert(SB.D == 5.5, "Woo!"); + _Static_assert(SB.U[0] == 6, "Woo!"); + _Static_assert(SB.U[1] == 1000, "Woo!"); + _Static_assert(SB.I == 8, "Woo!"); + _Static_assert(SB.I2 == 0, "Woo!"); + + // Make sure we read bitfields correctly + constexpr Base BB = {222.22, {100, 200}, -2, 7}; + constexpr int Arr3[5] = (int[5])BB; + _Static_assert(Arr3[0] == 222, "Woo!"); + _Static_assert(Arr3[1] == 100, "Woo!"); + _Static_assert(Arr3[2] == 200, "Woo!"); + _Static_assert(Arr3[3] == -2, "Woo!"); + _Static_assert(Arr3[4] == 7, "Woo!"); +}