Skip to content

Commit

Permalink
Merge pull request #4780 from ivirshup/gcd
Browse files Browse the repository at this point in the history
Implement np.gcd and math.gcd
  • Loading branch information
stuartarchibald committed Dec 4, 2019
2 parents 78eb7a3 + c2fc751 commit 9e12ffe
Show file tree
Hide file tree
Showing 13 changed files with 217 additions and 2 deletions.
1 change: 1 addition & 0 deletions docs/source/reference/numpysupported.rst
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,7 @@ Math operations
square Yes Yes
reciprocal Yes Yes
conjugate Yes Yes
gcd Yes Yes
============== ============= ===============


Expand Down
1 change: 1 addition & 0 deletions docs/source/reference/pysupported.rst
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,7 @@ The following functions from the :mod:`math` module are supported:
* :func:`math.floor`
* :func:`math.frexp`
* :func:`math.gamma`
* :func:`math.gcd`
* :func:`math.hypot`
* :func:`math.isfinite`
* :func:`math.isinf`
Expand Down
52 changes: 51 additions & 1 deletion numba/targets/mathimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
from llvmlite.llvmpy.core import Type

from numba.targets.imputils import Registry, impl_ret_untracked
from numba import types, cgutils, utils, config
from numba import types, typeof, cgutils, utils, config
from numba.extending import overload
from numba.typing import signature
from numba.unsafe.numbers import trailing_zeros


registry = Registry()
Expand Down Expand Up @@ -410,3 +412,51 @@ def degrees_float_impl(context, builder, sig, args):
def pow_impl(context, builder, sig, args):
impl = context.get_function(operator.pow, sig)
return impl(builder, args)

# -----------------------------------------------------------------------------


def _unsigned(T):
"""Convert integer to unsigned integer of equivalent width."""
pass

@overload(_unsigned)
def _unsigned_impl(T):
if T in types.unsigned_domain:
return lambda T: T
elif T in types.signed_domain:
newT = getattr(types, 'uint{}'.format(T.bitwidth))
return lambda T: newT(T)


def gcd_impl(context, builder, sig, args):
xty, yty = sig.args
assert xty == yty == sig.return_type
x, y = args

def gcd(a, b):
"""
Stein's algorithm, heavily cribbed from Julia implementation.
"""
T = type(a)
if a == 0: return abs(b)
if b == 0: return abs(a)
za = trailing_zeros(a)
zb = trailing_zeros(b)
k = min(za, zb)
# Uses np.*_shift instead of operators due to return types
u = _unsigned(abs(np.right_shift(a, za)))
v = _unsigned(abs(np.right_shift(b, zb)))
while u != v:
if u > v:
u, v = v, u
v -= u
v = np.right_shift(v, trailing_zeros(v))
r = np.left_shift(T(u), k)
return r

res = context.compile_internal(builder, gcd, sig, args)
return impl_ret_untracked(context, builder, sig.return_type, res)

if utils.PYVERSION >= (3, 5):
lower(math.gcd, types.Integer, types.Integer)(gcd_impl)
8 changes: 8 additions & 0 deletions numba/targets/npyfuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,14 @@ def np_complex_power_impl(context, builder, sig, args):
return numbers.complex_power_impl(context, builder, sig, args)


########################################################################
# numpy greatest common denominator

def np_gcd_impl(context, builder, sig, args):
_check_arity_and_homogeneity(sig, args, 2)
return mathimpl.gcd_impl(context, builder, sig, args)


########################################################################
# Numpy style complex sign

Expand Down
14 changes: 14 additions & 0 deletions numba/targets/ufunc_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,20 @@ def _fill_ufunc_db(ufunc_db):
'DD->D': npyfuncs.np_complex_power_impl,
}

if v >= (1, 15):
ufunc_db[np.gcd] = {
'bb->b': npyfuncs.np_gcd_impl,
'BB->B': npyfuncs.np_gcd_impl,
'hh->h': npyfuncs.np_gcd_impl,
'HH->H': npyfuncs.np_gcd_impl,
'ii->i': npyfuncs.np_gcd_impl,
'II->I': npyfuncs.np_gcd_impl,
'll->l': npyfuncs.np_gcd_impl,
'LL->L': npyfuncs.np_gcd_impl,
'qq->q': npyfuncs.np_gcd_impl,
'QQ->Q': npyfuncs.np_gcd_impl,
}

