# Imports

In [1]:
# export
from collections import namedtuple, defaultdict
import os
import re
from nbdev.imports import *

# Helpers

In [2]:
# hide
def run_tests(cases, func, verbose=False):
    nr_correct = 0
    for i, (n, c, r) in enumerate(cases):
        if verbose: print(f'({i + 1} / {len(cases)}) TEST {n}:')
        try:
            res = func(c)
            assert res == r, f'TEST FAILED WITH RESULT: {res}\nEXPECTED: {r}'
            nr_correct += (res == r)
            if verbose: print(f'TEST RESULT: SUCCESS\n')
        except Exception as e:
            if verbose: print(f'TEST FAILED WITH EXCEPTION:\n{e}\n')
            # raise e
    print('--------------- ALL TESTS COMPLETED ---------------')
    print(f'{nr_correct} / {len(cases)} Correct')

# Init

In [3]:
create_config('nbdev-rewrite', 'flpeters', nbs_path='.')

In [4]:
if not os.environ.get("IN_TEST", None):
    assert IN_NOTEBOOK
    assert not IN_COLAB
    assert IN_IPYTHON

# Notebook Loading

In [7]:
#export
def read_nb(fname):
    "Read the notebook in `fname`."
    with open(Path(fname),'r', encoding='utf8') as f: return nbformat.reads(f.read(), as_version=4)

In [8]:
test_nb = read_nb('00_export.ipynb')

In [9]:
test_nb.keys()

dict_keys(['cells', 'metadata', 'nbformat', 'nbformat_minor'])

In [10]:
test_nb['metadata']

{'kernelspec': {'display_name': 'Python 3',
  'language': 'python',
  'name': 'python3'},
 'language_info': {'codemirror_mode': {'name': 'ipython', 'version': 3},
  'file_extension': '.py',
  'mimetype': 'text/x-python',
  'name': 'python',
  'nbconvert_exporter': 'python',
  'pygments_lexer': 'ipython3',
  'version': '3.7.3'},
 'toc': {'base_numbering': 1,
  'nav_menu': {},
  'number_sections': True,
  'sideBar': True,
  'skip_h1_title': False,
  'title_cell': 'Table of Contents',
  'title_sidebar': 'Contents',
  'toc_cell': False,
  'toc_position': {},
  'toc_section_display': True,
  'toc_window_display': False}}

In [11]:
f"{test_nb['nbformat']}.{test_nb['nbformat_minor']}"

'4.4'

In [12]:
test_nb['cells'][0]

{'cell_type': 'markdown', 'metadata': {}, 'source': '# Imports'}

In [14]:
len(test_nb['cells'])

53

# Keyword Comments

`detect_comments()` is used to find and extract all comments from a code block.  
It's main purpose is to avoid matching on "comments" that are actually just part of a string, and not real python comments. One example would be: 

    """
    # export
    """
A naive parser would see the literal "#" and match that statement. In reality however, this code snippet is a string, and might be e.g. part of a test suit (which is how this bug was found in the first place), and not really meant to be exported.

This detects all comments in a piece of code, excluding those that are a part of a string.  
View https://docs.python.org/3/reference/lexical_analysis.html#strings for more info

In [125]:
# from numba import jit

In [124]:
# @jit
def iter_comments(src:str, pure_comments_only:bool=True, line_limit=None):
    in_lstr = in_sstr = False
    count, quote = 1, ''
    for i, line in enumerate(src.splitlines()[:line_limit]):
        is_pure, escape, prev_c = True, False, '\n'
        for j, c in enumerate(line):
            # we can't break as soon as not is_pure, because we have to detect if a multiline string beginns
            if is_pure and (not (c.isspace() or c == '#')): is_pure = False
            if (in_sstr or in_lstr):
                # assert (in_sstr and not in_lstr) or (in_lstr and not in_sstr)
                if escape: count = 0
                else:
                    if (c == quote):
                        count = ((count + 1) if (c == prev_c) else 1)
                        if in_sstr: in_sstr = False
                        elif (in_lstr and (count == 3)): count, in_lstr = 0, False
                escape = False if escape else (c == '\\')
            else:                    
                if (c == '#'):
                    if (pure_comments_only and is_pure): yield (line, (i, j))
                    elif (not pure_comments_only):       yield (line[j:], (i, j))
                    break
                elif c == "'" or c == '"':
                    count = ((count + 1) if (c == prev_c) else 1)
                    if count == 1: in_sstr = True
                    elif count == 3: count, in_lstr = 0, True
                    else: pass#raise SyntaxError(f'Unexpected quote repetition count: {count} Should be either 1 or 3')
                    quote = c
            prev_c = c

