Skip to content
Permalink
 
 
Cannot retrieve contributors at this time
58 lines (39 sloc) 1.38 KB
from numba.core import types, sigutils
from .compiler import (compile_kernel, compile_device, AutoJitHSAKernel,
compile_device_template)
def jit(signature=None, device=False):
"""JIT compile a python function conforming to
the HSA-Python
"""
if signature is None:
return autojit(device=device)
elif not sigutils.is_signature(signature):
func = signature
return autojit(device=device)(func)
else:
if device:
return _device_jit(signature)
else:
return _kernel_jit(signature)
def autojit(device=False):
if device:
return _device_autojit
else:
return _kernel_autojit
def _device_jit(signature):
argtypes, restype = sigutils.normalize_signature(signature)
def _wrapped(pyfunc):
return compile_device(pyfunc, restype, argtypes)
return _wrapped
def _kernel_jit(signature):
argtypes, restype = sigutils.normalize_signature(signature)
if restype is not None and restype != types.void:
msg = "HSA kernel must have void return type but got {restype}"
raise TypeError(msg.format(restype=restype))
def _wrapped(pyfunc):
return compile_kernel(pyfunc, argtypes)
return _wrapped
def _device_autojit(pyfunc):
return compile_device_template(pyfunc)
def _kernel_autojit(pyfunc):
return AutoJitHSAKernel(pyfunc)
You can’t perform that action at this time.