In [1]:
    import traitlets, unittest, doctest, pidgy.base, re, ast, contextlib, IPython, inspect, sys, textwrap

In [2]:
    class NullOutputCheck(doctest.OutputChecker):
        def check_output(self, *e): return True

    class InlineDoctestParser(doctest.DocTestParser):
        _EXAMPLE_RE = re.compile(r'`(?P<indent>\s{0})'
    r'(?P<source>[^`].*?)'
    r'`')
        def _parse_example(self, m, name, lineno): return m.group('source'), None, "...", None



In [3]:
    @contextlib.contextmanager
    def ipython_compiler(shell):
        def compiler(input, filename, symbol, *args, **kwargs):
            nonlocal shell
            return shell.compile(
                ast.Interactive(
                    body=shell.transform_ast(
                    shell.compile.ast_parse(shell.transform_cell(textwrap.indent(input, ' '*4)))
                ).body),
                F"In[{shell.last_execution_result.execution_count}]",
                "single",
            )

        yield setattr(doctest, "compile", compiler)
        doctest.compile = compile

In [4]:
    import doctest

In [5]:
    class Collect:
        def collect_unittest(self, object):
            return unittest.defaultTestLoader.loadTestsFromTestCase(object)

        def collect_function(self, object):
            return unittest.FunctionTestCase(object)
        
        def collect_doctest(self, object, vars, name):
            doctest_suite = doctest.DocTestSuite()
            test_case = doctest.DocTestParser().get_doctest(object, vars, name, name, 1)
            test_case.examples and doctest_suite.addTest(doctest.DocTestCase(test_case, doctest.ELLIPSIS))
            test_case = InlineDoctestParser().get_doctest(object, vars, name, name, 1)
            test_case.examples and doctest_suite.addTest(doctest.DocTestCase(test_case, checker=NullOutputCheck))
            if doctest_suite._tests: return doctest_suite
            
        def collect(self, *objects, vars,name):
            suite = unittest.TestSuite()
            for object in objects:
                if isinstance(object, type) and issubclass(object, unittest.TestCase):
                    object = self.collect_unittest(object)
                elif isinstance(object, str):
                    object = self.collect_doctest(object, vars, name)
                elif inspect.isfunction(object):
                    object = self.collect_function(object)
                else: continue
                if object is not None:
                    suite.addTest(object)
                    
            suite._tests = [x for x in suite._tests if x]
            return suite

In [6]:
    class Definitions(pidgy.base.Trait, ast.NodeTransformer):
        def visit_FunctionDef(self, node):
            self.parent.medial_test_definitions.append(node.name)
            return node
        visit_ClassDef = visit_FunctionDef


In [7]:
    class Register:
        def register(self):
            if not any(x for x in self.parent.ast_transformers if isinstance(x, type(self))):
                self.parent.ast_transformers.append(self.visitor)
        def unregister(self):
            self.parent.ast_transformers = [x for x in self.parent.ast_transformers if x is not self.visitor]

In [8]:
    import re

    santize_doctest = re.compile(
        r"""Traceback.+\n\s*File "[^In]\S+", line [0-9]+.+in.+\n.+""", re.MULTILINE
    )

    def clean_doctest_traceback(str):
        return re.sub(santize_doctest, "", str).lstrip().replace("-" * 70, "\n")


In [9]:
    class TestingBase(Register, Collect):
        def display_result(self, result, test_result):
            if test_result.failures:
                msg = '\n'.join(msg for text, msg in test_result.failures)
                #msg = re.sub(re.compile("<ipython-input-[0-9]+-\S+>"), F'In[{result.execution_count}]', pidgy.util.clean_doctest_traceback(msg))
                sys.stderr.writelines((str(test_result) + '\n' + msg).splitlines(True))


        def test(self, result, *object):
            globs, filename = self.parent.user_ns, F"In[{self.parent.last_execution_result.execution_count}]"
            test_result = unittest.TestResult()
            self.collect(result.info.raw_cell, *object, vars=self.parent.user_ns, name=filename).run(test_result)
            self.display_result(result, test_result)

In [10]:
    class Testing(pidgy.base.Trait, TestingBase):
        medial_test_definitions = traitlets.List()
        pattern = traitlets.Unicode('test_')
        visitor = traitlets.Instance('ast.NodeTransformer')
        
        @traitlets.default('visitor')
        def _default_visitor(self): return Definitions(parent=self)
        
        @pidgy.implementation
        def post_run_cell(self, result):
            if not (result.error_before_exec or result.error_in_exec):
                tests = []
                with ipython_compiler(self.parent):
                    while self.medial_test_definitions:
                        name = self.medial_test_definitions.pop(0)
                        object = self.parent.user_ns.get(name, None)
                        if name.startswith(self.pattern) or pidgy.util.istype(object, unittest.TestCase):
                            tests.append(object)
                    self.test(result, *tests)