Skip to content

Commit

Permalink
Workaround numpy 1.x assert_allclose false-positive result in compari…
Browse files Browse the repository at this point in the history
…ng complex infinities.
  • Loading branch information
pearu committed Apr 4, 2024
1 parent 026f309 commit 2ef5bc6
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 45 deletions.
20 changes: 5 additions & 15 deletions jax/_src/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1725,22 +1725,12 @@ def negative(self, x):

def sqrt(self, x):
ctx = x.context
# workaround mpmath bugs:
if isinstance(x, ctx.mpc):
if ctx.isinf(x.real) and ctx.isinf(x.imag):
if x.real > 0: return x
ninf = x.real
inf = -ninf
if x.imag > 0: return ctx.make_mpc((inf._mpf_, inf._mpf_))
return ctx.make_mpc((inf._mpf_, inf._mpf_))
elif ctx.isfinite(x.real) and ctx.isinf(x.imag):
if x.imag > 0:
inf = x.imag
return ctx.make_mpc((inf._mpf_, inf._mpf_))
else:
ninf = x.imag
inf = -ninf
return ctx.make_mpc((inf._mpf_, ninf._mpf_))
# Workaround mpmath 1.3 bug in sqrt(+-inf+-infj) evaluation (see mpmath/mpmath#776).
# TODO(pearu): remove this function when mpmath 1.4 or newer
# will be the required test dependency.
if ctx.isinf(x.imag):
return ctx.make_mpc((ctx.inf._mpf_, x.imag._mpf_))
return ctx.sqrt(x)

def expm1(self, x):
Expand Down
119 changes: 89 additions & 30 deletions tests/lax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3567,21 +3567,37 @@ def _testOnComplexPlaneWorker(self, name, dtype, kind):
mposj=(slice(s0 + 3 + s03, s0 + 3 + 2 * s03), s1 + 1),
)

# The regions are split to real and imaginary parts (of function
# return values) to (i) workaround numpy 1.x assert_allclose bug
# in comparing complex infinities, and (ii) expose more details
# about failing cases:
s_dict_parts = dict()
for k, v in s_dict.items():
s_dict_parts[k + '.real'] = v
s_dict_parts[k + '.imag'] = v

# Start with an assumption that all regions are problematic for a
# particular function:
regions_with_inaccuracies = list(s_dict)
regions_with_inaccuracies = list(s_dict_parts)

# Next, we'll remove non-problematic regions from the
# regions_with_inaccuracies list by explicitly keeping problematic
# regions:
def regions_with_inaccuracies_keep(*to_keep):
to_keep_parts = []
for r in to_keep:
if r.endswith('.real') or r.endswith('.imag'):
to_keep_parts.append(r)
else:
to_keep_parts.append(r + '.real')
to_keep_parts.append(r + '.imag')
for item in regions_with_inaccuracies[:]:
if item not in to_keep:
if item not in to_keep_parts:
regions_with_inaccuracies.remove(item)

if name == 'absolute':
if is_cuda and dtype == np.complex128:
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4')
regions_with_inaccuracies_keep('q1.real', 'q2.real', 'q3.real', 'q4.real')
else:
regions_with_inaccuracies.clear()

Expand All @@ -3590,95 +3606,122 @@ def regions_with_inaccuracies_keep(*to_keep):

elif name == 'square':
if is_cuda:
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'ninf', 'pinf', 'ninfj', 'pinfj')
regions_with_inaccuracies_keep('q1.real', 'q2.real', 'q3.real', 'q4.real', 'ninf.real', 'pinf.real', 'ninfj.real', 'pinfj.real')
if is_cpu:
regions_with_inaccuracies_keep('ninf', 'pinf')
regions_with_inaccuracies_keep('ninf.real', 'pinf.real', 'q1.real', 'q2.real', 'q3.real', 'q4.real')

elif name == 'log':
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'ninf', 'pinf', 'ninfj', 'pinfj')
regions_with_inaccuracies_keep('q1.real', 'q2.real', 'q3.real', 'q4.real', 'ninf.imag', 'pinf.imag', 'ninfj.imag', 'pinfj.imag')

elif name == 'log10':
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'ninf', 'pinf', 'ninfj', 'pinfj', 'zero')
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'ninf.imag', 'pinf.imag', 'ninfj.imag', 'pinfj.imag', 'zero.imag')

elif name == 'log1p':
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj')
regions_with_inaccuracies_keep('q1.real', 'q2.real', 'q3.real', 'q4.real', 'neg.real', 'pos.real',
'negj.real', 'posj.real', 'ninf.real', 'ninfj.real', 'pinfj.real')
# TODO(pearu): after landing openxla/xla#10503, switch to
# regions_with_inaccuracies_keep('ninf', 'pinf', 'ninfj', 'pinfj')

