In [None]:
import traceback
import numpy as np
import ast

def _in_otter():
    for frame in traceback.StackSummary.extract(traceback.walk_stack(None)):
        if frame.filename.endswith('ok_test.py'):
            return True
    return False


def _print_message(test, message):
    # if in otter, skip the header and the ANSI color codes because Otter 
    #   already shows all the info in its HTML-formatted error messages.
    in_otter = _in_otter()
    
    if not in_otter: 
        print("\u001b[35;1m")
        print("---------------------------------------------------------------------------")
        print("Yipes! " + test)
        print("                                                                           ")
        
    if np.shape(message) == ():
        message = str(message).strip().split("\n")
    for line in message:
        print("  ", line)
        
    if not in_otter: 
        print("\u001b[0m")


def _arguments(test_line):
    test_line = test_line.strip()
    tree = ast.parse(test_line, mode='eval')
    args = [ test_line[x.col_offset:x.end_col_offset].strip() for x in tree.body.args]
    return args
    
    
def _extact_test_as_text():
    # The frame with the test call is the third from the top if we call 
    # a test function directly
    tbo = traceback.extract_stack()
    text = tbo[-3].line
    return text

In [None]:
get_ipython().__class__.__name__

In [None]:
    for frame in traceback.StackSummary.extract(traceback.walk_stack(None)):
        print(frame.filename)


In [None]:
def check(a):
    if not np.all(a):
        _print_message(_extact_test_as_text(), "Expression is not True")
        

In [None]:
def check_equal(a,b):
    try:
        np.testing.assert_equal(a,b)
    except AssertionError as e:
        _print_message(_extact_test_as_text(), f"{a} is not equal to {b}")

In [None]:
def _shorten(x):
    return np.array2string(np.array(x),threshold=10)

In [None]:
def term_at_index(arg, value, index):
    if arg == repr(value):
        return None, value
    if np.shape(value) == ():
        return f"{arg} == {repr(value)}", value
    else:
        return f"{arg}[{index}] == {value[index]}", value[index]
    

def internal_check(args, a, b, test_op, test_str):
    result = test_op(a,b)
    if not np.all(result):
        shape = np.shape(result)
        if shape == ():
            ai,av = term_at_index(args[0],a,0)
            bi,bv = term_at_index(args[1],b,0)
            terms = " and ".join([x for x in [ai,bi] if x != None])
            if len(terms) > 0: 
                terms += ', and'
            return [ f"{terms} {test_str(av,bv)} is False" ]
        elif len(shape) == 1:
            message = [  ]
            false_indices = np.where(result == False)[0]
            for i in false_indices[0:3]:
                ai,av = term_at_index(args[0],a,i)
                bi,bv = term_at_index(args[1],b,i)
                terms = " and ".join([x for x in [ai,bi] if x != None])
                if len(terms) > 0: 
                    terms += ', and'
                message += [ f"{terms} {test_str(av,bv)} is False" ]
            if len(false_indices) > 3:
                message += [ f"... omitting {len(false_indices)-3} more case(s) ..." ]
            return message 
        else:
            return [ f"    {_shorten(a)} {test_str} {_shorten(b)}" ]
    else:
        return []
            
def check_equal(a, b):
    text = _extact_test_as_text()
    args = _arguments(text)            
    message = internal_check(args, a, b, lambda x,y: x == y, lambda x,y: f"{repr(x)} == {repr(y)}")            
    if message != []:
        _print_message(text, message)
    
def check_close(a, b, atol=1e-5):
    text = _extact_test_as_text()
    args = _arguments(text)        
    message = internal_check(args, a, b, lambda x,y: np.isclose(x,y,atol), lambda x,y: f"{y-atol} <= {x} <= {y+atol}")
    if message != []:
        _print_message(text, message)
    
    
def _internal_ordering(a, relation, message):
    text = _extact_test_as_text()
    args = _arguments(text)
    message = []
    for i in range(len(a)-1):
        m = internal_check(args[i:i+2], a[i], a[i+1], relation, message)            
        if len(m) > 0 and len(message) > 0:
            message += [ "" ]
        message += m
    if message != []:
        _print_message(text, message)
    
