Skip to content

Commit

Permalink
Merge pull request #20297 from charris/backport-20153
Browse files Browse the repository at this point in the history
BUG, SIMD: Fix 64-bit/8-bit integer division by a scalar
  • Loading branch information
charris committed Nov 4, 2021
2 parents 6b3d17e + 86a8ae5 commit 3d1487b
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 53 deletions.
7 changes: 4 additions & 3 deletions numpy/core/src/common/simd/intdiv.h
Expand Up @@ -162,11 +162,12 @@ NPY_FINLINE npy_uint64 npyv__divh128_u64(npy_uint64 high, npy_uint64 divisor)
npy_uint32 divisor_hi = divisor >> 32;
npy_uint32 divisor_lo = divisor & 0xFFFFFFFF;
// compute high quotient digit
npy_uint32 quotient_hi = (npy_uint32)(high / divisor_hi);
npy_uint64 quotient_hi = high / divisor_hi;
npy_uint64 remainder = high - divisor_hi * quotient_hi;
npy_uint64 base32 = 1ULL << 32;
while (quotient_hi >= base32 || quotient_hi*divisor_lo > base32*remainder) {
remainder += --divisor_hi;
--quotient_hi;
remainder += divisor_hi;
if (remainder >= base32) {
break;
}
Expand Down Expand Up @@ -200,7 +201,7 @@ NPY_FINLINE npyv_u8x3 npyv_divisor_u8(npy_uint8 d)
default:
l = npyv__bitscan_revnz_u32(d - 1) + 1; // ceil(log2(d))
l2 = (npy_uint8)(1 << l); // 2^l, overflow to 0 if l = 8
m = ((l2 - d) << 8) / d + 1; // multiplier
m = ((npy_uint16)((l2 - d) << 8)) / d + 1; // multiplier
sh1 = 1; sh2 = l - 1; // shift counts
}
npyv_u8x3 divisor;
Expand Down
75 changes: 25 additions & 50 deletions numpy/core/tests/test_simd.py
Expand Up @@ -329,7 +329,7 @@ def test_square(self):
data_square = [x*x for x in data]
square = self.square(vdata)
assert square == data_square

def test_max(self):
"""
Test intrinsics:
Expand Down Expand Up @@ -818,6 +818,7 @@ def test_arithmetic_intdiv(self):
if self._is_fp():
return

int_min = self._int_min()
def trunc_div(a, d):
"""
Divide towards zero works with large integers > 2^53,
Expand All @@ -830,57 +831,31 @@ def trunc_div(a, d):
return a // d
return (a + sign_d - sign_a) // d + 1

int_min = self._int_min() if self._is_signed() else 1
int_max = self._int_max()
rdata = (
0, 1, self.nlanes, int_max-self.nlanes,
int_min, int_min//2 + 1
)
divisors = (1, 2, 9, 13, self.nlanes, int_min, int_max, int_max//2)

for x, d in itertools.product(rdata, divisors):
data = self._data(x)
vdata = self.load(data)
data_divc = [trunc_div(a, d) for a in data]
divisor = self.divisor(d)
divc = self.divc(vdata, divisor)
assert divc == data_divc

if not self._is_signed():
return

safe_neg = lambda x: -x-1 if -x > int_max else -x
# test round divison for signed integers
for x, d in itertools.product(rdata, divisors):
d_neg = safe_neg(d)
data = self._data(x)
data_neg = [safe_neg(a) for a in data]
vdata = self.load(data)
vdata_neg = self.load(data_neg)
divisor = self.divisor(d)
divisor_neg = self.divisor(d_neg)

# round towards zero
data_divc = [trunc_div(a, d_neg) for a in data]
divc = self.divc(vdata, divisor_neg)
assert divc == data_divc
data_divc = [trunc_div(a, d) for a in data_neg]
divc = self.divc(vdata_neg, divisor)
data = [1, -int_min] # to test overflow
data += range(0, 2**8, 2**5)
data += range(0, 2**8, 2**5-1)
bsize = self._scalar_size()
if bsize > 8:
data += range(2**8, 2**16, 2**13)
data += range(2**8, 2**16, 2**13-1)
if bsize > 16:
data += range(2**16, 2**32, 2**29)
data += range(2**16, 2**32, 2**29-1)
if bsize > 32:
data += range(2**32, 2**64, 2**61)
data += range(2**32, 2**64, 2**61-1)
# negate
data += [-x for x in data]
for dividend, divisor in itertools.product(data, data):
divisor = self.setall(divisor)[0] # cast
if divisor == 0:
continue
dividend = self.load(self._data(dividend))
data_divc = [trunc_div(a, divisor) for a in dividend]
divisor_parms = self.divisor(divisor)
divc = self.divc(dividend, divisor_parms)
assert divc == data_divc

# test truncate sign if the dividend is zero
vzero = self.zero()
for d in (-1, -10, -100, int_min//2, int_min):
divisor = self.divisor(d)
divc = self.divc(vzero, divisor)
assert divc == vzero

# test overflow
vmin = self.setall(int_min)
divisor = self.divisor(-1)
divc = self.divc(vmin, divisor)
assert divc == vmin

def test_arithmetic_reduce_sum(self):
"""
Test reduce sum intrinsics:
Expand Down

0 comments on commit 3d1487b

Please sign in to comment.