Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[HIP] Fix comdat of template kernel handle #66283

Merged
merged 1 commit into from
Sep 14, 2023

Conversation

yxsamliu
Copy link
Collaborator

Currently, clang emits LLVM IR that fails verifier for the following code:

template<typename T>
__global__ void foo(T x);

void bar() {
  foo<<<1, 1>>>(0);
}

This is due to clang putting the kernel handle for foo into comdat, which is not allowed, since the kernel handle is a declaration.

The siutation is similar to calling a declaration-only template function. The callee will be a declaration in LLVM IR and won't be put into comdat. This is in contrast to calling a template function with body, which will be put into comdat.

Fixes: SWDEV-419769

@yxsamliu yxsamliu requested a review from a team as a code owner September 13, 2023 20:05
@llvmbot llvmbot added clang Clang issues not falling into any other category clang:codegen labels Sep 13, 2023
@llvmbot
Copy link
Collaborator

llvmbot commented Sep 13, 2023

@llvm/pr-subscribers-clang-codegen

@llvm/pr-subscribers-clang

Changes Currently, clang emits LLVM IR that fails verifier for the following code:
template<typename T>
__global__ void foo(T x);

void bar() {
  foo<<<1, 1>>>(0);
}

This is due to clang putting the kernel handle for foo into comdat, which is not allowed, since the kernel handle is a declaration.

The siutation is similar to calling a declaration-only template function. The callee will be a declaration in LLVM IR and won't be put into comdat. This is in contrast to calling a template function with body, which will be put into comdat.

Fixes: SWDEV-419769

Full diff: https://github.com/llvm/llvm-project/pull/66283.diff

2 Files Affected:

  • (modified) clang/lib/CodeGen/CGCUDANV.cpp (+4-1)
  • (modified) clang/test/CodeGenCUDA/kernel-stub-name.cu (+12-1)
