Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[HLSL] Implement array temporary support #79382

Merged
merged 9 commits into from
Apr 1, 2024

Conversation

llvm-beanz
Copy link
Collaborator

@llvm-beanz llvm-beanz commented Jan 24, 2024

HLSL constant sized array function parameters do not decay to pointers.
Instead constant sized array types are preserved as unique types for
overload resolution, template instantiation and name mangling.

This implements the change by adding a new ArrayParameterType which
represents a non-decaying ConstantArrayType. The new type behaves the
same as ConstantArrayType except that it does not decay to a pointer.

Values of ConstantArrayType in HLSL decay during overload resolution
via a new HLSLArrayRValue cast to ArrayParameterType.

ArrayParamterType values are passed indirectly by-value to functions
in IR generation resulting in callee generated memcpy instructions.

The behavior of HLSL function calls is documented in the draft language specification under the Expr.Post.Call heading.

Additionally the design of this implementation approach is documented in Clang's documentation

Resolves #70123

@llvmbot llvmbot added clang Clang issues not falling into any other category clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:modules C++20 modules and Clang Header Modules clang:codegen clang:static analyzer HLSL HLSL Language Support labels Jan 24, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Jan 24, 2024

@llvm/pr-subscribers-debuginfo
@llvm/pr-subscribers-clang
@llvm/pr-subscribers-clang-static-analyzer-1

@llvm/pr-subscribers-clang-modules

Author: Chris B (llvm-beanz)

Changes

In HLSL function parameters are passed by value, including array parameters. This change introduces a new AST node to represent array temporary expressions. They behave as lvalues to temporary arrays and decay to pointers for overload resolution and code generation.

The behavior of HLSL function calls is documented in the draft language specification under the Expr.Post.Call heading.

Additionally the design of this implementation approach is documented in Clang's documentation


Full diff: https://github.com/llvm/llvm-project/pull/79382.diff

20 Files Affected:

  • (modified) clang/include/clang/AST/Expr.h (+38)
  • (modified) clang/include/clang/AST/RecursiveASTVisitor.h (+2)
  • (modified) clang/include/clang/Basic/StmtNodes.td (+3)
  • (modified) clang/lib/AST/Expr.cpp (+11)
  • (modified) clang/lib/AST/ExprClassification.cpp (+1)
  • (modified) clang/lib/AST/ExprConstant.cpp (+1)
  • (modified) clang/lib/AST/ItaniumMangle.cpp (+4)
  • (modified) clang/lib/AST/StmtPrinter.cpp (+4)
  • (modified) clang/lib/AST/StmtProfile.cpp (+9)
  • (modified) clang/lib/CodeGen/CGExpr.cpp (+1)
  • (modified) clang/lib/CodeGen/CGExprAgg.cpp (+6)
  • (modified) clang/lib/Sema/SemaExceptionSpec.cpp (+1)
  • (modified) clang/lib/Sema/SemaInit.cpp (+5)
  • (modified) clang/lib/Sema/TreeTransform.h (+13)
  • (modified) clang/lib/Serialization/ASTReaderStmt.cpp (+8)
  • (modified) clang/lib/Serialization/ASTWriterStmt.cpp (+8)
  • (modified) clang/lib/StaticAnalyzer/Core/ExprEngine.cpp (+2-1)
  • (added) clang/test/CodeGenHLSL/ArrayTemporary.hlsl (+49)
  • (added) clang/test/SemaHLSL/ArrayTemporary.hlsl (+46)
  • (modified) clang/tools/libclang/CXCursor.cpp (+1)
diff --git a/clang/include/clang/AST/Expr.h b/clang/include/clang/AST/Expr.h
index 59f0aee2c0cedd..c28747e7a3796f 100644
--- a/clang/include/clang/AST/Expr.h
+++ b/clang/include/clang/AST/Expr.h
@@ -6651,6 +6651,44 @@ class RecoveryExpr final : public Expr,
   friend class ASTStmtWriter;
 };
 
+/// HLSLArrayTemporaryExpr - In HLSL, default parameter passing is by value
+/// including for arrays. This AST node represents a materialized temporary of a
+/// constant size arrray.
+class HLSLArrayTemporaryExpr : public Expr {
+  Expr *SourceExpr;
+
+  HLSLArrayTemporaryExpr(Expr *S)
+      : Expr(HLSLArrayTemporaryExprClass, S->getType(), VK_LValue, OK_Ordinary),
+        SourceExpr(S) {}
+
+  HLSLArrayTemporaryExpr(EmptyShell Empty)
+      : Expr(HLSLArrayTemporaryExprClass, Empty), SourceExpr(nullptr) {}
+
+public:
+  static HLSLArrayTemporaryExpr *Create(const ASTContext &Ctx, Expr *S);
+  static HLSLArrayTemporaryExpr *CreateEmpty(const ASTContext &Ctx);
+
+  const Expr *getSourceExpr() const { return SourceExpr; }
+  Expr *getSourceExpr() { return SourceExpr; }
+  void setSourceExpr(Expr *S) { SourceExpr = S; }
+
+  SourceLocation getBeginLoc() const { return SourceExpr->getBeginLoc(); }
+
+  SourceLocation getEndLoc() const { return SourceExpr->getEndLoc(); }
+
+  static bool classof(const Stmt *T) {
+    return T->getStmtClass() == HLSLArrayTemporaryExprClass;
+  }
+
+  // Iterators
+  child_range children() {
+    return child_range(child_iterator(), child_iterator());
+  }
+  const_child_range children() const {
+    return const_child_range(const_child_iterator(), const_child_iterator());
+  }
+};
+
 } // end namespace clang
 
 #endif // LLVM_CLANG_AST_EXPR_H
diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h
index 2aee6a947141b6..f7f8bbbb05cf20 100644
--- a/clang/include/clang/AST/RecursiveASTVisitor.h
+++ b/clang/include/clang/AST/RecursiveASTVisitor.h
@@ -3171,6 +3171,8 @@ DEF_TRAVERSE_STMT(OMPTargetParallelGenericLoopDirective,
 DEF_TRAVERSE_STMT(OMPErrorDirective,
                   { TRY_TO(TraverseOMPExecutableDirective(S)); })
 
+DEF_TRAVERSE_STMT(HLSLArrayTemporaryExpr, {})
+
 // OpenMP clauses.
 template <typename Derived>
 bool RecursiveASTVisitor<Derived>::TraverseOMPClause(OMPClause *C) {
diff --git a/clang/include/clang/Basic/StmtNodes.td b/clang/include/clang/Basic/StmtNodes.td
index cec301dfca2817..3d14d33a14b878 100644
--- a/clang/include/clang/Basic/StmtNodes.td
+++ b/clang/include/clang/Basic/StmtNodes.td
@@ -295,3 +295,6 @@ def OMPTargetTeamsGenericLoopDirective : StmtNode<OMPLoopDirective>;
 def OMPParallelGenericLoopDirective : StmtNode<OMPLoopDirective>;
 def OMPTargetParallelGenericLoopDirective : StmtNode<OMPLoopDirective>;
 def OMPErrorDirective : StmtNode<OMPExecutableDirective>;
+
+// HLSL Extensions
+def HLSLArrayTemporaryExpr : StmtNode<Expr>;
diff --git a/clang/lib/AST/Expr.cpp b/clang/lib/AST/Expr.cpp
index f1efa98e175edf..5903e2763dbe59 100644
--- a/clang/lib/AST/Expr.cpp
+++ b/clang/lib/AST/Expr.cpp
@@ -3569,6 +3569,7 @@ bool Expr::HasSideEffects(const ASTContext &Ctx,
   case ConceptSpecializationExprClass:
   case RequiresExprClass:
   case SYCLUniqueStableNameExprClass:
+  case HLSLArrayTemporaryExprClass:
     // These never have a side-effect.
     return false;
 
@@ -5227,3 +5228,13 @@ OMPIteratorExpr *OMPIteratorExpr::CreateEmpty(const ASTContext &Context,
       alignof(OMPIteratorExpr));
   return new (Mem) OMPIteratorExpr(EmptyShell(), NumIterators);
 }
+
+HLSLArrayTemporaryExpr *
+HLSLArrayTemporaryExpr::Create(const ASTContext &Ctx, Expr *Base) {
+  return new (Ctx) HLSLArrayTemporaryExpr(Base);
+}
+
+HLSLArrayTemporaryExpr *
+HLSLArrayTemporaryExpr::CreateEmpty(const ASTContext &Ctx) {
+  return new (Ctx) HLSLArrayTemporaryExpr(EmptyShell());
+}
diff --git a/clang/lib/AST/ExprClassification.cpp b/clang/lib/AST/ExprClassification.cpp
index ffa7c6802ea6e1..f07f2756c3f81d 100644
--- a/clang/lib/AST/ExprClassification.cpp
+++ b/clang/lib/AST/ExprClassification.cpp
@@ -148,6 +148,7 @@ static Cl::Kinds ClassifyInternal(ASTContext &Ctx, const Expr *E) {
   case Expr::OMPArraySectionExprClass:
   case Expr::OMPArrayShapingExprClass:
   case Expr::OMPIteratorExprClass:
+  case Expr::HLSLArrayTemporaryExprClass:
     return Cl::CL_LValue;
 
     // C99 6.5.2.5p5 says that compound literals are lvalues.
diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index f1d07d022b2584..ef474983b4b7dd 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -16044,6 +16044,7 @@ static ICEDiag CheckICE(const Expr* E, const ASTContext &Ctx) {
   case Expr::CoyieldExprClass:
   case Expr::SYCLUniqueStableNameExprClass:
   case Expr::CXXParenListInitExprClass:
+  case Expr::HLSLArrayTemporaryExprClass:
     return ICEDiag(IK_NotICE, E->getBeginLoc());
 
   case Expr::InitListExprClass: {
diff --git a/clang/lib/AST/ItaniumMangle.cpp b/clang/lib/AST/ItaniumMangle.cpp
index 40b1e086ddd0c6..7da53e71719c14 100644
--- a/clang/lib/AST/ItaniumMangle.cpp
+++ b/clang/lib/AST/ItaniumMangle.cpp
@@ -4701,6 +4701,10 @@ void CXXNameMangler::mangleExpression(const Expr *E, unsigned Arity,
     E = cast<ConstantExpr>(E)->getSubExpr();
     goto recurse;
 
+  case Expr::HLSLArrayTemporaryExprClass:
+    E = cast<HLSLArrayTemporaryExpr>(E)->getSourceExpr();
+    goto recurse;
+
   // FIXME: invent manglings for all these.
   case Expr::BlockExprClass:
   case Expr::ChooseExprClass:
diff --git a/clang/lib/AST/StmtPrinter.cpp b/clang/lib/AST/StmtPrinter.cpp
index 9b5bf436c13be9..dd2a4f8be403da 100644
--- a/clang/lib/AST/StmtPrinter.cpp
+++ b/clang/lib/AST/StmtPrinter.cpp
@@ -2749,6 +2749,10 @@ void StmtPrinter::VisitAsTypeExpr(AsTypeExpr *Node) {
   OS << ")";
 }
 
+void StmtPrinter::VisitHLSLArrayTemporaryExpr(HLSLArrayTemporaryExpr *Node) {
+  PrintExpr(Node->getSourceExpr());
+}
+
 //===----------------------------------------------------------------------===//
 // Stmt method implementations
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp
index dd0838edab7b3f..0c5acb596f1179 100644
--- a/clang/lib/AST/StmtProfile.cpp
+++ b/clang/lib/AST/StmtProfile.cpp
@@ -2433,6 +2433,15 @@ void StmtProfiler::VisitTemplateArgument(const TemplateArgument &Arg) {
   }
 }
 
+//===----------------------------------------------------------------------===//
+// HLSL AST Nodes
+//===----------------------------------------------------------------------===//
+
+void StmtProfiler::VisitHLSLArrayTemporaryExpr(
+    const HLSLArrayTemporaryExpr *S) {
+  VisitExpr(S);
+}
+
 void Stmt::Profile(llvm::FoldingSetNodeID &ID, const ASTContext &Context,
                    bool Canonical, bool ProfileLambdaExpr) const {
   StmtProfilerWithPointers Profiler(ID, Context, Canonical, ProfileLambdaExpr);
diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index c5f6b6d3a99f0b..cd7ad03208654a 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -1590,6 +1590,7 @@ LValue CodeGenFunction::EmitLValueHelper(const Expr *E,
   case Expr::CXXUuidofExprClass:
     return EmitCXXUuidofLValue(cast<CXXUuidofExpr>(E));
   case Expr::LambdaExprClass:
+  case Expr::HLSLArrayTemporaryExprClass:
     return EmitAggExprToLValue(E);
 
   case Expr::ExprWithCleanupsClass: {
diff --git a/clang/lib/CodeGen/CGExprAgg.cpp b/clang/lib/CodeGen/CGExprAgg.cpp
index 810b28f25fa18b..57471da27133d9 100644
--- a/clang/lib/CodeGen/CGExprAgg.cpp
+++ b/clang/lib/CodeGen/CGExprAgg.cpp
@@ -235,6 +235,8 @@ class AggExprEmitter : public StmtVisitor<AggExprEmitter> {
     RValue Res = CGF.EmitAtomicExpr(E);
     EmitFinalDestCopy(E->getType(), Res);
   }
+
+  void VisitHLSLArrayTemporaryExpr(HLSLArrayTemporaryExpr *E);
 };
 }  // end anonymous namespace.
 
@@ -1923,6 +1925,10 @@ void AggExprEmitter::VisitDesignatedInitUpdateExpr(DesignatedInitUpdateExpr *E)
   VisitInitListExpr(E->getUpdater());
 }
 
+void AggExprEmitter::VisitHLSLArrayTemporaryExpr(HLSLArrayTemporaryExpr *E) {
+  Visit(E->getSourceExpr());
+}
+
 //===----------------------------------------------------------------------===//
 //                        Entry Points into this File
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaExceptionSpec.cpp b/clang/lib/Sema/SemaExceptionSpec.cpp
index 75730ea888afb4..0e4c11bf5b2614 100644
--- a/clang/lib/Sema/SemaExceptionSpec.cpp
+++ b/clang/lib/Sema/SemaExceptionSpec.cpp
@@ -1414,6 +1414,7 @@ CanThrowResult Sema::canThrow(const Stmt *S) {
   case Expr::SourceLocExprClass:
   case Expr::ConceptSpecializationExprClass:
   case Expr::RequiresExprClass:
+  case Expr::HLSLArrayTemporaryExprClass:
     // These expressions can never throw.
     return CT_Cannot;
 
diff --git a/clang/lib/Sema/SemaInit.cpp b/clang/lib/Sema/SemaInit.cpp
index 457fa377355a97..54f990f50d8576 100644
--- a/clang/lib/Sema/SemaInit.cpp
+++ b/clang/lib/Sema/SemaInit.cpp
@@ -10524,6 +10524,11 @@ Sema::PerformCopyInitialization(const InitializedEntity &Entity,
   Expr *InitE = Init.get();
   assert(InitE && "No initialization expression?");
 
+  if (LangOpts.HLSL)
+    if (auto AdjTy = dyn_cast<DecayedType>(Entity.getType()))
+      if (AdjTy->getOriginalType()->isConstantArrayType())
+        InitE = HLSLArrayTemporaryExpr::Create(getASTContext(), InitE);
+
   if (EqualLoc.isInvalid())
     EqualLoc = InitE->getBeginLoc();
 
diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h
index e55e752b9cc354..3772e9a61f5f97 100644
--- a/clang/lib/Sema/TreeTransform.h
+++ b/clang/lib/Sema/TreeTransform.h
@@ -15461,6 +15461,19 @@ TreeTransform<Derived>::TransformCapturedStmt(CapturedStmt *S) {
   return getSema().ActOnCapturedRegionEnd(Body.get());
 }
 
+template <typename Derived>
+ExprResult TreeTransform<Derived>::TransformHLSLArrayTemporaryExpr(
+    HLSLArrayTemporaryExpr *E) {
+  ExprResult SrcExpr = getDerived().TransformExpr(E->getSourceExpr());
+  if (SrcExpr.isInvalid())
+    return ExprError();
+
+  if (!getDerived().AlwaysRebuild() && SrcExpr.get() == E->getSourceExpr())
+    return E;
+
+  return HLSLArrayTemporaryExpr::Create(getSema().Context, SrcExpr.get());
+}
+
 } // end namespace clang
 
 #endif // LLVM_CLANG_LIB_SEMA_TREETRANSFORM_H
diff --git a/clang/lib/Serialization/ASTReaderStmt.cpp b/clang/lib/Serialization/ASTReaderStmt.cpp
index 85ecfa1a1a0bf2..2203a0add1f648 100644
--- a/clang/lib/Serialization/ASTReaderStmt.cpp
+++ b/clang/lib/Serialization/ASTReaderStmt.cpp
@@ -2776,6 +2776,14 @@ void ASTStmtReader::VisitOMPTargetParallelGenericLoopDirective(
   VisitOMPLoopDirective(D);
 }
 
+//===----------------------------------------------------------------------===//
+// HLSL AST Nodes
+//===----------------------------------------------------------------------===//
+
+void ASTStmtReader::VisitHLSLArrayTemporaryExpr(HLSLArrayTemporaryExpr *S) {
+  VisitExpr(S);
+}
+
 //===----------------------------------------------------------------------===//
 // ASTReader Implementation
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/Serialization/ASTWriterStmt.cpp b/clang/lib/Serialization/ASTWriterStmt.cpp
index e5836f5dcbe955..9951426e2f2d81 100644
--- a/clang/lib/Serialization/ASTWriterStmt.cpp
+++ b/clang/lib/Serialization/ASTWriterStmt.cpp
@@ -2825,6 +2825,14 @@ void ASTStmtWriter::VisitOMPTargetParallelGenericLoopDirective(
   Code = serialization::STMT_OMP_TARGET_PARALLEL_GENERIC_LOOP_DIRECTIVE;
 }
 
+//===----------------------------------------------------------------------===//
+// HLSL AST Nodes
+//===----------------------------------------------------------------------===//
+
+void ASTStmtWriter::VisitHLSLArrayTemporaryExpr(HLSLArrayTemporaryExpr *S) {
+  VisitExpr(S);
+}
+
 //===----------------------------------------------------------------------===//
 // ASTWriter Implementation
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp b/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp
index 24e91a22fd6884..39cdec788648fb 100644
--- a/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp
+++ b/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp
@@ -1821,7 +1821,8 @@ void ExprEngine::Visit(const Stmt *S, ExplodedNode *Pred,
     case Stmt::OMPTargetParallelGenericLoopDirectiveClass:
     case Stmt::CapturedStmtClass:
     case Stmt::OMPUnrollDirectiveClass:
-    case Stmt::OMPMetaDirectiveClass: {
+    case Stmt::OMPMetaDirectiveClass: 
+    case Stmt::HLSLArrayTemporaryExprClass: {
       const ExplodedNode *node = Bldr.generateSink(S, Pred, Pred->getState());
       Engine.addAbortedBlock(node, currBldrCtx->getBlock());
       break;
diff --git a/clang/test/CodeGenHLSL/ArrayTemporary.hlsl b/clang/test/CodeGenHLSL/ArrayTemporary.hlsl
new file mode 100644
index 00000000000000..411e3182f45809
--- /dev/null
+++ b/clang/test/CodeGenHLSL/ArrayTemporary.hlsl
@@ -0,0 +1,49 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -emit-llvm -disable-llvm-passes -o - %s | Filecheck %s
+
+void fn(float x[2]) { }
+
+// CHECK-LABEL: define void {{.*}}call{{.*}}
+// CHECK: [[Arr:%.*]] = alloca [2 x float]
+// CHECK: [[Tmp:%.*]] = alloca [2 x float]
+// CHECK: call void @llvm.memset.p0.i32(ptr align 4 [[Arr]], i8 0, i32 8, i1 false)
+// CHECK: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[Tmp]], ptr align 4 [[Arr]], i32 8, i1 false)
+// CHECK: [[Decay:%.*]] = getelementptr inbounds [2 x float], ptr [[Tmp]], i32 0, i32 0
+// CHECK: call void {{.*}}fn{{.*}}(ptr noundef [[Decay]])
+void call() {
+  float Arr[2] = {0, 0};
+  fn(Arr);
+}
+
+struct Obj {
+  float V;
+  int X;
+};
+
+void fn2(Obj O[4]) { }
+
+// CHECK-LABEL: define void {{.*}}call2{{.*}}
+// CHECK: [[Arr:%.*]] = alloca [4 x %struct.Obj]
+// CHECK: [[Tmp:%.*]] = alloca [4 x %struct.Obj]
+// CHECK: call void @llvm.memset.p0.i32(ptr align 4 [[Arr]], i8 0, i32 32, i1 false)
+// CHECK: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[Tmp]], ptr align 4 [[Arr]], i32 32, i1 false)
+// CHECK: [[Decay:%.*]] = getelementptr inbounds [4 x %struct.Obj], ptr [[Tmp]], i32 0, i32 0
+// CHECK: call void {{.*}}fn2{{.*}}(ptr noundef [[Decay]])
+void call2() {
+  Obj Arr[4] = {};
+  fn2(Arr);
+}
+
+
+void fn3(float x[2][2]) { }
+
+// CHECK-LABEL: define void {{.*}}call3{{.*}}
+// CHECK: [[Arr:%.*]] = alloca [2 x [2 x float]]
+// CHECK: [[Tmp:%.*]] = alloca [2 x [2 x float]]
+// CHECK: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[Arr]], ptr align 4 {{.*}}, i32 16, i1 false)
+// CHECK: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[Tmp]], ptr align 4 [[Arr]], i32 16, i1 false)
+// CHECK: [[Decay:%.*]] = getelementptr inbounds [2 x [2 x float]], ptr [[Tmp]], i32 0, i32 0
+// CHECK: call void {{.*}}fn3{{.*}}(ptr noundef [[Decay]])
+void call3() {
+  float Arr[2][2] = {{0, 0}, {1,1}};
+  fn3(Arr);
+}
diff --git a/clang/test/SemaHLSL/ArrayTemporary.hlsl b/clang/test/SemaHLSL/ArrayTemporary.hlsl
new file mode 100644
index 00000000000000..32f4fc0f9abdca
--- /dev/null
+++ b/clang/test/SemaHLSL/ArrayTemporary.hlsl
@@ -0,0 +1,46 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -ast-dump %s | Filecheck %s
+
+void fn(float x[2]) { }
+
+// CHECK: CallExpr {{.*}} 'void'
+// CHECK-NEXT: ImplicitCastExpr {{.*}} 'void (*)(float *)' <FunctionToPointerDecay>
+// CHECK-NEXT: DeclRefExpr {{.*}} 'void (float *)' lvalue Function {{.*}} 'fn' 'void (float *)'
+// CHECK-NEXT: ImplicitCastExpr {{.*}} 'float *' <ArrayToPointerDecay>
+// CHECK-NEXT: HLSLArrayTemporaryExpr {{.*}} 'float[2]' lvalue
+
+void call() {
+  float Arr[2] = {0, 0};
+  fn(Arr);
+}
+
+struct Obj {
+  float V;
+  int X;
+};
+
+void fn2(Obj O[4]) { }
+
+// CHECK: CallExpr {{.*}} 'void'
+// CHECK-NEXT: ImplicitCastExpr {{.*}} 'void (*)(Obj *)' <FunctionToPointerDecay>
+// CHECK-NEXT: DeclRefExpr {{.*}} 'void (Obj *)' lvalue Function {{.*}} 'fn2' 'void (Obj *)'
+// CHECK-NEXT: ImplicitCastExpr {{.*}} 'Obj *' <ArrayToPointerDecay>
+// CHECK-NEXT: HLSLArrayTemporaryExpr {{.*}} 'Obj[4]' lvalue
+
+void call2() {
+  Obj Arr[4] = {};
+  fn2(Arr);
+}
+
+
+void fn3(float x[2][2]) { }
+
+// CHECK: CallExpr {{.*}} 'void'
+// CHECK-NEXT: ImplicitCastExpr {{.*}} 'void (*)(float (*)[2])' <FunctionToPointerDecay>
+// CHECK-NEXT: DeclRefExpr {{.*}} 'void (float (*)[2])' lvalue Function {{.*}} 'fn3' 'void (float (*)[2])'
+// CHECK-NEXT: ImplicitCastExpr {{.*}} 'float (*)[2]' <ArrayToPointerDecay>
+// CHECK-NEXT: HLSLArrayTemporaryExpr {{.*}} 'float[2][2]' lvalue
+
+void call3() {
+  float Arr[2][2] = {{0, 0}, {1,1}};
+  fn3(Arr);
+}
diff --git a/clang/tools/libclang/CXCursor.cpp b/clang/tools/libclang/CXCursor.cpp
index 978adac5521aaa..43b088c573c0a4 100644
--- a/clang/tools/libclang/CXCursor.cpp
+++ b/clang/tools/libclang/CXCursor.cpp
@@ -335,6 +335,7 @@ CXCursor cxcursor::MakeCXCursor(const Stmt *S, const Decl *Parent,
   case Stmt::ObjCSubscriptRefExprClass:
   case Stmt::RecoveryExprClass:
   case Stmt::SYCLUniqueStableNameExprClass:
+  case Stmt::HLSLArrayTemporaryExprClass:
     K = CXCursor_UnexposedExpr;
     break;
 

@llvmbot
Copy link
Collaborator

llvmbot commented Jan 24, 2024

@llvm/pr-subscribers-hlsl

Author: Chris B (llvm-beanz)

Changes

In HLSL function parameters are passed by value, including array parameters. This change introduces a new AST node to represent array temporary expressions. They behave as lvalues to temporary arrays and decay to pointers for overload resolution and code generation.

The behavior of HLSL function calls is documented in the draft language specification under the Expr.Post.Call heading.

Additionally the design of this implementation approach is documented in Clang's documentation


Full diff: https://github.com/llvm/llvm-project/pull/79382.diff

20 Files Affected:

  • (modified) clang/include/clang/AST/Expr.h (+38)
  • (modified) clang/include/clang/AST/RecursiveASTVisitor.h (+2)
  • (modified) clang/include/clang/Basic/StmtNodes.td (+3)
  • (modified) clang/lib/AST/Expr.cpp (+11)
  • (modified) clang/lib/AST/ExprClassification.cpp (+1)
  • (modified) clang/lib/AST/ExprConstant.cpp (+1)
  • (modified) clang/lib/AST/ItaniumMangle.cpp (+4)
  • (modified) clang/lib/AST/StmtPrinter.cpp (+4)
  • (modified) clang/lib/AST/StmtProfile.cpp (+9)
  • (modified) clang/lib/CodeGen/CGExpr.cpp (+1)
  • (modified) clang/lib/CodeGen/CGExprAgg.cpp (+6)
  • (modified) clang/lib/Sema/SemaExceptionSpec.cpp (+1)
  • (modified) clang/lib/Sema/SemaInit.cpp (+5)
  • (modified) clang/lib/Sema/TreeTransform.h (+13)
  • (modified) clang/lib/Serialization/ASTReaderStmt.cpp (+8)
  • (modified) clang/lib/Serialization/ASTWriterStmt.cpp (+8)
  • (modified) clang/lib/StaticAnalyzer/Core/ExprEngine.cpp (+2-1)
  • (added) clang/test/CodeGenHLSL/ArrayTemporary.hlsl (+49)
  • (added) clang/test/SemaHLSL/ArrayTemporary.hlsl (+46)
  • (modified) clang/tools/libclang/CXCursor.cpp (+1)
diff --git a/clang/include/clang/AST/Expr.h b/clang/include/clang/AST/Expr.h
index 59f0aee2c0cedd..c28747e7a3796f 100644
--- a/clang/include/clang/AST/Expr.h
+++ b/clang/include/clang/AST/Expr.h
@@ -6651,6 +6651,44 @@ class RecoveryExpr final : public Expr,
   friend class ASTStmtWriter;
 };
 
+/// HLSLArrayTemporaryExpr - In HLSL, default parameter passing is by value
+/// including for arrays. This AST node represents a materialized temporary of a
+/// constant size arrray.
+class HLSLArrayTemporaryExpr : public Expr {
+  Expr *SourceExpr;
+
+  HLSLArrayTemporaryExpr(Expr *S)
+      : Expr(HLSLArrayTemporaryExprClass, S->getType(), VK_LValue, OK_Ordinary),
+        SourceExpr(S) {}
+
+  HLSLArrayTemporaryExpr(EmptyShell Empty)
+      : Expr(HLSLArrayTemporaryExprClass, Empty), SourceExpr(nullptr) {}
+
+public:
+  static HLSLArrayTemporaryExpr *Create(const ASTContext &Ctx, Expr *S);
+  static HLSLArrayTemporaryExpr *CreateEmpty(const ASTContext &Ctx);
+
+  const Expr *getSourceExpr() const { return SourceExpr; }
+  Expr *getSourceExpr() { return SourceExpr; }
+  void setSourceExpr(Expr *S) { SourceExpr = S; }
+
+  SourceLocation getBeginLoc() const { return SourceExpr->getBeginLoc(); }
+
+  SourceLocation getEndLoc() const { return SourceExpr->getEndLoc(); }
+
+  static bool classof(const Stmt *T) {
+    return T->getStmtClass() == HLSLArrayTemporaryExprClass;
+  }
+
+  // Iterators
+  child_range children() {
+    return child_range(child_iterator(), child_iterator());
+  }
+  const_child_range children() const {
+    return const_child_range(const_child_iterator(), const_child_iterator());
+  }
+};
+
 } // end namespace clang
 
 #endif // LLVM_CLANG_AST_EXPR_H
diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h
index 2aee6a947141b6..f7f8bbbb05cf20 100644
--- a/clang/include/clang/AST/RecursiveASTVisitor.h
+++ b/clang/include/clang/AST/RecursiveASTVisitor.h
@@ -3171,6 +3171,8 @@ DEF_TRAVERSE_STMT(OMPTargetParallelGenericLoopDirective,
 DEF_TRAVERSE_STMT(OMPErrorDirective,
                   { TRY_TO(TraverseOMPExecutableDirective(S)); })
 
+DEF_TRAVERSE_STMT(HLSLArrayTemporaryExpr, {})
+
 // OpenMP clauses.
 template <typename Derived>
 bool RecursiveASTVisitor<Derived>::TraverseOMPClause(OMPClause *C) {
diff --git a/clang/include/clang/Basic/StmtNodes.td b/clang/include/clang/Basic/StmtNodes.td
index cec301dfca2817..3d14d33a14b878 100644
--- a/clang/include/clang/Basic/StmtNodes.td
+++ b/clang/include/clang/Basic/StmtNodes.td
@@ -295,3 +295,6 @@ def OMPTargetTeamsGenericLoopDirective : StmtNode<OMPLoopDirective>;
 def OMPParallelGenericLoopDirective : StmtNode<OMPLoopDirective>;
 def OMPTargetParallelGenericLoopDirective : StmtNode<OMPLoopDirective>;
 def OMPErrorDirective : StmtNode<OMPExecutableDirective>;
+
+// HLSL Extensions
+def HLSLArrayTemporaryExpr : StmtNode<Expr>;
diff --git a/clang/lib/AST/Expr.cpp b/clang/lib/AST/Expr.cpp
index f1efa98e175edf..5903e2763dbe59 100644
--- a/clang/lib/AST/Expr.cpp
+++ b/clang/lib/AST/Expr.cpp
@@ -3569,6 +3569,7 @@ bool Expr::HasSideEffects(const ASTContext &Ctx,
   case ConceptSpecializationExprClass:
   case RequiresExprClass:
   case SYCLUniqueStableNameExprClass:
+  case HLSLArrayTemporaryExprClass:
     // These never have a side-effect.
     return false;
 
@@ -5227,3 +5228,13 @@ OMPIteratorExpr *OMPIteratorExpr::CreateEmpty(const ASTContext &Context,
       alignof(OMPIteratorExpr));
   return new (Mem) OMPIteratorExpr(EmptyShell(), NumIterators);
 }
+
+HLSLArrayTemporaryExpr *
+HLSLArrayTemporaryExpr::Create(const ASTContext &Ctx, Expr *Base) {
+  return new (Ctx) HLSLArrayTemporaryExpr(Base);
+}
+
+HLSLArrayTemporaryExpr *
+HLSLArrayTemporaryExpr::CreateEmpty(const ASTContext &Ctx) {
+  return new (Ctx) HLSLArrayTemporaryExpr(EmptyShell());
+}
diff --git a/clang/lib/AST/ExprClassification.cpp b/clang/lib/AST/ExprClassification.cpp
index ffa7c6802ea6e1..f07f2756c3f81d 100644
--- a/clang/lib/AST/ExprClassification.cpp
+++ b/clang/lib/AST/ExprClassification.cpp
@@ -148,6 +148,7 @@ static Cl::Kinds ClassifyInternal(ASTContext &Ctx, const Expr *E) {
   case Expr::OMPArraySectionExprClass:
   case Expr::OMPArrayShapingExprClass:
   case Expr::OMPIteratorExprClass:
+  case Expr::HLSLArrayTemporaryExprClass:
     return Cl::CL_LValue;
 
     // C99 6.5.2.5p5 says that compound literals are lvalues.
diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index f1d07d022b2584..ef474983b4b7dd 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -16044,6 +16044,7 @@ static ICEDiag CheckICE(const Expr* E, const ASTContext &Ctx) {
   case Expr::CoyieldExprClass:
   case Expr::SYCLUniqueStableNameExprClass:
   case Expr::CXXParenListInitExprClass:
+  case Expr::HLSLArrayTemporaryExprClass:
     return ICEDiag(IK_NotICE, E->getBeginLoc());
 
   case Expr::InitListExprClass: {
diff --git a/clang/lib/AST/ItaniumMangle.cpp b/clang/lib/AST/ItaniumMangle.cpp
index 40b1e086ddd0c6..7da53e71719c14 100644
--- a/clang/lib/AST/ItaniumMangle.cpp
+++ b/clang/lib/AST/ItaniumMangle.cpp
@@ -4701,6 +4701,10 @@ void CXXNameMangler::mangleExpression(const Expr *E, unsigned Arity,
     E = cast<ConstantExpr>(E)->getSubExpr();
     goto recurse;
 
+  case Expr::HLSLArrayTemporaryExprClass:
+    E = cast<HLSLArrayTemporaryExpr>(E)->getSourceExpr();
+    goto recurse;
+
   // FIXME: invent manglings for all these.
   case Expr::BlockExprClass:
   case Expr::ChooseExprClass:
diff --git a/clang/lib/AST/StmtPrinter.cpp b/clang/lib/AST/StmtPrinter.cpp
index 9b5bf436c13be9..dd2a4f8be403da 100644
--- a/clang/lib/AST/StmtPrinter.cpp
+++ b/clang/lib/AST/StmtPrinter.cpp
@@ -2749,6 +2749,10 @@ void StmtPrinter::VisitAsTypeExpr(AsTypeExpr *Node) {
   OS << ")";
 }
 
+void StmtPrinter::VisitHLSLArrayTemporaryExpr(HLSLArrayTemporaryExpr *Node) {
+  PrintExpr(Node->getSourceExpr());
+}
+
 //===----------------------------------------------------------------------===//
 // Stmt method implementations
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp
index dd0838edab7b3f..0c5acb596f1179 100644
--- a/clang/lib/AST/StmtProfile.cpp
+++ b/clang/lib/AST/StmtProfile.cpp
@@ -2433,6 +2433,15 @@ void StmtProfiler::VisitTemplateArgument(const TemplateArgument &Arg) {
   }
 }
 
+//===----------------------------------------------------------------------===//
+// HLSL AST Nodes
+//===----------------------------------------------------------------------===//
+
+void StmtProfiler::VisitHLSLArrayTemporaryExpr(
+    const HLSLArrayTemporaryExpr *S) {
+  VisitExpr(S);
+}
+
 void Stmt::Profile(llvm::FoldingSetNodeID &ID, const ASTContext &Context,
                    bool Canonical, bool ProfileLambdaExpr) const {
   StmtProfilerWithPointers Profiler(ID, Context, Canonical, ProfileLambdaExpr);
diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index c5f6b6d3a99f0b..cd7ad03208654a 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -1590,6 +1590,7 @@ LValue CodeGenFunction::EmitLValueHelper(const Expr *E,
   case Expr::CXXUuidofExprClass:
     return EmitCXXUuidofLValue(cast<CXXUuidofExpr>(E));
   case Expr::LambdaExprClass:
+  case Expr::HLSLArrayTemporaryExprClass:
     return EmitAggExprToLValue(E);
 
   case Expr::ExprWithCleanupsClass: {
diff --git a/clang/lib/CodeGen/CGExprAgg.cpp b/clang/lib/CodeGen/CGExprAgg.cpp
index 810b28f25fa18b..57471da27133d9 100644
--- a/clang/lib/CodeGen/CGExprAgg.cpp
+++ b/clang/lib/CodeGen/CGExprAgg.cpp
@@ -235,6 +235,8 @@ class AggExprEmitter : public StmtVisitor<AggExprEmitter> {
     RValue Res = CGF.EmitAtomicExpr(E);
     EmitFinalDestCopy(E->getType(), Res);
   }
+
+  void VisitHLSLArrayTemporaryExpr(HLSLArrayTemporaryExpr *E);
 };
 }  // end anonymous namespace.
 
@@ -1923,6 +1925,10 @@ void AggExprEmitter::VisitDesignatedInitUpdateExpr(DesignatedInitUpdateExpr *E)
   VisitInitListExpr(E->getUpdater());
 }
 
+void AggExprEmitter::VisitHLSLArrayTemporaryExpr(HLSLArrayTemporaryExpr *E) {
+  Visit(E->getSourceExpr());
+}
+
 //===----------------------------------------------------------------------===//
 //                        Entry Points into this File
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaExceptionSpec.cpp b/clang/lib/Sema/SemaExceptionSpec.cpp
index 75730ea888afb4..0e4c11bf5b2614 100644
--- a/clang/lib/Sema/SemaExceptionSpec.cpp
+++ b/clang/lib/Sema/SemaExceptionSpec.cpp
@@ -1414,6 +1414,7 @@ CanThrowResult Sema::canThrow(const Stmt *S) {
   case Expr::SourceLocExprClass:
   case Expr::ConceptSpecializationExprClass:
   case Expr::RequiresExprClass:
+  case Expr::HLSLArrayTemporaryExprClass:
     // These expressions can never throw.
     return CT_Cannot;
 
diff --git a/clang/lib/Sema/SemaInit.cpp b/clang/lib/Sema/SemaInit.cpp
index 457fa377355a97..54f990f50d8576 100644
--- a/clang/lib/Sema/SemaInit.cpp
+++ b/clang/lib/Sema/SemaInit.cpp
@@ -10524,6 +10524,11 @@ Sema::PerformCopyInitialization(const InitializedEntity &Entity,
   Expr *InitE = Init.get();
   assert(InitE && "No initialization expression?");
 
+  if (LangOpts.HLSL)
+    if (auto AdjTy = dyn_cast<DecayedType>(Entity.getType()))
+      if (AdjTy->getOriginalType()->isConstantArrayType())
+        InitE = HLSLArrayTemporaryExpr::Create(getASTContext(), InitE);
+
   if (EqualLoc.isInvalid())
     EqualLoc = InitE->getBeginLoc();
 
diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h
index e55e752b9cc354..3772e9a61f5f97 100644
--- a/clang/lib/Sema/TreeTransform.h
+++ b/clang/lib/Sema/TreeTransform.h
@@ -15461,6 +15461,19 @@ TreeTransform<Derived>::TransformCapturedStmt(CapturedStmt *S) {
   return getSema().ActOnCapturedRegionEnd(Body.get());
 }
 
+template <typename Derived>
+ExprResult TreeTransform<Derived>::TransformHLSLArrayTemporaryExpr(
+    HLSLArrayTemporaryExpr *E) {
+  ExprResult SrcExpr = getDerived().TransformExpr(E->getSourceExpr());
+  if (SrcExpr.isInvalid())
+    return ExprError();
+
+  if (!getDerived().AlwaysRebuild() && SrcExpr.get() == E->getSourceExpr())
+    return E;
+
+  return HLSLArrayTemporaryExpr::Create(getSema().Context, SrcExpr.get());
+}
+
 } // end namespace clang
 
 #endif // LLVM_CLANG_LIB_SEMA_TREETRANSFORM_H
diff --git a/clang/lib/Serialization/ASTReaderStmt.cpp b/clang/lib/Serialization/ASTReaderStmt.cpp
index 85ecfa1a1a0bf2..2203a0add1f648 100644
--- a/clang/lib/Serialization/ASTReaderStmt.cpp
+++ b/clang/lib/Serialization/ASTReaderStmt.cpp
@@ -2776,6 +2776,14 @@ void ASTStmtReader::VisitOMPTargetParallelGenericLoopDirective(
   VisitOMPLoopDirective(D);
 }
 
+//===----------------------------------------------------------------------===//
+// HLSL AST Nodes
+//===----------------------------------------------------------------------===//
+
+void ASTStmtReader::VisitHLSLArrayTemporaryExpr(HLSLArrayTemporaryExpr *S) {
+  VisitExpr(S);
+}
+
 //===----------------------------------------------------------------------===//
 // ASTReader Implementation
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/Serialization/ASTWriterStmt.cpp b/clang/lib/Serialization/ASTWriterStmt.cpp
index e5836f5dcbe955..9951426e2f2d81 100644
--- a/clang/lib/Serialization/ASTWriterStmt.cpp
+++ b/clang/lib/Serialization/ASTWriterStmt.cpp
@@ -2825,6 +2825,14 @@ void ASTStmtWriter::VisitOMPTargetParallelGenericLoopDirective(
   Code = serialization::STMT_OMP_TARGET_PARALLEL_GENERIC_LOOP_DIRECTIVE;
 }
 
+//===----------------------------------------------------------------------===//
+// HLSL AST Nodes
+//===----------------------------------------------------------------------===//
+
+void ASTStmtWriter::VisitHLSLArrayTemporaryExpr(HLSLArrayTemporaryExpr *S) {
+  VisitExpr(S);
+}
+
 //===----------------------------------------------------------------------===//
 // ASTWriter Implementation
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp b/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp
index 24e91a22fd6884..39cdec788648fb 100644
--- a/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp
+++ b/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp
@@ -1821,7 +1821,8 @@ void ExprEngine::Visit(const Stmt *S, ExplodedNode *Pred,
     case Stmt::OMPTargetParallelGenericLoopDirectiveClass:
     case Stmt::CapturedStmtClass:
     case Stmt::OMPUnrollDirectiveClass:
-    case Stmt::OMPMetaDirectiveClass: {
+    case Stmt::OMPMetaDirectiveClass: 
+    case Stmt::HLSLArrayTemporaryExprClass: {
       const ExplodedNode *node = Bldr.generateSink(S, Pred, Pred->getState());
       Engine.addAbortedBlock(node, currBldrCtx->getBlock());
       break;
diff --git a/clang/test/CodeGenHLSL/ArrayTemporary.hlsl b/clang/test/CodeGenHLSL/ArrayTemporary.hlsl
new file mode 100644
index 00000000000000..411e3182f45809
--- /dev/null
+++ b/clang/test/CodeGenHLSL/ArrayTemporary.hlsl
@@ -0,0 +1,49 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -emit-llvm -disable-llvm-passes -o - %s | Filecheck %s
+
+void fn(float x[2]) { }
+
+// CHECK-LABEL: define void {{.*}}call{{.*}}
+// CHECK: [[Arr:%.*]] = alloca [2 x float]
+// CHECK: [[Tmp:%.*]] = alloca [2 x float]
+// CHECK: call void @llvm.memset.p0.i32(ptr align 4 [[Arr]], i8 0, i32 8, i1 false)
+// CHECK: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[Tmp]], ptr align 4 [[Arr]], i32 8, i1 false)
+// CHECK: [[Decay:%.*]] = getelementptr inbounds [2 x float], ptr [[Tmp]], i32 0, i32 0
+// CHECK: call void {{.*}}fn{{.*}}(ptr noundef [[Decay]])
+void call() {
+  float Arr[2] = {0, 0};
+  fn(Arr);
+}
+
+struct Obj {
+  float V;
+  int X;
+};
+
+void fn2(Obj O[4]) { }
+
+// CHECK-LABEL: define void {{.*}}call2{{.*}}
+// CHECK: [[Arr:%.*]] = alloca [4 x %struct.Obj]
+// CHECK: [[Tmp:%.*]] = alloca [4 x %struct.Obj]
+// CHECK: call void @llvm.memset.p0.i32(ptr align 4 [[Arr]], i8 0, i32 32, i1 false)
+// CHECK: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[Tmp]], ptr align 4 [[Arr]], i32 32, i1 false)
+// CHECK: [[Decay:%.*]] = getelementptr inbounds [4 x %struct.Obj], ptr [[Tmp]], i32 0, i32 0
+// CHECK: call void {{.*}}fn2{{.*}}(ptr noundef [[Decay]])
+void call2() {
+  Obj Arr[4] = {};
+  fn2(Arr);
+}
+
+
+void fn3(float x[2][2]) { }
+
+// CHECK-LABEL: define void {{.*}}call3{{.*}}
+// CHECK: [[Arr:%.*]] = alloca [2 x [2 x float]]
+// CHECK: [[Tmp:%.*]] = alloca [2 x [2 x float]]
+// CHECK: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[Arr]], ptr align 4 {{.*}}, i32 16, i1 false)
+// CHECK: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[Tmp]], ptr align 4 [[Arr]], i32 16, i1 false)
+// CHECK: [[Decay:%.*]] = getelementptr inbounds [2 x [2 x float]], ptr [[Tmp]], i32 0, i32 0
+// CHECK: call void {{.*}}fn3{{.*}}(ptr noundef [[Decay]])
+void call3() {
+  float Arr[2][2] = {{0, 0}, {1,1}};
+  fn3(Arr);
+}
diff --git a/clang/test/SemaHLSL/ArrayTemporary.hlsl b/clang/test/SemaHLSL/ArrayTemporary.hlsl
new file mode 100644
index 00000000000000..32f4fc0f9abdca
--- /dev/null
+++ b/clang/test/SemaHLSL/ArrayTemporary.hlsl
@@ -0,0 +1,46 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -ast-dump %s | Filecheck %s
+
+void fn(float x[2]) { }
+
+// CHECK: CallExpr {{.*}} 'void'
+// CHECK-NEXT: ImplicitCastExpr {{.*}} 'void (*)(float *)' <FunctionToPointerDecay>
+// CHECK-NEXT: DeclRefExpr {{.*}} 'void (float *)' lvalue Function {{.*}} 'fn' 'void (float *)'
+// CHECK-NEXT: ImplicitCastExpr {{.*}} 'float *' <ArrayToPointerDecay>
+// CHECK-NEXT: HLSLArrayTemporaryExpr {{.*}} 'float[2]' lvalue
+
+void call() {
+  float Arr[2] = {0, 0};
+  fn(Arr);
+}
+
+struct Obj {
+  float V;
+  int X;
+};
+
+void fn2(Obj O[4]) { }
+
+// CHECK: CallExpr {{.*}} 'void'
+// CHECK-NEXT: ImplicitCastExpr {{.*}} 'void (*)(Obj *)' <FunctionToPointerDecay>
+// CHECK-NEXT: DeclRefExpr {{.*}} 'void (Obj *)' lvalue Function {{.*}} 'fn2' 'void (Obj *)'
+// CHECK-NEXT: ImplicitCastExpr {{.*}} 'Obj *' <ArrayToPointerDecay>
+// CHECK-NEXT: HLSLArrayTemporaryExpr {{.*}} 'Obj[4]' lvalue
+
+void call2() {
+  Obj Arr[4] = {};
+  fn2(Arr);
+}
+
+
+void fn3(float x[2][2]) { }
+
+// CHECK: CallExpr {{.*}} 'void'
+// CHECK-NEXT: ImplicitCastExpr {{.*}} 'void (*)(float (*)[2])' <FunctionToPointerDecay>
+// CHECK-NEXT: DeclRefExpr {{.*}} 'void (float (*)[2])' lvalue Function {{.*}} 'fn3' 'void (float (*)[2])'
+// CHECK-NEXT: ImplicitCastExpr {{.*}} 'float (*)[2]' <ArrayToPointerDecay>
+// CHECK-NEXT: HLSLArrayTemporaryExpr {{.*}} 'float[2][2]' lvalue
+
+void call3() {
+  float Arr[2][2] = {{0, 0}, {1,1}};
+  fn3(Arr);
+}
diff --git a/clang/tools/libclang/CXCursor.cpp b/clang/tools/libclang/CXCursor.cpp
index 978adac5521aaa..43b088c573c0a4 100644
--- a/clang/tools/libclang/CXCursor.cpp
+++ b/clang/tools/libclang/CXCursor.cpp
@@ -335,6 +335,7 @@ CXCursor cxcursor::MakeCXCursor(const Stmt *S, const Decl *Parent,
   case Stmt::ObjCSubscriptRefExprClass:
   case Stmt::RecoveryExprClass:
   case Stmt::SYCLUniqueStableNameExprClass:
+  case Stmt::HLSLArrayTemporaryExprClass:
     K = CXCursor_UnexposedExpr;
     break;
 

@llvmbot
Copy link
Collaborator

llvmbot commented Jan 24, 2024

@llvm/pr-subscribers-clang-codegen

Author: Chris B (llvm-beanz)

Changes

In HLSL function parameters are passed by value, including array parameters. This change introduces a new AST node to represent array temporary expressions. They behave as lvalues to temporary arrays and decay to pointers for overload resolution and code generation.

The behavior of HLSL function calls is documented in the draft language specification under the Expr.Post.Call heading.

Additionally the design of this implementation approach is documented in Clang's documentation


Full diff: https://github.com/llvm/llvm-project/pull/79382.diff

20 Files Affected:

  • (modified) clang/include/clang/AST/Expr.h (+38)
  • (modified) clang/include/clang/AST/RecursiveASTVisitor.h (+2)
  • (modified) clang/include/clang/Basic/StmtNodes.td (+3)
  • (modified) clang/lib/AST/Expr.cpp (+11)
  • (modified) clang/lib/AST/ExprClassification.cpp (+1)
  • (modified) clang/lib/AST/ExprConstant.cpp (+1)
  • (modified) clang/lib/AST/ItaniumMangle.cpp (+4)
  • (modified) clang/lib/AST/StmtPrinter.cpp (+4)
  • (modified) clang/lib/AST/StmtProfile.cpp (+9)
  • (modified) clang/lib/CodeGen/CGExpr.cpp (+1)
  • (modified) clang/lib/CodeGen/CGExprAgg.cpp (+6)
  • (modified) clang/lib/Sema/SemaExceptionSpec.cpp (+1)
  • (modified) clang/lib/Sema/SemaInit.cpp (+5)
  • (modified) clang/lib/Sema/TreeTransform.h (+13)
  • (modified) clang/lib/Serialization/ASTReaderStmt.cpp (+8)
  • (modified) clang/lib/Serialization/ASTWriterStmt.cpp (+8)
  • (modified) clang/lib/StaticAnalyzer/Core/ExprEngine.cpp (+2-1)
  • (added) clang/test/CodeGenHLSL/ArrayTemporary.hlsl (+49)
  • (added) clang/test/SemaHLSL/ArrayTemporary.hlsl (+46)
  • (modified) clang/tools/libclang/CXCursor.cpp (+1)
diff --git a/clang/include/clang/AST/Expr.h b/clang/include/clang/AST/Expr.h
index 59f0aee2c0ceddf..c28747e7a3796f3 100644
--- a/clang/include/clang/AST/Expr.h
+++ b/clang/include/clang/AST/Expr.h
@@ -6651,6 +6651,44 @@ class RecoveryExpr final : public Expr,
   friend class ASTStmtWriter;
 };
 
+/// HLSLArrayTemporaryExpr - In HLSL, default parameter passing is by value
+/// including for arrays. This AST node represents a materialized temporary of a
+/// constant size arrray.
+class HLSLArrayTemporaryExpr : public Expr {
+  Expr *SourceExpr;
+
+  HLSLArrayTemporaryExpr(Expr *S)
+      : Expr(HLSLArrayTemporaryExprClass, S->getType(), VK_LValue, OK_Ordinary),
+        SourceExpr(S) {}
+
+  HLSLArrayTemporaryExpr(EmptyShell Empty)
+      : Expr(HLSLArrayTemporaryExprClass, Empty), SourceExpr(nullptr) {}
+
+public:
+  static HLSLArrayTemporaryExpr *Create(const ASTContext &Ctx, Expr *S);
+  static HLSLArrayTemporaryExpr *CreateEmpty(const ASTContext &Ctx);
+
+  const Expr *getSourceExpr() const { return SourceExpr; }
+  Expr *getSourceExpr() { return SourceExpr; }
+  void setSourceExpr(Expr *S) { SourceExpr = S; }
+
+  SourceLocation getBeginLoc() const { return SourceExpr->getBeginLoc(); }
+
+  SourceLocation getEndLoc() const { return SourceExpr->getEndLoc(); }
+
+  static bool classof(const Stmt *T) {
+    return T->getStmtClass() == HLSLArrayTemporaryExprClass;
+  }
+
+  // Iterators
+  child_range children() {
+    return child_range(child_iterator(), child_iterator());
+  }
+  const_child_range children() const {
+    return const_child_range(const_child_iterator(), const_child_iterator());
+  }
+};
+
 } // end namespace clang
 
 #endif // LLVM_CLANG_AST_EXPR_H
diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h
index 2aee6a947141b6c..f7f8bbbb05cf20c 100644
--- a/clang/include/clang/AST/RecursiveASTVisitor.h
+++ b/clang/include/clang/AST/RecursiveASTVisitor.h
@@ -3171,6 +3171,8 @@ DEF_TRAVERSE_STMT(OMPTargetParallelGenericLoopDirective,
 DEF_TRAVERSE_STMT(OMPErrorDirective,
                   { TRY_TO(TraverseOMPExecutableDirective(S)); })
 
+DEF_TRAVERSE_STMT(HLSLArrayTemporaryExpr, {})
+
 // OpenMP clauses.
 template <typename Derived>
 bool RecursiveASTVisitor<Derived>::TraverseOMPClause(OMPClause *C) {
diff --git a/clang/include/clang/Basic/StmtNodes.td b/clang/include/clang/Basic/StmtNodes.td
index cec301dfca2817b..3d14d33a14b8785 100644
--- a/clang/include/clang/Basic/StmtNodes.td
+++ b/clang/include/clang/Basic/StmtNodes.td
@@ -295,3 +295,6 @@ def OMPTargetTeamsGenericLoopDirective : StmtNode<OMPLoopDirective>;
 def OMPParallelGenericLoopDirective : StmtNode<OMPLoopDirective>;
 def OMPTargetParallelGenericLoopDirective : StmtNode<OMPLoopDirective>;
 def OMPErrorDirective : StmtNode<OMPExecutableDirective>;
+
+// HLSL Extensions
+def HLSLArrayTemporaryExpr : StmtNode<Expr>;
diff --git a/clang/lib/AST/Expr.cpp b/clang/lib/AST/Expr.cpp
index f1efa98e175edf5..5903e2763dbe59a 100644
--- a/clang/lib/AST/Expr.cpp
+++ b/clang/lib/AST/Expr.cpp
@@ -3569,6 +3569,7 @@ bool Expr::HasSideEffects(const ASTContext &Ctx,
   case ConceptSpecializationExprClass:
   case RequiresExprClass:
   case SYCLUniqueStableNameExprClass:
+  case HLSLArrayTemporaryExprClass:
     // These never have a side-effect.
     return false;
 
@@ -5227,3 +5228,13 @@ OMPIteratorExpr *OMPIteratorExpr::CreateEmpty(const ASTContext &Context,
       alignof(OMPIteratorExpr));
   return new (Mem) OMPIteratorExpr(EmptyShell(), NumIterators);
 }
+
+HLSLArrayTemporaryExpr *
+HLSLArrayTemporaryExpr::Create(const ASTContext &Ctx, Expr *Base) {
+  return new (Ctx) HLSLArrayTemporaryExpr(Base);
+}
+
+HLSLArrayTemporaryExpr *
+HLSLArrayTemporaryExpr::CreateEmpty(const ASTContext &Ctx) {
+  return new (Ctx) HLSLArrayTemporaryExpr(EmptyShell());
+}
diff --git a/clang/lib/AST/ExprClassification.cpp b/clang/lib/AST/ExprClassification.cpp
index ffa7c6802ea6e19..f07f2756c3f81de 100644
--- a/clang/lib/AST/ExprClassification.cpp
+++ b/clang/lib/AST/ExprClassification.cpp
@@ -148,6 +148,7 @@ static Cl::Kinds ClassifyInternal(ASTContext &Ctx, const Expr *E) {
   case Expr::OMPArraySectionExprClass:
   case Expr::OMPArrayShapingExprClass:
   case Expr::OMPIteratorExprClass:
+  case Expr::HLSLArrayTemporaryExprClass:
     return Cl::CL_LValue;
 
     // C99 6.5.2.5p5 says that compound literals are lvalues.
diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index f1d07d022b25848..ef474983b4b7dda 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -16044,6 +16044,7 @@ static ICEDiag CheckICE(const Expr* E, const ASTContext &Ctx) {
   case Expr::CoyieldExprClass:
   case Expr::SYCLUniqueStableNameExprClass:
   case Expr::CXXParenListInitExprClass:
+  case Expr::HLSLArrayTemporaryExprClass:
     return ICEDiag(IK_NotICE, E->getBeginLoc());
 
   case Expr::InitListExprClass: {
diff --git a/clang/lib/AST/ItaniumMangle.cpp b/clang/lib/AST/ItaniumMangle.cpp
index 40b1e086ddd0c61..7da53e71719c148 100644
--- a/clang/lib/AST/ItaniumMangle.cpp
+++ b/clang/lib/AST/ItaniumMangle.cpp
@@ -4701,6 +4701,10 @@ void CXXNameMangler::mangleExpression(const Expr *E, unsigned Arity,
     E = cast<ConstantExpr>(E)->getSubExpr();
     goto recurse;
 
+  case Expr::HLSLArrayTemporaryExprClass:
+    E = cast<HLSLArrayTemporaryExpr>(E)->getSourceExpr();
+    goto recurse;
+
   // FIXME: invent manglings for all these.
   case Expr::BlockExprClass:
   case Expr::ChooseExprClass:
diff --git a/clang/lib/AST/StmtPrinter.cpp b/clang/lib/AST/StmtPrinter.cpp
index 9b5bf436c13be99..dd2a4f8be403da3 100644
--- a/clang/lib/AST/StmtPrinter.cpp
+++ b/clang/lib/AST/StmtPrinter.cpp
@@ -2749,6 +2749,10 @@ void StmtPrinter::VisitAsTypeExpr(AsTypeExpr *Node) {
   OS << ")";
 }
 
+void StmtPrinter::VisitHLSLArrayTemporaryExpr(HLSLArrayTemporaryExpr *Node) {
+  PrintExpr(Node->getSourceExpr());
+}
+
 //===----------------------------------------------------------------------===//
 // Stmt method implementations
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp
index dd0838edab7b3f9..0c5acb596f11795 100644
--- a/clang/lib/AST/StmtProfile.cpp
+++ b/clang/lib/AST/StmtProfile.cpp
@@ -2433,6 +2433,15 @@ void StmtProfiler::VisitTemplateArgument(const TemplateArgument &Arg) {
   }
 }
 
+//===----------------------------------------------------------------------===//
+// HLSL AST Nodes
+//===----------------------------------------------------------------------===//
+
+void StmtProfiler::VisitHLSLArrayTemporaryExpr(
+    const HLSLArrayTemporaryExpr *S) {
+  VisitExpr(S);
+}
+
 void Stmt::Profile(llvm::FoldingSetNodeID &ID, const ASTContext &Context,
                    bool Canonical, bool ProfileLambdaExpr) const {
   StmtProfilerWithPointers Profiler(ID, Context, Canonical, ProfileLambdaExpr);
diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index c5f6b6d3a99f0b2..cd7ad03208654a0 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -1590,6 +1590,7 @@ LValue CodeGenFunction::EmitLValueHelper(const Expr *E,
   case Expr::CXXUuidofExprClass:
     return EmitCXXUuidofLValue(cast<CXXUuidofExpr>(E));
   case Expr::LambdaExprClass:
+  case Expr::HLSLArrayTemporaryExprClass:
     return EmitAggExprToLValue(E);
 
   case Expr::ExprWithCleanupsClass: {
diff --git a/clang/lib/CodeGen/CGExprAgg.cpp b/clang/lib/CodeGen/CGExprAgg.cpp
index 810b28f25fa18bf..57471da27133d9e 100644
--- a/clang/lib/CodeGen/CGExprAgg.cpp
+++ b/clang/lib/CodeGen/CGExprAgg.cpp
@@ -235,6 +235,8 @@ class AggExprEmitter : public StmtVisitor<AggExprEmitter> {
     RValue Res = CGF.EmitAtomicExpr(E);
     EmitFinalDestCopy(E->getType(), Res);
   }
+
+  void VisitHLSLArrayTemporaryExpr(HLSLArrayTemporaryExpr *E);
 };
 }  // end anonymous namespace.
 
@@ -1923,6 +1925,10 @@ void AggExprEmitter::VisitDesignatedInitUpdateExpr(DesignatedInitUpdateExpr *E)
   VisitInitListExpr(E->getUpdater());
 }
 
+void AggExprEmitter::VisitHLSLArrayTemporaryExpr(HLSLArrayTemporaryExpr *E) {
+  Visit(E->getSourceExpr());
+}
+
 //===----------------------------------------------------------------------===//
 //                        Entry Points into this File
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaExceptionSpec.cpp b/clang/lib/Sema/SemaExceptionSpec.cpp
index 75730ea888afb41..0e4c11bf5b26140 100644
--- a/clang/lib/Sema/SemaExceptionSpec.cpp
+++ b/clang/lib/Sema/SemaExceptionSpec.cpp
@@ -1414,6 +1414,7 @@ CanThrowResult Sema::canThrow(const Stmt *S) {
   case Expr::SourceLocExprClass:
   case Expr::ConceptSpecializationExprClass:
   case Expr::RequiresExprClass:
+  case Expr::HLSLArrayTemporaryExprClass:
     // These expressions can never throw.
     return CT_Cannot;
 
diff --git a/clang/lib/Sema/SemaInit.cpp b/clang/lib/Sema/SemaInit.cpp
index 457fa377355a97c..54f990f50d85768 100644
--- a/clang/lib/Sema/SemaInit.cpp
+++ b/clang/lib/Sema/SemaInit.cpp
@@ -10524,6 +10524,11 @@ Sema::PerformCopyInitialization(const InitializedEntity &Entity,
   Expr *InitE = Init.get();
   assert(InitE && "No initialization expression?");
 
+  if (LangOpts.HLSL)
+    if (auto AdjTy = dyn_cast<DecayedType>(Entity.getType()))
+      if (AdjTy->getOriginalType()->isConstantArrayType())
+        InitE = HLSLArrayTemporaryExpr::Create(getASTContext(), InitE);
+
   if (EqualLoc.isInvalid())
     EqualLoc = InitE->getBeginLoc();
 
diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h
index e55e752b9cc354c..3772e9a61f5f97f 100644
--- a/clang/lib/Sema/TreeTransform.h
+++ b/clang/lib/Sema/TreeTransform.h
@@ -15461,6 +15461,19 @@ TreeTransform<Derived>::TransformCapturedStmt(CapturedStmt *S) {
   return getSema().ActOnCapturedRegionEnd(Body.get());
 }
 
+template <typename Derived>
+ExprResult TreeTransform<Derived>::TransformHLSLArrayTemporaryExpr(
+    HLSLArrayTemporaryExpr *E) {
+  ExprResult SrcExpr = getDerived().TransformExpr(E->getSourceExpr());
+  if (SrcExpr.isInvalid())
+    return ExprError();
+
+  if (!getDerived().AlwaysRebuild() && SrcExpr.get() == E->getSourceExpr())
+    return E;
+
+  return HLSLArrayTemporaryExpr::Create(getSema().Context, SrcExpr.get());
+}
+
 } // end namespace clang
 
 #endif // LLVM_CLANG_LIB_SEMA_TREETRANSFORM_H
diff --git a/clang/lib/Serialization/ASTReaderStmt.cpp b/clang/lib/Serialization/ASTReaderStmt.cpp
index 85ecfa1a1a0bf2a..2203a0add1f6482 100644
--- a/clang/lib/Serialization/ASTReaderStmt.cpp
+++ b/clang/lib/Serialization/ASTReaderStmt.cpp
@@ -2776,6 +2776,14 @@ void ASTStmtReader::VisitOMPTargetParallelGenericLoopDirective(
   VisitOMPLoopDirective(D);
 }
 
+//===----------------------------------------------------------------------===//
+// HLSL AST Nodes
+//===----------------------------------------------------------------------===//
+
+void ASTStmtReader::VisitHLSLArrayTemporaryExpr(HLSLArrayTemporaryExpr *S) {
+  VisitExpr(S);
+}
+
 //===----------------------------------------------------------------------===//
 // ASTReader Implementation
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/Serialization/ASTWriterStmt.cpp b/clang/lib/Serialization/ASTWriterStmt.cpp
index e5836f5dcbe955a..9951426e2f2d814 100644
--- a/clang/lib/Serialization/ASTWriterStmt.cpp
+++ b/clang/lib/Serialization/ASTWriterStmt.cpp
@@ -2825,6 +2825,14 @@ void ASTStmtWriter::VisitOMPTargetParallelGenericLoopDirective(
   Code = serialization::STMT_OMP_TARGET_PARALLEL_GENERIC_LOOP_DIRECTIVE;
 }
 
+//===----------------------------------------------------------------------===//
+// HLSL AST Nodes
+//===----------------------------------------------------------------------===//
+
+void ASTStmtWriter::VisitHLSLArrayTemporaryExpr(HLSLArrayTemporaryExpr *S) {
+  VisitExpr(S);
+}
+
 //===----------------------------------------------------------------------===//
 // ASTWriter Implementation
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp b/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp
index 24e91a22fd6884f..39cdec788648fbe 100644
--- a/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp
+++ b/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp
@@ -1821,7 +1821,8 @@ void ExprEngine::Visit(const Stmt *S, ExplodedNode *Pred,
     case Stmt::OMPTargetParallelGenericLoopDirectiveClass:
     case Stmt::CapturedStmtClass:
     case Stmt::OMPUnrollDirectiveClass:
-    case Stmt::OMPMetaDirectiveClass: {
+    case Stmt::OMPMetaDirectiveClass: 
+    case Stmt::HLSLArrayTemporaryExprClass: {
       const ExplodedNode *node = Bldr.generateSink(S, Pred, Pred->getState());
       Engine.addAbortedBlock(node, currBldrCtx->getBlock());
       break;
diff --git a/clang/test/CodeGenHLSL/ArrayTemporary.hlsl b/clang/test/CodeGenHLSL/ArrayTemporary.hlsl
new file mode 100644
index 000000000000000..411e3182f458091
--- /dev/null
+++ b/clang/test/CodeGenHLSL/ArrayTemporary.hlsl
@@ -0,0 +1,49 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -emit-llvm -disable-llvm-passes -o - %s | Filecheck %s
+
+void fn(float x[2]) { }
+
+// CHECK-LABEL: define void {{.*}}call{{.*}}
+// CHECK: [[Arr:%.*]] = alloca [2 x float]
+// CHECK: [[Tmp:%.*]] = alloca [2 x float]
+// CHECK: call void @llvm.memset.p0.i32(ptr align 4 [[Arr]], i8 0, i32 8, i1 false)
+// CHECK: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[Tmp]], ptr align 4 [[Arr]], i32 8, i1 false)
+// CHECK: [[Decay:%.*]] = getelementptr inbounds [2 x float], ptr [[Tmp]], i32 0, i32 0
+// CHECK: call void {{.*}}fn{{.*}}(ptr noundef [[Decay]])
+void call() {
+  float Arr[2] = {0, 0};
+  fn(Arr);
+}
+
+struct Obj {
+  float V;
+  int X;
+};
+
+void fn2(Obj O[4]) { }
+
+// CHECK-LABEL: define void {{.*}}call2{{.*}}
+// CHECK: [[Arr:%.*]] = alloca [4 x %struct.Obj]
+// CHECK: [[Tmp:%.*]] = alloca [4 x %struct.Obj]
+// CHECK: call void @llvm.memset.p0.i32(ptr align 4 [[Arr]], i8 0, i32 32, i1 false)
+// CHECK: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[Tmp]], ptr align 4 [[Arr]], i32 32, i1 false)
+// CHECK: [[Decay:%.*]] = getelementptr inbounds [4 x %struct.Obj], ptr [[Tmp]], i32 0, i32 0
+// CHECK: call void {{.*}}fn2{{.*}}(ptr noundef [[Decay]])
+void call2() {
+  Obj Arr[4] = {};
+  fn2(Arr);
+}
+
+
+void fn3(float x[2][2]) { }
+
+// CHECK-LABEL: define void {{.*}}call3{{.*}}
+// CHECK: [[Arr:%.*]] = alloca [2 x [2 x float]]
+// CHECK: [[Tmp:%.*]] = alloca [2 x [2 x float]]
+// CHECK: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[Arr]], ptr align 4 {{.*}}, i32 16, i1 false)
+// CHECK: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[Tmp]], ptr align 4 [[Arr]], i32 16, i1 false)
+// CHECK: [[Decay:%.*]] = getelementptr inbounds [2 x [2 x float]], ptr [[Tmp]], i32 0, i32 0
+// CHECK: call void {{.*}}fn3{{.*}}(ptr noundef [[Decay]])
+void call3() {
+  float Arr[2][2] = {{0, 0}, {1,1}};
+  fn3(Arr);
+}
diff --git a/clang/test/SemaHLSL/ArrayTemporary.hlsl b/clang/test/SemaHLSL/ArrayTemporary.hlsl
new file mode 100644
index 000000000000000..32f4fc0f9abdca2
--- /dev/null
+++ b/clang/test/SemaHLSL/ArrayTemporary.hlsl
@@ -0,0 +1,46 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -ast-dump %s | Filecheck %s
+
+void fn(float x[2]) { }
+
+// CHECK: CallExpr {{.*}} 'void'
+// CHECK-NEXT: ImplicitCastExpr {{.*}} 'void (*)(float *)' <FunctionToPointerDecay>
+// CHECK-NEXT: DeclRefExpr {{.*}} 'void (float *)' lvalue Function {{.*}} 'fn' 'void (float *)'
+// CHECK-NEXT: ImplicitCastExpr {{.*}} 'float *' <ArrayToPointerDecay>
+// CHECK-NEXT: HLSLArrayTemporaryExpr {{.*}} 'float[2]' lvalue
+
+void call() {
+  float Arr[2] = {0, 0};
+  fn(Arr);
+}
+
+struct Obj {
+  float V;
+  int X;
+};
+
+void fn2(Obj O[4]) { }
+
+// CHECK: CallExpr {{.*}} 'void'
+// CHECK-NEXT: ImplicitCastExpr {{.*}} 'void (*)(Obj *)' <FunctionToPointerDecay>
+// CHECK-NEXT: DeclRefExpr {{.*}} 'void (Obj *)' lvalue Function {{.*}} 'fn2' 'void (Obj *)'
+// CHECK-NEXT: ImplicitCastExpr {{.*}} 'Obj *' <ArrayToPointerDecay>
+// CHECK-NEXT: HLSLArrayTemporaryExpr {{.*}} 'Obj[4]' lvalue
+
+void call2() {
+  Obj Arr[4] = {};
+  fn2(Arr);
+}
+
+
+void fn3(float x[2][2]) { }
+
+// CHECK: CallExpr {{.*}} 'void'
+// CHECK-NEXT: ImplicitCastExpr {{.*}} 'void (*)(float (*)[2])' <FunctionToPointerDecay>
+// CHECK-NEXT: DeclRefExpr {{.*}} 'void (float (*)[2])' lvalue Function {{.*}} 'fn3' 'void (float (*)[2])'
+// CHECK-NEXT: ImplicitCastExpr {{.*}} 'float (*)[2]' <ArrayToPointerDecay>
+// CHECK-NEXT: HLSLArrayTemporaryExpr {{.*}} 'float[2][2]' lvalue
+
+void call3() {
+  float Arr[2][2] = {{0, 0}, {1,1}};
+  fn3(Arr);
+}
diff --git a/clang/tools/libclang/CXCursor.cpp b/clang/tools/libclang/CXCursor.cpp
index 978adac5521aaa0..43b088c573c0a4f 100644
--- a/clang/tools/libclang/CXCursor.cpp
+++ b/clang/tools/libclang/CXCursor.cpp
@@ -335,6 +335,7 @@ CXCursor cxcursor::MakeCXCursor(const Stmt *S, const Decl *Parent,
   case Stmt::ObjCSubscriptRefExprClass:
   case Stmt::RecoveryExprClass:
   case Stmt::SYCLUniqueStableNameExprClass:
+  case Stmt::HLSLArrayTemporaryExprClass:
     K = CXCursor_UnexposedExpr;
     break;
 

Copy link

github-actions bot commented Jan 24, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@@ -3569,6 +3569,7 @@ bool Expr::HasSideEffects(const ASTContext &Ctx,
case ConceptSpecializationExprClass:
case RequiresExprClass:
case SYCLUniqueStableNameExprClass:
case HLSLArrayTemporaryExprClass:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks wrong; I think you need the recursive visit here?

if (LangOpts.HLSL)
if (auto AdjTy = dyn_cast<DecayedType>(Entity.getType()))
if (AdjTy->getOriginalType()->isConstantArrayType())
InitE = HLSLArrayTemporaryExpr::Create(getASTContext(), InitE);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This implementation is surprising. If arrays are supposed to be passed by value, decay shouldn't be happening in the first place.

Allowing decay to happen, then trying to fix it up later, has some weird implications: for example, void f(int[2]) and void f(int[4]) have the same type. You introduce pointer types in places where they shouldn't really exist, and casts which aren't really supposed to happen. And it's not clear to me the non-canonical type wrapper will be preserved in all the cases you need (preserving non-canonical types is best-effort; sometimes we're forced to canonicalize).

As an alternative approach, maybe we can introduce a new type to represent a "function argument of array type". When we construct the function type, instead of decaying to a pointer, it would "decay" to "function argument of array type". The result is sort of similar, but instead of a non-canonical DecayedType, you have a canonical type which would be reliably preserved across the compiler. For argument passing, initialization would work similarly to what you're doing now.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ooo. I like this idea a lot. Let me go give that a try. I was trying to avoid completely breaking array decay completely because at some point we’re going to align the language more with C/C++.

@llvm-beanz
Copy link
Collaborator Author

I force pushed an update here because the new implementation is basically a complete rewrite based on @efriedma-quic's suggestion. I think this is a significantly cleaner approach that more accurately models the behavior of HLSL.

Copy link
Collaborator

@efriedma-quic efriedma-quic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like the right direction.

@@ -10900,6 +10936,10 @@ QualType ASTContext::mergeTypes(QualType LHS, QualType RHS, bool OfBlockPointer,
assert(LHS != RHS &&
"Equivalent pipe types should have already been handled!");
return {};
case Type::ArrayParameter:
if (RHS != LHS)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's impossible for RHS and LHS to be equal here.

// If an argument is an array paramter expression being passed through. Emit
// the argument to a temporary and pass the temporary as the call arg.
if (auto AT = dyn_cast<ArrayParameterType>(type)) {
args.add(EmitAnyExprToTemp(E), type);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this code should be necessary? E shouldn't be an lvalue or a record type, so we should just fall through to the same code at the end of the function, I think?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, yea, I think I need to fix this a different way.

The case where this happens is:

void fn(float x[2]);

void call(float Arr[2]) {
  fn(Arr);
}

In the CallExpr to fn, the argument is an lvalue DeclRefExpr to the ParmVar.

I think the right way to do this is to insert an ImplicitCastExpr for CK_HLSLArrayRValue or CK_LValueToRValue.

@@ -1126,6 +1126,7 @@ class ConstExprEmitter :
case CK_NonAtomicToAtomic:
case CK_NoOp:
case CK_ConstructorConversion:
case CK_HLSLArrayRValue:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't accept CK_HLSLArrayRValue casts here (or if we do, we need to do something different with them).

@@ -2231,7 +2231,8 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
case CK_UserDefinedConversion:
return Visit(const_cast<Expr*>(E));

case CK_NoOp: {
case CK_NoOp:
case CK_HLSLArrayRValue: {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It shouldn't be possible to hit this case; a value of ArrayParameterType isn't a scalar.

@@ -500,6 +500,7 @@ ComplexPairTy ComplexExprEmitter::EmitCast(CastKind CK, Expr *Op,
case CK_NoOp:
case CK_LValueToRValue:
case CK_UserDefinedConversion:
case CK_HLSLArrayRValue:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The result of an CK_HLSLArrayRValue can't be a complex type.

@@ -1442,6 +1445,7 @@ static bool castPreservesZero(const CastExpr *CE) {
case CK_BitCast:
case CK_ToUnion:
case CK_ToVoid:
case CK_HLSLArrayRValue:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably should be false, like other lvalue-to-rvalue casts?

}
};

/// Represents the canonical version of C arrays with a specified constant size.
/// For example, the canonical type for 'int A[4 + 4*100]' is a
/// ConstantArrayType where the element type is 'int' and the size is 404.
class ConstantArrayType final
: public ArrayType,
private llvm::TrailingObjects<ConstantArrayType, const Expr *> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the TrailingObjects optimization here was added in 772e266? Not sure what the context was; CC @zygoloid

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The size expression is usually not part of a constant array type. Normally, all we want as part of the type node is the bound, and the size expression if needed can be found on the ArrayTypeLoc. The rare special case being handled here is that if the array bound is instantiation-dependent but not dependent, then we need to preserve the expression in the type (not just in the TypeLoc) because it affects (for example) whether two template declarations involving such a type are considered to declare the same template. Because this is a very rare case, and ConstantArrayTypes can be relatively common (eg, as the types of string literals), it seemed preferable to not store the Expr* if it's not needed, but in absolute terms it's probably a minor difference in memory usage, so it's probably fine to remove this optimization.

If we want to keep the optimization, perhaps we could replace the Size and SizeExpr fields with a pointer union that can hold either a <= 63-bit array bound (the common case) or a pointer to a slab-allocated pair of APInt and const Expr* (in the case where the bound doesn't fit in 63 bits or the size expression is instantiation-dependent).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zygoloid do you think the optimization is important to preserve? I guessed it might be safe to eliminate because other array types store the SizeExpr pointer, but if it's important I can make a separate PR to change the implementation as you suggested and we can merge that change first.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a general principle we try to keep our AST nodes small. This change isn't going to make a big difference by itself, but abandoning that principle and making a bunch of changes like this would. I think it's worth putting in a small amount of effort here, but if it doesn't work out nicely then it's fine to remove the optimization. We can find other places to save space.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My team is trying to be pedantic about tracking our work, so I've filed #85124 to track implementing @zygoloid's suggestion. I'll try and get a PR up for that in the next day or two then update this PR on top of that change.

llvm-beanz added a commit to llvm-beanz/llvm-project that referenced this pull request Mar 18, 2024
In PR llvm#79382, I need to add a new type that derives from
ConstantArrayType. This means that ConstantArrayType can no longer use
`llvm::TrailingObjects` to store the trailing optional Expr*.

This change refactors ConstantArrayType to store a 60-bit integer and
4-bits for the integer size in bytes. This replaces the APInt field
previously in the type but preserves enough information to recreate it
where needed.

To reduce the number of places where the APInt is re-constructed I've
also added some helper methods to the ConstantArrayType to allow some
common use cases that operate on either the stored small integer or the
APInt as appropriate.

Resolves llvm#85124.
llvm-beanz added a commit that referenced this pull request Mar 26, 2024
In PR #79382, I need to add a new type that derives from
ConstantArrayType. This means that ConstantArrayType can no longer use
`llvm::TrailingObjects` to store the trailing optional Expr*.

This change refactors ConstantArrayType to store a 60-bit integer and
4-bits for the integer size in bytes. This replaces the APInt field
previously in the type but preserves enough information to recreate it
where needed.

To reduce the number of places where the APInt is re-constructed I've
also added some helper methods to the ConstantArrayType to allow some
common use cases that operate on either the stored small integer or the
APInt as appropriate.

Resolves #85124.
HLSL constant sized array function parameters do not decay to pointers.
Instead constant sized array types are preserved as unique types for
overload resolution, template instantiation and name mangling.

This implements the change by adding a new `ArrayParameterType` which
represents a non-decaying `ConstantArrayType`. The new type behaves the
same as `ConstantArrayType` except that it does not decay to a pointer.

Values of `ConstantArrayType` in HLSL decay during overload resolution
via a new `HLSLArrayRValue` cast to `ArrayParameterType`.

`ArrayParamterType` values are passed indirectly by-value to functions
in IR generation resulting in callee generated memcpy instructions.
I beleive this captures all of the feedback from @efriedma-quic. Thank
you for your patience and great feedback!
case Type::ArrayParameter:
assert(LHS != RHS &&
"Equivalent pipe types should have already been handled!");
return LHS;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return {};? Also, assertion text is wrong.

You could try to merge like we do for ConstantArray... but it's unlikely to matter, since HLSL is a C++ variant (and type merging is basically a C-only thing).

@@ -2331,6 +2332,11 @@ void CodeGenFunction::EmitVariablyModifiedType(QualType type) {
type = cast<DecayedType>(ty)->getPointeeType();
break;

case Type::ArrayParameter:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just use the existing "case" for ConstantArray?

Copy link
Collaborator

@efriedma-quic efriedma-quic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@coopp coopp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WOW.. This was pretty big. I went through it twice and didn't see anything that jumped out. Looks good to me.

@llvm-beanz llvm-beanz merged commit 9434c08 into llvm:main Apr 1, 2024
6 of 7 checks passed
@@ -0,0 +1,76 @@
; ModuleID = '/Users/cbieneman/dev/llvm-project/clang/test/SemaHLSL/ArrayTemporary.hlsl'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this .ll file necessary? It doesn't look like a lit test file, no RUN commands etc.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, this file was added unintentionally, sent out #87346

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doh! I'll delete it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
clang:codegen clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:modules C++20 modules and Clang Header Modules clang:static analyzer clang Clang issues not falling into any other category debuginfo HLSL HLSL Language Support
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[HLSL] Array by-value parameter passing
6 participants