In [100]:
def test_iter_comments(src): return list(iter_comments(src, True))
test_strings = [
("trippe quote(''')", """'''
# string
'''""", []),
('tripple quote(""")', '''"""
#string
"""''', []),
('single quote(")', '"\
\n#string\n\
"', []),
("single quote(')", "'\
\n#string\n\
'", []),
("simple comment", """
#comment
""", [('#comment', (1, 0))]),
("comment sandwich", """
'this is a string'
# this is a comment
'another string , but between is an actual comment'
""", [('# this is a comment', (2, 0))]),
("tricky case 2", """
  a #non-pure comment
'''
#string
'''
####comment""", [('####comment', (5, 0))]),
("end of string quote", """
'''
'
# still part of the string
'
'''
""", []),
("single end of string escape", """
'\\'\\\n#str\\\n\\''
""", []),
("weird escape sequence", """
'\\\n\\''
""", []),
("long end of string escape", """
'''
# string
\\'''
# string
'''
""", []),
("raw string escape", """
r'''
# string
\\'''
# string
'''
""", []),
("multiple strings", """
'''a''''''b'''
""", []),
]
run_tests(test_strings, test_iter_comments, verbose=False)

--------------- ALL TESTS COMPLETED ---------------
13 / 13 Correct


In [42]:
# export
class KeywordParser:
    def __init__(self, *init_keywords):
        self.parsers = {}
        for kw in init_keywords: self.parsers[kw] = self._create_parser(kw)

    def _create_parser(self, keyword):
        # TODO: decide on the syntax
        # TODO: Should there be any whitespace allowed before special comments?
        # TODO: Should more than one "#" be allowed for special comments?
        pattern = fr"""
        ^              # start of line, since MULTILINE is passed
        \s*            # any amount of whitespace
        \#+\s*          # literal "#", then any amount of whitespace
        {keyword}(.*)  # keyword followed by arbitrary symbols (except new line)
        $              # end of line, since MULTILINE is passed
        """
        return re.compile(pattern, re.IGNORECASE | re.MULTILINE | re.VERBOSE)

    def __getitem__(self, key):
        if key in self.parsers: return self.parsers[key]
        else:
            parser = self._create_parser(key)
            self.parsers[key] = parser
            return parser
        
#     def search(self, key, text):
#         return self[key].search('\n'.join(detect_comments(text)))
        
#     def _search_remove(self, key, text):
#         print('WARNING: _search_remove() DOESN\'T WORK YET')
#         # TODO: This function is supposed to remove the keyword comment from the input
#         # TODO: detect_comments() has to be modified to allow for the positions to be returned
#         parser = self[key]
#         text, locations = detect_comments(text)
#         for comment, l in zip(text, locations):
#             res = parser.search(comment)
#             if res: return res, l

In [43]:
# export
OptionsTuple = namedtuple(typename='Options',
                          field_names=['export_target', 'internal'],
                          defaults=[None, False])

In [44]:
# export
_re_legacy_options = re.compile(fr'^(i)?\s*([a-zA-Z0-9]+\S*|)\s*$')
def legacy_parse_options(options:str) -> OptionsTuple:
    res = _re_legacy_options.search(options)
    if res:
        internal, export_target = res.groups()
        return OptionsTuple(export_target=(export_target if export_target else None), internal=(internal == 'i'))
    else: return None

In [45]:
# export
def parse_options(options:str, legacy:bool=True) -> OptionsTuple:
    if (options is None) or (options == '') or (options.isspace()): return OptionsTuple()
    else:
        if legacy:
            res = legacy_parse_options(options)
            if res: return res
        # TODO: New Syntax for specifying keyword options
        raise NotImplementedError('this branch of parse_options() is not implemented yet.')

