Skip to content

Commit

Permalink
Merge pull request #8154 from testhound/testhound/native-cast-8138
Browse files Browse the repository at this point in the history
Testhound/native cast 8138
  • Loading branch information
esc committed Jul 4, 2022
2 parents 3023967 + 3c9cb3a commit 15e4878
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 0 deletions.
12 changes: 12 additions & 0 deletions numba/cuda/cudadecl.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,18 @@ class Cuda_fp16_binary(ConcreteTemplate):
return Cuda_fp16_binary


@register_global(float)
class Float(AbstractTemplate):

def generic(self, args, kws):
assert not kws

[arg] = args

if arg == types.float16:
return signature(arg, arg)


def _genfp16_binary_comparison(l_key):
@register
class Cuda_fp16_cmp(ConcreteTemplate):
Expand Down
13 changes: 13 additions & 0 deletions numba/cuda/tests/cudapy/test_casting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,15 @@
from numba.core import types
from numba.cuda.testing import (CUDATestCase, skip_on_cudasim,
skip_unless_cc_53)
from numba.types import float16, float32
import itertools
import unittest


def native_cast(x):
return float(x)


def to_int8(x):
return np.int8(x)

Expand Down Expand Up @@ -234,6 +239,14 @@ def test_float_to_complex(self):
np.testing.assert_allclose(cfunc(-3.21),
pyfunc(fromty(-3.21)) + 0j)

@skip_on_cudasim('Compilation unsupported in the simulator')
def test_native_cast(self):
float32_ptx, _ = cuda.compile_ptx(native_cast, (float32,), device=True)
self.assertIn("st.f32", float32_ptx)

float16_ptx, _ = cuda.compile_ptx(native_cast, (float16,), device=True)
self.assertIn("st.u16", float16_ptx)


if __name__ == '__main__':
unittest.main()

0 comments on commit 15e4878

Please sign in to comment.