diff --git a/clang/lib/CodeGen/CGCUDANV.cpp b/clang/lib/CodeGen/CGCUDANV.cpp
index 08769c98dc298a0..0efe7e8db0183fe 100644
--- a/clang/lib/CodeGen/CGCUDANV.cpp
+++ b/clang/lib/CodeGen/CGCUDANV.cpp
@@ -1234,7 +1234,10 @@ llvm::GlobalValue *CGNVCUDARuntime::getKernelHandle(llvm::Function *F,
   Var->setAlignment(CGM.getPointerAlign().getAsAlign());
   Var->setDSOLocal(F->isDSOLocal());
   Var->setVisibility(F->getVisibility());
-  CGM.maybeSetTrivialComdat(*GD.getDecl(), *Var);
+  auto *FD = cast<FunctionDecl>(GD.getDecl());
+  auto *FT = FD->getPrimaryTemplate();
+  if (!FT || FT->isThisDeclarationADefinition())
+    CGM.maybeSetTrivialComdat(*FD, *Var);
   KernelHandles[F->getName()] = Var;
   KernelStubs[Var] = F;
   return Var;
diff --git a/clang/test/CodeGenCUDA/kernel-stub-name.cu b/clang/test/CodeGenCUDA/kernel-stub-name.cu
index 9884046fcd0fd0c..008d66bd590b759 100644
--- a/clang/test/CodeGenCUDA/kernel-stub-name.cu
+++ b/clang/test/CodeGenCUDA/kernel-stub-name.cu
@@ -26,12 +26,13 @@
 // GNU: @[[HNSKERN:_ZN2ns8nskernelEv]] = constant ptr @[[NSSTUB:_ZN2ns23__device_stub__nskernelEv]], align 8
 // GNU: @[[HTKERN:_Z10kernelfuncIiEvv]] = linkonce_odr constant ptr @[[TSTUB:_Z25__device_stub__kernelfuncIiEvv]], comdat, align 8
 // GNU: @[[HDKERN:_Z11kernel_declv]] = external constant ptr, align 8
+// GNU: @[[HTDKERN:_Z16temp_kernel_declIiEvT_]] = external constant ptr, align 8
 
 // MSVC: @[[HCKERN:ckernel]] = dso_local constant ptr @[[CSTUB:__device_stub__ckernel]], align 8
 // MSVC: @[[HNSKERN:"\?nskernel@ns@@YAXXZ.*"]] = dso_local constant ptr @[[NSSTUB:"\?__device_stub__nskernel@ns@@YAXXZ"]], align 8
 // MSVC: @[[HTKERN:"\?\?\$kernelfunc@H@@YAXXZ.*"]] = linkonce_odr dso_local constant ptr @[[TSTUB:"\?\?\$__device_stub__kernelfunc@H@@YAXXZ.*"]], comdat, align 8
 // MSVC: @[[HDKERN:"\?kernel_decl@@YAXXZ.*"]] = external dso_local constant ptr, align 8
-
+// MSVC: @[[HTDKERN:"\?\?\$temp_kernel_decl@H@@YAXH.*"]] = external dso_local constant ptr, align 8
 extern "C" __global__ void ckernel() {}
 
 namespace ns {
@@ -43,6 +44,9 @@ __global__ void kernelfunc() {}
 
 __global__ void kernel_decl();
 
+template<class T>
+__global__ void temp_kernel_decl(T x);
+
 extern "C" void (*kernel_ptr)();
 extern "C" void *void_ptr;
 
@@ -69,13 +73,16 @@ extern "C" void launch(void *kern);
 // CHECK: call void @[[NSSTUB]]()
 // CHECK: call void @[[TSTUB]]()
 // GNU: call void @[[DSTUB:_Z26__device_stub__kernel_declv]]()
+// GNU: call void @[[TDSTUB:_Z31__device_stub__temp_kernel_declIiEvT_]](
 // MSVC: call void @[[DSTUB:"\?__device_stub__kernel_decl@@YAXXZ"]]()
+// MSVC: call void @[[TDSTUB:"\?\?\$__device_stub__temp_kernel_decl@H@@YAXH@Z"]](
 
 extern "C" void fun1(void) {
   ckernel<<<1, 1>>>();
   ns::nskernel<<<1, 1>>>();
   kernelfunc<int><<<1, 1>>>();
   kernel_decl<<<1, 1>>>();
+  temp_kernel_decl<<<1, 1>>>(1);
 }
 
 // Template kernel stub functions
@@ -86,6 +93,7 @@ extern "C" void fun1(void) {
 // Check declaration of stub function for external kernel.
 
 // CHECK: declare{{.*}}@[[DSTUB]]
+// CHECK: declare{{.*}}@[[TDSTUB]]
 
 // Check kernel handle is used for passing the kernel as a function pointer.
 
@@ -94,11 +102,13 @@ extern "C" void fun1(void) {
 // CHECK: call void @launch({{.*}}[[HNSKERN]]
 // CHECK: call void @launch({{.*}}[[HTKERN]]
 // CHECK: call void @launch({{.*}}[[HDKERN]]
+// CHECK: call void @launch({{.*}}[[HTDKERN]]
 extern "C" void fun2() {
   launch((void *)ckernel);
   launch((void *)ns::nskernel);
   launch((void *)kernelfunc<int>);
   launch((void *)kernel_decl);
+  launch((void *)temp_kernel_decl<int>);
 }
 
 // Check kernel handle is used for assigning a kernel to a function pointer.
@@ -148,3 +158,4 @@ extern "C" void fun5() {
 // CHECK: call{{.*}}@__hipRegisterFunction{{.*}}@[[HTKERN]]{{.*}}@[[TKERN]]
 // NEG-NOT: call{{.*}}@__hipRegisterFunction{{.*}}__device_stub
 // NEG-NOT: call{{.*}}@__hipRegisterFunction{{.*}}kernel_decl
+// NEG-NOT: call{{.*}}@__hipRegisterFunction{{.*}}temp_kernel_decl

@@ -43,6 +44,9 @@ __global__ void kernelfunc() {}

__global__ void kernel_decl();

template<class T>
__global__ void temp_kernel_decl(T x);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: rename temp -> template? temp is strongly associated with 'temporary'.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will do

Currently, clang emits LLVM IR that fails verifier for the
following code:

```
template<typename T>
__global__ void foo(T x);

void bar() {
  foo<<<1, 1>>>(0);
}
```
This is due to clang putting the kernel handle for foo
into comdat, which is not allowed, since the kernel handle
is a declaration.

The siutation is similar to calling a declaration-only
template function. The callee will be a declaration in LLVM
IR and won't be put into comdat. This is in contrast to calling
a template function with body, which will be put into comdat.

Fixes: SWDEV-419769
@yxsamliu yxsamliu merged commit d7e1932 into llvm:main Sep 14, 2023
2 checks passed
ZijunZhaoCCK pushed a commit to ZijunZhaoCCK/llvm-project that referenced this pull request Sep 19, 2023
Currently, clang emits LLVM IR that fails verifier for the following
code:

```
template<typename T>
__global__ void foo(T x);

void bar() {
  foo<<<1, 1>>>(0);
}
```
This is due to clang putting the kernel handle for foo into comdat,
which is not allowed, since the kernel handle is a declaration.

The siutation is similar to calling a declaration-only template
function. The callee will be a declaration in LLVM IR and won't be put
into comdat. This is in contrast to calling a template function with
body, which will be put into comdat.

Fixes: SWDEV-419769
searlmc1 pushed a commit to ROCm/llvm-project that referenced this pull request Oct 13, 2023
Currently, clang emits LLVM IR that fails verifier for the following
code:

```
template<typename T>
__global__ void foo(T x);

void bar() {
  foo<<<1, 1>>>(0);
}
```
This is due to clang putting the kernel handle for foo into comdat,
which is not allowed, since the kernel handle is a declaration.

The siutation is similar to calling a declaration-only template
function. The callee will be a declaration in LLVM IR and won't be put
into comdat. This is in contrast to calling a template function with
body, which will be put into comdat.

Fixes: SWDEV-419769
Change-Id: Ib7e560c95fdc7c46e3799074af23e008f35d9430
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
clang:codegen clang Clang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants