Skip to content

Commit db344b9

Browse files
committed
[CUDA][HIP] Fix CTAD for host/device constructors
Currently Clang does not allow using CTAD in CUDA/HIP device functions since deduction guides are treated as host functions. This patch fixes that by treating deduction guides as host+device. The rationale is that deduction guides do not actually generate code in IR, and there is an existing check for device/host correctness for constructors. Also suppress duplicate implicit deduction guides from host/device constructors with identical signatures to prevent ambiguity. This ensures CTAD works correctly in CUDA/HIP for constructors with different target attributes. Example: ``` #include <tuple> __host__ __device__ void func() { std::tuple<int, int> t = std::tuple(1, 1); } ``` This compiles with nvcc but fails with clang for CUDA/HIP. Reference: https://godbolt.org/z/WhT1GrhWE Fixes: ROCm/ROCm#5646 Fixes: #146646
1 parent 2fc42c7 commit db344b9

File tree

4 files changed

+125
-2
lines changed

4 files changed

+125
-2
lines changed

clang/docs/HIPSupport.rst

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,46 @@ Example Usage
287287
basePtr->virtualFunction(); // Allowed since obj is constructed in device code
288288
}
289289

290+
C++17 Class Template Argument Deduction (CTAD) Support
291+
======================================================
292+
293+
Clang supports C++17 Class Template Argument Deduction (CTAD) in both host and device code for HIP.
294+
This allows you to omit template arguments when creating class template instances, letting the compiler
295+
deduce them from constructor arguments.
296+
297+
.. code-block:: c++
298+
299+
#include <tuple>
300+
301+
__host__ __device__ void func() {
302+
std::tuple<int, int> t = std::tuple(1, 1);
303+
}
304+
305+
In the above example, ``std::tuple(1, 1)`` automatically deduces the type to be ``std::tuple<int, int>``.
306+
307+
Deduction Guides
308+
----------------
309+
310+
User-defined deduction guides are also supported. Since deduction guides are not executable code and only
311+
participate in type deduction, they are treated as ``__host__ __device__`` by the compiler, regardless of
312+
explicit target attributes. This ensures they are available for deduction in both host and device contexts.
313+
314+
.. code-block:: c++
315+
316+
template <typename T>
317+
struct MyType {
318+
T value;
319+
MyType(T v) : value(v) {}
320+
};
321+
322+
// User-defined deduction guide
323+
template <typename T>
324+
MyType(T) -> MyType<T>;
325+
326+
__device__ void deviceFunc() {
327+
MyType m(10); // Deduces MyType<int>
328+
}
329+
290330
Host and Device Attributes of Default Destructors
291331
===================================================
292332

