Skip to content
Permalink
 
 
Cannot retrieve contributors at this time
328 lines (254 sloc) 10.2 KB
import re
from llvmlite.llvmpy import core as lc
from llvmlite import ir as llvmir
from llvmlite import binding as ll
from numba.core import typing, types, utils, datamodel, cgutils
from numba.core.utils import cached_property
from numba.core.base import BaseContext
from numba.core.callconv import MinimalCallConv
from numba.roc import codegen
from .hlc import DATALAYOUT
CC_SPIR_KERNEL = "spir_kernel"
CC_SPIR_FUNC = ""
# -----------------------------------------------------------------------------
# Typing
class HSATypingContext(typing.BaseContext):
def load_additional_registries(self):
from . import hsadecl, mathdecl
self.install_registry(hsadecl.registry)
self.install_registry(mathdecl.registry)
# -----------------------------------------------------------------------------
# Implementation
VALID_CHARS = re.compile(r'[^a-z0-9]', re.I)
# Address spaces
SPIR_GENERIC_ADDRSPACE = 0
SPIR_GLOBAL_ADDRSPACE = 1
SPIR_REGION_ADDRSPACE = 2
SPIR_CONSTANT_ADDRSPACE = 4
SPIR_LOCAL_ADDRSPACE = 3
SPIR_PRIVATE_ADDRSPACE = 5
SPIR_CONSTANT_32BIT_ADDRSPACE = 6
SPIR_VERSION = (2, 0)
class GenericPointerModel(datamodel.PrimitiveModel):
def __init__(self, dmm, fe_type):
adrsp = SPIR_GENERIC_ADDRSPACE
be_type = dmm.lookup(fe_type.dtype).get_data_type().as_pointer(adrsp)
super(GenericPointerModel, self).__init__(dmm, fe_type, be_type)
def _init_data_model_manager():
dmm = datamodel.default_manager.copy()
dmm.register(types.CPointer, GenericPointerModel)
return dmm
hsa_data_model_manager = _init_data_model_manager()
class HSATargetContext(BaseContext):
implement_powi_as_math_call = True
generic_addrspace = SPIR_GENERIC_ADDRSPACE
def init(self):
self._internal_codegen = codegen.JITHSACodegen("numba.hsa.jit")
self._target_data = \
ll.create_target_data(DATALAYOUT[utils.MACHINE_BITS])
# Override data model manager
self.data_model_manager = hsa_data_model_manager
def load_additional_registries(self):
from . import hsaimpl, mathimpl
self.insert_func_defn(hsaimpl.registry.functions)
self.insert_func_defn(mathimpl.registry.functions)
@cached_property
def call_conv(self):
return HSACallConv(self)
def codegen(self):
return self._internal_codegen
@property
def target_data(self):
return self._target_data
def mangler(self, name, argtypes):
def repl(m):
ch = m.group(0)
return "_%X_" % ord(ch)
qualified = name + '.' + '.'.join(str(a) for a in argtypes)
mangled = VALID_CHARS.sub(repl, qualified)
return 'hsapy_devfn_' + mangled
def prepare_hsa_kernel(self, func, argtypes):
module = func.module
func.linkage = 'linkonce_odr'
module.data_layout = DATALAYOUT[self.address_size]
wrapper = self.generate_kernel_wrapper(func, argtypes)
return wrapper
def mark_hsa_device(self, func):
# Adapt to SPIR
# module = func.module
func.calling_convention = CC_SPIR_FUNC
func.linkage = 'linkonce_odr'
return func
def generate_kernel_wrapper(self, func, argtypes):
module = func.module
arginfo = self.get_arg_packer(argtypes)
def sub_gen_with_global(lty):
if isinstance(lty, llvmir.PointerType):
return (lty.pointee.as_pointer(SPIR_GLOBAL_ADDRSPACE),
lty.addrspace)
return lty, None
if len(arginfo.argument_types) > 0:
llargtys, changed = zip(*map(sub_gen_with_global,
arginfo.argument_types))
else:
llargtys = changed = ()
wrapperfnty = lc.Type.function(lc.Type.void(), llargtys)
wrapper_module = self.create_module("hsa.kernel.wrapper")
wrappername = 'hsaPy_{name}'.format(name=func.name)
argtys = list(arginfo.argument_types)
fnty = lc.Type.function(lc.Type.int(),
[self.call_conv.get_return_type(
types.pyobject)] + argtys)
func = wrapper_module.add_function(fnty, name=func.name)
func.calling_convention = CC_SPIR_FUNC
wrapper = wrapper_module.add_function(wrapperfnty, name=wrappername)
builder = lc.Builder(wrapper.append_basic_block(''))
# Adjust address space of each kernel argument
fixed_args = []
for av, adrsp in zip(wrapper.args, changed):
if adrsp is not None:
casted = self.addrspacecast(builder, av, adrsp)
fixed_args.append(casted)
else:
fixed_args.append(av)
callargs = arginfo.from_arguments(builder, fixed_args)
# XXX handle error status
status, _ = self.call_conv.call_function(builder, func, types.void,
argtypes, callargs)
builder.ret_void()
set_hsa_kernel(wrapper)
# Link
module.link_in(ll.parse_assembly(str(wrapper_module)))
# To enable inlining which is essential because addrspacecast 1->0 is
# illegal. Inlining will optimize the addrspacecast out.
func.linkage = 'internal'
wrapper = module.get_function(wrapper.name)
module.get_function(func.name).linkage = 'internal'
return wrapper
def declare_function(self, module, fndesc):
ret = super(HSATargetContext, self).declare_function(module, fndesc)
# XXX: Refactor fndesc instead of this special case
if fndesc.llvm_func_name.startswith('hsapy_devfn'):
ret.calling_convention = CC_SPIR_FUNC
return ret
def make_constant_array(self, builder, typ, ary):
"""
Return dummy value.
"""
#
# a = self.make_array(typ)(self, builder)
# return a._getvalue()
raise NotImplementedError
def addrspacecast(self, builder, src, addrspace):
"""
Handle addrspacecast
"""
ptras = llvmir.PointerType(src.type.pointee, addrspace=addrspace)
return builder.addrspacecast(src, ptras)
def set_hsa_kernel(fn):
"""
Ensure `fn` is usable as a SPIR kernel.
- Fix calling convention
- Add metadata
"""
mod = fn.module
# Set nounwind
# fn.add_attribute(lc.ATTR_NO_UNWIND)
# Set SPIR kernel calling convention
fn.calling_convention = CC_SPIR_KERNEL
# Mark kernels
ocl_kernels = mod.get_or_insert_named_metadata("opencl.kernels")
ocl_kernels.add(lc.MetaData.get(mod, [fn,
gen_arg_addrspace_md(fn),
gen_arg_access_qual_md(fn),
gen_arg_type(fn),
gen_arg_type_qual(fn),
gen_arg_base_type(fn)]))
# SPIR version 2.0
make_constant = lambda x: lc.Constant.int(lc.Type.int(), x)
spir_version_constant = [make_constant(x) for x in SPIR_VERSION]
spir_version = mod.get_or_insert_named_metadata("opencl.spir.version")
if not spir_version.operands:
spir_version.add(lc.MetaData.get(mod, spir_version_constant))
ocl_version = mod.get_or_insert_named_metadata("opencl.ocl.version")
if not ocl_version.operands:
ocl_version.add(lc.MetaData.get(mod, spir_version_constant))
## The following metadata does not seem to be necessary
# Other metadata
# empty_md = lc.MetaData.get(mod, ())
# others = ["opencl.used.extensions",
# "opencl.used.optional.core.features",
# "opencl.compiler.options"]cat
#
# for name in others:
# nmd = mod.get_or_insert_named_metadata(name)
# if not nmd.operands:
# nmd.add(empty_md)
def gen_arg_addrspace_md(fn):
"""
Generate kernel_arg_addr_space metadata
"""
mod = fn.module
fnty = fn.type.pointee
codes = []
for a in fnty.args:
if cgutils.is_pointer(a):
codes.append(SPIR_GLOBAL_ADDRSPACE)
else:
codes.append(SPIR_PRIVATE_ADDRSPACE)
consts = [lc.Constant.int(lc.Type.int(), x) for x in codes]
name = lc.MetaDataString.get(mod, "kernel_arg_addr_space")
return lc.MetaData.get(mod, [name] + consts)
def gen_arg_access_qual_md(fn):
"""
Generate kernel_arg_access_qual metadata
"""
mod = fn.module
consts = [lc.MetaDataString.get(mod, "none")] * len(fn.args)
name = lc.MetaDataString.get(mod, "kernel_arg_access_qual")
return lc.MetaData.get(mod, [name] + consts)
def gen_arg_type(fn):
"""
Generate kernel_arg_type metadata
"""
mod = fn.module
fnty = fn.type.pointee
consts = [lc.MetaDataString.get(mod, str(a)) for a in fnty.args]
name = lc.MetaDataString.get(mod, "kernel_arg_type")
return lc.MetaData.get(mod, [name] + consts)
def gen_arg_type_qual(fn):
"""
Generate kernel_arg_type_qual metadata
"""
mod = fn.module
fnty = fn.type.pointee
consts = [lc.MetaDataString.get(mod, "") for _ in fnty.args]
name = lc.MetaDataString.get(mod, "kernel_arg_type_qual")
return lc.MetaData.get(mod, [name] + consts)
def gen_arg_base_type(fn):
"""
Generate kernel_arg_base_type metadata
"""
mod = fn.module
fnty = fn.type.pointee
consts = [lc.MetaDataString.get(mod, str(a)) for a in fnty.args]
name = lc.MetaDataString.get(mod, "kernel_arg_base_type")
return lc.MetaData.get(mod, [name] + consts)
class HSACallConv(MinimalCallConv):
def call_function(self, builder, callee, resty, argtys, args, env=None):
"""
Call the Numba-compiled *callee*.
"""
assert env is None
retty = callee.args[0].type.pointee
retvaltmp = cgutils.alloca_once(builder, retty)
# initialize return value
builder.store(cgutils.get_null_value(retty), retvaltmp)
arginfo = self.context.get_arg_packer(argtys)
args = arginfo.as_arguments(builder, args)
realargs = [retvaltmp] + list(args)
code = builder.call(callee, realargs)
status = self._get_return_status(builder, code)
retval = builder.load(retvaltmp)
out = self.context.get_returned_value(builder, resty, retval)
return status, out
You can’t perform that action at this time.