elif name == 'exp':
regions_with_inaccuracies_keep('pos', 'pinf', 'mpos')
regions_with_inaccuracies_keep('pos.imag', 'pinf.imag', 'mpos.imag')

elif name == 'exp2':
if dtype == np.complex64:
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'mq1', 'mq2', 'mq3', 'mq4', 'mpos', 'mnegj', 'mposj')
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'pos.imag', 'negj', 'posj', 'ninf', 'pinf', 'mq1', 'mq2', 'mq3', 'mq4', 'mpos.imag', 'mnegj', 'mposj')
if dtype == np.complex128:
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'mpos')
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'pos.imag', 'negj', 'posj', 'ninf', 'pinf', 'mpos.imag')

elif name == 'expm1' and xla_extension_version < 250:
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'pinf', 'mq1', 'mq2', 'mq3', 'mq4', 'mneg', 'mpos')

elif name == 'sinc':
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'mq1', 'mq2', 'mq3', 'mq4', 'mneg', 'mpos', 'mnegj', 'mposj')
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'mq1', 'mq2', 'mq3', 'mq4',
'mneg.real', 'mpos.real', 'mnegj', 'mposj',
'ninf.imag', 'pinf.imag', 'ninfj.real', 'pinfj.real')

elif name == 'tan':
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'negj', 'posj', 'ninfj', 'pinfj', 'mq1', 'mq2', 'mq3', 'mq4', 'mnegj', 'mposj')
regions_with_inaccuracies_keep('q1.imag', 'q2.imag', 'q3.imag', 'q4.imag', 'negj.imag', 'posj.imag',
'ninfj.imag', 'pinfj.imag', 'mq1.imag', 'mq2.imag', 'mq3.imag', 'mq4.imag', 'mnegj.imag', 'mposj.imag',
'ninf.imag', 'pinf.imag')

elif name == 'sinh':
if is_cuda:
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'ninf', 'pinf', 'mq1', 'mq2', 'mq3', 'mq4', 'mneg', 'mpos')
regions_with_inaccuracies_keep('q1.real', 'q2.real', 'q3.real', 'q4.real', 'neg', 'pos',
'ninf.imag', 'pinf.imag', 'mq1.real', 'mq2.real', 'mq3.real', 'mq4.real', 'mneg', 'mpos',
'ninfj.real', 'pinfj.real')
if is_cpu:
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'mq1', 'mq2', 'mq3', 'mq4', 'mneg', 'mpos')

regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj.imag', 'posj.imag', 'ninf.imag', 'pinf.imag',
'mq1.real', 'mq2.real', 'mq3.real', 'mq4.real', 'mneg', 'mpos',
'ninfj.real', 'pinfj.real')
elif name == 'cosh':
regions_with_inaccuracies_keep('neg', 'pos', 'ninf', 'pinf', 'mneg', 'mpos')
regions_with_inaccuracies_keep('neg.imag', 'pos.imag', 'ninf.imag', 'pinf.imag', 'mneg.imag', 'mpos.imag',
'ninfj.imag', 'pinfj.imag')

elif name == 'tanh':
regions_with_inaccuracies_keep('ninf', 'pinf', 'ninfj', 'pinfj')

elif name == 'arccos':
if dtype == np.complex64:
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq2', 'mq3', 'mneg', 'mpos', 'mnegj')
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq2', 'mq3', 'mneg',
'mpos.imag', 'mnegj')
if dtype == np.complex128:
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq2', 'mq3', 'mq4', 'mneg', 'mpos', 'mnegj')
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq2', 'mq3', 'mq4', 'mneg', 'mpos.imag', 'mnegj')

elif name == 'arccosh':
if dtype == np.complex64:
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq2', 'mq3', 'mneg', 'mpos', 'mnegj')
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj.real', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq2', 'mq3', 'mneg', 'mpos.imag', 'mnegj')
if dtype == np.complex128:
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq2', 'mq3', 'mq4', 'mneg', 'mnegj')
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos.real', 'negj', 'posj.real', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq2', 'mq3', 'mq4', 'mneg', 'mnegj')

elif name == 'arcsin':
if dtype == np.complex64:
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq1', 'mq2', 'mq3', 'mq4', 'mneg', 'mpos', 'mnegj', 'mposj')
if dtype == np.complex128:
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq1', 'mq2', 'mq3', 'mq4', 'mneg', 'mpos', 'mnegj', 'mposj')
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj',
'mq1', 'mq2', 'mq3', 'mq4', 'mneg.imag', 'mpos.imag', 'mnegj', 'mposj')