clang/lib/Sema/SemaCUDA.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,12 @@ CUDAFunctionTarget SemaCUDA::IdentifyTarget(const FunctionDecl *D,
137137
if (D == nullptr)
138138
return CurCUDATargetCtx.Target;
139139

140+
// C++ deduction guides are never codegen'ed and only participate in template
141+
// argument deduction. Treat them as if they were always host+device so that
142+
// CUDA/HIP target checking never rejects their use based solely on target.
143+
if (isa<CXXDeductionGuideDecl>(D))
144+
return CUDAFunctionTarget::HostDevice;
145+
140146
if (D->hasAttr<CUDAInvalidTargetAttr>())
141147
return CUDAFunctionTarget::InvalidTarget;
142148

@@ -907,6 +913,12 @@ bool SemaCUDA::CheckCall(SourceLocation Loc, FunctionDecl *Callee) {
907913
if (ExprEvalCtx.isUnevaluated() || ExprEvalCtx.isConstantEvaluated())
908914
return true;
909915

916+
// C++ deduction guides participate in overload resolution but are not
917+
// callable functions and are never codegen'ed. Treat them as always
918+
// allowed for CUDA/HIP compatibility checking.
919+
if (isa<CXXDeductionGuideDecl>(Callee))
920+
return true;
921+
910922
// FIXME: Is bailing out early correct here? Should we instead assume that
911923
// the caller is a global initializer?
912924
FunctionDecl *Caller = SemaRef.getCurFunctionDecl(/*AllowLambda=*/true);

clang/lib/Sema/SemaTemplateDeductionGuide.cpp

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,9 +218,33 @@ buildDeductionGuide(Sema &SemaRef, TemplateDecl *OriginalTemplate,
218218
TInfo->getTypeLoc().castAs<FunctionProtoTypeLoc>().getParams();
219219

220220
// Build the implicit deduction guide template.
221+
QualType GuideType = TInfo->getType();
222+
223+
// In CUDA/HIP mode, avoid creating duplicate implicit deduction guides with
224+
// identical function types. This can happen when there are separate
225+
// __host__ and __device__ constructors with the same signature; each would
226+
// otherwise synthesize its own implicit deduction guide, leading to
227+
// ambiguous CTAD purely due to target attributes. For such cases we keep the
228+
// first guide we created and skip building another one.
229+
if (IsImplicit && Ctor && SemaRef.getLangOpts().CUDA) {
230+
for (NamedDecl *Existing : DC->lookup(DeductionGuideName)) {
231+
auto *ExistingFT = dyn_cast<FunctionTemplateDecl>(Existing);
232+
auto *ExistingGuide =
233+
ExistingFT
234+
? dyn_cast<CXXDeductionGuideDecl>(ExistingFT->getTemplatedDecl())
235+
: dyn_cast<CXXDeductionGuideDecl>(Existing);
236+
if (!ExistingGuide)
237+
continue;
238+
239+
if (SemaRef.Context.hasSameType(ExistingGuide->getType(), GuideType)) {
240+
return Existing;
241+
}
242+
}
243+
}
244+
221245
auto *Guide = CXXDeductionGuideDecl::Create(
222-
SemaRef.Context, DC, LocStart, ES, Name, TInfo->getType(), TInfo, LocEnd,
223-
Ctor, DeductionCandidate::Normal, FunctionTrailingRC);
246+
SemaRef.Context, DC, LocStart, ES, Name, GuideType, TInfo, LocEnd, Ctor,
247+
DeductionCandidate::Normal, FunctionTrailingRC);
224248
Guide->setImplicit(IsImplicit);
225249
Guide->setParams(Params);
226250

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
// RUN: %clang_cc1 -std=c++17 -triple nvptx64-nvidia-cuda -fsyntax-only \
2+
// RUN: -fcuda-is-device -verify=expected,dev %s
3+
// RUN: %clang_cc1 -std=c++17 -triple nvptx64-nvidia-cuda -fsyntax-only \
4+
// RUN: -verify %s
5+
6+
#include "Inputs/cuda.h"
7+
8+
template <class T>
9+
struct CTADType { // expected-note 2{{candidate constructor (the implicit copy constructor) not viable: requires 1 argument, but 3 were provided}}
10+
// expected-note@-1 2{{candidate constructor (the implicit move constructor) not viable: requires 1 argument, but 3 were provided}}
11+
T first;
12+
T second;
13+
14+
CTADType(T x) : first(x), second(x) {} // expected-note 2{{candidate constructor not viable: requires single argument 'x', but 3 arguments were provided}}
15+
__device__ CTADType(T x) : first(x), second(x) {} // expected-note 2{{candidate constructor not viable: requires single argument 'x', but 3 arguments were provided}}
16+
__host__ __device__ CTADType(T x, T y) : first(x), second(y) {} // expected-note 2{{candidate constructor not viable: requires 2 arguments, but 3 were provided}}
17+
CTADType(T x, T y, T z) : first(x), second(z) {} // dev-note {{'CTADType' declared here}}
18+
// expected-note@-1 {{candidate constructor not viable: call to __host__ function from __device__ function}}
19+
// expected-note@-2 {{candidate constructor not viable: call to __host__ function from __global__ function}}
20+
};
21+
22+
template <class T>
23+
CTADType(T, T) -> CTADType<T>;
24+
25+
__host__ __device__ void use_ctad_host_device() {
26+
CTADType ctad_from_two_args(1, 1);
27+
CTADType ctad_from_one_arg(1);
28+
CTADType ctad_from_three_args(1, 2, 3); // dev-error {{reference to __host__ function 'CTADType' in __host__ __device__ function}}
29+
}
30+
31+
__host__ void use_ctad_host() {
32+
CTADType ctad_from_two_args(1, 1);
33+
CTADType ctad_from_one_arg(1);
34+
CTADType ctad_from_three_args(1, 2, 3);
35+
}
36+
37+
__device__ void use_ctad_device() {
38+
CTADType ctad_from_two_args(1, 1);
39+
CTADType ctad_from_one_arg(1);
40+
CTADType<int> ctad_from_three_args(1, 2, 3); // expected-error {{no matching constructor for initialization of 'CTADType<int>'}}
41+
}
42+
43+
__global__ void use_ctad_global() {
44+
CTADType ctad_from_two_args(1, 1);
45+
CTADType ctad_from_one_arg(1);
46+
CTADType<int> ctad_from_three_args(1, 2, 3); // expected-error {{no matching constructor for initialization of 'CTADType<int>'}}
47+
}

0 commit comments

Comments
 (0)