Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions clang/docs/HIPSupport.rst
Original file line number Diff line number Diff line change
Expand Up @@ -176,3 +176,58 @@ Predefined Macros
* - ``HIP_API_PER_THREAD_DEFAULT_STREAM``
- Alias to ``__HIP_API_PER_THREAD_DEFAULT_STREAM__``. Deprecated.

Support for Deduction Guides
============================

Explicit Deduction Guides
-------------------------

Explicit deduction guides in HIP can be annotated with either the
``__host__`` or ``__device__`` attributes. If no attribute is provided,
it defaults to ``__host__``.

.. code-block:: cpp

template <typename T>
class MyArray {
//...
};

template <typename T>
MyArray(T)->MyArray<T>;

__device__ MyArray(float)->MyArray<int>;

// Uses of the deduction guides
MyArray arr1 = 10; // Uses the default host guide
__device__ void foo() {
MyArray arr2 = 3.14f; // Uses the device guide
}

Implicit Deduction Guides
-------------------------
Implicit deduction guides derived from constructors inherit the same host or
device attributes as the originating constructor.

.. code-block:: cpp

template <typename T>
class MyVector {
public:
__device__ MyVector(T) { /* ... */ }
//...
};

// The implicit deduction guide for MyVector will be `__device__` due to the device constructor

__device__ void foo() {
MyVector vec(42); // Uses the implicit device guide derived from the constructor
}

Availability Checks
--------------------
When a deduction guide (either explicit or implicit) is used, HIP checks its
availability based on its host/device attributes and the context in a similar
way as checking a function. Utilizing a deduction guide in an incompatible context
results in a compile-time error.

9 changes: 1 addition & 8 deletions clang/include/clang/AST/DeclCXX.h
Original file line number Diff line number Diff line change
Expand Up @@ -1948,14 +1948,7 @@ class CXXDeductionGuideDecl : public FunctionDecl {
ExplicitSpecifier ES,
const DeclarationNameInfo &NameInfo, QualType T,
TypeSourceInfo *TInfo, SourceLocation EndLocation,
CXXConstructorDecl *Ctor, DeductionCandidate Kind)
: FunctionDecl(CXXDeductionGuide, C, DC, StartLoc, NameInfo, T, TInfo,
SC_None, false, false, ConstexprSpecKind::Unspecified),
Ctor(Ctor), ExplicitSpec(ES) {
if (EndLocation.isValid())
setRangeEnd(EndLocation);
setDeductionCandidateKind(Kind);
}
CXXConstructorDecl *Ctor, DeductionCandidate Kind);

CXXConstructorDecl *Ctor;
ExplicitSpecifier ExplicitSpec;
Expand Down
21 changes: 21 additions & 0 deletions clang/lib/AST/DeclCXX.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2113,6 +2113,27 @@ ExplicitSpecifier ExplicitSpecifier::getFromDecl(FunctionDecl *Function) {
}
}

CXXDeductionGuideDecl::CXXDeductionGuideDecl(
ASTContext &C, DeclContext *DC, SourceLocation StartLoc,
ExplicitSpecifier ES, const DeclarationNameInfo &NameInfo, QualType T,
TypeSourceInfo *TInfo, SourceLocation EndLocation, CXXConstructorDecl *Ctor,
DeductionCandidate Kind)
: FunctionDecl(CXXDeductionGuide, C, DC, StartLoc, NameInfo, T, TInfo,
SC_None, false, false, ConstexprSpecKind::Unspecified),
Ctor(Ctor), ExplicitSpec(ES) {
if (EndLocation.isValid())
setRangeEnd(EndLocation);
setDeductionCandidateKind(Kind);
// If Ctor is not nullptr, this deduction guide is implicitly derived from
// the ctor, therefore it should have the same host/device attribute.
if (Ctor && C.getLangOpts().CUDA) {
if (Ctor->hasAttr<CUDAHostAttr>())
this->addAttr(CUDAHostAttr::CreateImplicit(C));
if (Ctor->hasAttr<CUDADeviceAttr>())
this->addAttr(CUDADeviceAttr::CreateImplicit(C));
}
}

CXXDeductionGuideDecl *CXXDeductionGuideDecl::Create(
ASTContext &C, DeclContext *DC, SourceLocation StartLoc,
ExplicitSpecifier ES, const DeclarationNameInfo &NameInfo, QualType T,
Expand Down
5 changes: 4 additions & 1 deletion clang/lib/Sema/SemaCUDA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,13 @@ Sema::CUDAFunctionTarget Sema::IdentifyCUDATarget(const FunctionDecl *D,
return CFT_Device;
} else if (hasAttr<CUDAHostAttr>(D, IgnoreImplicitHDAttr)) {
return CFT_Host;
} else if ((D->isImplicit() || !D->isUserProvided()) &&
} else if (!isa<CXXDeductionGuideDecl>(D) &&
(D->isImplicit() || !D->isUserProvided()) &&
!IgnoreImplicitHDAttr) {
// Some implicit declarations (like intrinsic functions) are not marked.
// Set the most lenient target on them for maximal flexibility.
// Implicit deduction duides are derived from constructors and their
// host/device attributes are determined by their originating constructors.
return CFT_HostDevice;
}

Expand Down
18 changes: 13 additions & 5 deletions clang/lib/Sema/SemaTemplate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2685,19 +2685,27 @@ void Sema::DeclareImplicitDeductionGuides(TemplateDecl *Template,
AddedAny = true;
}

// Build simple deduction guide and set CUDA host/device attributes.
auto BuildSimpleDeductionGuide = [&](auto T) {
auto *DG = cast<CXXDeductionGuideDecl>(
cast<FunctionTemplateDecl>(Transform.buildSimpleDeductionGuide(T))
->getTemplatedDecl());
if (LangOpts.CUDA) {
DG->addAttr(CUDAHostAttr::CreateImplicit(getASTContext()));
DG->addAttr(CUDADeviceAttr::CreateImplicit(getASTContext()));
}
return DG;
};
// C++17 [over.match.class.deduct]
// -- If C is not defined or does not declare any constructors, an
// additional function template derived as above from a hypothetical
// constructor C().
if (!AddedAny)
Transform.buildSimpleDeductionGuide(std::nullopt);
BuildSimpleDeductionGuide(std::nullopt);

// -- An additional function template derived as above from a hypothetical
// constructor C(C), called the copy deduction candidate.
cast<CXXDeductionGuideDecl>(
cast<FunctionTemplateDecl>(
Transform.buildSimpleDeductionGuide(Transform.DeducedType))
->getTemplatedDecl())
BuildSimpleDeductionGuide(Transform.DeducedType)
->setDeductionCandidateKind(DeductionCandidate::Copy);
}

