Skip to content

Commit

Permalink
Merge pull request #4541 from ElDeveloper/issue-2367
Browse files Browse the repository at this point in the history
BUG:change formatting of assert_array_almost_equal
  • Loading branch information
charris committed Mar 26, 2014
2 parents 0953088 + a11c162 commit d35d5c1
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 10 deletions.
68 changes: 68 additions & 0 deletions numpy/testing/tests/test_utils.py
Expand Up @@ -134,6 +134,49 @@ def test_recarrays(self):

self._test_not_equal(c, b)

class TestBuildErrorMessage(unittest.TestCase):
def test_build_err_msg_defaults(self):
x = np.array([1.00001, 2.00002, 3.00003])
y = np.array([1.00002, 2.00003, 3.00004])
err_msg = 'There is a mismatch'

a = build_err_msg([x, y], err_msg)
b = ('\nItems are not equal: There is a mismatch\n ACTUAL: array([ '
'1.00001, 2.00002, 3.00003])\n DESIRED: array([ 1.00002, '
'2.00003, 3.00004])')
self.assertEqual(a, b)

def test_build_err_msg_no_verbose(self):
x = np.array([1.00001, 2.00002, 3.00003])
y = np.array([1.00002, 2.00003, 3.00004])
err_msg = 'There is a mismatch'

a = build_err_msg([x, y], err_msg, verbose=False)
b = '\nItems are not equal: There is a mismatch'
self.assertEqual(a, b)

def test_build_err_msg_custom_names(self):
x = np.array([1.00001, 2.00002, 3.00003])
y = np.array([1.00002, 2.00003, 3.00004])
err_msg = 'There is a mismatch'

a = build_err_msg([x, y], err_msg, names=('FOO', 'BAR'))
b = ('\nItems are not equal: There is a mismatch\n FOO: array([ '
'1.00001, 2.00002, 3.00003])\n BAR: array([ 1.00002, 2.00003, '
'3.00004])')
self.assertEqual(a, b)

def test_build_err_msg_custom_precision(self):
x = np.array([1.000000001, 2.00002, 3.00003])
y = np.array([1.000000002, 2.00003, 3.00004])
err_msg = 'There is a mismatch'

a = build_err_msg([x, y], err_msg, precision=10)
b = ('\nItems are not equal: There is a mismatch\n ACTUAL: array([ '
'1.000000001, 2.00002 , 3.00003 ])\n DESIRED: array([ '
'1.000000002, 2.00003 , 3.00004 ])')
self.assertEqual(a, b)

class TestEqual(TestArrayEqual):
def setUp(self):
self._assert_func = assert_equal
Expand Down Expand Up @@ -239,6 +282,31 @@ def test_complex(self):
self._test_not_equal(x, y)
self._test_not_equal(x, z)

def test_error_message(self):
"""Check the message is formatted correctly for the decimal value"""
x = np.array([1.00000000001, 2.00000000002, 3.00003])
y = np.array([1.00000000002, 2.00000000003, 3.00004])

# test with a different amount of decimal digits
# note that we only check for the formatting of the arrays themselves
b = ('x: array([ 1.00000000001, 2.00000000002, 3.00003 '
' ])\n y: array([ 1.00000000002, 2.00000000003, 3.00004 ])')
try:
self._assert_func(x, y, decimal=12)
except AssertionError as e:
# remove anything that's not the array string
self.assertEqual(str(e).split('%)\n ')[1], b)

# with the default value of decimal digits, only the 3rd element differs
# note that we only check for the formatting of the arrays themselves
b = ('x: array([ 1. , 2. , 3.00003])\n y: array([ 1. , '
'2. , 3.00004])')
try:
self._assert_func(x, y)
except AssertionError as e:
# remove anything that's not the array string
self.assertEqual(str(e).split('%)\n ')[1], b)

class TestApproxEqual(unittest.TestCase):
def setUp(self):
self._assert_func = assert_approx_equal
Expand Down
28 changes: 18 additions & 10 deletions numpy/testing/utils.py
Expand Up @@ -9,8 +9,9 @@
import re
import operator
import warnings
from functools import partial
from .nosetester import import_nose
from numpy.core import float32, empty, arange
from numpy.core import float32, empty, arange, array_repr, ndarray

if sys.version_info[0] >= 3:
from io import StringIO
Expand Down Expand Up @@ -190,8 +191,7 @@ def memusage(processName="python", instance=0):
win32pdh.PDH_FMT_LONG, None)

def build_err_msg(arrays, err_msg, header='Items are not equal:',
verbose=True,
names=('ACTUAL', 'DESIRED')):
verbose=True, names=('ACTUAL', 'DESIRED'), precision=8):
msg = ['\n' + header]
if err_msg:
if err_msg.find('\n') == -1 and len(err_msg) < 79-len(header):
Expand All @@ -200,8 +200,15 @@ def build_err_msg(arrays, err_msg, header='Items are not equal:',
msg.append(err_msg)
if verbose:
for i, a in enumerate(arrays):

if isinstance(a, ndarray):
# precision argument is only needed if the objects are ndarrays
r_func = partial(array_repr, precision=precision)
else:
r_func = repr

try:
r = repr(a)
r = r_func(a)
except:
r = '[repr failed]'
if r.count('\n') > 3:
Expand Down Expand Up @@ -575,7 +582,7 @@ def assert_approx_equal(actual,desired,significant=7,err_msg='',verbose=True):
raise AssertionError(msg)

def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
header=''):
header='', precision=6):
from numpy.core import array, isnan, isinf, any, all, inf
x = array(x, copy=False, subok=True)
y = array(y, copy=False, subok=True)
Expand All @@ -592,7 +599,7 @@ def chk_same_position(x_id, y_id, hasval='nan'):
msg = build_err_msg([x, y],
err_msg + '\nx and y %s location mismatch:' \
% (hasval), verbose=verbose, header=header,
names=('x', 'y'))
names=('x', 'y'), precision=precision)
raise AssertionError(msg)

try:
Expand All @@ -603,7 +610,7 @@ def chk_same_position(x_id, y_id, hasval='nan'):
+ '\n(shapes %s, %s mismatch)' % (x.shape,
y.shape),
verbose=verbose, header=header,
names=('x', 'y'))
names=('x', 'y'), precision=precision)
if not cond :
raise AssertionError(msg)

Expand Down Expand Up @@ -648,7 +655,7 @@ def chk_same_position(x_id, y_id, hasval='nan'):
err_msg
+ '\n(mismatch %s%%)' % (match,),
verbose=verbose, header=header,
names=('x', 'y'))
names=('x', 'y'), precision=precision)
if not cond :
raise AssertionError(msg)
except ValueError as e:
Expand All @@ -657,7 +664,7 @@ def chk_same_position(x_id, y_id, hasval='nan'):
header = 'error during assertion:\n\n%s\n\n%s' % (efmt, header)

msg = build_err_msg([x, y], err_msg, verbose=verbose, header=header,
names=('x', 'y'))
names=('x', 'y'), precision=precision)
raise ValueError(msg)

def assert_array_equal(x, y, err_msg='', verbose=True):
Expand Down Expand Up @@ -825,7 +832,8 @@ def compare(x, y):
return around(z, decimal) <= 10.0**(-decimal)

assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose,
header=('Arrays are not almost equal to %d decimals' % decimal))
header=('Arrays are not almost equal to %d decimals' % decimal),
precision=decimal)


def assert_array_less(x, y, err_msg='', verbose=True):
Expand Down

0 comments on commit d35d5c1

Please sign in to comment.