Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA: Fix source location on kernel entry and enable breakpoints to be set on kernels by mangled name #6841

Merged
merged 7 commits into from Jun 23, 2021
40 changes: 20 additions & 20 deletions numba/core/debuginfo.py
Expand Up @@ -12,19 +12,19 @@

class AbstractDIBuilder(metaclass=abc.ABCMeta):
@abc.abstractmethod
def mark_variable(self, builder, allocavalue, name, lltype, size, loc):
def mark_variable(self, builder, allocavalue, name, lltype, size, line):
"""Emit debug info for the variable.
"""
pass

@abc.abstractmethod
def mark_location(self, builder, loc):
def mark_location(self, builder, line):
"""Emit source location information to the given IRBuilder.
"""
pass

@abc.abstractmethod
def mark_subprogram(self, function, name, loc):
def mark_subprogram(self, function, name, line):
"""Emit source location information for the given function.
"""
pass
Expand All @@ -47,13 +47,13 @@ class DummyDIBuilder(AbstractDIBuilder):
def __init__(self, module, filepath):
pass

def mark_variable(self, builder, allocavalue, name, lltype, size, loc):
def mark_variable(self, builder, allocavalue, name, lltype, size, line):
pass

def mark_location(self, builder, loc):
def mark_location(self, builder, line):
pass

def mark_subprogram(self, function, name, loc):
def mark_subprogram(self, function, name, line):
pass

def initialize(self):
Expand Down Expand Up @@ -117,7 +117,7 @@ def _var_type(self, lltype, size):
})
return mdtype

def mark_variable(self, builder, allocavalue, name, lltype, size, loc):
def mark_variable(self, builder, allocavalue, name, lltype, size, line):
m = self.module
fnty = ir.FunctionType(ir.VoidType(), [ir.MetaDataType()] * 3)
decl = cgutils.get_or_insert_function(m, fnty, 'llvm.dbg.declare')
Expand All @@ -129,19 +129,19 @@ def mark_variable(self, builder, allocavalue, name, lltype, size, loc):
'arg': 0,
'scope': self.subprograms[-1],
'file': self.difile,
'line': loc.line,
'line': line,
'type': mdtype,
})
mdexpr = m.add_debug_info('DIExpression', {})

return builder.call(decl, [allocavalue, mdlocalvar, mdexpr])

def mark_location(self, builder, loc):
builder.debug_metadata = self._add_location(loc.line)
def mark_location(self, builder, line):
builder.debug_metadata = self._add_location(line)

def mark_subprogram(self, function, name, loc):
def mark_subprogram(self, function, name, line):
di_subp = self._add_subprogram(name=name, linkagename=function.name,
line=loc.line)
line=line)
function.set_metadata("dbg", di_subp)
# disable inlining for this function for easier debugging
function.attributes.add('noinline')
Expand Down Expand Up @@ -274,26 +274,26 @@ class NvvmDIBuilder(DIBuilder):
# Used in mark_location to remember last lineno to avoid duplication
_last_lineno = None

def mark_variable(self, builder, allocavalue, name, lltype, size, loc):
def mark_variable(self, builder, allocavalue, name, lltype, size, line):
# unsupported
pass

def mark_location(self, builder, loc):
def mark_location(self, builder, line):
# Avoid duplication
if self._last_lineno == loc.line:
if self._last_lineno == line:
return
self._last_lineno = loc.line
self._last_lineno = line
# Add call to an inline asm to mark line location
asmty = ir.FunctionType(ir.VoidType(), [])
asm = ir.InlineAsm(asmty, "// dbg {}".format(loc.line), "",
asm = ir.InlineAsm(asmty, "// dbg {}".format(line), "",
side_effect=True)
call = builder.call(asm, [])
md = self._di_location(loc.line)
md = self._di_location(line)
call.set_metadata('numba.dbg', md)

def mark_subprogram(self, function, name, loc):
def mark_subprogram(self, function, name, line):
self._add_subprogram(name=name, linkagename=function.name,
line=loc.line)
line=line)

#
# Helper methods to create the metadata nodes.
Expand Down
6 changes: 3 additions & 3 deletions numba/core/lowering.py
Expand Up @@ -87,7 +87,7 @@ def pre_lower(self):
self.pyapi = None
self.debuginfo.mark_subprogram(function=self.builder.function,
name=self.fndesc.qualname,
loc=self.func_ir.loc)
line=self.func_ir.loc.line)

