Skip to content
Cannot retrieve contributors at this time
158 lines (132 sloc) 6.42 KB
from warnings import warn
from numba.core import types, config, sigutils
from numba.core.errors import NumbaDeprecationWarning
from .compiler import (compile_device, declare_device_function, Dispatcher,
from .simulator.kernel import FakeCUDAKernel
_msg_deprecated_signature_arg = ("Deprecated keyword argument `{0}`. "
"Signatures should be passed as the first "
"positional argument.")
def jitdevice(func, link=[], debug=None, inline=False, opt=True):
"""Wrapper for device-jit.
debug = config.CUDA_DEBUGINFO_DEFAULT if debug is None else debug
if link:
raise ValueError("link keyword invalid for device function")
return compile_device_template(func, debug=debug, inline=inline, opt=opt)
def jit(func_or_sig=None, argtypes=None, device=False, inline=False,
link=[], debug=None, opt=True, **kws):
JIT compile a python function conforming to the CUDA Python specification.
If a signature is supplied, then a function is returned that takes a
function to compile.
:param func_or_sig: A function to JIT compile, or a signature of a function
to compile. If a function is supplied, then a
:class:`numba.cuda.compiler.AutoJitCUDAKernel` is returned. If a
signature is supplied, then a function is returned. The returned
function accepts another function, which it will compile and then return
a :class:`numba.cuda.compiler.AutoJitCUDAKernel`.
.. note:: A kernel cannot have any return value.
:param device: Indicates whether this is a device function.
:type device: bool
:param bind: (Deprecated) Force binding to CUDA context immediately
:type bind: bool
:param link: A list of files containing PTX source to link with the function
:type link: list
:param debug: If True, check for exceptions thrown when executing the
kernel. Since this degrades performance, this should only be used for
debugging purposes. Defaults to False. (The default value can be
overridden by setting environment variable ``NUMBA_CUDA_DEBUGINFO=1``.)
:param fastmath: If true, enables flush-to-zero and fused-multiply-add,
disables precise division and square root. This parameter has no effect
on device function, whose fastmath setting depends on the kernel function
from which they are called.
:param max_registers: Limit the kernel to using at most this number of
registers per thread. Useful for increasing occupancy.
:param opt: Whether to compile from LLVM IR to PTX with optimization
enabled. When ``True``, ``-opt=3`` is passed to NVVM. When
``False``, ``-opt=0`` is passed to NVVM. Defaults to ``True``.
:type opt: bool
debug = config.CUDA_DEBUGINFO_DEFAULT if debug is None else debug
if link and config.ENABLE_CUDASIM:
raise NotImplementedError('Cannot link PTX in the simulator')
if kws.get('boundscheck'):
raise NotImplementedError("bounds checking is not supported for CUDA")
if argtypes is not None:
msg = _msg_deprecated_signature_arg.format('argtypes')
warn(msg, category=NumbaDeprecationWarning)
if 'bind' in kws:
msg = _msg_deprecated_signature_arg.format('bind')
warn(msg, category=NumbaDeprecationWarning)
bind = True
fastmath = kws.get('fastmath', False)
if argtypes is None and not sigutils.is_signature(func_or_sig):
if func_or_sig is None:
def autojitwrapper(func):
return FakeCUDAKernel(func, device=device,
fastmath=fastmath, debug=debug)
def autojitwrapper(func):
return jit(func, device=device, debug=debug, opt=opt, **kws)
return autojitwrapper
# func_or_sig is a function
return FakeCUDAKernel(func_or_sig, device=device,
fastmath=fastmath, debug=debug)
elif device:
return jitdevice(func_or_sig, debug=debug, opt=opt, **kws)
targetoptions = kws.copy()
targetoptions['debug'] = debug
targetoptions['opt'] = opt
targetoptions['link'] = link
sigs = None
return Dispatcher(func_or_sig, sigs, bind=bind,
def jitwrapper(func):
return FakeCUDAKernel(func, device=device, fastmath=fastmath,
return jitwrapper
if isinstance(func_or_sig, list):
msg = 'Lists of signatures are not yet supported in CUDA'
raise ValueError(msg)
elif sigutils.is_signature(func_or_sig):
sigs = [func_or_sig]
elif func_or_sig is None:
# Handle the deprecated argtypes / restype specification
restype = kws.get('restype', types.void)
sigs = [restype(*argtypes)]
raise ValueError("Expecting signature or list of signatures")
for sig in sigs:
restype, argtypes = convert_types(sig, argtypes)
if restype and not device and restype != types.void:
raise TypeError("CUDA kernel must have void return type.")
def kernel_jit(func):
targetoptions = kws.copy()
targetoptions['debug'] = debug
targetoptions['link'] = link
targetoptions['opt'] = opt
return Dispatcher(func, sigs, bind=bind,
def device_jit(func):
return compile_device(func, restype, argtypes, inline=inline,
if device:
return device_jit
return kernel_jit
def declare_device(name, restype=None, argtypes=None):
restype, argtypes = convert_types(restype, argtypes)
return declare_device_function(name, restype, argtypes)
def convert_types(restype, argtypes):
# eval type string
if sigutils.is_signature(restype):
argtypes, restype = sigutils.normalize_signature(restype)
return restype, argtypes
You can’t perform that action at this time.