Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CUDA][HIP] Fix deduction guide #69366

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

yxsamliu
Copy link
Collaborator

Currently clang assumes implicit deduction guide to be host device. This generates two identical implicit deduction guides when a class have a device and a host constructor which have the same input parameter and cause ambiguity.

Since an implicit deduction guide is derived from a constructor, it should take the same host/device attribute as the originating constructor. This matches nvcc behavior.

@llvmbot llvmbot added clang Clang issues not falling into any other category clang:frontend Language frontend issues, e.g. anything involving "Sema" labels Oct 17, 2023
@llvmbot
Copy link
Collaborator

llvmbot commented Oct 17, 2023

@llvm/pr-subscribers-clang

Author: Yaxun (Sam) Liu (yxsamliu)

Changes

Currently clang assumes implicit deduction guide to be host device. This generates two identical implicit deduction guides when a class have a device and a host constructor which have the same input parameter and cause ambiguity.

Since an implicit deduction guide is derived from a constructor, it should take the same host/device attribute as the originating constructor. This matches nvcc behavior.


Full diff: https://github.com/llvm/llvm-project/pull/69366.diff

7 Files Affected:

  • (modified) clang/docs/HIPSupport.rst (+55)
  • (modified) clang/include/clang/AST/DeclCXX.h (+1-8)
  • (modified) clang/lib/AST/DeclCXX.cpp (+21)
  • (modified) clang/lib/Sema/SemaCUDA.cpp (+4-1)
  • (modified) clang/lib/Sema/SemaInit.cpp (+6-9)
  • (modified) clang/lib/Sema/SemaTemplate.cpp (+13-5)
  • (added) clang/test/SemaCUDA/deduction-guide.cu (+77)
diff --git a/clang/docs/HIPSupport.rst b/clang/docs/HIPSupport.rst
index 8b4649733a9c777..8a9802e19e6367f 100644
--- a/clang/docs/HIPSupport.rst
+++ b/clang/docs/HIPSupport.rst
@@ -176,3 +176,58 @@ Predefined Macros
    * - ``HIP_API_PER_THREAD_DEFAULT_STREAM``
      - Alias to ``__HIP_API_PER_THREAD_DEFAULT_STREAM__``. Deprecated.
 
+Support for Deduction Guides
+============================
+
+Explicit Deduction Guides
+-------------------------
+
+Explicit deduction guides in HIP can be annotated with either the
+``__host__`` or ``__device__`` attributes. If no attribute is provided,
+it defaults to ``__host__``.
+
+.. code-block:: cpp
+
+   template <typename T>
+   class MyArray {
+       //...
+   };
+
+   template <typename T>
+   MyArray(T)->MyArray<T>;
+
+   __device__ MyArray(float)->MyArray<int>;
+
+   // Uses of the deduction guides
+   MyArray arr1 = 10;      // Uses the default host guide
+   __device__ void foo() {
+       MyArray arr2 = 3.14f; // Uses the device guide
+   }
+
+Implicit Deduction Guides
+-------------------------
+Implicit deduction guides derived from constructors inherit the same host or
+device attributes as the originating constructor.
+
+.. code-block:: cpp
+
+   template <typename T>
+   class MyVector {
+   public:
+       __device__ MyVector(T) { /* ... */ }
+       //...
+   };
+
+   // The implicit deduction guide for MyVector will be `__device__` due to the device constructor
+
+   __device__ void foo() {
+       MyVector vec(42);  // Uses the implicit device guide derived from the constructor
+   }
+
+Availability Checks
+--------------------
+When a deduction guide (either explicit or implicit) is used, HIP checks its
+availability based on its host/device attributes and the context in a similar
+way as checking a function. Utilizing a deduction guide in an incompatible context
+results in a compile-time error.
+
diff --git a/clang/include/clang/AST/DeclCXX.h b/clang/include/clang/AST/DeclCXX.h
index 5eaae6bdd2bc63e..863ced731d42b2f 100644
--- a/clang/include/clang/AST/DeclCXX.h
+++ b/clang/include/clang/AST/DeclCXX.h
@@ -1948,14 +1948,7 @@ class CXXDeductionGuideDecl : public FunctionDecl {
                         ExplicitSpecifier ES,
                         const DeclarationNameInfo &NameInfo, QualType T,
                         TypeSourceInfo *TInfo, SourceLocation EndLocation,
-                        CXXConstructorDecl *Ctor, DeductionCandidate Kind)
-      : FunctionDecl(CXXDeductionGuide, C, DC, StartLoc, NameInfo, T, TInfo,
-                     SC_None, false, false, ConstexprSpecKind::Unspecified),
-        Ctor(Ctor), ExplicitSpec(ES) {
-    if (EndLocation.isValid())
-      setRangeEnd(EndLocation);
-    setDeductionCandidateKind(Kind);
-  }
+                        CXXConstructorDecl *Ctor, DeductionCandidate Kind);
 
   CXXConstructorDecl *Ctor;
   ExplicitSpecifier ExplicitSpec;
