Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ds_permute types and add tests #3281

Merged
merged 3 commits into from
Sep 9, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
18 changes: 14 additions & 4 deletions numba/roc/hsadecl.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,25 @@ class Hsa_activelanepermute_wavewidth(ConcreteTemplate):
cases = [signature(ty, ty, types.uint32, ty, types.bool_)
for ty in (types.integer_domain|types.real_domain)]


class _Hsa_ds_permuting(ConcreteTemplate):
# parameter: index, source
cases = [signature(types.int32, types.int32, types.int32),
signature(types.int32, types.int64, types.int32),
signature(types.float32, types.int32, types.float32),
signature(types.float32, types.int64, types.float32)]
unsafe_casting = False


@intrinsic
class Hsa_ds_permute(ConcreteTemplate):
class Hsa_ds_permute(_Hsa_ds_permuting):
key = roc.ds_permute
cases = [signature(types.int32, types.int32, types.int32)]


@intrinsic
class Hsa_ds_bpermute(ConcreteTemplate):
class Hsa_ds_bpermute(_Hsa_ds_permuting):
key = roc.ds_bpermute
cases = [signature(types.int32, types.int32, types.int32)]


# hsa.shared submodule -------------------------------------------------------

Expand Down
13 changes: 9 additions & 4 deletions numba/roc/hsaimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,19 +187,24 @@ def _impl(context, builder, sig, args):
"""
args are (index, src)
"""
assert sig.args[0] == sig.args[1]
assert sig.return_type == sig.args[1]
idx, src = args
i32 = Type.int(32)
fnty = Type.function(i32, [i32, i32])
fn = builder.module.declare_intrinsic(intrinsic_name, fnty=fnty)
# the args are byte addressable, VGPRs are 4 wide so mul idx by 4
# the idx might be an int64, this is ok to trunc to int32 as
# wavefront_size is never likely overflow an int32
idx = builder.trunc(idx, i32)
four = lc.Constant.int(i32, 4)
idx = builder.mul(idx, four)
return builder.call(fn, (idx, src))
# bit cast is so float32 works as packed i32, the return casts back
result = builder.call(fn, (idx, builder.bitcast(src, i32)))
return builder.bitcast(result, context.get_value_type(sig.return_type))
return _impl

lower(stubs.ds_permute, types.int32, types.int32)(_gen_ds_permute('llvm.amdgcn.ds.permute'))
lower(stubs.ds_bpermute, types.int32, types.int32)(_gen_ds_permute('llvm.amdgcn.ds.bpermute'))
lower(stubs.ds_permute, types.Any, types.Any)(_gen_ds_permute('llvm.amdgcn.ds.permute'))
lower(stubs.ds_bpermute, types.Any, types.Any)(_gen_ds_permute('llvm.amdgcn.ds.bpermute'))

