Skip to content

Conversation

spall
Copy link
Contributor

@spall spall commented Oct 22, 2025

Add support to handle these casts in the constant expression evaluator.

  • HLSLAggregateSplatCast
  • HLSLElementwiseCast
  • HLSLArrayRValue

Add tests
Closes #125766
Closes #125321

Constant expression evaluator. Add tests. Fix/Add support for
other minor necessary things.
@llvmbot llvmbot added clang Clang issues not falling into any other category clang:frontend Language frontend issues, e.g. anything involving "Sema" HLSL HLSL Language Support labels Oct 22, 2025
@llvmbot
Copy link
Member

llvmbot commented Oct 22, 2025

@llvm/pr-subscribers-hlsl

Author: Sarah Spall (spall)

Changes

Add support to handle these casts in the constant expression evaluator.

  • HLSLAggregateSplatCast
  • HLSLElementwiseCast
  • HLSLArrayRValue

Add tests


Patch is 28.32 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/164700.diff

4 Files Affected:

  • (modified) clang/lib/AST/ExprConstant.cpp (+586-1)
  • (added) clang/test/SemaHLSL/Types/AggregateSplatConstantExpr.hlsl (+89)
  • (modified) clang/test/SemaHLSL/Types/BuiltinVector/TruncationConstantExpr.hlsl (+21)
  • (added) clang/test/SemaHLSL/Types/ElementwiseCastConstantExpr.hlsl (+76)
diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index 00aaaab957591..5dfb2b3e3491f 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -3828,6 +3828,333 @@ static bool CheckArraySize(EvalInfo &Info, const ConstantArrayType *CAT,
       /*Diag=*/true);
 }
 