diff --git a/clang/lib/AST/DeclCXX.cpp b/clang/lib/AST/DeclCXX.cpp
index 9107525a44f22c2..e0683173e24f440 100644
--- a/clang/lib/AST/DeclCXX.cpp
+++ b/clang/lib/AST/DeclCXX.cpp
@@ -2113,6 +2113,27 @@ ExplicitSpecifier ExplicitSpecifier::getFromDecl(FunctionDecl *Function) {
   }
 }
 
+CXXDeductionGuideDecl::CXXDeductionGuideDecl(
+    ASTContext &C, DeclContext *DC, SourceLocation StartLoc,
+    ExplicitSpecifier ES, const DeclarationNameInfo &NameInfo, QualType T,
+    TypeSourceInfo *TInfo, SourceLocation EndLocation, CXXConstructorDecl *Ctor,
+    DeductionCandidate Kind)
+    : FunctionDecl(CXXDeductionGuide, C, DC, StartLoc, NameInfo, T, TInfo,
+                   SC_None, false, false, ConstexprSpecKind::Unspecified),
+      Ctor(Ctor), ExplicitSpec(ES) {
+  if (EndLocation.isValid())
+    setRangeEnd(EndLocation);
+  setDeductionCandidateKind(Kind);
+  // If Ctor is not nullptr, this deduction guide is implicitly derived from
+  // the ctor, therefore it should have the same host/device attribute.
+  if (Ctor && C.getLangOpts().CUDA) {
+    if (Ctor->hasAttr<CUDAHostAttr>())
+      this->addAttr(CUDAHostAttr::CreateImplicit(C));
+    if (Ctor->hasAttr<CUDADeviceAttr>())
+      this->addAttr(CUDADeviceAttr::CreateImplicit(C));
+  }
+}
+
 CXXDeductionGuideDecl *CXXDeductionGuideDecl::Create(
     ASTContext &C, DeclContext *DC, SourceLocation StartLoc,
     ExplicitSpecifier ES, const DeclarationNameInfo &NameInfo, QualType T,
diff --git a/clang/lib/Sema/SemaCUDA.cpp b/clang/lib/Sema/SemaCUDA.cpp
index d993499cf4a6e6e..d1d59ad1b9fc4b1 100644
--- a/clang/lib/Sema/SemaCUDA.cpp
+++ b/clang/lib/Sema/SemaCUDA.cpp
@@ -149,10 +149,13 @@ Sema::CUDAFunctionTarget Sema::IdentifyCUDATarget(const FunctionDecl *D,
     return CFT_Device;
   } else if (hasAttr<CUDAHostAttr>(D, IgnoreImplicitHDAttr)) {
     return CFT_Host;
-  } else if ((D->isImplicit() || !D->isUserProvided()) &&
+  } else if (!isa<CXXDeductionGuideDecl>(D) &&
+             (D->isImplicit() || !D->isUserProvided()) &&
              !IgnoreImplicitHDAttr) {
     // Some implicit declarations (like intrinsic functions) are not marked.
     // Set the most lenient target on them for maximal flexibility.
+    // Implicit deduction duides are derived from constructors and their
+    // host/device attributes are determined by their originating constructors.
     return CFT_HostDevice;
   }
 
diff --git a/clang/lib/Sema/SemaInit.cpp b/clang/lib/Sema/SemaInit.cpp
index 8f945bc764befa9..12df8853e7dd760 100644
--- a/clang/lib/Sema/SemaInit.cpp
+++ b/clang/lib/Sema/SemaInit.cpp
@@ -8872,15 +8872,12 @@ ExprResult InitializationSequence::Perform(Sema &S,
           return ExprError();
 
         // Build an expression that constructs a temporary.
-        CurInit = S.BuildCXXConstructExpr(Loc, Step->Type,
-                                          FoundFn, Constructor,
-                                          ConstructorArgs,
-                                          HadMultipleCandidates,
-                                          /*ListInit*/ false,
-                                          /*StdInitListInit*/ false,
-                                          /*ZeroInit*/ false,
-                                          CXXConstructExpr::CK_Complete,
-                                          SourceRange());
+        CurInit = S.BuildCXXConstructExpr(
+            Kind.getLocation(), Step->Type, FoundFn, Constructor,
+            ConstructorArgs, HadMultipleCandidates,
+            /*ListInit*/ false,
+            /*StdInitListInit*/ false,
+            /*ZeroInit*/ false, CXXConstructExpr::CK_Complete, SourceRange());
         if (CurInit.isInvalid())
           return ExprError();
 
diff --git a/clang/lib/Sema/SemaTemplate.cpp b/clang/lib/Sema/SemaTemplate.cpp
index 6389ec708bf34ae..0b854f06a95743b 100644
--- a/clang/lib/Sema/SemaTemplate.cpp
+++ b/clang/lib/Sema/SemaTemplate.cpp
@@ -2685,19 +2685,27 @@ void Sema::DeclareImplicitDeductionGuides(TemplateDecl *Template,
     AddedAny = true;
   }
 
+  // Build simple deduction guide and set CUDA host/device attributes.
+  auto BuildSimpleDeductionGuide = [&](auto T) {
+    auto *DG = cast<CXXDeductionGuideDecl>(
+        cast<FunctionTemplateDecl>(Transform.buildSimpleDeductionGuide(T))
+            ->getTemplatedDecl());
+    if (LangOpts.CUDA) {
+      DG->addAttr(CUDAHostAttr::CreateImplicit(getASTContext()));
+      DG->addAttr(CUDADeviceAttr::CreateImplicit(getASTContext()));
+    }
+    return DG;
+  };
   // C++17 [over.match.class.deduct]
   //    --  If C is not defined or does not declare any constructors, an
   //    additional function template derived as above from a hypothetical
   //    constructor C().
   if (!AddedAny)
-    Transform.buildSimpleDeductionGuide(std::nullopt);
+    BuildSimpleDeductionGuide(std::nullopt);
 
   //    -- An additional function template derived as above from a hypothetical
   //    constructor C(C), called the copy deduction candidate.
-  cast<CXXDeductionGuideDecl>(
-      cast<FunctionTemplateDecl>(
-          Transform.buildSimpleDeductionGuide(Transform.DeducedType))
-          ->getTemplatedDecl())
+  BuildSimpleDeductionGuide(Transform.DeducedType)
       ->setDeductionCandidateKind(DeductionCandidate::Copy);
 }
 
diff --git a/clang/test/SemaCUDA/deduction-guide.cu b/clang/test/SemaCUDA/deduction-guide.cu
new file mode 100644
index 000000000000000..df69979de6de72e
--- /dev/null
+++ b/clang/test/SemaCUDA/deduction-guide.cu
@@ -0,0 +1,77 @@
+// RUN: %clang_cc1 -fsyntax-only -verify=expected,host %s
+// RUN: %clang_cc1 -fcuda-is-device -fsyntax-only -verify=expected,dev %s
+
+#include "Inputs/cuda.h"
+
+// Implicit deduction guide for host.
+template <typename T>
+struct HGuideImp {       // expected-note {{candidate template ignored: could not match 'HGuideImp<T>' against 'int'}}
+   HGuideImp(T value) {} // expected-note {{candidate function not viable: call to __host__ function from __device__ function}}
+                         // dev-note@-1 {{'<deduction guide for HGuideImp><int>' declared here}}
+};
+
+// Explicit deduction guide for host.
+template <typename T>
+struct HGuideExp {       // expected-note {{candidate template ignored: could not match 'HGuideExp<T>' against 'int'}}
+   HGuideExp(T value) {} // expected-note {{candidate function not viable: call to __host__ function from __device__ function}}
+};
+template<typename T>
+HGuideExp(T) -> HGuideExp<T>; // expected-note {{candidate function not viable: call to __host__ function from __device__ function}}
+                              // dev-note@-1 {{'<deduction guide for HGuideExp><int>' declared here}}
+
+// Implicit deduction guide for device.
+template <typename T>
+struct DGuideImp {                  // expected-note {{candidate template ignored: could not match 'DGuideImp<T>' against 'int'}}
+   __device__ DGuideImp(T value) {} // expected-note {{candidate function not viable: call to __device__ function from __host__ function}}
+                                    // host-note@-1 {{'<deduction guide for DGuideImp><int>' declared here}}
+};
+
+// Explicit deduction guide for device.
+template <typename T>
+struct DGuideExp {                   // expected-note {{candidate template ignored: could not match 'DGuideExp<T>' against 'int'}}
+   __device__ DGuideExp(T value) {}  // expected-note {{candidate function not viable: call to __device__ function from __host__ function}}
+};
+
+template<typename T>
+__device__ DGuideExp(T) -> DGuideExp<T>; // expected-note {{candidate function not viable: call to __device__ function from __host__ function}}
+                                         // host-note@-1 {{'<deduction guide for DGuideExp><int>' declared here}}
+
+template <typename T>
+struct HDGuide {
+   __device__ HDGuide(T value) {}
+   HDGuide(T value) {}
+};
+
+template<typename T>
+HDGuide(T) -> HDGuide<T>;
+
+template<typename T>
+__device__ HDGuide(T) -> HDGuide<T>;
+
+void hfun() {
+    HGuideImp hgi = 10;
+    HGuideExp hge = 10;
+    DGuideImp dgi = 10; // expected-error {{no viable constructor or deduction guide for deduction of template arguments of 'DGuideImp'}}
+    DGuideExp dge = 10; // expected-error {{no viable constructor or deduction guide for deduction of template arguments of 'DGuideExp'}}
+    HDGuide hdg = 10;
+}
+
+__device__ void dfun() {
+    HGuideImp hgi = 10; // expected-error {{no viable constructor or deduction guide for deduction of template arguments of 'HGuideImp'}}
+    HGuideExp hge = 10; // expected-error {{no viable constructor or deduction guide for deduction of template arguments of 'HGuideExp'}}
+    DGuideImp dgi = 10;
+    DGuideExp dge = 10;
+    HDGuide hdg = 10;
+}
+
+__host__ __device__ void hdfun() {
+    HGuideImp hgi = 10; // dev-error {{reference to __host__ function '<deduction guide for HGuideImp><int>' in __host__ __device__ function}}
+    HGuideExp hge = 10; // dev-error {{reference to __host__ function '<deduction guide for HGuideExp><int>' in __host__ __device__ function}}
+    DGuideImp dgi = 10; // host-error {{reference to __device__ function '<deduction guide for DGuideImp><int>' in __host__ __device__ function}}
+    DGuideExp dge = 10; // host-error {{reference to __device__ function '<deduction guide for DGuideExp><int>' in __host__ __device__ function}}
+    HDGuide hdg = 10;
+}
+
+HGuideImp hgi = 10;
+HGuideExp hge = 10;
+HDGuide hdg = 10;

Currently clang assumes implicit deduction guide to be host
device. This generates two identical implicit deduction
guides when a class have a device and a host constructor
which have the same input parameter, which causes ambiguity.

Since an implicit deduction guide is derived from a constructor,
it should take the same host/device attribute as the originating
constructor. This matches nvcc behavior as seen in
https://godbolt.org/z/sY1vdYWKe and https://godbolt.org/z/vTer7xa3j
@yxsamliu
Copy link
Collaborator Author

nvcc behavior can be seen here

https://godbolt.org/z/sY1vdYWKe

https://godbolt.org/z/vTer7xa3j

@Artem-B
Copy link
Member

Artem-B commented Oct 30, 2023

@ldionne - Can you take a look if that would have unintended consequences for libc++?

@ldionne
Copy link
Member

ldionne commented Oct 31, 2023

@ldionne - Can you take a look if that would have unintended consequences for libc++?

Honestly, I don't know. I don't know CUDA nearly well enough to understand all the implications here. All I know is that this seems to be a pretty significant "fork" of C++ in terms of its semantics, and the likelihood that everything will just happen to work as designed is kinda small (but hopefully it does). In my (uneducated) opinion, host vs device should probably be handled closer to a link-time failure. That way you'd steer clear of any complicated front-end concepts like SFINAE, overload resolution and all the stuff that is incredibly complicated in C++. If you modify any of the rules there, the likelihood of introducing issues is really large IMO.

@tahonermann
Copy link
Contributor

For what it is worth, the described behavior sounds right to me from a design perspective. The fact that it matches nvcc behavior is a very good hint that it is the desired behavior as well. I haven't reviewed the code changes, but as long as they implement what is described, I'd give this a thumbs up.

@yxsamliu
Copy link
Collaborator Author

yxsamliu commented Nov 15, 2023

@ldionne - Can you take a look if that would have unintended consequences for libc++?

Honestly, I don't know. I don't know CUDA nearly well enough to understand all the implications here. All I know is that this seems to be a pretty significant "fork" of C++ in terms of its semantics, and the likelihood that everything will just happen to work as designed is kinda small (but hopefully it does). In my (uneducated) opinion, host vs device should probably be handled closer to a link-time failure. That way you'd steer clear of any complicated front-end concepts like SFINAE, overload resolution and all the stuff that is incredibly complicated in C++. If you modify any of the rules there, the likelihood of introducing issues is really large IMO.

I do agree that the further we could defer host/device-based overloading resolution the better. However, I doubt we could avoid host/device-based overloading resolution without breaking the existing CUDA/HIP code.

The reason is that we need to have correct overloading resolution to create the correct AST, especially when there is template instantiation. When we resolve overloaded functions, the host function candidate and device function candidate can have different signature. If we do not consider host/device attributes, we could end up calling a host function on device side if it has better match for argument types. Then the subsequent AST creation is all wrong.

To be able to avoid the host/device-based overloading resolution, we have to restrict overloading so that ignoring host/device-attributes do not affect the created AST. For example, we only allow host device functions, or we request each host function must have a corresponding device function with the same signature. We could add an extension for CUDA/HIP to request host/device overloading satisfy this restriction. I can see lots of things can be simplified with this extension.

However, for normal CUDA/HIP code, I don't think we can avoid host/device-based overloading resolution.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
clang:frontend Language frontend issues, e.g. anything involving "Sema" clang Clang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants