Skip to content

Commit

Permalink
CUDA: Compile modules with debug one at a time with NVVM
Browse files Browse the repository at this point in the history
Includes:

- Calls `llvm_to_ptx` once for each IR module for debug.
- Don't adjust linkage of functions in linked modules when debugging,
  because we need device functions to be externally visible.
- Fixed setting of NVVM options when calling `compile_cuda` from kernel
  compilation and device function template compilation.
- Removes debug_pubnames patch

Outcomes:

- The "Error: Debugging support cannot be enabled when number of debug
  compile units is more than 1" message is no longer produced with NVVM
  3.4.
- CUDA test suite passes, apart from those tests that check PTX, because
  get_asm_str() is returning a list of strings when debug is True.
- NVVM 7.0: Everything still seems to "work" as much as it did before.
  Stepping may be more stable, but this needs a bit more verification
  (could just be my late night perception).

Testing outside the test suite:

- Reproducers from Issue numba#5311 in the post description, and from
  c200chromebook.
- The code posted in Discourse thread 449, with debug=True, opt=0 added.

These will need to be made into appropriate test cases - they exposed
some problems with the linkage.
  • Loading branch information
gmarkall committed Mar 18, 2021
1 parent 1f9d320 commit ce840bc
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 12 deletions.
17 changes: 13 additions & 4 deletions numba/cuda/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,12 @@ def get_asm_str(self, cc=None):
options['arch'] = arch

irs = [str(mod) for mod in self.modules]
ptx = nvvm.llvm_to_ptx(irs, **options)
ptx = ptx.decode().strip('\x00').strip()
if options.get('debug', False):
ptx = [nvvm.llvm_to_ptx(ir, **options) for ir in irs]
ptx = [x.decode().strip('\x00').strip() for x in ptx]
else:
ptx = nvvm.llvm_to_ptx(irs, **options)
ptx = ptx.decode().strip('\x00').strip()

if config.DUMP_ASSEMBLY:
print(("ASSEMBLY %s" % self._name).center(80, '-'))
Expand All @@ -136,7 +140,11 @@ def get_cubin(self, cc=None):

ptx = self.get_asm_str(cc=cc)
linker = driver.Linker(max_registers=self._max_registers, cc=cc)
linker.add_ptx(ptx.encode())
if self._nvvm_options.get('debug', False):
for x in ptx:
linker.add_ptx(x.encode())
else:
linker.add_ptx(ptx.encode())
for path in self._linking_files:
linker.add_file_guess_ext(path)
cubin_buf, size = linker.complete()
Expand Down Expand Up @@ -232,7 +240,8 @@ def finalize(self):
# https://github.com/numba/numba/pull/890
for library in self._linking_libraries:
for fn in library._module.functions:
if not fn.is_declaration:
if (not fn.is_declaration and
not self._nvvm_options.get('debug', False)):
fn.linkage = 'linkonce_odr'

self._finalized = True
Expand Down
26 changes: 18 additions & 8 deletions numba/cuda/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,15 @@ def compile(self, args):
Returns the `CompileResult`.
"""
if args not in self._compileinfos:
# Repro for this fix: Issue #5311 reproducer
nvvm_options = {
'debug': self.debug,
# XXX TBC 'fastmath': fastmath, BUG!!! (test case seems to
# check this)
'opt': 3 if self.opt else 0
}
cres = compile_cuda(self.py_func, None, args, debug=self.debug,
inline=self.inline)
inline=self.inline, nvvm_options=nvvm_options)
first_definition = not self._compileinfos
self._compileinfos[args] = cres
libs = [cres.library]
Expand Down Expand Up @@ -467,19 +474,22 @@ def __init__(self, py_func, argtypes, link=None, debug=False, inline=False,
self.debug = debug
self.extensions = extensions or []

cres = compile_cuda(self.py_func, types.void, self.argtypes,
debug=self.debug,
inline=inline,
fastmath=fastmath)
fname = cres.fndesc.llvm_func_name
args = cres.signature.args

# Repro / test case for this fix - kernel that calls power - see e.g.
# discourse thread 449
nvvm_options = {
'debug': self.debug,
'fastmath': fastmath,
'opt': 3 if opt else 0
}

cres = compile_cuda(self.py_func, types.void, self.argtypes,
debug=self.debug,
inline=inline,
fastmath=fastmath,
nvvm_options=nvvm_options)
fname = cres.fndesc.llvm_func_name
args = cres.signature.args

tgt_ctx = cres.target_context
filename = cres.type_annotation.filename
linenum = int(cres.type_annotation.linenum)
Expand Down
1 change: 1 addition & 0 deletions numba/cuda/cudadrv/nvvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,7 @@ def llvm_to_ptx(llvmir, **opts):
cu.lazy_add_module(libdevice.get())

ptx = cu.compile(**opts)
return ptx
# XXX remove debug_pubnames seems to be necessary sometimes
return patch_ptx_debug_pubnames(ptx)

Expand Down

0 comments on commit ce840bc

Please sign in to comment.