Skip to content
Permalink
 
 
Cannot retrieve contributors at this time
219 lines (166 sloc) 7.69 KB
import itertools
import llvmlite.llvmpy.core as lc
from .cudadrv import nvvm
from .api import current_context
def declare_atomic_cas_int32(lmod):
fname = '___numba_cas_hack'
fnty = lc.Type.function(lc.Type.int(32), (lc.Type.pointer(lc.Type.int(32)),
lc.Type.int(32),
lc.Type.int(32)))
return lmod.get_or_insert_function(fnty, fname)
# For atomic intrinsics, "numba_nvvm" prevents LLVM 9 onwards auto-upgrading
# them into atomicrmw instructions that are not recognized by NVVM. It is
# replaced with "nvvm" in llvm_to_ptx later, after the module has been parsed
# and dumped by LLVM.
def declare_atomic_add_float32(lmod):
fname = 'llvm.numba_nvvm.atomic.load.add.f32.p0f32'
fnty = lc.Type.function(lc.Type.float(), (lc.Type.pointer(lc.Type.float(),
0),
lc.Type.float()))
return lmod.get_or_insert_function(fnty, name=fname)
def declare_atomic_add_float64(lmod):
if current_context().device.compute_capability >= (6, 0):
fname = 'llvm.numba_nvvm.atomic.load.add.f64.p0f64'
else:
fname = '___numba_atomic_double_add'
fnty = lc.Type.function(lc.Type.double(),
(lc.Type.pointer(lc.Type.double()),
lc.Type.double()))
return lmod.get_or_insert_function(fnty, fname)
def declare_atomic_sub_float32(lmod):
fname = '___numba_atomic_float_sub'
fnty = lc.Type.function(lc.Type.float(), (lc.Type.pointer(lc.Type.float()),
lc.Type.float()))
return lmod.get_or_insert_function(fnty, name=fname)
def declare_atomic_sub_float64(lmod):
fname = '___numba_atomic_double_sub'
fnty = lc.Type.function(lc.Type.double(),
(lc.Type.pointer(lc.Type.double()),
lc.Type.double()))
return lmod.get_or_insert_function(fnty, fname)
def declare_atomic_max_float32(lmod):
fname = '___numba_atomic_float_max'
fnty = lc.Type.function(lc.Type.float(), (lc.Type.pointer(lc.Type.float()),
lc.Type.float()))
return lmod.get_or_insert_function(fnty, fname)
def declare_atomic_max_float64(lmod):
fname = '___numba_atomic_double_max'
fnty = lc.Type.function(lc.Type.double(),
(lc.Type.pointer(lc.Type.double()),
lc.Type.double()))
return lmod.get_or_insert_function(fnty, fname)
def declare_atomic_min_float32(lmod):
fname = '___numba_atomic_float_min'
fnty = lc.Type.function(lc.Type.float(), (lc.Type.pointer(lc.Type.float()),
lc.Type.float()))
return lmod.get_or_insert_function(fnty, fname)
def declare_atomic_min_float64(lmod):
fname = '___numba_atomic_double_min'
fnty = lc.Type.function(lc.Type.double(),
(lc.Type.pointer(lc.Type.double()),
lc.Type.double()))
return lmod.get_or_insert_function(fnty, fname)
def declare_atomic_nanmax_float32(lmod):
fname = '___numba_atomic_float_nanmax'
fnty = lc.Type.function(lc.Type.float(), (lc.Type.pointer(lc.Type.float()),
lc.Type.float()))
return lmod.get_or_insert_function(fnty, fname)
def declare_atomic_nanmax_float64(lmod):
fname = '___numba_atomic_double_nanmax'
fnty = lc.Type.function(lc.Type.double(),
(lc.Type.pointer(lc.Type.double()),
lc.Type.double()))
return lmod.get_or_insert_function(fnty, fname)
def declare_atomic_nanmin_float32(lmod):
fname = '___numba_atomic_float_nanmin'
fnty = lc.Type.function(lc.Type.float(), (lc.Type.pointer(lc.Type.float()),
lc.Type.float()))
return lmod.get_or_insert_function(fnty, fname)
def declare_atomic_nanmin_float64(lmod):
fname = '___numba_atomic_double_nanmin'
fnty = lc.Type.function(lc.Type.double(),
(lc.Type.pointer(lc.Type.double()),
lc.Type.double()))
return lmod.get_or_insert_function(fnty, fname)
def insert_addrspace_conv(lmod, elemtype, addrspace):
addrspacename = {
nvvm.ADDRSPACE_SHARED: 'shared',
nvvm.ADDRSPACE_LOCAL: 'local',
nvvm.ADDRSPACE_CONSTANT: 'constant',
}[addrspace]
tyname = str(elemtype)
tyname = {'float': 'f32', 'double': 'f64'}.get(tyname, tyname)
s2g_name_fmt = 'llvm.nvvm.ptr.' + addrspacename + '.to.gen.p0%s.p%d%s'
s2g_name = s2g_name_fmt % (tyname, addrspace, tyname)
elem_ptr_ty = lc.Type.pointer(elemtype)
elem_ptr_ty_addrspace = lc.Type.pointer(elemtype, addrspace)
s2g_fnty = lc.Type.function(elem_ptr_ty,
[elem_ptr_ty_addrspace])
return lmod.get_or_insert_function(s2g_fnty, s2g_name)
def declare_string(builder, value):
lmod = builder.basic_block.function.module
cval = lc.Constant.stringz(value)
gl = lmod.add_global_variable(cval.type, name="_str",
addrspace=nvvm.ADDRSPACE_CONSTANT)
gl.linkage = lc.LINKAGE_INTERNAL
gl.global_constant = True
gl.initializer = cval
charty = lc.Type.int(8)
constcharptrty = lc.Type.pointer(charty, nvvm.ADDRSPACE_CONSTANT)
charptr = builder.bitcast(gl, constcharptrty)
conv = insert_addrspace_conv(lmod, charty, nvvm.ADDRSPACE_CONSTANT)
return builder.call(conv, [charptr])
def declare_vprint(lmod):
voidptrty = lc.Type.pointer(lc.Type.int(8))
# NOTE: the second argument to vprintf() points to the variable-length
# array of arguments (after the format)
vprintfty = lc.Type.function(lc.Type.int(), [voidptrty, voidptrty])
vprintf = lmod.get_or_insert_function(vprintfty, "vprintf")
return vprintf
# -----------------------------------------------------------------------------
SREG_MAPPING = {
'tid.x': 'llvm.nvvm.read.ptx.sreg.tid.x',
'tid.y': 'llvm.nvvm.read.ptx.sreg.tid.y',
'tid.z': 'llvm.nvvm.read.ptx.sreg.tid.z',
'ntid.x': 'llvm.nvvm.read.ptx.sreg.ntid.x',
'ntid.y': 'llvm.nvvm.read.ptx.sreg.ntid.y',
'ntid.z': 'llvm.nvvm.read.ptx.sreg.ntid.z',
'ctaid.x': 'llvm.nvvm.read.ptx.sreg.ctaid.x',
'ctaid.y': 'llvm.nvvm.read.ptx.sreg.ctaid.y',
'ctaid.z': 'llvm.nvvm.read.ptx.sreg.ctaid.z',
'nctaid.x': 'llvm.nvvm.read.ptx.sreg.nctaid.x',
'nctaid.y': 'llvm.nvvm.read.ptx.sreg.nctaid.y',
'nctaid.z': 'llvm.nvvm.read.ptx.sreg.nctaid.z',
'warpsize': 'llvm.nvvm.read.ptx.sreg.warpsize',
'laneid': 'llvm.nvvm.read.ptx.sreg.laneid',
}
def call_sreg(builder, name):
module = builder.module
fnty = lc.Type.function(lc.Type.int(), ())
fn = module.get_or_insert_function(fnty, name=SREG_MAPPING[name])
return builder.call(fn, ())
class SRegBuilder(object):
def __init__(self, builder):
self.builder = builder
def tid(self, xyz):
return call_sreg(self.builder, 'tid.%s' % xyz)
def ctaid(self, xyz):
return call_sreg(self.builder, 'ctaid.%s' % xyz)
def ntid(self, xyz):
return call_sreg(self.builder, 'ntid.%s' % xyz)
def nctaid(self, xyz):
return call_sreg(self.builder, 'nctaid.%s' % xyz)
def getdim(self, xyz):
tid = self.tid(xyz)
ntid = self.ntid(xyz)
nctaid = self.ctaid(xyz)
res = self.builder.add(self.builder.mul(ntid, nctaid), tid)
return res
def get_global_id(builder, dim):
sreg = SRegBuilder(builder)
it = (sreg.getdim(xyz) for xyz in 'xyz')
seq = list(itertools.islice(it, None, dim))
if dim == 1:
return seq[0]
else:
return seq
You can’t perform that action at this time.