Expand Down
85 changes: 85 additions & 0 deletions clang/test/SemaCUDA/deduction-guide.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
// RUN: %clang_cc1 -fsyntax-only -verify=expected,host %s
// RUN: %clang_cc1 -fcuda-is-device -fsyntax-only -verify=expected,dev %s

#include "Inputs/cuda.h"

// Implicit deduction guide for host.
template <typename T>
struct HGuideImp { // expected-note {{candidate template ignored: could not match 'HGuideImp<T>' against 'int'}}
HGuideImp(T value) {} // expected-note {{candidate function not viable: call to __host__ function from __device__ function}}
// dev-note@-1 {{'<deduction guide for HGuideImp><int>' declared here}}
// dev-note@-2 {{'HGuideImp' declared here}}
};

// Explicit deduction guide for host.
template <typename T>
struct HGuideExp { // expected-note {{candidate template ignored: could not match 'HGuideExp<T>' against 'int'}}
HGuideExp(T value) {} // expected-note {{candidate function not viable: call to __host__ function from __device__ function}}
// dev-note@-1 {{'HGuideExp' declared here}}
};
template<typename T>
HGuideExp(T) -> HGuideExp<T>; // expected-note {{candidate function not viable: call to __host__ function from __device__ function}}
// dev-note@-1 {{'<deduction guide for HGuideExp><int>' declared here}}

// Implicit deduction guide for device.
template <typename T>
struct DGuideImp { // expected-note {{candidate template ignored: could not match 'DGuideImp<T>' against 'int'}}
__device__ DGuideImp(T value) {} // expected-note {{candidate function not viable: call to __device__ function from __host__ function}}
// host-note@-1 {{'<deduction guide for DGuideImp><int>' declared here}}
// host-note@-2 {{'DGuideImp' declared here}}
};

// Explicit deduction guide for device.
template <typename T>
struct DGuideExp { // expected-note {{candidate template ignored: could not match 'DGuideExp<T>' against 'int'}}
__device__ DGuideExp(T value) {} // expected-note {{candidate function not viable: call to __device__ function from __host__ function}}
// host-note@-1 {{'DGuideExp' declared here}}
};

template<typename T>
__device__ DGuideExp(T) -> DGuideExp<T>; // expected-note {{candidate function not viable: call to __device__ function from __host__ function}}
// host-note@-1 {{'<deduction guide for DGuideExp><int>' declared here}}

template <typename T>
struct HDGuide {
__device__ HDGuide(T value) {}
HDGuide(T value) {}
};

template<typename T>
HDGuide(T) -> HDGuide<T>;

template<typename T>
__device__ HDGuide(T) -> HDGuide<T>;

void hfun() {
HGuideImp hgi = 10;
HGuideExp hge = 10;
DGuideImp dgi = 10; // expected-error {{no viable constructor or deduction guide for deduction of template arguments of 'DGuideImp'}}
DGuideExp dge = 10; // expected-error {{no viable constructor or deduction guide for deduction of template arguments of 'DGuideExp'}}
HDGuide hdg = 10;
}

__device__ void dfun() {
HGuideImp hgi = 10; // expected-error {{no viable constructor or deduction guide for deduction of template arguments of 'HGuideImp'}}
HGuideExp hge = 10; // expected-error {{no viable constructor or deduction guide for deduction of template arguments of 'HGuideExp'}}
DGuideImp dgi = 10;
DGuideExp dge = 10;
HDGuide hdg = 10;
}

__host__ __device__ void hdfun() {
HGuideImp hgi = 10; // dev-error {{reference to __host__ function '<deduction guide for HGuideImp><int>' in __host__ __device__ function}}
// dev-error@-1 {{reference to __host__ function 'HGuideImp' in __host__ __device__ function}}
HGuideExp hge = 10; // dev-error {{reference to __host__ function '<deduction guide for HGuideExp><int>' in __host__ __device__ function}}
// dev-error@-1 {{reference to __host__ function 'HGuideExp' in __host__ __device__ function}}
DGuideImp dgi = 10; // host-error {{reference to __device__ function '<deduction guide for DGuideImp><int>' in __host__ __device__ function}}
// host-error@-1 {{reference to __device__ function 'DGuideImp' in __host__ __device__ function}}
DGuideExp dge = 10; // host-error {{reference to __device__ function '<deduction guide for DGuideExp><int>' in __host__ __device__ function}}
// host-error@-1 {{reference to __device__ function 'DGuideExp' in __host__ __device__ function}}
HDGuide hdg = 10;
}

HGuideImp hgi = 10;
HGuideExp hge = 10;
HDGuide hdg = 10;