In [111]:
# export
keyword_parser = KeywordParser()
def parse_export(source:str) -> (bool, OptionsTuple):
    # TODO: This should check for all visibility affecting keywords, and prioritise the top most
    #       That would allow the user to overwrite any unwanted cases
    export, hide = keyword_parser['export'], keyword_parser['hide']
    for comment, location in iter_comments(source):
        res = export.search(comment)
        if res: return (True, parse_options(res.groups()[0]))
        res = hide.search(comment)
        if res: return (False, None)
    return (False, None)

In [47]:
# export
def find_exports(cells:list, default:str, code_only:bool=True) -> list:
    # check for each cell if it's supposed to be exported and aggregate cell content together with export options
    # remove whitespace at end of lines
    exports = []
    for i, cell in enumerate(cells):
        if code_only and (cell.cell_type != 'code'): continue
        else:
            source = cell.source
            to_export, options = parse_export(source)
            if to_export:
                assert options.export_target or default, f'Cell nr.{i} doesn\'t have an export target, \
                                                           and a default is not specified:\n{source}'
                if not options.export_target: options = options._replace(export_target=default)
                exports.append((source, options))
            else: continue
    return exports

In [48]:
test_nb['cells'][0].keys()

dict_keys(['cell_type', 'metadata', 'source'])

In [49]:
test_nb['cells'][0]

{'cell_type': 'markdown', 'metadata': {}, 'source': '# Imports'}

In [50]:
find_exports(test_nb['cells'], 'export', code_only=True)