+static bool handleScalarCast(EvalInfo &Info, const FPOptions FPO, const Expr *E,
+                             QualType SourceTy, QualType DestTy,
+                             APValue const &Original, APValue &Result) {
+  // boolean must be checked before integer
+  // since IsIntegerType() is true for bool
+  if (SourceTy->isBooleanType()) {
+    if (DestTy->isBooleanType()) {
+      Result = Original;
+      return true;
+    }
+    if (DestTy->isIntegerType() || DestTy->isRealFloatingType()) {
+      bool BoolResult;
+      if (!HandleConversionToBool(Original, BoolResult))
+        return false;
+      uint64_t IntResult = BoolResult;
+      Result = APValue(Info.Ctx.MakeIntValue(IntResult, DestTy));
+      // TODO destty is wrong here if destty is float....
+      // can we use sourcety here?
+    }
+    if (DestTy->isFloatingType()) {
+      APValue Result2 = APValue(APFloat(0.0));
+      if (!HandleIntToFloatCast(Info, E, FPO,
+                                Info.Ctx.getIntTypeForBitwidth(64, true),
+                                Result.getInt(), DestTy, Result2.getFloat()))
+        return false;
+      Result = Result2;
+    }
+    return true;
+  }
+  if (SourceTy->isIntegerType()) {
+    if (DestTy->isRealFloatingType()) {
+      Result = APValue(APFloat(0.0));
+      return HandleIntToFloatCast(Info, E, FPO, SourceTy, Original.getInt(),
+                                  DestTy, Result.getFloat());
+    }
+    if (DestTy->isBooleanType()) {
+      bool BoolResult;
+      if (!HandleConversionToBool(Original, BoolResult))
+        return false;
+      uint64_t IntResult = BoolResult;
+      Result = APValue(Info.Ctx.MakeIntValue(IntResult, DestTy));
+      return true;
+    }
+    if (DestTy->isIntegerType()) {
+      Result = APValue(
+          HandleIntToIntCast(Info, E, DestTy, SourceTy, Original.getInt()));
+      return true;
+    }
+  } else if (SourceTy->isRealFloatingType()) {
+    if (DestTy->isRealFloatingType()) {
+      Result = Original;
+      return HandleFloatToFloatCast(Info, E, SourceTy, DestTy,
+                                    Result.getFloat());
+    }
+    if (DestTy->isBooleanType()) {
+      bool BoolResult;
+      if (!HandleConversionToBool(Original, BoolResult))
+        return false;
+      uint64_t IntResult = BoolResult;
+      Result = APValue(Info.Ctx.MakeIntValue(IntResult, DestTy));
+      return true;
+    }
+    if (DestTy->isIntegerType()) {
+      Result = APValue(APSInt());
+      return HandleFloatToIntCast(Info, E, SourceTy, Original.getFloat(),
+                                  DestTy, Result.getInt());
+    }
+  }
+
+  // Info.FFDiag(E, diag::err_convertvector_constexpr_unsupported_vector_cast)
+  //   << SourceTy << DestTy;
+  return false;
+}
+
+// do the heavy lifting for casting to aggregate types
+// because we have to deal with bitfields specially
+static bool constructAggregate(EvalInfo &Info, const FPOptions FPO,
+                               const Expr *E, APValue &Result,
+                               QualType ResultType,
+                               SmallVectorImpl<APValue> &Elements,
+                               SmallVectorImpl<QualType> &ElTypes) {
+
+  SmallVector<std::tuple<APValue *, QualType, unsigned>> WorkList = {
+      {&Result, ResultType, 0}};
+
+  unsigned ElI = 0;
+  while (!WorkList.empty() && ElI < Elements.size()) {
+    auto [Res, Type, BitWidth] = WorkList.pop_back_val();
+
+    if (Type->isRealFloatingType() || Type->isBooleanType()) {
+      if (!handleScalarCast(Info, FPO, E, ElTypes[ElI], Type, Elements[ElI],
+                            *Res))
+        return false;
+      ElI++;
+      continue;
+    }
+    if (Type->isIntegerType()) {
+      if (!handleScalarCast(Info, FPO, E, ElTypes[ElI], Type, Elements[ElI],
+                            *Res))
+        return false;
+      if (BitWidth > 0) {
+        if (!Res->isInt())
+          return false;
+        APSInt &Int = Res->getInt();
+        unsigned OldBitWidth = Int.getBitWidth();
+        unsigned NewBitWidth = BitWidth;
+        if (NewBitWidth < OldBitWidth)
+          Int = Int.trunc(NewBitWidth).extend(OldBitWidth);
+      }
+      ElI++;
+      continue;
+    }
+    if (Type->isVectorType()) {
+      QualType ElTy = Type->castAs<VectorType>()->getElementType();
+      unsigned NumEl = Type->castAs<VectorType>()->getNumElements();
+      SmallVector<APValue> Vals(NumEl);
+      for (unsigned I = 0; I < NumEl; ++I) {
+        if (!handleScalarCast(Info, FPO, E, ElTypes[ElI], ElTy, Elements[ElI],
+                              Vals[I]))
+          return false;
+        ElI++;
+      }
+      *Res = APValue(Vals.data(), NumEl);
+      continue;
+    }
+    if (Type->isConstantArrayType()) {
+      QualType ElTy = cast<ConstantArrayType>(Info.Ctx.getAsArrayType(Type))
+                          ->getElementType();
+      uint64_t Size =
+          cast<ConstantArrayType>(Info.Ctx.getAsArrayType(Type))->getZExtSize();
+      *Res = APValue(APValue::UninitArray(), Size, Size);
+      for (int64_t I = Size - 1; I > -1; --I) {
+        WorkList.emplace_back(&Res->getArrayInitializedElt(I), ElTy, 0u);
+      }
+      continue;
+    }
+    if (Type->isRecordType()) {
+      const RecordDecl *RD = Type->getAsRecordDecl();
+
+      unsigned NumBases = 0;
+      if (auto *CXXRD = dyn_cast<CXXRecordDecl>(RD))
+        NumBases = CXXRD->getNumBases();
+
+      *Res = APValue(APValue::UninitStruct(), NumBases,
+                     std::distance(RD->field_begin(), RD->field_end()));
+
+      SmallVector<std::tuple<APValue *, QualType, unsigned>> ReverseList;
+      // we need to traverse backwards
+      // Visit the base classes.
+      if (auto *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
+        // todo assert there is only 1 base at most
+        for (size_t I = 0, E = CXXRD->getNumBases(); I != E; ++I) {
+          const CXXBaseSpecifier &BS = CXXRD->bases_begin()[I];
+          ReverseList.emplace_back(&Res->getStructBase(I), BS.getType(), 0u);
+        }
+      }
+
+      // Visit the fields.
+      for (FieldDecl *FD : RD->fields()) {
+        unsigned FDBW = 0;
+        if (FD->isUnnamedBitField())
+          continue;
+        if (FD->isBitField()) {
+          FDBW = FD->getBitWidthValue();
+        }
+
+        ReverseList.emplace_back(&Res->getStructField(FD->getFieldIndex()),
+                                 FD->getType(), FDBW);
+      }
+
+      std::reverse(ReverseList.begin(), ReverseList.end());
+      llvm::append_range(WorkList, ReverseList);
+      continue;
+    }
+    return false;
+  }
+  return true;
+}
+
+static bool handleElementwiseCast(EvalInfo &Info, const Expr *E,
+                                  const FPOptions FPO,
+                                  SmallVectorImpl<APValue> &Elements,
+                                  SmallVectorImpl<QualType> &SrcTypes,
+                                  SmallVectorImpl<QualType> &DestTypes,
+                                  SmallVectorImpl<APValue> &Results) {
+
+  assert((Elements.size() == SrcTypes.size()) &&
+         (Elements.size() == DestTypes.size()));
+
+  for (unsigned I = 0, ESz = Elements.size(); I < ESz; ++I) {
+    APValue Original = Elements[I];
+    QualType SourceTy = SrcTypes[I];
+    QualType DestTy = DestTypes[I];
+
+    if (!handleScalarCast(Info, FPO, E, SourceTy, DestTy, Original, Results[I]))
+      return false;
+  }
+  return true;
+}
+
+static unsigned elementwiseSize(EvalInfo &Info, QualType BaseTy) {
+
+  SmallVector<QualType> WorkList = {BaseTy};
+
+  unsigned Size = 0;
+  while (!WorkList.empty()) {
+    QualType Type = WorkList.pop_back_val();
+    if (Type->isRealFloatingType() || Type->isIntegerType() ||
+        Type->isBooleanType()) {
+      ++Size;
+      continue;
+    }
+    if (Type->isVectorType()) {
+      unsigned NumEl = Type->castAs<VectorType>()->getNumElements();
+      Size += NumEl;
+      continue;
+    }
+    if (Type->isConstantArrayType()) {
+      QualType ElTy = cast<ConstantArrayType>(Info.Ctx.getAsArrayType(Type))
+                          ->getElementType();
+      uint64_t Size =
+          cast<ConstantArrayType>(Info.Ctx.getAsArrayType(Type))->getZExtSize();
+      for (uint64_t I = 0; I < Size; ++I) {
+        WorkList.push_back(ElTy);
+      }
+      continue;
+    }
+    if (Type->isRecordType()) {
+      const RecordDecl *RD = Type->getAsRecordDecl();
+      // const ASTRecordLayout &Layout = Info.Ctx.getASTRecordLayout(RD);
+
+      // Visit the base classes.
+      if (auto *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
+        // todo assert there is only 1 base at most
+        for (size_t I = 0, E = CXXRD->getNumBases(); I != E; ++I) {
+          const CXXBaseSpecifier &BS = CXXRD->bases_begin()[I];
+          WorkList.push_back(BS.getType());
+        }
+      }
+
+      // visit the fields.
+      for (FieldDecl *FD : RD->fields()) {
+        if (FD->isUnnamedBitField())
+          continue;
+        WorkList.push_back(FD->getType());
+      }
+      continue;
+    }
+  }
+  return Size;
+}
+
+static bool flattenAPValue(const ASTContext &Ctx, APValue Value,
+                           QualType BaseTy, SmallVectorImpl<APValue> &Elements,
+                           SmallVectorImpl<QualType> &Types, unsigned Size) {
+
+  SmallVector<std::pair<APValue, QualType>> WorkList = {{Value, BaseTy}};
+  unsigned Populated = 0;
+  while (!WorkList.empty() && Populated < Size) {
+    auto [Work, Type] = WorkList.pop_back_val();
+
+    if (Work.isFloat() || Work.isInt()) { // todo what does this do with bool
+      Elements.push_back(Work);
+      Types.push_back(Type);
+      Populated++;
+      continue;
+    }
+    if (Work.isVector()) {
+      assert(Type->isVectorType() && "Type mismatch.");
+      QualType ElTy = Type->castAs<VectorType>()->getElementType();
+      for (unsigned I = 0; I < Work.getVectorLength() && Populated < Size;
+           I++) {
+        Elements.push_back(Work.getVectorElt(I));
+        Types.push_back(ElTy);
+        Populated++;
+      }
+      continue;
+    }
+    if (Work.isArray()) {
+      assert(Type->isConstantArrayType() && "Type mismatch.");
+      QualType ElTy =
+          cast<ConstantArrayType>(Ctx.getAsArrayType(Type))->getElementType();
+      for (int64_t I = Work.getArraySize() - 1; I > -1; --I) {
+        WorkList.emplace_back(Work.getArrayInitializedElt(I), ElTy);
+      }
+      continue;
+    }
+
+    if (Work.isStruct()) {
+      assert(Type->isRecordType() && "Type mismatch.");
+
+      const RecordDecl *RD = Type->getAsRecordDecl();
+
+      SmallVector<std::pair<APValue, QualType>> ReverseList;
+      // Visit the fields.
+      for (FieldDecl *FD : RD->fields()) {
+        if (FD->isUnnamedBitField())
+          continue;
+        // if (FD->isBitField()) {
+        ReverseList.emplace_back(Work.getStructField(FD->getFieldIndex()),
+                                 FD->getType());
+      }
+
+      std::reverse(ReverseList.begin(), ReverseList.end());
+      llvm::append_range(WorkList, ReverseList);
+
+      // Visit the base classes.
+      if (auto *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
+        if (CXXRD->getNumBases() > 0) {
+          assert(CXXRD->getNumBases() == 1);
+          const CXXBaseSpecifier &BS = CXXRD->bases_begin()[0];
+          const APValue &Base = Work.getStructBase(0);
+
+          // Can happen in error cases.
+          if (!Base.isStruct())
+            return false;
+
+          WorkList.emplace_back(Base, BS.getType());
+        }
+      }
+      continue;
+    }
+    return false;
+  }
+  return true;
+}
+
 namespace {
 /// A handle to a complete object (an object that is not a subobject of
 /// another object).
@@ -8666,6 +8993,25 @@ class ExprEvaluatorBase
     case CK_UserDefinedConversion:
       return StmtVisitorTy::Visit(E->getSubExpr());
 
+    case CK_HLSLArrayRValue: {
+      const Expr *SubExpr = E->getSubExpr();
+      if (!SubExpr->isGLValue()) {
+        APValue Val;
+        if (!Evaluate(Val, Info, SubExpr))
+          return false;
+        return DerivedSuccess(Val, E);
+      }
+
+      LValue LVal;
+      if (!EvaluateLValue(SubExpr, LVal, Info))
+        return false;
+      APValue RVal;
+      // Note, we use the subexpression's type in order to retain cv-qualifiers.
+      if (!handleLValueToRValueConversion(Info, E, SubExpr->getType(), LVal,
+                                          RVal))
+        return false;
+      return DerivedSuccess(RVal, E);
+    }
     case CK_LValueToRValue: {
       LValue LVal;
       if (!EvaluateLValue(E->getSubExpr(), LVal, Info))
@@ -10850,6 +11196,67 @@ bool RecordExprEvaluator::VisitCastExpr(const CastExpr *E) {
     Result = *Value;
     return true;
   }
+  case CK_HLSLAggregateSplatCast: {
+    APValue Val;
+    const Expr *SE = E->getSubExpr();
+
+    if (!Evaluate(Val, Info, SE))
+      return Error(E);
+
+    unsigned NEls = elementwiseSize(Info, E->getType());
+    // flatten the source
+    SmallVector<APValue, 1> SrcEls;
+    SmallVector<QualType, 1> SrcTypes;
+    if (!flattenAPValue(Info.Ctx, Val, SE->getType(), SrcEls, SrcTypes, NEls))
+      return Error(E);
+
+    // check there is only one and splat it
+    assert(SrcEls.size() == 1);
+    SmallVector<APValue> SplatEls(NEls, SrcEls[0]);
+    SmallVector<QualType> SplatType(NEls, SrcTypes[0]);
+
+    APValue Tmp;
+    handleDefaultInitValue(E->getType(), Tmp);
+
+    // cast the elements and construct our struct result
+    const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+    if (!constructAggregate(Info, FPO, E, Result, E->getType(), SplatEls,
+                            SplatType))
+      return Error(E);
+
+    return true;
+  }
+  case CK_HLSLElementwiseCast: {
+    APValue Val;
+    const Expr *SE = E->getSubExpr();
+
+    if (!Evaluate(Val, Info, SE))
+      return Error(E);
+
+    // must be dealing with a record;
+    if (Val.isLValue()) {
+      LValue LVal;
+      LVal.setFrom(Info.Ctx, Val);
+      if (!handleLValueToRValueConversion(Info, SE, SE->getType(), LVal, Val))
+        return false;
+    }
+
+    // flatten the source
+    SmallVector<APValue> SrcEls;
+    SmallVector<QualType> SrcTypes;
+    if (!flattenAPValue(Info.Ctx, Val, SE->getType(), SrcEls, SrcTypes,
+                        UINT_MAX))
+      return Error(E);
+
+    // cast the elements and construct our struct result
+    const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+
+    if (!constructAggregate(Info, FPO, E, Result, E->getType(), SrcEls,
+                            SrcTypes))
+      return Error(E);
+
+    return true;
+  }
   }
 }
 
