Skip to content
Permalink
 
 
Cannot retrieve contributors at this time
464 lines (372 sloc) 14.6 KB
import copy
from collections import namedtuple
import ctypes
import re
import numpy as np
from numba.core.typing.templates import ConcreteTemplate
from numba.core import types, config, compiler
from .hlc import hlc
from .hsadrv import devices, driver, enums, drvapi
from .hsadrv.error import HsaKernelLaunchError
from numba.roc import gcn_occupancy
from numba.roc.hsadrv.driver import hsa, dgpu_present
from .hsadrv import devicearray
from numba.core.typing.templates import AbstractTemplate
from numba.core.compiler_lock import global_compiler_lock
@global_compiler_lock
def compile_hsa(pyfunc, return_type, args, debug):
# First compilation will trigger the initialization of the HSA backend.
from .descriptor import HSATargetDesc
typingctx = HSATargetDesc.typingctx
targetctx = HSATargetDesc.targetctx
# TODO handle debug flag
flags = compiler.Flags()
# Do not compile (generate native code), just lower (to LLVM)
flags.set('no_compile')
flags.set('no_cpython_wrapper')
flags.set('no_cfunc_wrapper')
flags.unset('nrt')
# Run compilation pipeline
cres = compiler.compile_extra(typingctx=typingctx,
targetctx=targetctx,
func=pyfunc,
args=args,
return_type=return_type,
flags=flags,
locals={})
# Linking depending libraries
# targetctx.link_dependencies(cres.llvm_module, cres.target_context.linking)
library = cres.library
library.finalize()
return cres
def compile_kernel(pyfunc, args, debug=False):
cres = compile_hsa(pyfunc, types.void, args, debug=debug)
func = cres.library.get_function(cres.fndesc.llvm_func_name)
kernel = cres.target_context.prepare_hsa_kernel(func, cres.signature.args)
hsakern = HSAKernel(llvm_module=kernel.module,
name=kernel.name,
argtypes=cres.signature.args)
return hsakern
def compile_device(pyfunc, return_type, args, debug=False):
cres = compile_hsa(pyfunc, return_type, args, debug=debug)
func = cres.library.get_function(cres.fndesc.llvm_func_name)
cres.target_context.mark_hsa_device(func)
devfn = DeviceFunction(cres)
class device_function_template(ConcreteTemplate):
key = devfn
cases = [cres.signature]
cres.typing_context.insert_user_function(devfn, device_function_template)
libs = [cres.library]
cres.target_context.insert_user_function(devfn, cres.fndesc, libs)
return devfn
def compile_device_template(pyfunc):
"""Compile a DeviceFunctionTemplate
"""
from .descriptor import HSATargetDesc
dft = DeviceFunctionTemplate(pyfunc)
class device_function_template(AbstractTemplate):
key = dft
def generic(self, args, kws):
assert not kws
return dft.compile(args)
typingctx = HSATargetDesc.typingctx
typingctx.insert_user_function(dft, device_function_template)
return dft
class DeviceFunctionTemplate(object):
"""Unmaterialized device function
"""
def __init__(self, pyfunc, debug=False):
self.py_func = pyfunc
self.debug = debug
# self.inline = inline
self._compileinfos = {}
def compile(self, args):
"""Compile the function for the given argument types.
Each signature is compiled once by caching the compiled function inside
this object.
"""
if args not in self._compileinfos:
cres = compile_hsa(self.py_func, None, args, debug=self.debug)
func = cres.library.get_function(cres.fndesc.llvm_func_name)
cres.target_context.mark_hsa_device(func)
first_definition = not self._compileinfos
self._compileinfos[args] = cres
libs = [cres.library]
if first_definition:
# First definition
cres.target_context.insert_user_function(self, cres.fndesc,
libs)
else:
cres.target_context.add_user_function(self, cres.fndesc, libs)
else:
cres = self._compileinfos[args]
return cres.signature
class DeviceFunction(object):
def __init__(self, cres):
self.cres = cres
def _ensure_list(val):
if not isinstance(val, (tuple, list)):
return [val]
else:
return list(val)
def _ensure_size_or_append(val, size):
n = len(val)
for _ in range(n, size):
val.append(1)
class HSAKernelBase(object):
"""Define interface for configurable kernels
"""
def __init__(self):
self.global_size = (1,)
self.local_size = (1,)
self.stream = None
def copy(self):
return copy.copy(self)
def configure(self, global_size, local_size=None, stream=None):
"""Configure the OpenCL kernel
local_size can be None
"""
global_size = _ensure_list(global_size)
if local_size is not None:
local_size = _ensure_list(local_size)
size = max(len(global_size), len(local_size))
_ensure_size_or_append(global_size, size)
_ensure_size_or_append(local_size, size)
clone = self.copy()
clone.global_size = tuple(global_size)
clone.local_size = tuple(local_size) if local_size else None
clone.stream = stream
return clone
def forall(self, nelem, local_size=64, stream=None):
"""Simplified configuration for 1D kernel launch
"""
return self.configure(nelem, min(nelem, local_size), stream=stream)
def __getitem__(self, args):
"""Mimick CUDA python's square-bracket notation for configuration.
This assumes a the argument to be:
`griddim, blockdim, stream`
The blockdim maps directly to local_size.
The actual global_size is computed by multiplying the local_size to
griddim.
"""
griddim = _ensure_list(args[0])
blockdim = _ensure_list(args[1])
size = max(len(griddim), len(blockdim))
_ensure_size_or_append(griddim, size)
_ensure_size_or_append(blockdim, size)
# Compute global_size
gs = [g * l for g, l in zip(griddim, blockdim)]
return self.configure(gs, blockdim, *args[2:])
_CacheEntry = namedtuple("_CachedEntry", ['symbol', 'executable',
'kernarg_region'])
class _CachedProgram(object):
def __init__(self, entry_name, binary):
self._entry_name = entry_name
self._binary = binary
# key: hsa context
self._cache = {}
def get(self):
ctx = devices.get_context()
result = self._cache.get(ctx)
# The program does not exist as GCN yet.
if result is None:
# generate GCN
symbol = '{0}'.format(self._entry_name)
agent = ctx.agent
ba = bytearray(self._binary)
bblob = ctypes.c_byte * len(self._binary)
bas = bblob.from_buffer(ba)
code_ptr = drvapi.hsa_code_object_t()
driver.hsa.hsa_code_object_deserialize(
ctypes.addressof(bas),
len(self._binary),
None,
ctypes.byref(code_ptr)
)
code = driver.CodeObject(code_ptr)
ex = driver.Executable()
ex.load(agent, code)
ex.freeze()
symobj = ex.get_symbol(agent, symbol)
regions = agent.regions.globals
for reg in regions:
if reg.host_accessible:
if reg.supports(enums.HSA_REGION_GLOBAL_FLAG_KERNARG):
kernarg_region = reg
break
assert kernarg_region is not None
# Cache the GCN program
result = _CacheEntry(symbol=symobj, executable=ex,
kernarg_region=kernarg_region)
self._cache[ctx] = result
return ctx, result
class HSAKernel(HSAKernelBase):
"""
A HSA kernel object
"""
def __init__(self, llvm_module, name, argtypes):
super(HSAKernel, self).__init__()
self._llvm_module = llvm_module
self.assembly, self.binary = self._generateGCN()
self.entry_name = name
self.argument_types = tuple(argtypes)
self._argloc = []
# cached program
self._cacheprog = _CachedProgram(entry_name=self.entry_name,
binary=self.binary)
self._parse_kernel_resource()
def _parse_kernel_resource(self):
"""
Temporary workaround for register limit
"""
m = re.search(r"\bwavefront_sgpr_count\s*=\s*(\d+)", self.assembly)
self._wavefront_sgpr_count = int(m.group(1))
m = re.search(r"\bworkitem_vgpr_count\s*=\s*(\d+)", self.assembly)
self._workitem_vgpr_count = int(m.group(1))
def _sentry_resource_limit(self):
# only check resource factprs if either sgpr or vgpr is non-zero
#if (self._wavefront_sgpr_count > 0 or self._workitem_vgpr_count > 0):
group_size = np.prod(self.local_size)
limits = gcn_occupancy.get_limiting_factors(
group_size=group_size,
vgpr_per_workitem=self._workitem_vgpr_count,
sgpr_per_wave=self._wavefront_sgpr_count)
if limits.reasons:
fmt = 'insufficient resources to launch kernel due to:\n{}'
msg = fmt.format('\n'.join(limits.suggestions))
raise HsaKernelLaunchError(msg)
def _generateGCN(self):
hlcmod = hlc.Module()
hlcmod.load_llvm(str(self._llvm_module))
return hlcmod.generateGCN()
def bind(self):
"""
Bind kernel to device
"""
ctx, entry = self._cacheprog.get()
if entry.symbol.kernarg_segment_size > 0:
sz = ctypes.sizeof(ctypes.c_byte) *\
entry.symbol.kernarg_segment_size
kernargs = entry.kernarg_region.allocate(sz)
else:
kernargs = None
return ctx, entry.symbol, kernargs, entry.kernarg_region
def __call__(self, *args):
self._sentry_resource_limit()
ctx, symbol, kernargs, kernarg_region = self.bind()
# Unpack pyobject values into ctypes scalar values
expanded_values = []
# contains lambdas to execute on return
retr = []
for ty, val in zip(self.argument_types, args):
_unpack_argument(ty, val, expanded_values, retr)
# Insert kernel arguments
base = 0
for av in expanded_values:
# Adjust for alignment
align = ctypes.sizeof(av)
pad = _calc_padding_for_alignment(align, base)
base += pad
# Move to offset
offseted = kernargs.value + base
asptr = ctypes.cast(offseted, ctypes.POINTER(type(av)))
# Assign value
asptr[0] = av
# Increment offset
base += align
# Actual Kernel launch
qq = ctx.default_queue
if self.stream is None:
hsa.implicit_sync()
# Dispatch
signal = None
if self.stream is not None:
signal = hsa.create_signal(1)
qq.insert_barrier(self.stream._get_last_signal())
qq.dispatch(symbol, kernargs, workgroup_size=self.local_size,
grid_size=self.global_size, signal=signal)
if self.stream is not None:
self.stream._add_signal(signal)
# retrieve auto converted arrays
for wb in retr:
wb()
# Free kernel region
if kernargs is not None:
if self.stream is None:
kernarg_region.free(kernargs)
else:
self.stream._add_callback(lambda: kernarg_region.free(kernargs))
def _unpack_argument(ty, val, kernelargs, retr):
"""
Convert arguments to ctypes and append to kernelargs
"""
if isinstance(ty, types.Array):
c_intp = ctypes.c_ssize_t
# if a dgpu is present, move the data to the device.
if dgpu_present:
devary, conv = devicearray.auto_device(val, devices.get_context())
if conv:
retr.append(lambda: devary.copy_to_host(val))
data = devary.device_ctypes_pointer
else:
data = ctypes.c_void_p(val.ctypes.data)
meminfo = parent = ctypes.c_void_p(0)
nitems = c_intp(val.size)
itemsize = c_intp(val.dtype.itemsize)
kernelargs.append(meminfo)
kernelargs.append(parent)
kernelargs.append(nitems)
kernelargs.append(itemsize)
kernelargs.append(data)
for ax in range(val.ndim):
kernelargs.append(c_intp(val.shape[ax]))
for ax in range(val.ndim):
kernelargs.append(c_intp(val.strides[ax]))
elif isinstance(ty, types.Integer):
cval = getattr(ctypes, "c_%s" % ty)(val)
kernelargs.append(cval)
elif ty == types.float64:
cval = ctypes.c_double(val)
kernelargs.append(cval)
elif ty == types.float32:
cval = ctypes.c_float(val)
kernelargs.append(cval)
elif ty == types.boolean:
cval = ctypes.c_uint8(int(val))
kernelargs.append(cval)
elif ty == types.complex64:
kernelargs.append(ctypes.c_float(val.real))
kernelargs.append(ctypes.c_float(val.imag))
elif ty == types.complex128:
kernelargs.append(ctypes.c_double(val.real))
kernelargs.append(ctypes.c_double(val.imag))
else:
raise NotImplementedError(ty, val)
def _calc_padding_for_alignment(align, base):
"""
Returns byte padding required to move the base pointer into proper alignment
"""
rmdr = int(base) % align
if rmdr == 0:
return 0
else:
return align - rmdr
class AutoJitHSAKernel(HSAKernelBase):
def __init__(self, func):
super(AutoJitHSAKernel, self).__init__()
self.py_func = func
self.definitions = {}
from .descriptor import HSATargetDesc
self.typingctx = HSATargetDesc.typingctx
def __call__(self, *args):
kernel = self.specialize(*args)
cfg = kernel.configure(self.global_size, self.local_size, self.stream)
cfg(*args)
def specialize(self, *args):
argtypes = tuple([self.typingctx.resolve_argument_type(a)
for a in args])
kernel = self.definitions.get(argtypes)
if kernel is None:
kernel = compile_kernel(self.py_func, argtypes)
self.definitions[argtypes] = kernel
return kernel
You can’t perform that action at this time.