def check_less_than(*a):
    _internal_order(a, lambda x,y: x < y, lambda x,y: f"{repr(x)} < {repr(y)}")

def check_less_than_or_equal(*a):
    _internal_order(a, lambda x,y: x <= y, lambda x,y: f"{repr(x)} <= {repr(y)}")
        
        
def check_type(a, t):
    text = _extact_test_as_text()
    args = _arguments(text) 
    if type(a) is not t:
        _print_message(text, f"{repr(a)} does not have type {t.__name__}.")

def _grab_interval(*interval):
    if len(interval) == 1:
        interval = interval[0]
    if np.shape(interval) != (2,) or np.shape(interval[0]) != () or np.shape(interval[1]) != ():
        raise ValueError(f"Interval must be passed as two numbers " +
                         "or an array containing two numbers, not {interval}")            
    return interval
        
def check_between(a, *interval):
    text = _extact_test_as_text()
    args = _arguments(text)
    
    interval = _grab_interval(*interval)
    
    result = np.logical_and(interval[0] <= a, a < interval[1])
    
    if not np.all(result):
        shape = np.shape(result)
        if shape == ():
            message = [ f"{a} is not in interval [{interval[0]},{interval[1]})" ]
        elif len(shape) == 1:
            message = [ ]
            false_indices = np.where(result == False)[0]
            for i in false_indices[0:3]:
                ai,av = term_at_index(args[0],a,i)
                if ai != None:
                    terms = ai + ", and "
                else:
                    terms = ""
                message += [ f"{terms}{repr(av)} is not in interval [{interval[0]},{interval[1]})" ]
            if len(false_indices) > 3:
                message += [ f"... omitting {len(false_indices)-3} more case(s) ..." ]
        else:
            message = [ f"{a} is not in interval [{interval[0]},{interval[1]})" ]
        _print_message(text, message)
    


In [None]:
check_between(2,[1,3])
check_between(2,[3,3])
check_between(4,[3,3])
x=2
check_between(x,[1,3])
check_between(x,1,3)
y = [1,3]
check_between(x,y)

In [None]:
check_type(3,list)
check_type(3.0,int)
x = "cow"
check_type(x,int)
check_less_than('a', 'A')

In [None]:
check_equal(0,0)
check_equal(0,1)
x = 0
y = 1
check_equal(x,y)
check_equal(y,np.array([1,1,2,1]))

In [None]:
check_less_than(1,2)
check_less_than(2,1)
x = np.array([1,2,3,1,2,3])
y = np.array([2,2,2,2,2,2])

check_less_than(x, 4)
check_less_than(x, y)
check_less_than(x, np.array([0,0,0,0,0,10]))

a=-1
check_less_than(1,a)
check_less_than(x, a)
check_less_than(x, 1.0000000001)

check_less_than(2,x,y)

In [None]:
check_close(1,1.0001)
check_close(1,1.3,.3)
check_close(np.array([1,1,1]), 1.3,.3)

In [None]:
def check_close(a,b,atol=1e-5):
    try:
        np.testing.assert_allclose(a,b,atol=atol)
    except AssertionError as e:
        _print_message(_extact_test_as_text(), f"{a} is not between {b-atol} and {b+atol}")

In [None]:
def check_in(a, *r):
    if len(r) == 1:
        r = r[0]
    if a not in r:
        _print_message(_extact_test_as_text(), f"{a} is not in range {_shorten(r)}")

In [None]:
def _grab_interval(*interval):
    if len(interval) == 1:
        interval = interval[0]
    if np.shape(interval) != (2,):
        raise ValueError(f"Interval must be passed as two numbers " +
                         "or an array containing two numbers, not {interval}")            
    return interval
        