def post_lower(self):
"""
Expand Down Expand Up @@ -316,7 +316,7 @@ def post_block(self, block):

def lower_inst(self, inst):
# Set debug location for all subsequent LL instructions
self.debuginfo.mark_location(self.builder, self.loc)
self.debuginfo.mark_location(self.builder, self.loc.line)
self.debug_print(str(inst))
if isinstance(inst, ir.Assign):
ty = self.typeof(inst.target.name)
Expand Down Expand Up @@ -1313,7 +1313,7 @@ def alloca_lltype(self, name, lltype):
sizeof = self.context.get_abi_sizeof(lltype)
self.debuginfo.mark_variable(self.builder, aptr, name=name,
lltype=lltype, size=sizeof,
loc=self.loc)
line=self.loc.line)
return aptr

def incref(self, typ, val):
Expand Down
50 changes: 39 additions & 11 deletions numba/cuda/codegen.py
Expand Up @@ -98,31 +98,51 @@ def get_llvm_str(self):
return str(self._module)

def get_asm_str(self, cc=None):
return self._join_ptxes(self._get_ptxes(cc=cc))

def _get_ptxes(self, cc=None):
if not cc:
ctx = devices.get_context()
device = ctx.device
cc = device.compute_capability

ptx = self._ptx_cache.get(cc, None)
if ptx:
return ptx
ptxes = self._ptx_cache.get(cc, None)
if ptxes:
return ptxes

arch = nvvm.get_arch_option(*cc)
options = self._nvvm_options.copy()
options['arch'] = arch

irs = [str(mod) for mod in self.modules]
ptx = nvvm.llvm_to_ptx(irs, **options)
ptx = ptx.decode().strip('\x00').strip()

if options.get('debug', False):
# If we're compiling with debug, we need to compile modules with
# NVVM one at a time, because it does not support multiple modules
# with debug enabled:
# https://docs.nvidia.com/cuda/nvvm-ir-spec/index.html#source-level-debugging-support
ptxes = [nvvm.llvm_to_ptx(ir, **options) for ir in irs]
else:
# Otherwise, we compile all modules with NVVM at once because this
# results in better optimization than separate compilation.
ptxes = [nvvm.llvm_to_ptx(irs, **options)]

# Sometimes the result from NVVM contains trailing whitespace and
# nulls, which we strip so that the assembly dump looks a little
# tidier.
ptxes = [x.decode().strip('\x00').strip() for x in ptxes]

if config.DUMP_ASSEMBLY:
print(("ASSEMBLY %s" % self._name).center(80, '-'))
print(ptx)
print(self._join_ptxes(ptxes))
print('=' * 80)

self._ptx_cache[cc] = ptx
self._ptx_cache[cc] = ptxes

return ptxes

return ptx
def _join_ptxes(self, ptxes):
return "\n\n".join(ptxes)

def get_cubin(self, cc=None):
if cc is None:
Expand All @@ -134,11 +154,14 @@ def get_cubin(self, cc=None):
if cubin:
return cubin

ptx = self.get_asm_str(cc=cc)
linker = driver.Linker(max_registers=self._max_registers, cc=cc)
linker.add_ptx(ptx.encode())

ptxes = self._get_ptxes(cc=cc)
for ptx in ptxes:
linker.add_ptx(ptx.encode())
for path in self._linking_files:
linker.add_file_guess_ext(path)

cubin_buf, size = linker.complete()

# We take a copy of the cubin because it's owned by the linker
Expand Down Expand Up @@ -230,9 +253,14 @@ def finalize(self):
#
# See also discussion on PR #890:
# https://github.com/numba/numba/pull/890
#
# We don't adjust the linkage of functions when compiling for debug -
# because the device functions are in separate modules, we need them to
# be externally visible.
for library in self._linking_libraries:
for fn in library._module.functions:
if not fn.is_declaration:
if (not fn.is_declaration and
not self._nvvm_options.get('debug', False)):
fn.linkage = 'linkonce_odr'

self._finalized = True
Expand Down
27 changes: 17 additions & 10 deletions numba/cuda/compiler.py
Expand Up @@ -227,9 +227,11 @@ def compile_ptx(pyfunc, args, debug=False, lineinfo=False, device=False,
else:
fname = cres.fndesc.llvm_func_name
tgt = cres.target_context
filename = cres.type_annotation.filename
linenum = int(cres.type_annotation.linenum)
lib, kernel = tgt.prepare_cuda_kernel(cres.library, fname,
cres.signature.args, debug,
nvvm_options)
nvvm_options, filename, linenum)

cc = cc or config.CUDA_DEFAULT_PTX_CC
ptx = lib.get_asm_str(cc=cc)
Expand Down Expand Up @@ -311,8 +313,8 @@ def compile(self, args):
"""
if args not in self.overloads:
nvvm_options = {
'opt': 3 if self.opt else 0,
'debug': self.debug,
'opt': 3 if self.opt else 0
}

