77 changes: 34 additions & 43 deletions llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -537,59 +537,50 @@ void NVPTXAsmPrinter::emitKernelFunctionDirectives(const Function &F,
raw_ostream &O) const {
// If the NVVM IR has some of reqntid* specified, then output
// the reqntid directive, and set the unspecified ones to 1.
// If none of reqntid* is specified, don't output reqntid directive.
unsigned reqntidx, reqntidy, reqntidz;
bool specified = false;
if (!getReqNTIDx(F, reqntidx))
reqntidx = 1;
else
specified = true;
if (!getReqNTIDy(F, reqntidy))
reqntidy = 1;
else
specified = true;
if (!getReqNTIDz(F, reqntidz))
reqntidz = 1;
else
specified = true;

if (specified)
O << ".reqntid " << reqntidx << ", " << reqntidy << ", " << reqntidz
// If none of Reqntid* is specified, don't output reqntid directive.
unsigned Reqntidx, Reqntidy, Reqntidz;
Reqntidx = Reqntidy = Reqntidz = 1;
bool ReqSpecified = false;
ReqSpecified |= getReqNTIDx(F, Reqntidx);
ReqSpecified |= getReqNTIDy(F, Reqntidy);
ReqSpecified |= getReqNTIDz(F, Reqntidz);

if (ReqSpecified)
O << ".reqntid " << Reqntidx << ", " << Reqntidy << ", " << Reqntidz
<< "\n";

// If the NVVM IR has some of maxntid* specified, then output
// the maxntid directive, and set the unspecified ones to 1.
// If none of maxntid* is specified, don't output maxntid directive.
unsigned maxntidx, maxntidy, maxntidz;
specified = false;
if (!getMaxNTIDx(F, maxntidx))
maxntidx = 1;
else
specified = true;
if (!getMaxNTIDy(F, maxntidy))
maxntidy = 1;
else
specified = true;
if (!getMaxNTIDz(F, maxntidz))
maxntidz = 1;
else
specified = true;

if (specified)
O << ".maxntid " << maxntidx << ", " << maxntidy << ", " << maxntidz
unsigned Maxntidx, Maxntidy, Maxntidz;
Maxntidx = Maxntidy = Maxntidz = 1;
bool MaxSpecified = false;
MaxSpecified |= getMaxNTIDx(F, Maxntidx);
MaxSpecified |= getMaxNTIDy(F, Maxntidy);
MaxSpecified |= getMaxNTIDz(F, Maxntidz);

if (MaxSpecified)
O << ".maxntid " << Maxntidx << ", " << Maxntidy << ", " << Maxntidz
<< "\n";

unsigned mincta;
if (getMinCTASm(F, mincta))
O << ".minnctapersm " << mincta << "\n";
unsigned Mincta = 0;
if (getMinCTASm(F, Mincta))
O << ".minnctapersm " << Mincta << "\n";

unsigned Maxnreg = 0;
if (getMaxNReg(F, Maxnreg))
O << ".maxnreg " << Maxnreg << "\n";

unsigned maxnreg;
if (getMaxNReg(F, maxnreg))
O << ".maxnreg " << maxnreg << "\n";
// .maxclusterrank directive requires SM_90 or higher, make sure that we
// filter it out for lower SM versions, as it causes a hard ptxas crash.
const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
const auto *STI = static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
unsigned Maxclusterrank = 0;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to ignore this directive if the metadata exists, but we're targeting a pre-sm_90 GPU?

It may be useful for non-clang LLVM users (e.g XLA) to be able to always specify launch bounds metadata, and let LLVM decide on what it can do with it. Generating the directive for older GPUs would result in ptxas error, while ignoring it would still allow the kernels to compile and work, the same as would be the case if the metadata was correctly absent. I don't think there's not much point to require users to jump through more hoops just to achieve exactly the same result.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right, ptxas reacts to a sample with .maxclusterrank with pre SM_90 with a hard error:

ptxas --gpu-name sm_75 --output-file cluster_rank.o cluster_rank.s

ptxas cluster_rank.s, line 18; error   : Feature '.maxclusterrank' requires .target sm_90 or higher
ptxas fatal   : Ptx assembly aborted due to errors

Do I understand you right, that you'd like to see a check similar to what we do in SemaDeclAttr and filter out the directive on targets < SM_90?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do not have a good way to issue any diagnostics from LLVM, so the choice would be to either reject the IR as invalid, or make an effort to compile to valid PTX. Right now we're neither here nor there.

I'd be fine with either of the options above. That said, ignoring metadata which we can't apply seems OK to me.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've talked to @alinas who has more experience dealing with IR and she also thinks that ignoring maxclusterrank metadata on older GPUs is the right choice here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, done in: 261840a

if (getMaxClusterRank(F, Maxclusterrank) && STI->getSmVersion() >= 90)
O << ".maxclusterrank " << Maxclusterrank << "\n";
}

std::string
NVPTXAsmPrinter::getVirtualRegisterName(unsigned Reg) const {
std::string NVPTXAsmPrinter::getVirtualRegisterName(unsigned Reg) const {
const TargetRegisterClass *RC = MRI->getRegClass(Reg);

std::string Name;
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,10 @@ bool getMaxNTIDz(const Function &F, unsigned &z) {
return findOneNVVMAnnotation(&F, "maxntidz", z);
}

bool getMaxClusterRank(const Function &F, unsigned &x) {
return findOneNVVMAnnotation(&F, "maxclusterrank", x);
}

bool getReqNTIDx(const Function &F, unsigned &x) {
return findOneNVVMAnnotation(&F, "reqntidx", x);
}
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/NVPTX/NVPTXUtilities.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ bool getReqNTIDx(const Function &, unsigned &);
bool getReqNTIDy(const Function &, unsigned &);
bool getReqNTIDz(const Function &, unsigned &);

bool getMaxClusterRank(const Function &, unsigned &);
bool getMinCTASm(const Function &, unsigned &);
bool getMaxNReg(const Function &, unsigned &);
bool isKernelFunction(const Function &);
Expand Down
26 changes: 26 additions & 0 deletions llvm/test/CodeGen/NVPTX/maxclusterrank.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
; RUN: llc < %s -march=nvptx64 -mcpu=sm_90 | FileCheck %s --check-prefixes=CHECK,CHECK_SM_90
; RUN: llc < %s -march=nvptx64 -mcpu=sm_80 | FileCheck %s --check-prefixes=CHECK,CHECK_SM_80

target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64"
target triple = "nvptx64-unknown-unknown"

; CHECK: .maxntid 128, 1, 1
; CHECK: .minnctapersm 2
; CHECK_SM_90: .maxclusterrank 8
; CHECK_SM_80-NOT: .maxclusterrank 8

; Make sure that for SM version prior to 90 `.maxclusterrank` directive is
; sielently ignored.
define dso_local void @_Z18TestMaxClusterRankv() {
entry:
%a = alloca i32, align 4
store volatile i32 1, ptr %a, align 4
ret void
}

!nvvm.annotations = !{!0, !1, !2, !3}

!0 = !{ptr @_Z18TestMaxClusterRankv, !"kernel", i32 1}
!1 = !{ptr @_Z18TestMaxClusterRankv, !"maxntidx", i32 128}
!2 = !{ptr @_Z18TestMaxClusterRankv, !"minctasm", i32 2}
!3 = !{ptr @_Z18TestMaxClusterRankv, !"maxclusterrank", i32 8}