@lower(stubs.atomic.add, types.Array, types.intp, types.Any)
@lower(stubs.atomic.add, types.Array,
Expand Down
110 changes: 110 additions & 0 deletions numba/roc/tests/hsapy/test_intrinsics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
from __future__ import print_function, absolute_import, division

import numpy as np

from numba import unittest_support as unittest
from numba import roc
from numba.errors import TypingError
import operator as oper

_WAVESIZE = roc.get_context().agent.wavefront_size

@roc.jit(device=True)
def shuffle_up(val, width):
tid = roc.get_local_id(0)
roc.wavebarrier()
idx = (tid + width) % _WAVESIZE
res = roc.ds_permute(idx, val)
return res

@roc.jit(device=True)
def shuffle_down(val, width):
tid = roc.get_local_id(0)
roc.wavebarrier()
idx = (tid - width) % _WAVESIZE
res = roc.ds_permute(idx, val)
return res

@roc.jit(device=True)
def broadcast(val, from_lane):
tid = roc.get_local_id(0)
roc.wavebarrier()
res = roc.ds_bpermute(from_lane, val)
return res

def gen_kernel(shuffunc):
@roc.jit
def kernel(inp, outp, amount):
tid = roc.get_local_id(0)
val = inp[tid]
outp[tid] = shuffunc(val, amount)
return kernel


class TestDsPermute(unittest.TestCase):

def test_ds_permute(self):

inp = np.arange(_WAVESIZE).astype(np.int32)
outp = np.zeros_like(inp)

for shuffler, op in [(shuffle_down, oper.neg), (shuffle_up, oper.pos)]:
kernel = gen_kernel(shuffler)
for shuf in range(-_WAVESIZE, _WAVESIZE):
kernel[1, _WAVESIZE](inp, outp, shuf)
np.testing.assert_allclose(outp, np.roll(inp, op(shuf)))

def test_ds_permute_random_floats(self):

inp = np.linspace(0, 1, _WAVESIZE).astype(np.float32)
outp = np.zeros_like(inp)

for shuffler, op in [(shuffle_down, oper.neg), (shuffle_up, oper.pos)]:
kernel = gen_kernel(shuffler)
for shuf in range(-_WAVESIZE, _WAVESIZE):
kernel[1, _WAVESIZE](inp, outp, shuf)
np.testing.assert_allclose(outp, np.roll(inp, op(shuf)))

def test_ds_permute_type_safety(self):
""" Checks that float64's are not being downcast to float32"""
kernel = gen_kernel(shuffle_down)
inp = np.linspace(0, 1, _WAVESIZE).astype(np.float64)
outp = np.zeros_like(inp)
with self.assertRaises(TypingError) as e:
kernel[1, _WAVESIZE](inp, outp, 1)
errmsg = e.exception.msg
self.assertIn('Invalid use of Function', errmsg)
self.assertIn('with argument(s) of type(s): (float64, int64)', errmsg)

def test_ds_bpermute(self):

@roc.jit
def kernel(inp, outp, lane):
tid = roc.get_local_id(0)
val = inp[tid]
outp[tid] = broadcast(val, lane)

inp = np.arange(_WAVESIZE).astype(np.int32)
outp = np.zeros_like(inp)
for lane in range(0, _WAVESIZE):
kernel[1, _WAVESIZE](inp, outp, lane)
np.testing.assert_allclose(outp, lane)

def test_ds_bpermute_random_floats(self):

@roc.jit
def kernel(inp, outp, lane):
tid = roc.get_local_id(0)
val = inp[tid]
outp[tid] = broadcast(val, lane)

inp = np.linspace(0, 1, _WAVESIZE).astype(np.float32)
outp = np.zeros_like(inp)

for lane in range(0, _WAVESIZE):
kernel[1, _WAVESIZE](inp, outp, lane)
np.testing.assert_allclose(outp, inp[lane])


if __name__ == '__main__':
unittest.main()
85 changes: 45 additions & 40 deletions numba/roc/tests/hsapy/test_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np

from numba import unittest_support as unittest
from numba import roc, intp
from numba import roc, intp, int32


@roc.jit(device=True)
Expand Down Expand Up @@ -158,43 +158,48 @@ def shuffle_up(val, width):
res = roc.ds_permute(idx, val)
return res

@roc.jit(device=True)
def shuf_wave_inclusive_scan(val):
tid = roc.get_local_id(0)
lane = tid & (_WARPSIZE - 1)
def make_inclusive_scan(dtype):
@roc.jit(device=True)
def shuf_wave_inclusive_scan(val):
tid = roc.get_local_id(0)
lane = tid & (_WARPSIZE - 1)

roc.wavebarrier()
shuf = shuffle_up(val, 1)
if lane >= 1:
val += shuf
roc.wavebarrier()
shuf = shuffle_up(val, 1)
if lane >= 1:
val = dtype(val + shuf)

roc.wavebarrier()
shuf = shuffle_up(val, 2)
if lane >= 2:
val += shuf
roc.wavebarrier()
shuf = shuffle_up(val, 2)
if lane >= 2:
val = dtype(val + shuf)

roc.wavebarrier()
shuf = shuffle_up(val, 4)
if lane >= 4:
val += shuf
roc.wavebarrier()
shuf = shuffle_up(val, 4)
if lane >= 4:
val = dtype(val + shuf)

roc.wavebarrier()
shuf = shuffle_up(val, 8)
if lane >= 8:
val += shuf
roc.wavebarrier()
shuf = shuffle_up(val, 8)
if lane >= 8:
val = dtype(val + shuf)

roc.wavebarrier()
shuf = shuffle_up(val, 16)
if lane >= 16:
val += shuf
roc.wavebarrier()
shuf = shuffle_up(val, 16)
if lane >= 16:
val = dtype(val + shuf)

roc.wavebarrier()
shuf = shuffle_up(val, 32)
if lane >= 32:
val += shuf
roc.wavebarrier()
shuf = shuffle_up(val, 32)
if lane >= 32:
val = dtype(val + shuf)

roc.wavebarrier()
return val
roc.wavebarrier()
return val
return shuf_wave_inclusive_scan


shuf_wave_inclusive_scan_int32 = make_inclusive_scan(int32)


@roc.jit(device=True)
Expand All @@ -212,7 +217,7 @@ def shuf_device_inclusive_scan(data, temp):
warpid = tid >> 6

# Scan warps in parallel
warp_scan_res = shuf_wave_inclusive_scan(data)
warp_scan_res = shuf_wave_inclusive_scan_int32(data)

roc.barrier()

Expand All @@ -224,7 +229,7 @@ def shuf_device_inclusive_scan(data, temp):

# Scan the partial sum by first wave
if warpid == 0:
shuf_wave_inclusive_scan(temp[lane])
shuf_wave_inclusive_scan_int32(temp[lane])

roc.barrier()

Expand Down Expand Up @@ -396,10 +401,10 @@ def foo(inp, mask, out):
tid = roc.get_local_id(0)
out[tid] = roc.ds_permute(inp[tid], mask[tid])

inp = np.arange(64, dtype=np.intp)
inp = np.arange(64, dtype=np.int32)
np.random.seed(0)
for i in range(10):
mask = np.random.randint(0, inp.size, inp.size).astype(np.uint32)
mask = np.random.randint(0, inp.size, inp.size).astype(np.int32)
out = np.zeros_like(inp)
foo[1, 64](inp, mask, out)
np.testing.assert_equal(inp[mask], out)
Expand All @@ -410,7 +415,7 @@ def foo(inp, out):
gid = roc.get_global_id(0)
out[gid] = shuffle_up(inp[gid], 1)

inp = np.arange(128, dtype=np.intp)
inp = np.arange(128, dtype=np.int32)
out = np.zeros_like(inp)
foo[1, 128](inp, out)

Expand All @@ -425,9 +430,9 @@ def test_shuf_wave_inclusive_scan(self):
@roc.jit
def foo(inp, out):
gid = roc.get_global_id(0)
out[gid] = shuf_wave_inclusive_scan(inp[gid])
out[gid] = shuf_wave_inclusive_scan_int32(inp[gid])

inp = np.arange(64, dtype=np.intp)
inp = np.arange(64, dtype=np.int32)
out = np.zeros_like(inp)
foo[1, 64](inp, out)
np.testing.assert_equal(inp.cumsum(), out)
Expand All @@ -436,10 +441,10 @@ def test_shuf_device_inclusive_scan(self):
@roc.jit
def foo(inp, out):
gid = roc.get_global_id(0)
temp = roc.shared.array(2, dtype=intp)
temp = roc.shared.array(2, dtype=int32)
out[gid] = shuf_device_inclusive_scan(inp[gid], temp)

inp = np.arange(128, dtype=np.intp)
inp = np.arange(128, dtype=np.int32)
out = np.zeros_like(inp)

foo[1, inp.size](inp, out)
Expand Down