@@ -11345,6 +11752,58 @@ bool VectorExprEvaluator::VisitCastExpr(const CastExpr *E) {
       Elements.push_back(Val.getVectorElt(I));
     return Success(Elements, E);
   }
+  case CK_HLSLAggregateSplatCast: {
+    APValue Val;
+
+    if (!Evaluate(Val, Info, SE))
+      return Error(E);
+
+    // this cast doesn't handle splatting from scalars when result is a vector
+    SmallVector<APValue, 1> Elements;
+    SmallVector<QualType, 1> DestTypes = {VTy->getElementType()};
+    SmallVector<QualType, 1> SrcTypes;
+    if (!flattenAPValue(Info.Ctx, Val, SETy, Elements, SrcTypes, NElts))
+      return Error(E);
+
+    // check there is only one element and cast and splat it
+    assert(Elements.size() == 1 &&
+           "HLSLAggregateSplatCast RHS must contain one element");
+    const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+    SmallVector<APValue, 1> ResultEls(1);
+    if (!handleElementwiseCast(Info, E, FPO, Elements, SrcTypes, DestTypes,
+                               ResultEls))
+      return Error(E);
+
+    SmallVector<APValue, 4> SplatEls(NElts, ResultEls[0]);
+    return Success(SplatEls, E);
+  }
+  case CK_HLSLElementwiseCast: {
+    APValue Val;
+
+    if (!Evaluate(Val, Info, SE))
+      return Error(E);
+
+    // must be dealing with a record;
+    if (Val.isLValue()) {
+      LValue LVal;
+      LVal.setFrom(Info.Ctx, Val);
+      if (!handleLValueToRValueConversion(Info, SE, SE->getType(), LVal, Val))
+        return false;
+    }
+
+    SmallVector<APValue, 4> Elements;
+    SmallVector<QualType, 4> DestTypes(NElts, VTy->getElementType());
+    SmallVector<QualType, 4> SrcTypes;
+    if (!flattenAPValue(Info.Ctx, Val, SETy, Elements, SrcTypes, NElts))
+      return Error(E);
+    // cast elements
+    const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+    SmallVector<APValue, 4> ResultEls(NElts);
+    if (!handleElementwiseCast(Info, E, FPO, Elements, SrcTypes, DestTypes,
+                               ResultEls))
+      return Error(E);
+    return Success(ResultEls, E);
+  }
   default:
     return ExprEvaluatorBaseTy::VisitCastExpr(E);
   }
