Skip to content

Commit

Permalink
Fix numba#5073: Slices of dynamic shared memory all alias
Browse files Browse the repository at this point in the history
The root of the issue was that computations of indices and slice bounds
were incorrect because the shape of dynamic shared memory is generally
(0,). To fix this, we compute the shape (1-D only) of dynamic shared
arrays using the dynamic shared memory size and the itemsize of the type
of the array when it is created.

This is implemented by reading the special register %dynamic_smem_size -
unfortunately NVVM doesn't provide an intrinsic for this, so we access
it using inline assembly.
  • Loading branch information
gmarkall committed Jan 17, 2020
1 parent 2a497bb commit bdeacbd
Show file tree
Hide file tree
Showing 2 changed files with 291 additions and 23 deletions.
63 changes: 40 additions & 23 deletions numba/cuda/cudaimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import operator
import math

from llvmlite.llvmpy.core import Type
from llvmlite.llvmpy.core import Type, InlineAsm
import llvmlite.llvmpy.core as lc
import llvmlite.binding as ll

Expand Down Expand Up @@ -627,30 +627,38 @@ def _get_target_data(context):
def _generic_array(context, builder, shape, dtype, symbol_name, addrspace,
can_dynsized=False):
elemcount = reduce(operator.mul, shape)

# Check for valid shape for this type of allocation
dynamic_smem = elemcount <= 0 and can_dynsized
if elemcount <= 0 and not dynamic_smem:
raise ValueError("array length <= 0")

# Check that we support the requested dtype
other_supported_type = isinstance(dtype, (types.Record, types.Boolean))
if dtype not in types.number_domain and not other_supported_type:
raise TypeError("unsupported type: %s" % dtype)

lldtype = context.get_data_type(dtype)
laryty = Type.array(lldtype, elemcount)

if addrspace == nvvm.ADDRSPACE_LOCAL:
# Special case local addrespace allocation to use alloca
# Special case local address space allocation to use alloca
# NVVM is smart enough to only use local memory if no register is
# available
dataptr = cgutils.alloca_once(builder, laryty, name=symbol_name)
else:
lmod = builder.module

# Create global variable in the requested address-space
# Create global variable in the requested address space
gvmem = lmod.add_global_variable(laryty, symbol_name, addrspace)
# Specify alignment to avoid misalignment bug
align = context.get_abi_sizeof(lldtype)
# Alignment is required to be a power of 2 for shared memory. If it is
# not a power of 2 (e.g. for a Record array) then round up accordingly.
gvmem.align = 1 << (align - 1 ).bit_length()

if elemcount <= 0:
if can_dynsized: # dynamic shared memory
gvmem.linkage = lc.LINKAGE_EXTERNAL
else:
raise ValueError("array length <= 0")
if dynamic_smem:
gvmem.linkage = lc.LINKAGE_EXTERNAL
else:
## Comment out the following line to workaround a NVVM bug
## which generates a invalid symbol name when the linkage
Expand All @@ -660,36 +668,45 @@ def _generic_array(context, builder, shape, dtype, symbol_name, addrspace,

gvmem.initializer = lc.Constant.undef(laryty)

other_supported_type = isinstance(dtype, (types.Record, types.Boolean))
if dtype not in types.number_domain and not other_supported_type:
raise TypeError("unsupported type: %s" % dtype)

# Convert to generic address-space
conv = nvvmutils.insert_addrspace_conv(lmod, Type.int(8), addrspace)
addrspaceptr = gvmem.bitcast(Type.pointer(Type.int(8), addrspace))
dataptr = builder.call(conv, [addrspaceptr])

return _make_array(context, builder, dataptr, dtype, shape)


def _make_array(context, builder, dataptr, dtype, shape, layout='C'):
ndim = len(shape)
# Create array object
aryty = types.Array(dtype=dtype, ndim=ndim, layout='C')
ary = context.make_array(aryty)(context, builder)

targetdata = _get_target_data(context)
lldtype = context.get_data_type(dtype)
itemsize = lldtype.get_abi_size(targetdata)

# Compute strides
rstrides = [itemsize]
for i, lastsize in enumerate(reversed(shape[1:])):
rstrides.append(lastsize * rstrides[-1])
strides = [s for s in reversed(rstrides)]

kshape = [context.get_constant(types.intp, s) for s in shape]
kstrides = [context.get_constant(types.intp, s) for s in strides]

# Compute shape
if dynamic_smem:
# Compute the shape based on the dynamic shared memory configuration.
# Unfortunately NVVM does not provide an intrinsic for the
# %dynamic_smem_size register, so we must read it using inline
# assembly.
get_dynshared_size = InlineAsm.get(Type.function(Type.int(), []),
"mov.u32 $0, %dynamic_smem_size;",
'=r', side_effect=True)
dynsmem_size = builder.zext(builder.call(get_dynshared_size, []),
Type.int(width=64))
# Only 1-D dynamic shared memory is supported so the following is a
# sufficient construction of the shape
kitemsize = context.get_constant(types.intp, itemsize)
kshape = [builder.udiv(dynsmem_size, kitemsize)]
else:
kshape = [context.get_constant(types.intp, s) for s in shape]

# Create array object
ndim = len(shape)
aryty = types.Array(dtype=dtype, ndim=ndim, layout='C')
ary = context.make_array(aryty)(context, builder)

context.populate_array(ary,
data=builder.bitcast(dataptr, ary.data.type),
shape=cgutils.pack_array(builder, kshape),
Expand Down
251 changes: 251 additions & 0 deletions numba/cuda/tests/cudapy/test_sm.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,257 @@ def test_shared_bool(self):
arr = np.random.randint(2, size=(1024,), dtype=np.bool_)
self._test_shared(arr)

def _test_dynshared_slice(self, func, arr, expected):
# Check that slices of shared memory are correct
# (See Bug #5073 - prior to the addition of these tests and
# corresponding fix, slices of dynamic shared arrays all aliased each
# other)
nshared = arr.size * arr.dtype.itemsize
func[1, 1, 0, nshared](arr)
np.testing.assert_array_equal(expected, arr)

def test_dynshared_slice_write(self):
# Test writing values into disjoint slices of dynamic shared memory
@cuda.jit
def slice_write(x):
dynsmem = cuda.shared.array(0, dtype=int32)
sm1 = dynsmem[0:1]
sm2 = dynsmem[1:2]

sm1[0] = 1
sm2[0] = 2
x[0] = dynsmem[0]
x[1] = dynsmem[1]

arr = np.zeros(2, dtype=np.int32)
expected = np.array([1, 2], dtype=np.int32)
self._test_dynshared_slice(slice_write, arr, expected)

def test_dynshared_slice_read(self):
# Test reading values from disjoint slices of dynamic shared memory
@cuda.jit
def slice_read(x):
dynsmem = cuda.shared.array(0, dtype=int32)
sm1 = dynsmem[0:1]
sm2 = dynsmem[1:2]

dynsmem[0] = 1
dynsmem[1] = 2
x[0] = sm1[0]
x[1] = sm2[0]

arr = np.zeros(2, dtype=np.int32)
expected = np.array([1, 2], dtype=np.int32)
self._test_dynshared_slice(slice_read, arr, expected)

def test_dynshared_slice_diff_sizes(self):
# Test reading values from disjoint slices of dynamic shared memory
# with different sizes
@cuda.jit
def slice_diff_sizes(x):
dynsmem = cuda.shared.array(0, dtype=int32)
sm1 = dynsmem[0:1]
sm2 = dynsmem[1:3]

dynsmem[0] = 1
dynsmem[1] = 2
dynsmem[2] = 3
x[0] = sm1[0]
x[1] = sm2[0]
x[2] = sm2[1]

arr = np.zeros(3, dtype=np.int32)
expected = np.array([1, 2, 3], dtype=np.int32)
self._test_dynshared_slice(slice_diff_sizes, arr, expected)

def test_dynshared_slice_overlap(self):
# Test reading values from overlapping slices of dynamic shared memory
@cuda.jit
def slice_overlap(x):
dynsmem = cuda.shared.array(0, dtype=int32)
sm1 = dynsmem[0:2]
sm2 = dynsmem[1:4]

dynsmem[0] = 1
dynsmem[1] = 2
dynsmem[2] = 3
dynsmem[3] = 4
x[0] = sm1[0]
x[1] = sm1[1]
x[2] = sm2[0]
x[3] = sm2[1]
x[4] = sm2[2]

arr = np.zeros(5, dtype=np.int32)
expected = np.array([1, 2, 2, 3, 4], dtype=np.int32)
self._test_dynshared_slice(slice_overlap, arr, expected)

def test_dynshared_slice_gaps(self):
# Test writing values to slices of dynamic shared memory doesn't write
# outside the slice
@cuda.jit
def slice_gaps(x):
dynsmem = cuda.shared.array(0, dtype=int32)
sm1 = dynsmem[1:3]
sm2 = dynsmem[4:6]

# Initial values for dynamic shared memory, some to be overwritten
dynsmem[0] = 99
dynsmem[1] = 99
dynsmem[2] = 99
dynsmem[3] = 99
dynsmem[4] = 99
dynsmem[5] = 99
dynsmem[6] = 99

sm1[0] = 1
sm1[1] = 2
sm2[0] = 3
sm2[1] = 4

x[0] = dynsmem[0]
x[1] = dynsmem[1]
x[2] = dynsmem[2]
x[3] = dynsmem[3]
x[4] = dynsmem[4]
x[5] = dynsmem[5]
x[6] = dynsmem[6]

arr = np.zeros(7, dtype=np.int32)
expected = np.array([99, 1, 2, 99, 3, 4, 99], dtype=np.int32)
self._test_dynshared_slice(slice_gaps, arr, expected)

def test_dynshared_slice_write_backwards(self):
# Test writing values into disjoint slices of dynamic shared memory
# with negative steps
@cuda.jit
def slice_write_backwards(x):
dynsmem = cuda.shared.array(0, dtype=int32)
sm1 = dynsmem[1::-1]
sm2 = dynsmem[3:1:-1]

sm1[0] = 1
sm1[1] = 2
sm2[0] = 3
sm2[1] = 4
x[0] = dynsmem[0]
x[1] = dynsmem[1]
x[2] = dynsmem[2]
x[3] = dynsmem[3]

arr = np.zeros(4, dtype=np.int32)
expected = np.array([2, 1, 4, 3], dtype=np.int32)
self._test_dynshared_slice(slice_write_backwards, arr, expected)

def test_dynshared_slice_nonunit_stride(self):
# Test writing values into slice of dynamic shared memory with
# non-unit stride
@cuda.jit
def slice_nonunit_stride(x):
dynsmem = cuda.shared.array(0, dtype=int32)
sm1 = dynsmem[::2]

# Initial values for dynamic shared memory, some to be overwritten
dynsmem[0] = 99
dynsmem[1] = 99
dynsmem[2] = 99
dynsmem[3] = 99
dynsmem[4] = 99
dynsmem[5] = 99

sm1[0] = 1
sm1[1] = 2
sm1[2] = 3

x[0] = dynsmem[0]
x[1] = dynsmem[1]
x[2] = dynsmem[2]
x[3] = dynsmem[3]
x[4] = dynsmem[4]
x[5] = dynsmem[5]

arr = np.zeros(6, dtype=np.int32)
expected = np.array([1, 99, 2, 99, 3, 99], dtype=np.int32)
self._test_dynshared_slice(slice_nonunit_stride, arr, expected)

def test_dynshared_slice_nonunit_reverse_stride(self):
# Test writing values into slice of dynamic shared memory with
# reverse non-unit stride
@cuda.jit
def slice_nonunit_reverse_stride(x):
dynsmem = cuda.shared.array(0, dtype=int32)
sm1 = dynsmem[-1::-2]

# Initial values for dynamic shared memory, some to be overwritten
dynsmem[0] = 99
dynsmem[1] = 99
dynsmem[2] = 99
dynsmem[3] = 99
dynsmem[4] = 99
dynsmem[5] = 99

sm1[0] = 1
sm1[1] = 2
sm1[2] = 3

x[0] = dynsmem[0]
x[1] = dynsmem[1]
x[2] = dynsmem[2]
x[3] = dynsmem[3]
x[4] = dynsmem[4]
x[5] = dynsmem[5]

arr = np.zeros(6, dtype=np.int32)
expected = np.array([99, 3, 99, 2, 99, 1], dtype=np.int32)
self._test_dynshared_slice(slice_nonunit_reverse_stride, arr, expected)

def test_issue_5073(self):
# An example with which Bug #5073 (slices of dynamic shared memory all
# alias) was discovered. The kernel uses all threads in the block to
# load values into slices of dynamic shared memory. One thread per
# block then writes the loaded values back to a global array after
# syncthreads().

arr = np.arange(1024)
nelem = len(arr)
nthreads = 16
nblocks = int(nelem / nthreads)
dt = nps.from_dtype(arr.dtype)
nshared = nthreads * arr.dtype.itemsize
chunksize = int(nthreads / 2)

@cuda.jit
def sm_slice_copy(x, y, chunksize):
dynsmem = cuda.shared.array(0, dtype=dt)
sm1 = dynsmem[0:chunksize]
sm2 = dynsmem[chunksize:chunksize*2]

tx = cuda.threadIdx.x
bx = cuda.blockIdx.x
bd = cuda.blockDim.x

# load this block's chunk into shared
i = bx * bd + tx
if i < len(x):
if tx < chunksize:
sm1[tx] = x[i]
else:
sm2[tx - chunksize] = x[i]

cuda.syncthreads()

# one thread per block writes this block's chunk
if tx == 0:
for j in range(chunksize):
y[bd * bx + j] = sm1[j]
y[bd * bx + j + chunksize] = sm2[j]


d_result = cuda.device_array_like(arr)
sm_slice_copy[nblocks, nthreads, 0, nshared](arr, d_result, chunksize)
host_result = d_result.copy_to_host()
np.testing.assert_array_equal(arr, host_result)


if __name__ == '__main__':
unittest.main()

0 comments on commit bdeacbd

Please sign in to comment.