In [None]:
#default_exp test

In [None]:
#export
from fastai_local.imports import *
from itertools import zip_longest
from fastai.gen_doc.nbdoc import show_doc

# Test

> Helper functions to quickly write tests in notebooks

## Simple test functions

We can test for equality (`test_eq`) or inequality (`test_ne`) of arrays, tensors, and scalars, and lists of any of these. We can also check that code raises an exception when that's expected (`test_fail`).

In [None]:
#export
def test_fail(f, msg='', contains=''):
    "Fails with `msg` unless `f()` raises an exception and (optionally) has `contains` in `e.args`"
    try:
        f()
        assert False,f"Expected exception but none raised. {msg}"
    except Exception as e: assert not contains or contains in str(e)

In [None]:
def _fail(): raise Exception("foobar")
test_fail(_fail, contains="foo")

def _fail(): raise Exception()
test_fail(_fail)

In [None]:
#export
def test(a, b, cmp,cname=None):
    "`assert` that `cmp(a,b)`; display inputs and `cname or cmp.__name__` if it fails"
    if cname is None: cname=cmp.__name__
    assert cmp(a,b),f"{cname}:\n{a}\n{b}"

In [None]:
test([1,2],[1,2], operator.eq)
test_fail(lambda: test([1,2],[1], operator.eq))
test([1,2],[1],   operator.ne)
test_fail(lambda: test([1,2],[1,2], operator.ne))

In [None]:
#export
def equals(a,b):
    "Compares `a` and `b` for equality; supports sublists, tensors and arrays too"
    cmp = (torch.equal    if isinstance(a, Tensor  ) and a.dim() else 
           np.array_equal if isinstance(a, ndarray ) else
           operator.eq    if isinstance(a, str     ) else
           _all_equal     if isinstance(a, (list,tuple,Generator,Iterator)) else
           operator.eq)
    return cmp(a,b)

def _all_equal(a,b): return all(equals(a_,b_) for a_,b_ in zip_longest(a,b))

In [None]:
test(['abc'], ['abc'],  equals)

In [None]:
#export
def nequals(a,b): 
    "Compares `a` and `b` for `not equals`"
    return not equals(a,b)

In [None]:
test(['abc'], ['ab' ], nequals)

## test_eq test_ne, etc...

Just use `test_eq`/`test_ne` to test for `==`/`!=`. We define them using `test`:

In [None]:
#exports
def test_eq(a,b):
    "`test` that `a==b`"
    test(a,b,equals, '==')

In [None]:
test_eq([1,2],[1,2])
test_eq([1,2],map(int,[1,2]))
test_eq(torch.tensor(1),1)
test_eq(1,torch.tensor(1))
test_eq(torch.tensor([1,2]),torch.tensor([1,2]))
test_eq(array([1,2]),array([1,2]))
test_eq([array([1,2]),3],[array([1,2]),3])

In [None]:
#exports
def test_ne(a,b):
    "`test` that `a!=b`"
    test(a,b,nequals,'!=')

In [None]:
test_ne([1,2],[1])
test_ne([1,2],[1,3])
test_ne(torch.tensor([1,2]),torch.tensor([1,1]))
test_ne(array([1,2]),array([1,1]))
test_ne([array([1,2]),3],[array([1,2])])

In [None]:
#exports
def is_close(a,b,eps=1e-5):
    "Is `a` within `eps` of `b`"
    return abs(a-b)<eps

In [None]:
#exports
def test_close(a,b,eps=1e-5):
    "`test` that `a` is within `eps` of `b`"
    test(a,b,partial(is_close,eps=eps),'close')

In [None]:
test_close(1,1.001,eps=1e-2)
test_fail(lambda: test_close(1,1.001))

In [None]:
#exports
def test_is(a,b):
    "`test` that `a is b`"
    test(a,b,operator.is_, 'is')

In [None]:
test_fail(lambda: test_is([1], [1]))
a = [1]
test_is(a, a)

## Export -

In [None]:
#hide
from fastai_local.notebook.export import notebook2script
notebook2script(all_fs=True)

Converted 00_test.ipynb.
Converted 01_core.ipynb.
Converted 02_data_pipeline.ipynb.
Converted 03_data_source.ipynb.
Converted 05_data_core.ipynb.
Converted 06_pets_tutorial.ipynb.
Converted 99_export.ipynb.
Converted 99a_export2html.ipynb.
Converted _07_data_blocks.ipynb.
