-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
Copy pathdecorators.py
191 lines (160 loc) · 7.67 KB
/
decorators.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
from warnings import warn
from numba.core import types, config, sigutils
from numba.core.errors import DeprecationError, NumbaInvalidConfigWarning
from numba.cuda.compiler import declare_device_function
from numba.cuda.dispatcher import CUDADispatcher
from numba.cuda.simulator.kernel import FakeCUDAKernel
_msg_deprecated_signature_arg = ("Deprecated keyword argument `{0}`. "
"Signatures should be passed as the first "
"positional argument.")
def jit(func_or_sig=None, device=False, inline=False, link=None, debug=None,
opt=True, lineinfo=False, cache=False, **kws):
"""
JIT compile a Python function for CUDA GPUs.
:param func_or_sig: A function to JIT compile, or *signatures* of a
function to compile. If a function is supplied, then a
:class:`Dispatcher <numba.cuda.dispatcher.CUDADispatcher>` is returned.
Otherwise, ``func_or_sig`` may be a signature or a list of signatures,
and a function is returned. The returned function accepts another
function, which it will compile and then return a :class:`Dispatcher
<numba.cuda.dispatcher.CUDADispatcher>`. See :ref:`jit-decorator` for
more information about passing signatures.
.. note:: A kernel cannot have any return value.
:param device: Indicates whether this is a device function.
:type device: bool
:param link: A list of files containing PTX or CUDA C/C++ source to link
with the function
:type link: list
:param debug: If True, check for exceptions thrown when executing the
kernel. Since this degrades performance, this should only be used for
debugging purposes. If set to True, then ``opt`` should be set to False.
Defaults to False. (The default value can be overridden by setting
environment variable ``NUMBA_CUDA_DEBUGINFO=1``.)
:param fastmath: When True, enables fastmath optimizations as outlined in
the :ref:`CUDA Fast Math documentation <cuda-fast-math>`.
:param max_registers: Request that the kernel is limited to using at most
this number of registers per thread. The limit may not be respected if
the ABI requires a greater number of registers than that requested.
Useful for increasing occupancy.
:param opt: Whether to compile from LLVM IR to PTX with optimization
enabled. When ``True``, ``-opt=3`` is passed to NVVM. When
``False``, ``-opt=0`` is passed to NVVM. Defaults to ``True``.
:type opt: bool
:param lineinfo: If True, generate a line mapping between source code and
assembly code. This enables inspection of the source code in NVIDIA
profiling tools and correlation with program counter sampling.
:type lineinfo: bool
:param cache: If True, enables the file-based cache for this function.
:type cache: bool
"""
if link is None:
link = []
if link and config.ENABLE_CUDASIM:
raise NotImplementedError('Cannot link PTX in the simulator')
if kws.get('boundscheck'):
raise NotImplementedError("bounds checking is not supported for CUDA")
if kws.get('argtypes') is not None:
msg = _msg_deprecated_signature_arg.format('argtypes')
raise DeprecationError(msg)
if kws.get('restype') is not None:
msg = _msg_deprecated_signature_arg.format('restype')
raise DeprecationError(msg)
if kws.get('bind') is not None:
msg = _msg_deprecated_signature_arg.format('bind')
raise DeprecationError(msg)
debug = config.CUDA_DEBUGINFO_DEFAULT if debug is None else debug
fastmath = kws.get('fastmath', False)
extensions = kws.get('extensions', [])
if debug and opt:
msg = ("debug=True with opt=True (the default) "
"is not supported by CUDA. This may result in a crash"
" - set debug=False or opt=False.")
warn(NumbaInvalidConfigWarning(msg))
if debug and lineinfo:
msg = ("debug and lineinfo are mutually exclusive. Use debug to get "
"full debug info (this disables some optimizations), or "
"lineinfo for line info only with code generation unaffected.")
warn(NumbaInvalidConfigWarning(msg))
if device and kws.get('link'):
raise ValueError("link keyword invalid for device function")
if sigutils.is_signature(func_or_sig):
signatures = [func_or_sig]
specialized = True
elif isinstance(func_or_sig, list):
signatures = func_or_sig
specialized = False
else:
signatures = None
if signatures is not None:
if config.ENABLE_CUDASIM:
def jitwrapper(func):
return FakeCUDAKernel(func, device=device, fastmath=fastmath)
return jitwrapper
def _jit(func):
targetoptions = kws.copy()
targetoptions['debug'] = debug
targetoptions['lineinfo'] = lineinfo
targetoptions['link'] = link
targetoptions['opt'] = opt
targetoptions['fastmath'] = fastmath
targetoptions['device'] = device
targetoptions['extensions'] = extensions
disp = CUDADispatcher(func, targetoptions=targetoptions)
if cache:
disp.enable_caching()
for sig in signatures:
argtypes, restype = sigutils.normalize_signature(sig)
if restype and not device and restype != types.void:
raise TypeError("CUDA kernel must have void return type.")
if device:
from numba.core import typeinfer
with typeinfer.register_dispatcher(disp):
disp.compile_device(argtypes, restype)
else:
disp.compile(argtypes)
disp._specialized = specialized
disp.disable_compile()
return disp
return _jit
else:
if func_or_sig is None:
if config.ENABLE_CUDASIM:
def autojitwrapper(func):
return FakeCUDAKernel(func, device=device,
fastmath=fastmath)
else:
def autojitwrapper(func):
return jit(func, device=device, debug=debug, opt=opt,
lineinfo=lineinfo, link=link, cache=cache, **kws)
return autojitwrapper
# func_or_sig is a function
else:
if config.ENABLE_CUDASIM:
return FakeCUDAKernel(func_or_sig, device=device,
fastmath=fastmath)
else:
targetoptions = kws.copy()
targetoptions['debug'] = debug
targetoptions['lineinfo'] = lineinfo
targetoptions['opt'] = opt
targetoptions['link'] = link
targetoptions['fastmath'] = fastmath
targetoptions['device'] = device
targetoptions['extensions'] = extensions
disp = CUDADispatcher(func_or_sig, targetoptions=targetoptions)
if cache:
disp.enable_caching()
return disp
def declare_device(name, sig):
"""
Declare the signature of a foreign function. Returns a descriptor that can
be used to call the function from a Python kernel.
:param name: The name of the foreign function.
:type name: str
:param sig: The Numba signature of the function.
"""
argtypes, restype = sigutils.normalize_signature(sig)
if restype is None:
msg = 'Return type must be provided for device declarations'
raise TypeError(msg)
return declare_device_function(name, restype, argtypes)