ufunc_db[np.rint] = {
'f->f': npyfuncs.np_real_rint_impl,
'd->d': npyfuncs.np_real_rint_impl,
Expand Down
7 changes: 7 additions & 0 deletions numba/tests/test_builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,13 @@ def test_abs(self, flags=enable_pyobj_flags):
for x in complex_values:
self.assertPreciseEqual(cfunc(x), pyfunc(x))

for unsigned_type in types.unsigned_domain:
unsigned_values = [0, 10, 2, 2 ** unsigned_type.bitwidth - 1]
cr = compile_isolated(pyfunc, (unsigned_type,), flags=flags)
cfunc = cr.entry_point
for x in unsigned_values:
self.assertPreciseEqual(cfunc(x), pyfunc(x))

@tag('important')
def test_abs_npm(self):
self.test_abs(flags=no_pyobj_flags)
Expand Down
19 changes: 19 additions & 0 deletions numba/tests/test_mathlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@ def lgamma(x):
def pow(x, y):
return math.pow(x, y)

def gcd(x, y):
return math.gcd(x, y)

def copysign(x, y):
return math.copysign(x, y)
Expand Down Expand Up @@ -623,6 +625,23 @@ def test_pow(self, flags=enable_pyobj_flags):
y_values = [x * 2 for x in x_values]
self.run_binary(pyfunc, x_types, x_values, y_values, flags)

@unittest.skipIf(utils.PYVERSION < (3, 5), "gcd added in Python 3.5")
def test_gcd(self, flags=enable_pyobj_flags):
from itertools import product, repeat, chain
pyfunc = gcd
signed_args = product(
sorted(types.signed_domain), *repeat((-2, -1, 0, 1, 2, 7, 10), 2)
)
unsigned_args = product(
sorted(types.unsigned_domain), *repeat((0, 1, 2, 7, 9, 16), 2)
)
x_types, x_values, y_values = zip(*chain(signed_args, unsigned_args))
self.run_binary(pyfunc, x_types, x_values, y_values, flags)

@unittest.skipIf(utils.PYVERSION < (3, 5), "gcd added in Python 3.5")
def test_gcd_npm(self):
self.test_gcd(flags=no_pyobj_flags)

@tag('important')
def test_pow_npm(self):
self.test_pow(flags=no_pyobj_flags)
Expand Down
4 changes: 4 additions & 0 deletions numba/tests/test_ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,10 @@ def test_power_ufunc(self, flags=no_pyobj_flags):
self.binary_ufunc_test(np.power, flags=flags,
positive_only=after_numpy_112)

def test_gcd_ufunc(self, flags=no_pyobj_flags):
if numpy_support.version >= (1, 15):
self.binary_ufunc_test(np.gcd, flags=flags, kinds="iu")

@tag('important')
def test_remainder_ufunc(self, flags=no_pyobj_flags):
self.binary_ufunc_test(np.remainder, flags=flags)
Expand Down
61 changes: 61 additions & 0 deletions numba/tests/test_unsafe_intrinsics.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from numba.unsafe.ndarray import to_fixed_tuple, empty_inferred
from numba.unsafe.bytes import memcpy_region
from numba.unsafe.refcount import dump_refcount
from numba.unsafe.numbers import trailing_zeros, leading_zeros
from numba.errors import TypingError


Expand Down Expand Up @@ -152,3 +153,63 @@ def use_dump_refcount():
tupty = types.Tuple.from_types([aryty] * 2)
self.assertIn(pat.format(aryty), output)
self.assertIn(pat.format(tupty), output)


class TestZeroCounts(TestCase):
def test_zero_count(self):
lz = njit(lambda x: leading_zeros(x))
tz = njit(lambda x: trailing_zeros(x))

evens = [2, 42, 126, 128]

for T in types.unsigned_domain:
self.assertTrue(tz(T(0)) == lz(T(0)) == T.bitwidth)
for i in range(T.bitwidth):
val = T(2 ** i)
self.assertEqual(lz(val) + tz(val) + 1, T.bitwidth)
for n in evens:
self.assertGreater(tz(T(n)), 0)
self.assertEqual(tz(T(n + 1)), 0)

for T in types.signed_domain:
self.assertTrue(tz(T(0)) == lz(T(0)) == T.bitwidth)
for i in range(T.bitwidth - 1):
val = T(2 ** i)
self.assertEqual(lz(val) + tz(val) + 1, T.bitwidth)
self.assertEqual(lz(-val), 0)
self.assertEqual(tz(val), tz(-val))
for n in evens:
self.assertGreater(tz(T(n)), 0)
self.assertEqual(tz(T(n + 1)), 0)

def check_error_msg(self, func):
cfunc = njit(lambda *x: func(*x))
func_name = func._name

unsupported_types = filter(
lambda x: not isinstance(x, types.Integer), types.number_domain
)
for typ in unsupported_types:
with self.assertRaises(TypingError) as e:
cfunc(typ(2))
self.assertIn(
"{} is only defined for integers, but passed value was '{}'."
.format(func_name, typ),
str(e.exception),
)

# Testing w/ too many arguments
arg_cases = [(1, 2), ()]
for args in arg_cases:
with self.assertRaises(TypingError) as e:
cfunc(*args)
self.assertIn(
"Invalid use of Function({})".format(str(func)),
str(e.exception)
)

def test_trailing_zeros_error(self):
self.check_error_msg(trailing_zeros)

def test_leading_zeros_error(self):
self.check_error_msg(leading_zeros)
3 changes: 2 additions & 1 deletion numba/typing/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,11 @@ def generic(self, args, kws):
@infer_global(abs)
class Abs(ConcreteTemplate):
int_cases = [signature(ty, ty) for ty in sorted(types.signed_domain)]
uint_cases = [signature(ty, ty) for ty in sorted(types.unsigned_domain)]
real_cases = [signature(ty, ty) for ty in sorted(types.real_domain)]
complex_cases = [signature(ty.underlying_float, ty)
for ty in sorted(types.complex_domain)]
cases = int_cases + real_cases + complex_cases
cases = int_cases + uint_cases + real_cases + complex_cases


@infer_global(slice)
Expand Down
16 changes: 16 additions & 0 deletions numba/typing/mathdecl.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,22 @@ class Math_pow(ConcreteTemplate):
signature(types.float64, types.float64, types.float64),
]