@@ -13029,6 +13488,7 @@ namespace {
     bool VisitCallExpr(const CallExpr *E) {
       return handleCallExpr(E, Result, &This);
     }
+    bool VisitCastExpr(const CastExpr *E);
     bool VisitInitListExpr(const InitListExpr *E,
                            QualType AllocType = QualType());
     bool VisitArrayInitLoopExpr(const ArrayInitLoopExpr *E);
@@ -13099,6 +13559,70 @@ static bool MaybeElementDependentArrayFiller(const Expr *FillerExpr) {
   return true;
 }
 
+bool ArrayExprEvaluator::VisitCastExpr(const CastExpr *E) {
+  const Expr *SE = E->getSubExpr();
+
+  switch (E->getCastKind()) {
+  default:
+    return ExprEvaluatorBaseTy::VisitCastExpr(E);
+  case CK_HLSLAggregateSplatCast: {
+    APValue Val;
+
+    if (!Evaluate(Val, Info, SE))
+      return Error(E);
+
+    unsigned NEls = elementwiseSize(Info, E->getType());
+    // flatten the source
+    SmallVector<APValue, 1> SrcEls;
+    SmallVector<QualType, 1> SrcTypes;
+    if (!flattenAPValue(Info.Ctx, Val, SE->getType(), SrcEls, SrcTypes, NEls))
+      return Error(E);
+
+    // check there is only one and splat it
+    assert(SrcEls.size() == 1);
+    SmallVector<APValue> SplatEls(NEls, SrcEls[0]);
+    SmallVector<QualType> SplatType(NEls, SrcTypes[0]);
+
+    // cast the elements
+    const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+    if (!constructAggregate(Info, FPO, E, Result, E->getType(), SplatEls,
+                            SplatType))
+      return Error(E);
+
+    return true;
+  }
+  case CK_HLSLElementwiseCast: {
+    APValue Val;
+
+    if (!Evaluate(Val, Info, SE))
+      return Error(E);
+
+    // must be dealing with a record;
+    if (Val.isLValue()) {
+      LValue LVal;
+      LVal.setFrom(Info.Ctx, Val);
+      if (!handleLValueToRValueConversion(Info, SE, SE->getType(), LVal, Val))
+        return false;
+    }
+
+    // flatten the source
+    SmallVector<APValue> SrcEls;
+    SmallVector<QualType> SrcTypes;
+    if (!flattenAPValue(Info.Ctx, Val, SE->getType(), SrcEls, SrcTypes,
+                        UINT_MAX))
+      return Error(E);
+
+    // cast the elements
+    const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+    if (!constructAggregate(Info, FPO, E, Result, E->getType(), SrcEls,
+                            SrcTypes))
+      return Error(E);
+
+    return true;
+  }
+  }
+}
+
 bool ArrayExprEvaluator::VisitInitListExpr(const InitListExpr *E,
                                            QualType AllocType) {
   const ConstantArrayType *CAT = Info.Ctx.getAsConstantArrayType(
@@ -16801,7 +17325,6 @@ bool IntExprEvaluator::VisitCastExpr(const CastExpr *E) {
   case CK_NoOp:
   case CK_LValueToRValueBitCast:
   case CK_HLSLArrayRValue:
-  case CK_HLSLElementwiseCast:
     return ExprEvaluatorBaseTy::VisitCastExpr(E);
 
   case CK_MemberPointerToBoolean:
@@ -16948,6 +17471,35 @@ bool IntExprEvaluator::VisitCastExpr(const CastExpr *E) {
       return Error(E);
     return Success(Val.getVectorElt(0), E);
   }
+  case CK_HLSLElementwiseCast: {
+    APValue Val;
+
+    if (!Evaluate(Val, Info, SubExpr))
+      return Error(E);
+
+    // must be dealing with a record;
+    if (Val.isLValue()) {
+      LValue LVal;
+      LVal.setFrom(Info.Ctx, Val);
+      if (!handleLValueToRValueConversion(Info, SubExpr, SubExpr->getType(),
+                                          LVal, Val...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Oct 22, 2025

@llvm/pr-subscribers-clang

Author: Sarah Spall (spall)

Changes

Add support to handle these casts in the constant expression evaluator.

  • HLSLAggregateSplatCast
  • HLSLElementwiseCast
  • HLSLArrayRValue

Add tests


Patch is 28.32 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/164700.diff

4 Files Affected:

  • (modified) clang/lib/AST/ExprConstant.cpp (+586-1)
  • (added) clang/test/SemaHLSL/Types/AggregateSplatConstantExpr.hlsl (+89)
  • (modified) clang/test/SemaHLSL/Types/BuiltinVector/TruncationConstantExpr.hlsl (+21)
  • (added) clang/test/SemaHLSL/Types/ElementwiseCastConstantExpr.hlsl (+76)
diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index 00aaaab957591..5dfb2b3e3491f 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -3828,6 +3828,333 @@ static bool CheckArraySize(EvalInfo &Info, const ConstantArrayType *CAT,
       /*Diag=*/true);
 }
 
+static bool handleScalarCast(EvalInfo &Info, const FPOptions FPO, const Expr *E,
+                             QualType SourceTy, QualType DestTy,
+                             APValue const &Original, APValue &Result) {
+  // boolean must be checked before integer
+  // since IsIntegerType() is true for bool
+  if (SourceTy->isBooleanType()) {
+    if (DestTy->isBooleanType()) {
+      Result = Original;
+      return true;
+    }
+    if (DestTy->isIntegerType() || DestTy->isRealFloatingType()) {
+      bool BoolResult;
+      if (!HandleConversionToBool(Original, BoolResult))
+        return false;
+      uint64_t IntResult = BoolResult;
+      Result = APValue(Info.Ctx.MakeIntValue(IntResult, DestTy));
+      // TODO destty is wrong here if destty is float....
+      // can we use sourcety here?
+    }
+    if (DestTy->isFloatingType()) {
+      APValue Result2 = APValue(APFloat(0.0));
+      if (!HandleIntToFloatCast(Info, E, FPO,
+                                Info.Ctx.getIntTypeForBitwidth(64, true),
+                                Result.getInt(), DestTy, Result2.getFloat()))
+        return false;
+      Result = Result2;
+    }
+    return true;
+  }
+  if (SourceTy->isIntegerType()) {
+    if (DestTy->isRealFloatingType()) {
+      Result = APValue(APFloat(0.0));
+      return HandleIntToFloatCast(Info, E, FPO, SourceTy, Original.getInt(),
+                                  DestTy, Result.getFloat());
+    }
+    if (DestTy->isBooleanType()) {
+      bool BoolResult;
+      if (!HandleConversionToBool(Original, BoolResult))
+        return false;
+      uint64_t IntResult = BoolResult;
+      Result = APValue(Info.Ctx.MakeIntValue(IntResult, DestTy));
+      return true;
+    }
+    if (DestTy->isIntegerType()) {
+      Result = APValue(
+          HandleIntToIntCast(Info, E, DestTy, SourceTy, Original.getInt()));
+      return true;
+    }
+  } else if (SourceTy->isRealFloatingType()) {
+    if (DestTy->isRealFloatingType()) {
+      Result = Original;
+      return HandleFloatToFloatCast(Info, E, SourceTy, DestTy,
+                                    Result.getFloat());
+    }
+    if (DestTy->isBooleanType()) {
+      bool BoolResult;
+      if (!HandleConversionToBool(Original, BoolResult))
+        return false;
+      uint64_t IntResult = BoolResult;
+      Result = APValue(Info.Ctx.MakeIntValue(IntResult, DestTy));
+      return true;
+    }
+    if (DestTy->isIntegerType()) {
+      Result = APValue(APSInt());
+      return HandleFloatToIntCast(Info, E, SourceTy, Original.getFloat(),
+                                  DestTy, Result.getInt());
+    }
+  }
+
+  // Info.FFDiag(E, diag::err_convertvector_constexpr_unsupported_vector_cast)
+  //   << SourceTy << DestTy;
+  return false;
+}
+
+// do the heavy lifting for casting to aggregate types
+// because we have to deal with bitfields specially
+static bool constructAggregate(EvalInfo &Info, const FPOptions FPO,
+                               const Expr *E, APValue &Result,
+                               QualType ResultType,
+                               SmallVectorImpl<APValue> &Elements,
+                               SmallVectorImpl<QualType> &ElTypes) {
+
+  SmallVector<std::tuple<APValue *, QualType, unsigned>> WorkList = {
+      {&Result, ResultType, 0}};
+
+  unsigned ElI = 0;
+  while (!WorkList.empty() && ElI < Elements.size()) {
+    auto [Res, Type, BitWidth] = WorkList.pop_back_val();
+
+    if (Type->isRealFloatingType() || Type->isBooleanType()) {
+      if (!handleScalarCast(Info, FPO, E, ElTypes[ElI], Type, Elements[ElI],
+                            *Res))
+        return false;
+      ElI++;
+      continue;
+    }
+    if (Type->isIntegerType()) {
+      if (!handleScalarCast(Info, FPO, E, ElTypes[ElI], Type, Elements[ElI],
+                            *Res))
+        return false;
+      if (BitWidth > 0) {
+        if (!Res->isInt())
+          return false;
+        APSInt &Int = Res->getInt();
+        unsigned OldBitWidth = Int.getBitWidth();
+        unsigned NewBitWidth = BitWidth;
+        if (NewBitWidth < OldBitWidth)
+          Int = Int.trunc(NewBitWidth).extend(OldBitWidth);
+      }
+      ElI++;
+      continue;
+    }
+    if (Type->isVectorType()) {
+      QualType ElTy = Type->castAs<VectorType>()->getElementType();
+      unsigned NumEl = Type->castAs<VectorType>()->getNumElements();
+      SmallVector<APValue> Vals(NumEl);
+      for (unsigned I = 0; I < NumEl; ++I) {
+        if (!handleScalarCast(Info, FPO, E, ElTypes[ElI], ElTy, Elements[ElI],
+                              Vals[I]))
+          return false;
+        ElI++;
+      }
+      *Res = APValue(Vals.data(), NumEl);
+      continue;
+    }
+    if (Type->isConstantArrayType()) {
+      QualType ElTy = cast<ConstantArrayType>(Info.Ctx.getAsArrayType(Type))
+                          ->getElementType();
+      uint64_t Size =
+          cast<ConstantArrayType>(Info.Ctx.getAsArrayType(Type))->getZExtSize();
+      *Res = APValue(APValue::UninitArray(), Size, Size);
+      for (int64_t I = Size - 1; I > -1; --I) {
+        WorkList.emplace_back(&Res->getArrayInitializedElt(I), ElTy, 0u);
+      }
+      continue;
+    }
+    if (Type->isRecordType()) {
+      const RecordDecl *RD = Type->getAsRecordDecl();
+
+      unsigned NumBases = 0;
+      if (auto *CXXRD = dyn_cast<CXXRecordDecl>(RD))
+        NumBases = CXXRD->getNumBases();
+
+      *Res = APValue(APValue::UninitStruct(), NumBases,
+                     std::distance(RD->field_begin(), RD->field_end()));
+
+      SmallVector<std::tuple<APValue *, QualType, unsigned>> ReverseList;
+      // we need to traverse backwards
+      // Visit the base classes.
+      if (auto *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
+        // todo assert there is only 1 base at most
+        for (size_t I = 0, E = CXXRD->getNumBases(); I != E; ++I) {
+          const CXXBaseSpecifier &BS = CXXRD->bases_begin()[I];
+          ReverseList.emplace_back(&Res->getStructBase(I), BS.getType(), 0u);
+        }
+      }
+
+      // Visit the fields.
+      for (FieldDecl *FD : RD->fields()) {
+        unsigned FDBW = 0;
+        if (FD->isUnnamedBitField())
+          continue;
+        if (FD->isBitField()) {
+          FDBW = FD->getBitWidthValue();
+        }
+
+        ReverseList.emplace_back(&Res->getStructField(FD->getFieldIndex()),
+                                 FD->getType(), FDBW);
+      }
+
+      std::reverse(ReverseList.begin(), ReverseList.end());
+      llvm::append_range(WorkList, ReverseList);
+      continue;
+    }
+    return false;
+  }
+  return true;
+}
+
+static bool handleElementwiseCast(EvalInfo &Info, const Expr *E,
+                                  const FPOptions FPO,
+                                  SmallVectorImpl<APValue> &Elements,
+                                  SmallVectorImpl<QualType> &SrcTypes,
+                                  SmallVectorImpl<QualType> &DestTypes,
+                                  SmallVectorImpl<APValue> &Results) {
+
+  assert((Elements.size() == SrcTypes.size()) &&
+         (Elements.size() == DestTypes.size()));
+
+  for (unsigned I = 0, ESz = Elements.size(); I < ESz; ++I) {
+    APValue Original = Elements[I];
+    QualType SourceTy = SrcTypes[I];
+    QualType DestTy = DestTypes[I];
+
+    if (!handleScalarCast(Info, FPO, E, SourceTy, DestTy, Original, Results[I]))
+      return false;
+  }
+  return true;
+}
+
+static unsigned elementwiseSize(EvalInfo &Info, QualType BaseTy) {
+
+  SmallVector<QualType> WorkList = {BaseTy};
+
+  unsigned Size = 0;
+  while (!WorkList.empty()) {
+    QualType Type = WorkList.pop_back_val();
+    if (Type->isRealFloatingType() || Type->isIntegerType() ||
+        Type->isBooleanType()) {
+      ++Size;
+      continue;
+    }
+    if (Type->isVectorType()) {
+      unsigned NumEl = Type->castAs<VectorType>()->getNumElements();
+      Size += NumEl;
+      continue;
+    }
+    if (Type->isConstantArrayType()) {
+      QualType ElTy = cast<ConstantArrayType>(Info.Ctx.getAsArrayType(Type))
+                          ->getElementType();
+      uint64_t Size =
+          cast<ConstantArrayType>(Info.Ctx.getAsArrayType(Type))->getZExtSize();
+      for (uint64_t I = 0; I < Size; ++I) {
+        WorkList.push_back(ElTy);
+      }
+      continue;
+    }
+    if (Type->isRecordType()) {
+      const RecordDecl *RD = Type->getAsRecordDecl();
+      // const ASTRecordLayout &Layout = Info.Ctx.getASTRecordLayout(RD);
+
+      // Visit the base classes.
+      if (auto *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
+        // todo assert there is only 1 base at most
+        for (size_t I = 0, E = CXXRD->getNumBases(); I != E; ++I) {
+          const CXXBaseSpecifier &BS = CXXRD->bases_begin()[I];
+          WorkList.push_back(BS.getType());
+        }
+      }
+
+      // visit the fields.
+      for (FieldDecl *FD : RD->fields()) {
+        if (FD->isUnnamedBitField())
+          continue;
+        WorkList.push_back(FD->getType());
+      }
+      continue;
+    }
+  }
+  return Size;
+}
+
+static bool flattenAPValue(const ASTContext &Ctx, APValue Value,
+                           QualType BaseTy, SmallVectorImpl<APValue> &Elements,
+                           SmallVectorImpl<QualType> &Types, unsigned Size) {
+
+  SmallVector<std::pair<APValue, QualType>> WorkList = {{Value, BaseTy}};
+  unsigned Populated = 0;
+  while (!WorkList.empty() && Populated < Size) {
+    auto [Work, Type] = WorkList.pop_back_val();
+
+    if (Work.isFloat() || Work.isInt()) { // todo what does this do with bool
+      Elements.push_back(Work);
+      Types.push_back(Type);
+      Populated++;
+      continue;
+    }
+    if (Work.isVector()) {
+      assert(Type->isVectorType() && "Type mismatch.");
+      QualType ElTy = Type->castAs<VectorType>()->getElementType();
+      for (unsigned I = 0; I < Work.getVectorLength() && Populated < Size;
+           I++) {
+        Elements.push_back(Work.getVectorElt(I));
+        Types.push_back(ElTy);
+        Populated++;
+      }
+      continue;
+    }
+    if (Work.isArray()) {
+      assert(Type->isConstantArrayType() && "Type mismatch.");
+      QualType ElTy =
+          cast<ConstantArrayType>(Ctx.getAsArrayType(Type))->getElementType();
+      for (int64_t I = Work.getArraySize() - 1; I > -1; --I) {
+        WorkList.emplace_back(Work.getArrayInitializedElt(I), ElTy);
+      }
+      continue;
+    }
+
+    if (Work.isStruct()) {
+      assert(Type->isRecordType() && "Type mismatch.");
+
+      const RecordDecl *RD = Type->getAsRecordDecl();
+
+      SmallVector<std::pair<APValue, QualType>> ReverseList;
+      // Visit the fields.
+      for (FieldDecl *FD : RD->fields()) {
+        if (FD->isUnnamedBitField())
+          continue;
+        // if (FD->isBitField()) {
+        ReverseList.emplace_back(Work.getStructField(FD->getFieldIndex()),
+                                 FD->getType());
+      }
+
+      std::reverse(ReverseList.begin(), ReverseList.end());
+      llvm::append_range(WorkList, ReverseList);
+
+      // Visit the base classes.
+      if (auto *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
+        if (CXXRD->getNumBases() > 0) {
+          assert(CXXRD->getNumBases() == 1);
+          const CXXBaseSpecifier &BS = CXXRD->bases_begin()[0];
+          const APValue &Base = Work.getStructBase(0);
+
+          // Can happen in error cases.
+          if (!Base.isStruct())
+            return false;
+
+          WorkList.emplace_back(Base, BS.getType());
+        }
+      }
+      continue;
+    }
+    return false;
+  }
+  return true;
+}
+
 namespace {
 /// A handle to a complete object (an object that is not a subobject of
 /// another object).
@@ -8666,6 +8993,25 @@ class ExprEvaluatorBase
     case CK_UserDefinedConversion:
       return StmtVisitorTy::Visit(E->getSubExpr());
 
+    case CK_HLSLArrayRValue: {
+      const Expr *SubExpr = E->getSubExpr();
+      if (!SubExpr->isGLValue()) {
+        APValue Val;
+        if (!Evaluate(Val, Info, SubExpr))
+          return false;
+        return DerivedSuccess(Val, E);
+      }
+
+      LValue LVal;
+      if (!EvaluateLValue(SubExpr, LVal, Info))
+        return false;
+      APValue RVal;
+      // Note, we use the subexpression's type in order to retain cv-qualifiers.
+      if (!handleLValueToRValueConversion(Info, E, SubExpr->getType(), LVal,
+                                          RVal))
+        return false;
+      return DerivedSuccess(RVal, E);
+    }
     case CK_LValueToRValue: {
       LValue LVal;
       if (!EvaluateLValue(E->getSubExpr(), LVal, Info))
@@ -10850,6 +11196,67 @@ bool RecordExprEvaluator::VisitCastExpr(const CastExpr *E) {
     Result = *Value;
     return true;
   }
+  case CK_HLSLAggregateSplatCast: {
+    APValue Val;
+    const Expr *SE = E->getSubExpr();
+
+    if (!Evaluate(Val, Info, SE))
+      return Error(E);
+
+    unsigned NEls = elementwiseSize(Info, E->getType());
+    // flatten the source
+    SmallVector<APValue, 1> SrcEls;
+    SmallVector<QualType, 1> SrcTypes;
+    if (!flattenAPValue(Info.Ctx, Val, SE->getType(), SrcEls, SrcTypes, NEls))
+      return Error(E);
+
+    // check there is only one and splat it
+    assert(SrcEls.size() == 1);
+    SmallVector<APValue> SplatEls(NEls, SrcEls[0]);
+    SmallVector<QualType> SplatType(NEls, SrcTypes[0]);
+
+    APValue Tmp;
+    handleDefaultInitValue(E->getType(), Tmp);
+
+    // cast the elements and construct our struct result
+    const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+    if (!constructAggregate(Info, FPO, E, Result, E->getType(), SplatEls,
+                            SplatType))
+      return Error(E);
+
+    return true;
+  }
+  case CK_HLSLElementwiseCast: {
+    APValue Val;
+    const Expr *SE = E->getSubExpr();
+
+    if (!Evaluate(Val, Info, SE))
+      return Error(E);
+
+    // must be dealing with a record;
+    if (Val.isLValue()) {
+      LValue LVal;
+      LVal.setFrom(Info.Ctx, Val);
+      if (!handleLValueToRValueConversion(Info, SE, SE->getType(), LVal, Val))
+        return false;
+    }
+
+    // flatten the source
+    SmallVector<APValue> SrcEls;
+    SmallVector<QualType> SrcTypes;
+    if (!flattenAPValue(Info.Ctx, Val, SE->getType(), SrcEls, SrcTypes,
+                        UINT_MAX))
+      return Error(E);
+
+    // cast the elements and construct our struct result
+    const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+
+    if (!constructAggregate(Info, FPO, E, Result, E->getType(), SrcEls,
+                            SrcTypes))
+      return Error(E);
+
+    return true;
+  }
   }
 }
 
@@ -11345,6 +11752,58 @@ bool VectorExprEvaluator::VisitCastExpr(const CastExpr *E) {
       Elements.push_back(Val.getVectorElt(I));
     return Success(Elements, E);
   }
+  case CK_HLSLAggregateSplatCast: {
+    APValue Val;
+
+    if (!Evaluate(Val, Info, SE))
+      return Error(E);
+
+    // this cast doesn't handle splatting from scalars when result is a vector
+    SmallVector<APValue, 1> Elements;
+    SmallVector<QualType, 1> DestTypes = {VTy->getElementType()};
+    SmallVector<QualType, 1> SrcTypes;
+    if (!flattenAPValue(Info.Ctx, Val, SETy, Elements, SrcTypes, NElts))
+      return Error(E);
+
+    // check there is only one element and cast and splat it
+    assert(Elements.size() == 1 &&
+           "HLSLAggregateSplatCast RHS must contain one element");
+    const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+    SmallVector<APValue, 1> ResultEls(1);
+    if (!handleElementwiseCast(Info, E, FPO, Elements, SrcTypes, DestTypes,
+                               ResultEls))
+      return Error(E);
+
+    SmallVector<APValue, 4> SplatEls(NElts, ResultEls[0]);
+    return Success(SplatEls, E);
+  }
+  case CK_HLSLElementwiseCast: {
+    APValue Val;
+
+    if (!Evaluate(Val, Info, SE))
+      return Error(E);
+
+    // must be dealing with a record;
+    if (Val.isLValue()) {
+      LValue LVal;
+      LVal.setFrom(Info.Ctx, Val);
+      if (!handleLValueToRValueConversion(Info, SE, SE->getType(), LVal, Val))
+        return false;
+    }
+
+    SmallVector<APValue, 4> Elements;
+    SmallVector<QualType, 4> DestTypes(NElts, VTy->getElementType());
+    SmallVector<QualType, 4> SrcTypes;
+    if (!flattenAPValue(Info.Ctx, Val, SETy, Elements, SrcTypes, NElts))
+      return Error(E);
+    // cast elements
+    const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+    SmallVector<APValue, 4> ResultEls(NElts);
+    if (!handleElementwiseCast(Info, E, FPO, Elements, SrcTypes, DestTypes,
+                               ResultEls))
+      return Error(E);
+    return Success(ResultEls, E);
+  }
   default:
     return ExprEvaluatorBaseTy::VisitCastExpr(E);
   }
@@ -13029,6 +13488,7 @@ namespace {
     bool VisitCallExpr(const CallExpr *E) {
       return handleCallExpr(E, Result, &This);
     }
+    bool VisitCastExpr(const CastExpr *E);
     bool VisitInitListExpr(const InitListExpr *E,
                            QualType AllocType = QualType());
     bool VisitArrayInitLoopExpr(const ArrayInitLoopExpr *E);
@@ -13099,6 +13559,70 @@ static bool MaybeElementDependentArrayFiller(const Expr *FillerExpr) {
   return true;
 }
 
+bool ArrayExprEvaluator::VisitCastExpr(const CastExpr *E) {
+  const Expr *SE = E->getSubExpr();
+
+  switch (E->getCastKind()) {
+  default:
+    return ExprEvaluatorBaseTy::VisitCastExpr(E);
+  case CK_HLSLAggregateSplatCast: {
+    APValue Val;
+
+    if (!Evaluate(Val, Info, SE))
+      return Error(E);
+
+    unsigned NEls = elementwiseSize(Info, E->getType());
+    // flatten the source
+    SmallVector<APValue, 1> SrcEls;
+    SmallVector<QualType, 1> SrcTypes;
+    if (!flattenAPValue(Info.Ctx, Val, SE->getType(), SrcEls, SrcTypes, NEls))
+      return Error(E);
+
+    // check there is only one and splat it
+    assert(SrcEls.size() == 1);
+    SmallVector<APValue> SplatEls(NEls, SrcEls[0]);
+    SmallVector<QualType> SplatType(NEls, SrcTypes[0]);
+
+    // cast the elements
+    const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+    if (!constructAggregate(Info, FPO, E, Result, E->getType(), SplatEls,
+                            SplatType))
+      return Error(E);
+
+    return true;
+  }
+  case CK_HLSLElementwiseCast: {
+    APValue Val;
+
+    if (!Evaluate(Val, Info, SE))
+      return Error(E);
+
+    // must be dealing with a record;
+    if (Val.isLValue()) {
+      LValue LVal;
+      LVal.setFrom(Info.Ctx, Val);
+      if (!handleLValueToRValueConversion(Info, SE, SE->getType(), LVal, Val))
+        return false;
+    }
+
+    // flatten the source
+    SmallVector<APValue> SrcEls;
+    SmallVector<QualType> SrcTypes;
+    if (!flattenAPValue(Info.Ctx, Val, SE->getType(), SrcEls, SrcTypes,
+                        UINT_MAX))
+      return Error(E);
+
+    // cast the elements
+    const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+    if (!constructAggregate(Info, FPO, E, Result, E->getType(), SrcEls,
+                            SrcTypes))
+      return Error(E);
+
+    return true;
+  }
+  }
+}
+
 bool ArrayExprEvaluator::VisitInitListExpr(const InitListExpr *E,
                                            QualType AllocType) {
   const ConstantArrayType *CAT = Info.Ctx.getAsConstantArrayType(
@@ -16801,7 +17325,6 @@ bool IntExprEvaluator::VisitCastExpr(const CastExpr *E) {
   case CK_NoOp:
   case CK_LValueToRValueBitCast:
   case CK_HLSLArrayRValue:
-  case CK_HLSLElementwiseCast:
     return ExprEvaluatorBaseTy::VisitCastExpr(E);
 
   case CK_MemberPointerToBoolean:
@@ -16948,6 +17471,35 @@ bool IntExprEvaluator::VisitCastExpr(const CastExpr *E) {
       return Error(E);
     return Success(Val.getVectorElt(0), E);
   }
+  case CK_HLSLElementwiseCast: {
+    APValue Val;
+
+    if (!Evaluate(Val, Info, SubExpr))
+      return Error(E);
+
+    // must be dealing with a record;
+    if (Val.isLValue()) {
+      LValue LVal;
+      LVal.setFrom(Info.Ctx, Val);
+      if (!handleLValueToRValueConversion(Info, SubExpr, SubExpr->getType(),
+                                          LVal, Val...
[truncated]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

clang:frontend Language frontend issues, e.g. anything involving "Sema" clang Clang issues not falling into any other category HLSL HLSL Language Support

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[HLSL] Implement constant expression evaluator for HLSL splat cast [HLSL] Implement Constant expression evaluator for flat casts

2 participants