From 052b1ed2fd182115fe39abfb36c2615bf5948877 Mon Sep 17 00:00:00 2001 From: Artem Belevich Date: Mon, 18 Jul 2016 19:54:56 +0000 Subject: [PATCH] [NVPTX] Force minimum alignment of 4 for byval arguments of device-side functions. Taking address of a byval variable in PTX is legal, but currently runs into miscompilation by ptxas on sm_50+ (NVIDIA issue 1789042). Work around the issue by enforcing minimum alignment on byval arguments of device functions. The change is a no-op on SASS level for sm_3x where ptxas already aligns local copy by at least 4. Differential Revision: https://reviews.llvm.org/D22428 llvm-svn: 275893 --- llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp | 14 +++++++++++++- llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 15 ++++++++++----- llvm/test/CodeGen/NVPTX/param-align.ll | 8 ++++++++ 3 files changed, 31 insertions(+), 6 deletions(-) diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp index 660016bfcd055..a196b7afd179e 100644 --- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp @@ -1589,7 +1589,19 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) { unsigned align = PAL.getParamAlignment(paramIndex + 1); if (align == 0) align = DL.getABITypeAlignment(ETy); - + // Work around a bug in ptxas. When PTX code takes address of + // byval parameter with alignment < 4, ptxas generates code to + // spill argument into memory. Alas on sm_50+ ptxas generates + // SASS code that fails with misaligned access. To work around + // the problem, make sure that we align byval parameters by at + // least 4. Matching change must be made in LowerCall() where we + // prepare parameters for the call. + // + // TODO: this will need to be undone when we get to support multi-TU + // device-side compilation as it breaks ABI compatibility with nvcc. + // Hopefully ptxas bug is fixed by then. + if (!isKernelFunc && align < 4) + align = 4; unsigned sz = DL.getTypeAllocSize(ETy); O << "\t.param .align " << align << " .b8 "; printParamName(I, paramIndex, O); diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp index f28c89cd976ab..dba685548dc89 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -1072,6 +1072,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, MachineFunction &MF = DAG.getMachineFunction(); const Function *F = MF.getFunction(); auto &DL = MF.getDataLayout(); + bool isKernel = llvm::isKernelFunction(*F); SDValue tempChain = Chain; Chain = DAG.getCALLSEQ_START(Chain, @@ -1337,11 +1338,15 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, // The ByValAlign in the Outs[OIdx].Flags is alway set at this point, // so we don't need to worry about natural alignment or not. // See TargetLowering::LowerCallTo(). - SDValue DeclareParamOps[] = { - Chain, DAG.getConstant(Outs[OIdx].Flags.getByValAlign(), dl, MVT::i32), - DAG.getConstant(paramCount, dl, MVT::i32), - DAG.getConstant(sz, dl, MVT::i32), InFlag - }; + + // Enforce minumum alignment of 4 to work around ptxas miscompile + // for sm_50+. See corresponding alignment adjustment in + // emitFunctionParamList() for details. + if (!isKernel && ArgAlign < 4) + ArgAlign = 4; + SDValue DeclareParamOps[] = {Chain, DAG.getConstant(ArgAlign, dl, MVT::i32), + DAG.getConstant(paramCount, dl, MVT::i32), + DAG.getConstant(sz, dl, MVT::i32), InFlag}; Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs, DeclareParamOps); InFlag = Chain.getValue(1); diff --git a/llvm/test/CodeGen/NVPTX/param-align.ll b/llvm/test/CodeGen/NVPTX/param-align.ll index 84ccb650d40d9..6d58fda59aee3 100644 --- a/llvm/test/CodeGen/NVPTX/param-align.ll +++ b/llvm/test/CodeGen/NVPTX/param-align.ll @@ -23,3 +23,11 @@ define ptx_device void @t3(%struct.float2* byval %x) { ; CHECK: .param .align 4 .b8 t3_param_0[8] ret void } + +;;; Need at least 4-byte alignment in order to avoid miscompilation by +;;; ptxas for sm_50+ +define ptx_device void @t4(i8* byval %x) { +; CHECK: .func t4 +; CHECK: .param .align 4 .b8 t4_param_0[1] + ret void +}