[('# export\nfrom collections import namedtuple, defaultdict\nimport os\nimport re\nfrom nbdev.imports import *',
  Options(export_target='export', internal=False)),
 ('#export\ndef read_nb(fname):\n    "Read the notebook in `fname`."\n    with open(Path(fname),\'r\', encoding=\'utf8\') as f: return nbformat.reads(f.read(), as_version=4)',
  Options(export_target='export', internal=False)),
  Options(export_target='export', internal=False)),
 ("# export\nOptionsTuple = namedtuple(typename='Options',\n                          field_names=['export_target', 'internal'],\n                          defaults=[None, False])",
  Options(export_target='export', internal=False)),
 ("# export\n_re_legacy_options = re.compile(fr'^(i)?\\s*([a-zA-Z0-9]+\\S*|)\\s*$')\ndef legacy_parse_options(options:str) -> OptionsTuple:\n    res = _re_legacy_options.search(options)\n    if res:\n        internal, export_target = res.groups()\n        return OptionsTuple(export_target=(export_target if export_targe

## Tests

In [51]:
test_strings = [
("trippe quote(''')", """'''
#export
'''""", (False, None)),
('tripple quote(""")', '''"""
#export
"""''', (False, None)),
('single quote(")', '"\
\n#export\n\
"', (False, None)),
("single quote(')", "'\
\n#export\n\
'", (False, None)),
("correct", """
#export
""", (True, OptionsTuple())),
("tricky case 1", """
'this is a string'
#export
'this also, but between is an actual comment'
""", (True, OptionsTuple())),
("tricky case 2", """
  a #export
'''
#export
'''
####export""", (True, OptionsTuple())),
("tricky case 3", """
'''
'
# export
'
'''
""", (False, None)),
("tricky case 4", """
'''
\'
# export
\'
'''
""", (False, None)),
]
run_tests(test_strings, parse_export)

--------------- ALL TESTS COMPLETED ---------------
9 / 9 Correct


In [53]:
test_markup = [
('export', """
# export
""", (True, OptionsTuple())),
('comment layout', """
#export
""", (True, OptionsTuple())),
('export internal legacy', """
# exporti
""", (True, OptionsTuple(internal=True))),
('export internal', """
# export -i
""", (True, OptionsTuple(internal=True))),
('export show source', """
# export -s
""", (True, OptionsTuple())),
('export internal show', """
# export -i -s
""", (True, OptionsTuple(internal=True))),
('default empty', """

""", (False, None)),
('hide', """
# hide
""", (False, None)),
('multiple comments', """
# export
# hide
""", (True, OptionsTuple())),
('multi comment same line', """
# export hide
""", (True, OptionsTuple(export_target='hide'))),
('multiple comments default_exp', """
# export
# default_exp
""", (True, OptionsTuple())),
]
run_tests(test_markup, parse_export, verbose=False)

--------------- ALL TESTS COMPLETED ---------------
8 / 11 Correct


# Names

In [55]:
# export
import ast
from ast import iter_fields, AST
import _ast
# from pprint import pprint

In [56]:
# hide
def print_tree(node):
    if isinstance(node, (list, tuple)):
        for x in node:
            print_tree(x)
    elif hasattr(node, '_fields'):
        for f in node._fields:
            # print(f)
            print_tree(node.__getattribute__(f))
    else:
        print(node)
        # pass

In [57]:
# hide
class TestParser(ast.NodeVisitor):
    def visit(self, node):
        method = 'visit_' + node.__class__.__name__
        visitor = getattr(self, method, self.generic_visit)
        print(f'{node.__class__.__name__} -> {visitor.__name__}')
        return visitor(node)
    
    def _default(self, node):
        # pprint(node.__dict__)
        print(f'attr:   {node._attributes}\nfields: {node._fields}\n{"-"*25}')
        
    def visit_Assign(self, node): print(node.targets[0].id) # self._default(node)
    
    def visit_FunctionDef(self, node):
        # self._default(node)
        print(node.name)
#         for d in node.decorator_list:
#             print(self.visit(d))
            
    def visit_ClassDef(self, node): print(node.name) # self._default(node)

In [58]:
# export
def remove_private_names(names):
    to_remove = {n for n in names if n.startswith('_')}
    return names.difference(to_remove)

In [243]:
def unwrap_attr(node:_ast.Attribute) -> str:
    if isinstance(node.value, _ast.Attribute): return '.'.join((unwrap_attr(node.value), node.attr))
    else: return '.'.join((node.value.id, node.attr))

In [332]:
# export
def update_recursive(node, names):
    """inplace, recursive updating of names"""
    if   isinstance(node, (_ast.List, _ast.Tuple)):
        for x in node.elts: update_recursive(x, names)
    elif isinstance(node, _ast.Name)   : names.append(node.id)
    elif isinstance(node, _ast.Starred): names.append(node.value.id)
    elif isinstance(node, _ast.Attribute) : names.append(unwrap_attr(node))
    elif isinstance(node, list):
        for x in node: update_recursive(x, names)
    else: raise SyntaxError(f'Can\'t resolve {node} to name, unknown type')

In [161]:
def not_private(name): return not (name.startswith('_') and (not name.startswith('__')))

In [333]:
def update_from_all_(node, names):
    if isinstance(node, _ast.Str): names.add(node.s)
    elif isinstance(node, _ast.Name): names.add(node.id)
    elif isinstance(node, _ast.Attribute): names.add(unwrap_attr(node))
    elif isinstance(node, (_ast.List, _ast.Tuple)):
        for x in node.elts: update_from_all_(x, names)
    elif isinstance(node, _ast.Starred): raise SyntaxError(f'Starred expression *{node.value.id} not allowed in _all_')
    else: raise SyntaxError(f'{node} {node._attributes} {node._fields}')

In [334]:
def add_names(node, names):
    tmp_names = list()
    update_recursive(node.targets, tmp_names)
    print(tmp_names)
    for name in tmp_names:
        if not_private(name): names.add(name)
        elif name == '_all_':
            assert len(tmp_names) == 1, 'reserved keyword _all_ can only be used in simple assignments'
            update_from_all_(node.value, out)
    print('------------')

In [165]:
# export
def find_names(code:str) -> list:
    tree = ast.parse(code) # @expensive
    names = set()
    for node in tree.body:
        if   isinstance(node, _ast.Assign): add_names(node, names) # update_recursive(node.targets, names)
        elif isinstance(node, (_ast.FunctionDef, _ast.ClassDef)) and not_private(node.name): names.add(node.name)
        else: pass
    # names = remove_private_names(names)
    return names

In [286]:
code = """_all_ = 'var_a'"""

In [298]:
code = """_all_ = ['var_a']"""

In [309]:
code = """_all_ = [*a]"""

In [322]:
code = """_all_ = ['var_a', var_b, a.b.c]"""

In [323]:
tree = ast.parse(code)

In [324]:
node = tree.body[0]

In [325]:
node._attributes, node._fields

(('lineno', 'col_offset'), ('targets', 'value'))

In [326]:
node.targets[0].id

'_all_'

In [277]:
node.value.elts

AttributeError: 'Str' object has no attribute 'elts'

In [260]:
attr = node.value.elts[2]

In [261]:
attr._attributes, attr._fields

AttributeError: 'list' object has no attribute '_attributes'

In [252]:
unwrap_attr(attr)

'a.b.c'

In [231]:
attr.value.value

<_ast.Name at 0x176042a0748>

In [227]:
attr.attr

'c'

In [225]:
attr.value, attr

(<_ast.Attribute at 0x176042a07b8>, <_ast.Attribute at 0x176042a0e48>)

In [254]:
find_names("""_all_ = ['var_a', var_b]""")

['_all_']
------------


set()

## Tests

In [251]:
test_assignment = [
('Default Assignment', """
a = 1
b = a
a = 2
""", {'a', 'b'}),
('Tuple unpacking', """
a, b = (1, 2)
""", {'a', 'b'}),
('unpacking to tuples and lists', """
(a, b) = (1, 2)
[a, b] = (1, 2)
""", {'a', 'b'}),
('unpacking to tuples and lists x2', """
([a], (b)) = (1, 2)
[[a, ((b))]] = (1, 2)
""", {'a', 'b'}),
('Multiple assignments', """
a = b = 2
""", {'a', 'b'}),
('List Deconstruction', """
head, *tail = [1,2,3,4,5]
""", {'head', 'tail'}),
('Private Variables', """
_a = 1
""", set()),
('Dunder Variables', """
__a = 1
""", {'__a'}),
('Attribues', """
a.b = 1
""", {'a.b'}),
('_all_ special keyword', """
_all_ = ['var_a', var_b, a.b, 'c.d', _abc]
""", {'var_a', 'var_b', 'a.b', 'c.d', '_abc'}),
]
run_tests(test_assignment, find_names, verbose=True)

(1 / 10) TEST Default Assignment:
['a']
------------
['b']
------------
['a']
------------
TEST RESULT: SUCCESS

(2 / 10) TEST Tuple unpacking:
['a', 'b']
------------
TEST RESULT: SUCCESS

(3 / 10) TEST unpacking to tuples and lists:
['a', 'b']
------------
['a', 'b']
------------
TEST RESULT: SUCCESS

(4 / 10) TEST unpacking to tuples and lists x2:
['a', 'b']
------------
['a', 'b']
------------
TEST RESULT: SUCCESS

(5 / 10) TEST Multiple assignments:
['a', 'b']
------------
TEST RESULT: SUCCESS

(6 / 10) TEST List Deconstruction:
['head', 'tail']
------------
TEST RESULT: SUCCESS

(7 / 10) TEST Private Variables:
['_a']
------------
TEST RESULT: SUCCESS

(8 / 10) TEST Dunder Variables:
['__a']
------------
TEST RESULT: SUCCESS

(9 / 10) TEST Attribues:
['a.b']
------------
TEST RESULT: SUCCESS

(10 / 10) TEST _all_ special keyword:
['_all_']
------------
TEST FAILED WITH EXCEPTION:
TEST FAILED WITH RESULT: set()
EXPECTED: {'var_a', 'c.d', 'a.b', 'var_b', '_abc'}

--------------- AL

In [170]:
test_funcdef = [
('Default function definition', """
def add(a, b):
    return a + b
""", {'add'}),
('Type Annotated function def', """
def calc(a:int, b:int) -> int:
    c:float = 2.0
    return (a + b) * c
""", {'calc'}),
('function decorators', """
@test1
@test2
def add(a, b):
    return a + b
""", {'add'}),
('@patch and more complex type annotations', """
@patch
def func (obj:(Class1, Class2), a:int)->int:
    pass
""", {'func'})
]
run_tests(test_funcdef, find_names)

--------------- ALL TESTS COMPLETED ---------------
4 / 4 Correct


In [171]:
test_classdef = [
('Default class definition', """
class Abc:
    pass
""", {'Abc'}),
('Default class def 2', """
class Abc():
    pass
""", {'Abc'}),
]
run_tests(test_classdef, find_names)

--------------- ALL TESTS COMPLETED ---------------
2 / 2 Correct


# Export

In [83]:
# export
class ExportCache:
    def __init__(self, default_export=None):
        self.tupletype = namedtuple(typename='exports', field_names=['export_code', 'export_names'])
        self.exports = defaultdict(self._create_exp)
        if default_export is not None: self[default_export]
    
    def _create_exp(self): return self.tupletype(export_code=list(), export_names=set())
    
    def __getitem__(self, key): return self.exports[key]
    
    def add_names(self, key, names): self[key].export_names.update(names)
            
    def add_code(self, key, code): self[key].export_code.append(code)

In [84]:
# export
def find_default_export(cells:list) -> str:
    # search through all cells to find the default_exp keyword and return it's value.
    # syntax checking
    # maybe do some sanity checking
    return 'export'
    pass

In [85]:
# export
def create_mod_file(orig_nbfname, targ_pyfname):
    # create the .py file in the correct folder, with a header saying where it was originally from
    pass

In [86]:
# export
def _notebook2script(cells=None, fname=None, silent=False, to_dict=False):
    """Convert a single notebook"""
    # if cells: print('WARNING: The Cells parameter is only used for testing purposes!')
    if fname is not None: raise NotImplementedError('WARNING: fname is a "must pass", but not yet')
    # load notebook content
    # load config
    default = find_default_export(cells)
    if default is None:
        print('WARNING: No default export file found! (should this crash, or see if each export has its own target?)')
    else:
        # maybe this should be done at the bottom, together with all the others
        # create_mod_file(original_nbfile_path, target_pyfile_path) # flipped in original code
        pass
    export_cache = ExportCache(default)
    # load _nbdev file and create a spec from it (no idea why this is needed)
    exports = find_exports(cells, default)
    for j, (code, options)  in enumerate(exports):
        # code = clean_code(code)
        e, i = options.export_target, options.internal
        if not i: export_cache.add_names(e, find_names(code))
        export_cache.add_code(e, code)
    # write_to_export_files(export_cache, default)
    # add names to _nbdev index
    # write code cell to file
    # save _nbdev file
    return export_cache

In [123]:
%%prun
ec = _notebook2script(test_nb['cells'])

 

In [160]:
for key in ec.exports.keys():
    exp = ec.exports[key]
    pprint(exp.export_code)
    print('')
    print(exp.export_names)

['# export\nfrom collections import namedtuple, defaultdict',
 '# export\nimport os',
 '# export\nimport re',
 '# export\n'
 'def run_tests(cases, func, verbose=False):\n'
 '    nr_correct = 0\n'
 '    for i, (n, c, r) in enumerate(cases):\n'
 "        if verbose: print(f'({i + 1} / {len(cases)}) TEST {n}:')\n"
 '        try:\n'
 '            res = func(c)\n'
 "            assert res == r, f'TEST FAILED WITH RESULT: {res}\\nEXPECTED: "
 "{r}'\n"
 '            nr_correct += (res == r)\n'
 "            if verbose: print(f'TEST RESULT: SUCCESS\\n')\n"
 '        except Exception as e:\n'
 "            if verbose: print(f'TEST FAILED WITH EXCEPTION:\\n{e}\\n')\n"
 '            # raise e\n'
 "    print('--------------- ALL TESTS COMPLETED ---------------')\n"
 "    print(f'{nr_correct} / {len(cases)} Correct')",
 '# export\nfrom nbdev.imports import *',
 '#export\n'
 'def read_nb(fname):\n'
 '    "Read the notebook in `fname`."\n'
 "    with open(Path(fname),'r', encoding='utf8') as f: retur

In [237]:
# export 
def notebook2script(fname=None, silent=False, to_dict=False):
    "Convert notebooks matching `fname` to modules"
    # initial checks
    if os.environ.get('IN_TEST',0): return  # don't export if running tests
    if fname is None:
        reset_nbdev_module()
        update_version()
        update_baseurl()
        files = [f for f in Config().nbs_path.glob('*.ipynb') if not f.name.startswith('_')]
    else: files = glob.glob(fname)
    d = collections.defaultdict(list) if to_dict else None
    for f in sorted(files): d = _notebook2script(f, silent=silent, to_dict=d)
    if to_dict: return d
    else: add_init(Config().lib_path)