elif name == 'arcsinh':
if dtype == np.complex64:
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq1', 'mq2', 'mq3', 'mq4', 'mneg', 'mpos', 'mnegj')
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg.real', 'pos.real', 'negj', 'posj.real', 'ninf', 'pinf', 'ninfj', 'pinfj',
'mq1.real', 'mq2', 'mq3', 'mq4.real', 'mneg.real', 'mpos.real', 'mnegj')
if dtype == np.complex128:
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq2', 'mq3', 'mneg', 'mnegj')
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg.real', 'pos.real', 'negj', 'posj.real', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq2', 'mq3', 'mneg.real', 'mnegj')

elif name == 'arctan':
if dtype == np.complex64:
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mq1', 'mq2', 'mq3', 'mq4', 'mnegj', 'mposj')
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj',
'mq1.imag', 'mq2.imag', 'mq3.imag', 'mq4.imag', 'mnegj.imag', 'mposj.imag')
if dtype == np.complex128:
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj')

elif name == 'arctanh':
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mpos')
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mpos.imag')
# TODO(pearu): after landing openxla/xla#10503, switch to
# regions_with_inaccuracies_keep('pos', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mpos')

elif name in {'cos', 'sin'}:
regions_with_inaccuracies_keep('ninf.imag', 'pinf.imag')

elif name in {'positive', 'negative', 'conjugate', 'sin', 'cos', 'sqrt', 'expm1'}:
regions_with_inaccuracies.clear()
else:
assert 0 # unreachable

# Finally, perform the closeness tests per region:
unexpected_success_regions = []
for region_name, region_slice in s_dict.items():
for region_name, region_slice in s_dict_parts.items():
region = args[0][region_slice]
inexact_indices = np.where(normalized_result[region_slice] != normalized_expected[region_slice])
if region_name.endswith('.real'):
result_slice, expected_slice = result[region_slice].real, expected[region_slice].real
normalized_result_slice, normalized_expected_slice = normalized_result[region_slice].real, normalized_expected[region_slice].real
elif region_name.endswith('.imag'):
result_slice, expected_slice = result[region_slice].imag, expected[region_slice].imag
normalized_result_slice, normalized_expected_slice = normalized_result[region_slice].imag, normalized_expected[region_slice].imag
else:
result_slice, expected_slice = result[region_slice], expected[region_slice]
normalized_result_slice, normalized_expected_slice = normalized_result[region_slice], normalized_expected[region_slice]

inexact_indices = np.where(normalized_result_slice != normalized_expected_slice)

if inexact_indices[0].size == 0:
inexact_samples = ''
Expand All @@ -3697,20 +3740,36 @@ def regions_with_inaccuracies_keep(*to_keep):
if kind == 'success' and region_name not in regions_with_inaccuracies:
with jtu.ignore_warning(category=RuntimeWarning, message="overflow encountered in.*"):
self.assertAllClose(
normalized_result[region_slice], normalized_expected[region_slice], atol=atol,
normalized_result_slice, normalized_expected_slice, atol=atol,
err_msg=f"{name} in {region_name}, {is_cpu=} {is_cuda=}, {xla_extension_version=}\n{inexact_samples}")

if kind == 'failure' and region_name in regions_with_inaccuracies:
try:
with self.assertRaises(AssertionError, msg=f"{name} in {region_name}, {is_cpu=} {is_cuda=}, {xla_extension_version=}"):
with jtu.ignore_warning(category=RuntimeWarning, message="overflow encountered in.*"):
self.assertAllClose(normalized_result[region_slice], normalized_expected[region_slice])
self.assertAllClose(normalized_result_slice, normalized_expected_slice)
except AssertionError as msg:
if str(msg).startswith('AssertionError not raised'):
unexpected_success_regions.append(region_name)
else:
raise # something else is wrong..

def eliminate_parts(seq):
# replace n.real and n.imag items in seq with n.
result = []
for part_name in seq:
name = part_name.split('.')[0]
if name in result:
continue
if name + '.real' in seq and name + '.imag' in seq:
result.append(name)
else:
result.append(part_name)
return result

regions_with_inaccuracies = eliminate_parts(regions_with_inaccuracies)
unexpected_success_regions = eliminate_parts(unexpected_success_regions)

if kind == 'success' and regions_with_inaccuracies:
reason = "xfail: problematic regions: " + ", ".join(regions_with_inaccuracies)
raise unittest.SkipTest(reason)
Expand Down

0 comments on commit 2ef5bc6

Please sign in to comment.