cres = compile_cuda(self.py_func, None, args, debug=self.debug,
Expand Down Expand Up @@ -535,24 +537,29 @@ def __init__(self, py_func, argtypes, link=None, debug=False,
self.lineinfo = lineinfo
self.extensions = extensions or []

cres = compile_cuda(self.py_func, types.void, self.argtypes,
debug=self.debug,
lineinfo=self.lineinfo,
inline=inline,
fastmath=fastmath)
fname = cres.fndesc.llvm_func_name
args = cres.signature.args

nvvm_options = {
'debug': self.debug,
'lineinfo': self.lineinfo,
'fastmath': fastmath,
'opt': 3 if opt else 0
}

cres = compile_cuda(self.py_func, types.void, self.argtypes,
debug=self.debug,
lineinfo=self.lineinfo,
inline=inline,
fastmath=fastmath,
nvvm_options=nvvm_options)
fname = cres.fndesc.llvm_func_name
args = cres.signature.args

tgt_ctx = cres.target_context
code = self.py_func.__code__
filename = code.co_filename
linenum = code.co_firstlineno
lib, kernel = tgt_ctx.prepare_cuda_kernel(cres.library, fname, args,
debug, nvvm_options,
filename, linenum,
max_registers)

if not link:
Expand Down
23 changes: 1 addition & 22 deletions numba/cuda/cudadrv/nvvm.py
Expand Up @@ -708,28 +708,7 @@ def llvm_to_ptx(llvmir, **opts):
cu.add_module(mod.encode('utf8'))
cu.lazy_add_module(libdevice.get())

ptx = cu.compile(**opts)
# XXX remove debug_pubnames seems to be necessary sometimes
return patch_ptx_debug_pubnames(ptx)


def patch_ptx_debug_pubnames(ptx):
"""
Patch PTX to workaround .debug_pubnames NVVM error::

ptxas fatal : Internal error: overlapping non-identical data

"""
while True:
# Repeatedly remove debug_pubnames sections
start = ptx.find(b'.section .debug_pubnames')
if start < 0:
break
stop = ptx.find(b'}', start)
if stop < 0:
raise ValueError('missing "}"')
ptx = ptx[:start] + ptx[stop + 1:]
return ptx
return cu.compile(**opts)


re_metadata_def = re.compile(r"\!\d+\s*=")
Expand Down
17 changes: 14 additions & 3 deletions numba/cuda/target.py
Expand Up @@ -126,7 +126,8 @@ def mangler(self, name, argtypes):
return itanium_mangler.mangle(name, argtypes)

def prepare_cuda_kernel(self, codelib, func_name, argtypes, debug,
nvvm_options, max_registers=None):
nvvm_options, filename, linenum,
max_registers=None):
"""
Adapt a code library ``codelib`` with the numba compiled CUDA kernel
with name ``fname`` and arguments ``argtypes`` for NVVM.
Expand All @@ -143,6 +144,8 @@ def prepare_cuda_kernel(self, codelib, func_name, argtypes, debug,
argtypes: An iterable of the types of arguments to the kernel.
debug: Whether to compile with debug.
nvvm_options: Dict of NVVM options used when compiling the new library.
filename: The source filename that the function is contained in.
linenum: The source line that the function is on.
Comment on lines +147 to +148
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if someone materialized a function from a string?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

filename is <string> and linenum is the line in the string on which the function began - I think this is normal / expected?

max_registers: The max_registers argument for the code library.
"""
kernel_name = itanium_mangler.prepend_namespace(func_name, ns='cudapy')
Expand All @@ -152,11 +155,12 @@ def prepare_cuda_kernel(self, codelib, func_name, argtypes, debug,
max_registers=max_registers)
library.add_linking_library(codelib)
wrapper = self.generate_kernel_wrapper(library, kernel_name, func_name,
argtypes, debug)
argtypes, debug, filename,
linenum)
return library, wrapper

def generate_kernel_wrapper(self, library, kernel_name, func_name,
argtypes, debug):
argtypes, debug, filename, linenum):
"""
Generate the kernel wrapper in the given ``library``.
The function being wrapped have the name ``fname`` and argument types
Expand All @@ -175,6 +179,11 @@ def generate_kernel_wrapper(self, library, kernel_name, func_name,
wrapfn = ir.Function(wrapper_module, wrapfnty, prefixed)
builder = ir.IRBuilder(wrapfn.append_basic_block(''))

if debug:
debuginfo = self.DIBuilder(module=wrapper_module, filepath=filename)
debuginfo.mark_subprogram(wrapfn, kernel_name, linenum)
debuginfo.mark_location(builder, linenum)

# Define error handling variables
def define_error_gv(postfix):
name = wrapfn.name + postfix
Expand Down Expand Up @@ -234,6 +243,8 @@ def define_error_gv(postfix):

nvvm.set_cuda_kernel(wrapfn)
library.add_ir_module(wrapper_module)
if debug:
debuginfo.finalize()
library.finalize()
wrapfn = library.get_function(wrapfn.name)
return wrapfn
Expand Down