if utils.PYVERSION >= (3, 5):
@infer_global(math.gcd)
class Math_gcd(ConcreteTemplate):
cases = [
signature(types.int64, types.int64, types.int64),
signature(types.int32, types.int32, types.int32),
signature(types.int16, types.int16, types.int16),
signature(types.int8, types.int8, types.int8),
signature(types.uint64, types.uint64, types.uint64),
signature(types.uint32, types.uint32, types.uint32),
signature(types.uint16, types.uint16, types.uint16),
signature(types.uint8, types.uint8, types.uint8),
]


@infer_global(math.frexp)
class Math_frexp(ConcreteTemplate):
cases = [
Expand Down
3 changes: 3 additions & 0 deletions numba/typing/npydecl.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,9 @@ def generic(self, args, kws):
"sqrt", "square", "reciprocal",
"divide", "mod", "abs", "fabs" ]

if numpy_version >= (1, 15):
_math_operations.append("gcd")

_trigonometric_functions = [ "sin", "cos", "tan", "arcsin",
"arccos", "arctan", "arctan2",
"hypot", "sinh", "cosh", "tanh",
Expand Down
30 changes: 30 additions & 0 deletions numba/unsafe/numbers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,33 @@ def codegen(cgctx, builder, typ, args):
retty = viewty.dtype
sig = retty(val, viewty)
return sig, codegen


@intrinsic
def trailing_zeros(typeingctx, src):
"""Counts trailing zeros in the binary representation of an integer."""
if not isinstance(src, types.Integer):
raise TypeError(
"trailing_zeros is only defined for integers, but passed value was"
" '{}'.".format(src)
)

def codegen(context, builder, signature, args):
[src] = args
return builder.cttz(src, ir.Constant(ir.IntType(1), 0))
return src(src), codegen


@intrinsic
def leading_zeros(typeingctx, src):
"""Counts leading zeros in the binary representation of an integer."""
if not isinstance(src, types.Integer):
raise TypeError(
"leading_zeros is only defined for integers, but passed value was "
"'{}'.".format(src)
)

def codegen(context, builder, signature, args):
[src] = args
return builder.ctlz(src, ir.Constant(ir.IntType(1), 0))
return src(src), codegen

0 comments on commit 9e12ffe

Please sign in to comment.