def check_between(a, *interval):
    interval = _grab_interval(*interval)
    if a < interval[0] or a >= interval[1]:
        _print_message(_extact_test_as_text(), 
                       f"{a} is not in interval [{interval[0]}, {interval[1]})")

def check_between_or_equal(a, *interval):
    interval = _grab_interval(*interval)
    if a < interval[0] or a > interval[1]:
        _print_message(_extact_test_as_text(), 
                       f"{a} is not in interval [{interval[0]}, {interval[1]}]")

def check_strictly_between(a, *interval):
    interval = _grab_interval(*interval)
    if a <= interval[0] or a >= interval[1]:
        _print_message(_extact_test_as_text(), 
                       f"{a} is not in interval ({interval[0]}, {interval[1]})")

# def _value_at_index_if_array(v, i):
#     if np.shape(v) == ():
#         return v
#     else:
#         return v.item(i)
        
# def _binary_less_than(source_terms, values):
#     result = np.less(values[0], values[1])
#     if np.all(result):
#         return None
#     else:
#         shape = np.shape(result)
#         print(shape)
#         if shape == ():
#             return f"check_less_than({source_terms[0]}, {source_terms[1]}) failed:\n" + \
#                    f"    Expression is not true: " + " < ".join([_shorten(x) for x in values])
#         elif len(shape) == 1:
#             false_indices = np.where(result == False)
#             return f"check_less_than({source_terms[0]}, {source_terms[1]}) failed:\n" + \
#                    f"    Expression is not true at indices {_shorten(false_indices[0])}"
#         else:
#             return f"Expression is not true: " + " < ".join([_shorten(x) for x in values])

# def check_less_than(*a):
#     text = _extact_test_as_text()
#     args = _arguments(text)
#     for i in range(len(a)-1):
#         result = _binary_less_than(args[i:i+2], a[i:i+2])
#         if result != None:
#             _print_message(_extact_test_as_text(), result)
#             return
        
            
        
def check_less_than(*a):
    for i in range(len(a)-1):
        result = np.less(a[i], a[i+1])
        if not np.all(result):
            text = _extact_test_as_text()
            args = _arguments(text)
            _print_message(_extact_test_as_text(), 
                           f"Expression is not true: " + " < ".join([_shorten(x) for x in args[i:i+2]]))
            return

# def check_less_than(*a):
#     for i in range(len(a)-1):
#         result = np.less(a[i], a[i+1])
#         if not np.all(result):
#             text = _extact_test_as_text()
#             args = _arguments(text)
#             _print_message(_extact_test_as_text(), 
#                            f"Expression is not true: " + " < ".join([_shorten(x) for x in args[i:i+2]]))
#             return
        
        
def check_less_than_or_equal(*a):
    for i in range(len(a)-1):
        if not np.all(np.less_equal(a[i], a[i+1])):
            text = _extact_test_as_text()
            args = _arguments(text)
            _print_message(_extact_test_as_text(), 
                           f"Expression is not true: " + " <= ".join([_shorten(x) for x in args[i:i+2]]))
            return

    
# @register_cell_magic
# @needs_local_scope
# def test(line, cell, local_ns=None): 
#     original_count = _message_count
#     for text_as_text in cell.split("\n"):
#         if text_as_text != '':
#             try:
#                 eval(text_as_text, local_ns)
#             except Exception as e:
#                 etype, evalue, tb = sys.exc_info()
                
#                 # Take all the frames above the one for the call to eval, which
#                 # will have the file name <string>.
#                 files = [ frame.filename for frame in traceback.extract_tb(tb) ]
#                 index = files.index('<string>')
#                 limit = index - len(files) + 1
#                 tbo = traceback.format_exception(etype, evalue, tb, limit=limit)
                
#                 _print_message(line + ": " + text_as_text, "\n".join(tbo))

#     if original_count == _message_count:
#         if line.strip() != '':
#             print("\u001b[35;1mPassed all tests for " + line + "!\u001b[0m")
#         else:
#             print("\u001b[35;1mPassed all tests!\u001b[0m")
                
#     return None