In [None]:
# default_exp util.test_functions

# Test Functions

Support a few test functions instead of using a particular test tool. I'm not sure if I want to keep these, but I've been using them while I've been building the library.

* `check_raises` tests when something raises an errror.
* `check_is_near` compares numbers and arrays, if necessary. Wraps `numpy.isclose`.
* `check_equals` wraps `check_is_near`, enforcing equality.

In [None]:
#hide
from nbdev.showdoc import *

In [None]:
#export
import numpy as np
from contextlib import contextmanager

In [None]:
#export
@contextmanager
def check_raises(**kw):
    """Assert some code raises an error.
    Can pass in a message (`check_raises(message="Custom message")`).
    Can pass in an exception (`check_raises(exception=ArgumentError)`).
    """
    message = kw.get('message', "Expected to raise, did not.")
    expected_exception = kw.get('exception', Exception)
    failed = False
    try:
        yield
    except expected_exception:
        failed = True
    except Exception as e:
        message = f"Expected to raise {expected_exception}. Instead received {e.__class__.__name__}"
    finally:
        if not failed:
            assert False, message

def check_is_near(a, b, message=None, **kw):
    """Wrap the numpy isclose function."""
    if message is None:
        message = f"Expected {a} to be close to {b}."
    result = np.isclose(a, b, **kw)
    if np.size(result) == 1: result = [result]
    if not all(result):
        assert False, message

def check_equals(a, b, **kw):
    """Check if two values are equal.
    Not type sensitive.
    Can handle 1 or n-dimensional objects."""
    kw = {**kw, **{'atol': 0, 'message': f"Expected {a} to equal {b}."}}
    return check_is_near(a, b, **kw)

In [None]:
with check_raises(exception=ValueError):
    raise ValueError
    
with check_raises():
    with check_raises(exception=ValueError):
        raise ZeroDivisionError
        
check_is_near(1, 1.01, atol=0.01)
check_is_near(np.zeros(3) + 1, 1)
with check_raises():
    cheeck_is_near(1, 2)
    
check_equals(1, 1)
check_equals([1,1], 1)
check_equals(np.zeros(2) + 1, 1)
check_equals([1,2], [1,2])
with check_raises():
    check_equals([1,2], [1,2.01])