Skip to content

Commit

Permalink
BUG, SIMD: Fix invalid value encountered in rint/trunc/ceil/floor on …
Browse files Browse the repository at this point in the history
…armhf/neon
  • Loading branch information
seiko2plus authored and charris committed Dec 19, 2022
1 parent 783b6de commit 50989d8
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 68 deletions.
149 changes: 81 additions & 68 deletions numpy/core/src/common/simd/neon/math.h
Expand Up @@ -278,20 +278,25 @@ NPY_FINLINE npyv_f32 npyv_rint_f32(npyv_f32 a)
return vrndnq_f32(a);
#else
// ARMv7 NEON only supports fp to int truncate conversion.
// a magic trick of adding 1.5 * 2**23 is used for rounding
// a magic trick of adding 1.5 * 2^23 is used for rounding
// to nearest even and then subtract this magic number to get
// the integer.
const npyv_s32 szero = vreinterpretq_s32_f32(vdupq_n_f32(-0.0f));
const npyv_f32 magic = vdupq_n_f32(12582912.0f); // 1.5 * 2**23
npyv_f32 round = vsubq_f32(vaddq_f32(a, magic), magic);
npyv_b32 overflow = vcleq_f32(vabsq_f32(a), vreinterpretq_f32_u32(vdupq_n_u32(0x4b000000)));
round = vbslq_f32(overflow, round, a);
// signed zero
round = vreinterpretq_f32_s32(vorrq_s32(
vreinterpretq_s32_f32(round),
vandq_s32(vreinterpretq_s32_f32(a), szero)
));
return round;
//
const npyv_u32 szero = vreinterpretq_u32_f32(vdupq_n_f32(-0.0f));
const npyv_u32 sign_mask = vandq_u32(vreinterpretq_u32_f32(a), szero);
const npyv_f32 two_power_23 = vdupq_n_f32(8388608.0); // 2^23
const npyv_f32 two_power_23h = vdupq_n_f32(12582912.0f); // 1.5 * 2^23
npyv_u32 nnan_mask = vceqq_f32(a, a);
// eliminate nans to avoid invalid fp errors
npyv_f32 abs_x = vabsq_f32(vreinterpretq_f32_u32(vandq_u32(nnan_mask, vreinterpretq_u32_f32(a))));
// round by add magic number 1.5 * 2^23
npyv_f32 round = vsubq_f32(vaddq_f32(two_power_23h, abs_x), two_power_23h);
// copysign
round = vreinterpretq_f32_u32(vorrq_u32(vreinterpretq_u32_f32(round), sign_mask ));
// a if |a| >= 2^23 or a == NaN
npyv_u32 mask = vcleq_f32(abs_x, two_power_23);
mask = vandq_u32(mask, nnan_mask);
return vbslq_f32(mask, round, a);
#endif
}
#if NPY_SIMD_F64
Expand All @@ -302,33 +307,30 @@ NPY_FINLINE npyv_f32 npyv_rint_f32(npyv_f32 a)
#ifdef NPY_HAVE_ASIMD
#define npyv_ceil_f32 vrndpq_f32
#else
NPY_FINLINE npyv_f32 npyv_ceil_f32(npyv_f32 a)
{
const npyv_s32 szero = vreinterpretq_s32_f32(vdupq_n_f32(-0.0f));
NPY_FINLINE npyv_f32 npyv_ceil_f32(npyv_f32 a)
{
const npyv_u32 one = vreinterpretq_u32_f32(vdupq_n_f32(1.0f));
const npyv_s32 max_int = vdupq_n_s32(0x7fffffff);
/**
* On armv7, vcvtq.f32 handles special cases as follows:
* NaN return 0
* +inf or +outrange return 0x80000000(-0.0f)
* -inf or -outrange return 0x7fffffff(nan)
*/
npyv_s32 roundi = vcvtq_s32_f32(a);
npyv_f32 round = vcvtq_f32_s32(roundi);
const npyv_u32 szero = vreinterpretq_u32_f32(vdupq_n_f32(-0.0f));
const npyv_u32 sign_mask = vandq_u32(vreinterpretq_u32_f32(a), szero);
const npyv_f32 two_power_23 = vdupq_n_f32(8388608.0); // 2^23
const npyv_f32 two_power_23h = vdupq_n_f32(12582912.0f); // 1.5 * 2^23
npyv_u32 nnan_mask = vceqq_f32(a, a);
npyv_f32 x = vreinterpretq_f32_u32(vandq_u32(nnan_mask, vreinterpretq_u32_f32(a)));
// eliminate nans to avoid invalid fp errors
npyv_f32 abs_x = vabsq_f32(x);
// round by add magic number 1.5 * 2^23
npyv_f32 round = vsubq_f32(vaddq_f32(two_power_23h, abs_x), two_power_23h);
// copysign
round = vreinterpretq_f32_u32(vorrq_u32(vreinterpretq_u32_f32(round), sign_mask));
npyv_f32 ceil = vaddq_f32(round, vreinterpretq_f32_u32(
vandq_u32(vcltq_f32(round, a), one))
);
// respect signed zero, e.g. -0.5 -> -0.0
npyv_f32 rzero = vreinterpretq_f32_s32(vorrq_s32(
vreinterpretq_s32_f32(ceil),
vandq_s32(vreinterpretq_s32_f32(a), szero)
));
// if nan or overflow return a
npyv_u32 nnan = npyv_notnan_f32(a);
npyv_u32 overflow = vorrq_u32(
vceqq_s32(roundi, szero), vceqq_s32(roundi, max_int)
vandq_u32(vcltq_f32(round, x), one))
);
return vbslq_f32(vbicq_u32(nnan, overflow), rzero, a);
// respects signed zero
ceil = vreinterpretq_f32_u32(vorrq_u32(vreinterpretq_u32_f32(ceil), sign_mask));
// a if |a| >= 2^23 or a == NaN
npyv_u32 mask = vcleq_f32(abs_x, two_power_23);
mask = vandq_u32(mask, nnan_mask);
return vbslq_f32(mask, ceil, a);
}
#endif
#if NPY_SIMD_F64
Expand All @@ -339,29 +341,37 @@ NPY_FINLINE npyv_f32 npyv_rint_f32(npyv_f32 a)
#ifdef NPY_HAVE_ASIMD
#define npyv_trunc_f32 vrndq_f32
#else
NPY_FINLINE npyv_f32 npyv_trunc_f32(npyv_f32 a)
{
const npyv_s32 szero = vreinterpretq_s32_f32(vdupq_n_f32(-0.0f));
NPY_FINLINE npyv_f32 npyv_trunc_f32(npyv_f32 a)
{
const npyv_s32 max_int = vdupq_n_s32(0x7fffffff);
const npyv_u32 exp_mask = vdupq_n_u32(0xff000000);
const npyv_s32 szero = vreinterpretq_s32_f32(vdupq_n_f32(-0.0f));
const npyv_u32 sign_mask = vandq_u32(
vreinterpretq_u32_f32(a), vreinterpretq_u32_s32(szero));

npyv_u32 nfinite_mask = vshlq_n_u32(vreinterpretq_u32_f32(a), 1);
nfinite_mask = vandq_u32(nfinite_mask, exp_mask);
nfinite_mask = vceqq_u32(nfinite_mask, exp_mask);
// elminate nans/inf to avoid invalid fp errors
npyv_f32 x = vreinterpretq_f32_u32(
veorq_u32(nfinite_mask, vreinterpretq_u32_f32(a)));
/**
* On armv7, vcvtq.f32 handles special cases as follows:
* NaN return 0
* +inf or +outrange return 0x80000000(-0.0f)
* -inf or -outrange return 0x7fffffff(nan)
*/
npyv_s32 roundi = vcvtq_s32_f32(a);
npyv_f32 round = vcvtq_f32_s32(roundi);
npyv_s32 trunci = vcvtq_s32_f32(x);
npyv_f32 trunc = vcvtq_f32_s32(trunci);
// respect signed zero, e.g. -0.5 -> -0.0
npyv_f32 rzero = vreinterpretq_f32_s32(vorrq_s32(
vreinterpretq_s32_f32(round),
vandq_s32(vreinterpretq_s32_f32(a), szero)
));
// if nan or overflow return a
npyv_u32 nnan = npyv_notnan_f32(a);
npyv_u32 overflow = vorrq_u32(
vceqq_s32(roundi, szero), vceqq_s32(roundi, max_int)
trunc = vreinterpretq_f32_u32(
vorrq_u32(vreinterpretq_u32_f32(trunc), sign_mask));
// if overflow return a
npyv_u32 overflow_mask = vorrq_u32(
vceqq_s32(trunci, szero), vceqq_s32(trunci, max_int)
);
return vbslq_f32(vbicq_u32(nnan, overflow), rzero, a);
// a if a overflow or nonfinite
return vbslq_f32(vorrq_u32(nfinite_mask, overflow_mask), a, trunc);
}
#endif
#if NPY_SIMD_F64
Expand All @@ -372,28 +382,31 @@ NPY_FINLINE npyv_f32 npyv_rint_f32(npyv_f32 a)
#ifdef NPY_HAVE_ASIMD
#define npyv_floor_f32 vrndmq_f32
#else
NPY_FINLINE npyv_f32 npyv_floor_f32(npyv_f32 a)
{
const npyv_s32 szero = vreinterpretq_s32_f32(vdupq_n_f32(-0.0f));
NPY_FINLINE npyv_f32 npyv_floor_f32(npyv_f32 a)
{
const npyv_u32 one = vreinterpretq_u32_f32(vdupq_n_f32(1.0f));
const npyv_s32 max_int = vdupq_n_s32(0x7fffffff);
const npyv_u32 szero = vreinterpretq_u32_f32(vdupq_n_f32(-0.0f));
const npyv_u32 sign_mask = vandq_u32(vreinterpretq_u32_f32(a), szero);
const npyv_f32 two_power_23 = vdupq_n_f32(8388608.0); // 2^23
const npyv_f32 two_power_23h = vdupq_n_f32(12582912.0f); // 1.5 * 2^23

npyv_s32 roundi = vcvtq_s32_f32(a);
npyv_f32 round = vcvtq_f32_s32(roundi);
npyv_u32 nnan_mask = vceqq_f32(a, a);
npyv_f32 x = vreinterpretq_f32_u32(vandq_u32(nnan_mask, vreinterpretq_u32_f32(a)));
// eliminate nans to avoid invalid fp errors
npyv_f32 abs_x = vabsq_f32(x);
// round by add magic number 1.5 * 2^23
npyv_f32 round = vsubq_f32(vaddq_f32(two_power_23h, abs_x), two_power_23h);
// copysign
round = vreinterpretq_f32_u32(vorrq_u32(vreinterpretq_u32_f32(round), sign_mask));
npyv_f32 floor = vsubq_f32(round, vreinterpretq_f32_u32(
vandq_u32(vcgtq_f32(round, a), one)
));
// respect signed zero
npyv_f32 rzero = vreinterpretq_f32_s32(vorrq_s32(
vreinterpretq_s32_f32(floor),
vandq_s32(vreinterpretq_s32_f32(a), szero)
vandq_u32(vcgtq_f32(round, x), one)
));
npyv_u32 nnan = npyv_notnan_f32(a);
npyv_u32 overflow = vorrq_u32(
vceqq_s32(roundi, szero), vceqq_s32(roundi, max_int)
);

return vbslq_f32(vbicq_u32(nnan, overflow), rzero, a);
// respects signed zero
floor = vreinterpretq_f32_u32(vorrq_u32(vreinterpretq_u32_f32(floor), sign_mask));
// a if |a| >= 2^23 or a == NaN
npyv_u32 mask = vcleq_f32(abs_x, two_power_23);
mask = vandq_u32(mask, nnan_mask);
return vbslq_f32(mask, floor, a);
}
#endif // NPY_HAVE_ASIMD
#if NPY_SIMD_F64
Expand Down
10 changes: 10 additions & 0 deletions numpy/core/tests/test_simd.py
Expand Up @@ -568,6 +568,16 @@ def test_special_cases(self):
nnan = self.notnan(self.setall(self._nan()))
assert nnan == [0]*self.nlanes

@pytest.mark.parametrize("intrin_name", [
"rint", "trunc", "ceil", "floor"
])
def test_unary_invalid_fpexception(self, intrin_name):
intrin = getattr(self, intrin_name)
for d in [float("nan"), float("inf"), -float("inf")]:
v = self.setall(d)
clear_floatstatus()
intrin(v)
assert check_floatstatus(invalid=True) == False

@pytest.mark.parametrize("intrin_name", [
"cmpltq", "cmpleq", "cmpgtq", "cmpgeq"
Expand Down

0 comments on commit 50989d8

Please sign in to comment.