In [76]:
# export
from nbdev.imports import *
from nbdev.core import *

In [77]:
# default_exp export
# default_cls_lvl 3

# Converting notebooks to modules

> The functions that transform notebooks in a library

## Reading a notebook

### What's a notebook?

A jupyter notebook is a json file behind the scenes. We can just read it with the json module, which will return a nested dictionary of dictionaries/lists of dictionaries, but there are some small differences between reading the json and using the tools from `nbformat` so we'll use this one.

In [78]:
#export
def read_nb(fname):
    "Read the notebook in `fname`."
    with open(Path(fname),'r', encoding='utf8') as f: return nbformat.reads(f.read(), as_version=4)

`fname` can be a string or a pathlib object.

In [79]:
test_nb = read_nb('01_export.ipynb')

The root has four keys: `cells` contains the cells of the notebook, `metadata` some stuff around the version of python used to execute the notebook, `nbformat` and `nbformat_minor` the version of nbformat. 

In [80]:
test_nb.keys()

dict_keys(['cells', 'metadata', 'nbformat', 'nbformat_minor'])

In [81]:
test_nb['metadata']

{'kernelspec': {'display_name': 'Python 3',
  'language': 'python',
  'name': 'python3'},
 'language_info': {'codemirror_mode': {'name': 'ipython', 'version': 3},
  'file_extension': '.py',
  'mimetype': 'text/x-python',
  'name': 'python',
  'nbconvert_exporter': 'python',
  'pygments_lexer': 'ipython3',
  'version': '3.7.4'}}

In [82]:
f"{test_nb['nbformat']}.{test_nb['nbformat_minor']}"

'4.4'

The cells key then contains a list of cells. Each one is a new dictionary that contains entries like the type (code or markdown), the source (what is written in the cell) and the output (for code cells).

In [83]:
test_nb['cells'][0]

{'cell_type': 'code',
 'execution_count': 1,
 'metadata': {'hide_input': False},
 'outputs': [],
 'source': '# export\nfrom nbdev.imports import *\nfrom nbdev.core import *'}

### Finding patterns

In [84]:
# export
def check_re(cell, pat, code_only=True):
    "Check if `cell` contains a line with regex `pat`"
    if code_only and cell['cell_type'] != 'code': return
    if isinstance(pat, str): pat = re.compile(pat, re.IGNORECASE | re.MULTILINE)
    return pat.search(cell['source'])

`pat` can be a string or a compiled regex, if `code_only=True`, ignores markdown cells.

In [85]:
cell = test_nb['cells'][0].copy()
assert check_re(cell, '# export') is not None
assert check_re(cell, re.compile('# export')) is not None
assert check_re(cell, '# bla') is None
cell['cell_type'] = 'markdown'
assert check_re(cell, '# export') is None
assert check_re(cell, '# export', code_only=False) is not None

In [86]:
# export
_re_blank_export = re.compile(r"""
# Matches any line with #export or #exports without any module name:
^         # beginning of line (since re.MULTILINE is passed)
\s*       # any number of whitespace
\#\s*     # # then any number of whitespace
exports?  # export or exports
\s*       # any number of whitespace
$         # end of line (since re.MULTILINE is passed)
""", re.IGNORECASE | re.MULTILINE | re.VERBOSE)

In [87]:
# export
_re_mod_export = re.compile(r"""
# Matches any line with #export or #exports with a module name and catches it in group 1:
^         # beginning of line (since re.MULTILINE is passed)
\s*       # any number of whitespace
\#\s*     # # then any number of whitespace
exports?  # export or exports
\s*       # any number of whitespace
(\S+)     # catch a group with any non-whitespace chars
\s*       # any number of whitespace
$         # end of line (since re.MULTILINE is passed)
""", re.IGNORECASE | re.MULTILINE | re.VERBOSE)

In [88]:
# export
def is_export(cell, default):
    "Check if `cell` is to be exported and returns the name of the module."
    if check_re(cell, _re_blank_export):
        if default is None:
            print(f"This cell doesn't have an export destination and was ignored:\n{cell['source'][1]}")
        return default
    tst = check_re(cell, _re_mod_export)
    return os.path.sep.join(tst.groups()[0].split('.')) if tst else None

The cells to export are marked with an `#export` or `#exports` code, potentially with a module name where we want it exported. The default is given in a cell of the form `#default_exp bla` inside the notebook (usually at the top), though in this function, it needs the be passed (the final script will read the whole notebook to find it).

In [89]:
cell = test_nb['cells'][0].copy()
test_eq(is_export(cell, 'export'), 'export')
cell['source'] = "# exports" 
test_eq(is_export(cell, 'export'), 'export')
cell['source'] = "# export mod" 
test_eq(is_export(cell, 'export'), 'mod')
cell['source'] = "# export mod.file" 
test_eq(is_export(cell, 'export'), 'mod/file')
cell['source'] = "# expt mod.file"
assert is_export(cell, 'export') is None

In [90]:
# export
_re_default_exp = re.compile(r"""
# Matches any line with #default_exp with a module name and catches it in group 1:
^            # beginning of line (since re.MULTILINE is passed)
\s*          # any number of whitespace
\#\s*        # # then any number of whitespace
default_exp  # export or exports
\s*          # any number of whitespace
(\S+)        # catch a group with any non-whitespace chars
\s*          # any number of whitespace
$            # end of line (since re.MULTILINE is passed)
""", re.IGNORECASE | re.MULTILINE | re.VERBOSE)

In [91]:
# export
def find_default_export(cells):
    "Find in `cells` the default export module."
    for cell in cells:
        tst = check_re(cell, _re_default_exp)
        if tst: return tst.groups()[0]

Stops at the first cell containing a `#default_exp` code and return the value behind. Returns `None` if there are no cell with that code.

In [92]:
test_eq(find_default_export(test_nb['cells']), 'export')
assert find_default_export(test_nb['cells'][2:]) is None

### Exporting notebooks

We're now ready to export notebooks!

In [93]:
# export
def _create_mod_file(fname, nb_path):
    "Create a module file for `fname`."
    fname.parent.mkdir(parents=True, exist_ok=True)
    with open(fname, 'w') as f:
        f.write(f"#AUTOGENERATED! DO NOT EDIT! File to edit: dev/{nb_path.name} (unless otherwise specified).")
        f.write('\n\n__all__ = []')

In [94]:
#export
_re_patch_func = re.compile(r"""
# Catches any function decorated with @patch, its name in group 1 and the patched class in group 2
@patch         # At any place in the cell, something that begins with @patch
\s*def         # Any number of whitespace (including a new line probably) followed by def
\s+            # One whitespace or more
([^\(\s]*)     # Catch a group composed of anything but whitespace or an opening parenthesis (name of the function)
\s*\(          # Any number of whitespace followed by an opening parenthesis
[^:]*          # Any number of character different of : (the name of the first arg that is type-annotated)
:\s*           # A column followed by any number of whitespace
(?:            # Non-catching group with either
([^,\s\(\)]*)  #    a group composed of anything but a comma, a parenthesis or whitespace (name of the class)
|              #  or
(\([^\)]*\)))  #    a group composed of something between parenthesis (tuple of classes)
\s*            # Any number of whitespace
(?:,|\))       # Non-catching group with either a comma or a closing parenthesis
""", re.VERBOSE)

In [95]:
#hide
tst = _re_patch_func.search("""
@patch
def func(obj:Class):""")
test_eq(tst.groups(), ("func", "Class", None))
tst = _re_patch_func.search("""
@patch
def func (obj:Class, a)""")
test_eq(tst.groups(), ("func", "Class", None))
tst = _re_patch_func.search("""
@patch
def func (obj:(Class1, Class2), a)""")
test_eq(tst.groups(), ("func", None, "(Class1, Class2)"))

In [96]:
#export
_re_typedispatch_func = re.compile(r"""
# Catches any function decorated with @typedispatch
(@typedispatch  # At any place in the cell, catch a group with something that begins with @patch
\s*def          # Any number of whitespace (including a new line probably) followed by def
\s+             # One whitespace or more
[^\(]*          # Anything but whitespace or an opening parenthesis (name of the function)
\s*\(           # Any number of whitespace followed by an opening parenthesis
[^\)]*          # Any number of character different of )
\)\s*:)         # A closing parenthesis followed by whitespace and :
""", re.VERBOSE)

In [97]:
#hide
assert _re_typedispatch_func.search("@typedispatch\ndef func(a, b):").groups() == ('@typedispatch\ndef func(a, b):',)

In [98]:
#export
_re_class_func_def = re.compile(r"""
# Catches any 0-indented function or class definition with its name in group 1
^              # Beginning of a line (since re.MULTILINE is passed)
(?:def|class)  # Non-catching group for def or class
\s+            # One whitespace or more
([^\(\s]*)     # Catching group with any character except an opening parenthesis or a whitespace (name)
\s*            # Any number of whitespace
(?:\(|:)       # Non-catching group with either an opening parenthesis or a : (classes don't need ())
""", re.MULTILINE | re.VERBOSE)

In [99]:
#hide
test_eq(_re_class_func_def.search("class Class:").groups(), ('Class',))
test_eq(_re_class_func_def.search("def func(a, b):").groups(), ('func',))

In [100]:
#export
_re_obj_def = re.compile(r"""
# Catches any 0-indented object definition (bla = thing) with its name in group 1
^          # Beginning of a line (since re.MULTILINE is passed)
([^=\s]*)  # Catching group with any character except a whitespace or an equal sign
\s*=       # Any number of whitespace followed by an =
""", re.MULTILINE | re.VERBOSE)

In [101]:
#hide
test_eq(_re_obj_def.search("a = 1").groups(), ('a',))
test_eq(_re_obj_def.search("a=1").groups(), ('a',))

In [102]:
# export
def _not_private(n):
    for t in n.split('.'):
        if (t.startswith('_') and not t.startswith('__')) or t.startswith('@'): return False
    return '\\' not in t and '^' not in t and '[' not in t

def export_names(code, func_only=False):
    "Find the names of the objects, functions or classes defined in `code` that are exported."
    #Format monkey-patches with @patch
    def _f(gps):
        nm, cls, t = gps.groups()
        if cls is not None: return f"def {cls}.{nm}():"
        return '\n'.join([f"def {c}.{nm}():" for c in re.split(', *', t[1:-1])])

    code = _re_typedispatch_func.sub('', code)
    code = _re_patch_func.sub(_f, code)
    names = _re_class_func_def.findall(code)
    if not func_only: names += _re_obj_def.findall(code)
    return [n for n in names if _not_private(n)]

This function only picks the zero-indented objects, functions or classes (we don't want the class methods for instance) and excludes private names (that begin with `_`). It only returns func and class names when `func_only=True`.

In [103]:
test_eq(export_names("def my_func(x):\n  pass\nclass MyClass():"), ["my_func", "MyClass"])

#Indented funcs are ignored (funcs inside a class)
test_eq(export_names("  def my_func(x):\n  pass\nclass MyClass():"), ["MyClass"])

#Private funcs are ignored, dunder are not
test_eq(export_names("def _my_func():\n  pass\nclass MyClass():"), ["MyClass"])
test_eq(export_names("__version__ = 1:\n  pass\nclass MyClass():"), ["MyClass", "__version__"])

#trailing spaces
test_eq(export_names("def my_func ():\n  pass\nclass MyClass():"), ["my_func", "MyClass"])

#class without parenthesis
test_eq(export_names("def my_func ():\n  pass\nclass MyClass:"), ["my_func", "MyClass"])

#object and funcs
test_eq(export_names("def my_func ():\n  pass\ndefault_bla=[]:"), ["my_func", "default_bla"])
test_eq(export_names("def my_func ():\n  pass\ndefault_bla=[]:", func_only=True), ["my_func"])

#Private objects are ignored
test_eq(export_names("def my_func ():\n  pass\n_default_bla = []:"), ["my_func"])

#Objects with dots are privates if one part is private
test_eq(export_names("def my_func ():\n  pass\ndefault.bla = []:"), ["my_func", "default.bla"])
test_eq(export_names("def my_func ():\n  pass\ndefault._bla = []:"), ["my_func"])

#Monkey-path with @patch are properly renamed
test_eq(export_names("@patch\ndef my_func(x:Class):\n  pass"), ["Class.my_func"])
test_eq(export_names("@patch\ndef my_func(x:Class):\n  pass", func_only=True), ["Class.my_func"])
test_eq(export_names("some code\n@patch\ndef my_func(x:Class, y):\n  pass"), ["Class.my_func"])
test_eq(export_names("some code\n@patch\ndef my_func(x:(Class1,Class2), y):\n  pass"), ["Class1.my_func", "Class2.my_func"])

#Check delegates
test_eq(export_names("@delegates(keep=True)\nclass someClass:\n  pass"), ["someClass"])

#Typedispatch decorated functions shouldn't be added
test_eq(export_names("@patch\ndef my_func(x:Class):\n  pass\n@typedispatch\ndef func(x: TensorImage): pass"), ["Class.my_func"])

In [104]:
#export
_re_all_def   = re.compile(r"""
# Catches a cell with defines \_all\_ = [\*\*] and get that \*\* in group 1
^_all_   #  Beginning of line (since re.MULTILINE is passed)
\s*=\s*  #  Any number of whitespace, =, any number of whitespace
\[       #  Opening [
([^\n\]]*) #  Catching group with anything except a ] or newline
\]       #  Closing ]
""", re.MULTILINE | re.VERBOSE)

#Same with __all__
_re__all__def = re.compile(r'^__all__\s*=\s*\[([^\]]*)\]', re.MULTILINE)

In [105]:
# export
def extra_add(code):
    "Catch adds to `__all__` required by a cell with `_all_=`"
    if _re_all_def.search(code):
        names = _re_all_def.search(code).groups()[0]
        names = re.sub('\s*,\s*', ',', names)
        names = names.replace('"', "'")
        code = _re_all_def.sub('', code)
        code = re.sub(r'([^\n]|^)\n*$', r'\1', code)
        return names.split(','),code
    return [],code

In [106]:
test_eq(extra_add('_all_ = ["func", "func1", "func2"]'), (["'func'", "'func1'", "'func2'"],''))
test_eq(extra_add('_all_ = ["func",   "func1" , "func2"]'), (["'func'", "'func1'", "'func2'"],''))
test_eq(extra_add("_all_ = ['func','func1', 'func2']\n"), (["'func'", "'func1'", "'func2'"],''))
test_eq(extra_add('code\n\n_all_ = ["func", "func1", "func2"]'), (["'func'", "'func1'", "'func2'"],'code'))

In [107]:
#export
def _add2add(fname, names, line_width=120):
    if len(names) == 0: return
    with open(fname, 'r', encoding='utf8') as f: text = f.read()
    tw = TextWrapper(width=120, initial_indent='', subsequent_indent=' '*11, break_long_words=False)
    re_all = _re__all__def.search(text)
    start,end = re_all.start(),re_all.end()
    text_all = tw.wrap(f"{text[start:end-1]}{'' if text[end-2]=='[' else ', '}{', '.join(names)}]")
    with open(fname, 'w', encoding='utf8') as f: f.write(text[:start] + '\n'.join(text_all) + text[end:])

In [108]:
fname = 'test_add.txt'
with open(fname, 'w', encoding='utf8') as f: f.write("Bla\n__all__ = [my_file, MyClas]\nBli")
_add2add(fname, ['new_function'])
with open(fname, 'r', encoding='utf8') as f: 
    test_eq(f.read(), "Bla\n__all__ = [my_file, MyClas, new_function]\nBli")
_add2add(fname, [f'new_function{i}' for i in range(10)])
with open(fname, 'r', encoding='utf8') as f: 
    test_eq(f.read(), """Bla
__all__ = [my_file, MyClas, new_function, new_function0, new_function1, new_function2, new_function3, new_function4,
           new_function5, new_function6, new_function7, new_function8, new_function9]
Bli""")
os.remove(fname)

In [109]:
# export
def _relative_import(name, fname):
    mods = name.split('.')
    splits = str(fname).split(os.path.sep)
    if mods[0] not in splits: return name
    i=len(splits)-1
    while i>0 and splits[i] != mods[0]: i-=1
    splits = splits[i:]
    while len(mods)>0 and splits[0] == mods[0]: splits,mods = splits[1:],mods[1:]
    return '.' * (len(splits)) + '.'.join(mods)

In [110]:
test_eq(_relative_import('nbdev.core', Path.cwd()/'nbdev'/'data.py'), '.core')
test_eq(_relative_import('nbdev.core', Path('nbdev')/'vision'/'data.py'), '..core')
test_eq(_relative_import('nbdev.vision.transform', Path('nbdev')/'vision'/'data.py'), '.transform')
test_eq(_relative_import('nbdev.notebook.core', Path('nbdev')/'data'/'external.py'), '..notebook.core')
test_eq(_relative_import('nbdev.vision', Path('nbdev')/'vision'/'learner.py'), '.')

In [111]:
#export
#Catches any from nbdev.bla import something and catches nbdev.bla in group 1, the imported thing(s) in group 2.
_re_import = re.compile(r'^(\s*)from (' + Config().lib_name + '.\S*) import (.*)$')

In [112]:
# export
def _deal_import(code_lines, fname):
    def _replace(m):
        sp,mod,obj = m.groups()
        return f"{sp}from {_relative_import(mod, fname)} import {obj}"
    return [_re_import.sub(_replace,line) for line in code_lines]

In [113]:
#hide
lines = ["from nbdev.core import *", "nothing to see", "  from nbdev.vision import bla1, bla2", "from nbdev.vision import models"]
test_eq(_deal_import(lines, Path.cwd()/'nbdev'/'data.py'), [
    "from .core import *", "nothing to see", "  from .vision import bla1, bla2", "from .vision import models"
])

In [114]:
#export
def reset_nbdev_module():
    fname = Config().lib_path/'_nbdev.py'
    fname.parent.mkdir(parents=True, exist_ok=True)
    with open(fname, 'w') as f:
        f.write(f"#AUTOGENERATED BY NBDEV! DO NOT EDIT!")
        f.write('\n\n__all__ = ["index", "modules"]')
        f.write('\n\nindex = {}')
        f.write('\n\nmodules = []')

In [115]:
# export
def get_nbdev_module():
    try: 
        mod = importlib.import_module(f'{Config().lib_name}._nbdev')
        mod = importlib.reload(mod)
        return mod
    except:
        print("Run `reset_nbdev_module` to create an empty skeletton.")

In [116]:
#export
_re_index_idx = re.compile(r'index\s*=\s*{[^}]*}')
_re_index_mod = re.compile(r'modules\s*=\s*\[[^\]]*\]')

In [117]:
#export
def save_nbdev_module(mod):
    fname = Config().lib_path/'_nbdev.py'
    with open(fname, 'r') as f: code = f.read()
    t = ',\n         '.join([f'"{k}": "{v}"' for k,v in mod.index.items()])
    code = _re_index_idx.sub("index = {"+ t +"}", code)
    t = ',\n           '.join([f'"{f}"' for f in mod.modules])
    code = _re_index_mod.sub(f"modules = [{t}]", code)
    with open(fname, 'w') as f: f.write(code)

In [118]:
#hide
ind,ind_bak = Config().lib_path/'_nbdev.py',Config().lib_path/'_nbdev.bak'
if ind.exists(): shutil.move(ind, ind_bak)
try:
    reset_nbdev_module()
    mod = get_nbdev_module()
    test_eq(mod.index, {})
    test_eq(mod.modules, [])

    mod.index = {'foo':'bar'}
    mod.modules.append('lala.bla')
    save_nbdev_module(mod)

    mod = get_nbdev_module()
    test_eq(mod.index, {'foo':'bar'})
    test_eq(mod.modules, ['lala.bla'])
finally:
    if ind_bak.exists(): shutil.move(ind_bak, ind)

In [119]:
#export
def _notebook2script(fname, silent=False, to_dict=None):
    "Finds cells starting with `#export` and puts them into a new module"
    if os.environ.get('IN_TEST',0): return  # don't export if running tests
    fname = Path(fname)
    nb = read_nb(fname)
    default = find_default_export(nb['cells'])
    if default is not None:
        default = os.path.sep.join(default.split('.'))
        if to_dict is None: _create_mod_file(Config().lib_path/f'{default}.py', fname)
    mod = get_nbdev_module()
    exports = [is_export(c, default) for c in nb['cells']]
    cells = [(i,c,e) for i,(c,e) in enumerate(zip(nb['cells'],exports)) if e is not None]
    for i,c,e in cells:
        fname_out = Config().lib_path/f'{e}.py'
        orig = ('#C' if e==default else f'#Comes from {fname.name}, c') + 'ell\n'
        code = '\n\n' + orig + '\n'.join(_deal_import(c['source'].split('\n')[1:], fname_out))
        # remove trailing spaces
        names = export_names(code)
        extra,code = extra_add(code)
        if to_dict is None: _add2add(fname_out, [f"'{f}'" for f in names if '.' not in f and len(f) > 0] + extra)
        mod.index.update({f: fname.name for f in names})
        code = re.sub(r' +$', '', code, flags=re.MULTILINE)
        if code != '\n\n' + orig[:-1]:
            if to_dict is not None: to_dict[fname_out].append((i, fname, code))
            else:
                with open(fname_out, 'a', encoding='utf8') as f: f.write(code)
        if f'{e}.py' not in mod.modules: mod.modules.append(f'{e}.py')
    save_nbdev_module(mod)
    
    if not silent: print(f"Converted {fname.name}.")
    return to_dict

In [120]:
_notebook2script('01_export.ipynb')

Converted 01_export.ipynb.


In [121]:
#export 
def notebook2script(fname=None, silent=False, to_dict=False):
    "Convert `fname` or all the notebook satisfying `all_fs`."
    # initial checks
    if os.environ.get('IN_TEST',0): return  # don't export if running tests
    if fname is None: 
        reset_nbdev_module()
        files = [f for f in Config().nbs_path.glob('*.ipynb') if not f.name.startswith('_')]
    else: files = glob.glob(fname)
    d = collections.defaultdict(list) if to_dict else None
    for f in files: d = _notebook2script(f, silent=silent, to_dict=d)
    if to_dict: return d

Finds cells starting with `#export` and puts them into the appropriate module.
* `fname`: the filename of one notebook to convert or a glob expression, will default to all notebooks that don't being with an _

Examples of use in console:
```
notebook2script                                 # Parse all files
notebook2script --fname 00_export.ipynb         # Parse 00_export.ipynb
notebook2script --fname=nb*                     # Parse all files starting with nb*
```

### Finding the way back to notebooks

We need to get the name of the object we are looking for, and then we'll try to find it in our index file.

In [122]:
#export 
def _get_property_name(p):
    "Get the name of property `p`"
    if hasattr(p, 'fget'):
        return p.fget.func.__qualname__ if hasattr(p.fget, 'func') else p.fget.__qualname__
    else: return next(iter(re.findall(r'\'(.*)\'', str(p)))).split('.')[-1]

def get_name(obj):
    "Get the name of `obj`"
    if hasattr(obj, '__name__'):       return obj.__name__
    elif getattr(obj, '_name', False): return obj._name
    elif hasattr(obj,'__origin__'):    return str(obj.__origin__).split('.')[-1] #for types
    elif type(obj)==property:          return _get_property_name(obj)
    else:                              return str(obj).split('.')[-1]

In [123]:
# export
def qual_name(obj):
    "Get the qualified name of `obj`"
    if hasattr(obj,'__qualname__'): return obj.__qualname__
    if inspect.ismethod(obj):       return f"{get_name(obj.__self__)}.{get_name(fn)}"
    return get_name(obj)

In [124]:
test_eq(get_name(in_ipython), 'in_ipython')
test_eq(get_name(DocsTestClass.test), 'test')

For properties defined using `property` or our own `add_props` helper, we approximate the name by looking at their getter functions, since we don't seem to have access to the property name itself. If everything fails (a getter cannot be found), we return the name of the object that contains the property. This suffices for `source_nb` to work.

In [125]:
#hide
class PropertyClass:
    p_lambda = property(lambda x: x)
    def some_getter(self): return 7
    p_getter = property(some_getter)

test_eq(get_name(PropertyClass.p_lambda), 'PropertyClass.<lambda>')
test_eq(get_name(PropertyClass.p_getter), 'PropertyClass.some_getter')
test_eq(get_name(PropertyClass), 'PropertyClass')

In [126]:
# export
def source_nb(func, is_name=None, return_all=False):
    "Return the name of the notebook where `func` was defined"
    is_name = is_name or isinstance(func, str)
    index = get_nbdev_module().index
    name = func if is_name else qual_name(func)
    while len(name) > 0:
        if name in index: return (name,index[name]) if return_all else index[name]
        name = '.'.join(name.split('.')[:-1])

In [127]:
test_eq(qual_name(DocsTestClass), 'DocsTestClass')
test_eq(qual_name(DocsTestClass.test), 'DocsTestClass.test')

In [128]:
# export
_re_default_nb = re.compile(r'File to edit: dev/(\S+)\s+')
_re_cell = re.compile(r'^#Cell|^#Comes from\s+(\S+), cell')

You can either pass an object or its name (by default `is_name` will look if `func` is a string or not, but you can override if there is some inconsistent behavior). 

If passed a method of a class, the function will return the notebook in which the largest part of the function was defined in case there is a monkey-matching that defines `class.method` in a different notebook than `class`. If `return_all=True`, the function will return a tuple with the name by which the function was found and the notebook.

In [129]:
test_eq(source_nb(in_notebook), '00_core.ipynb')
test_eq(source_nb(DocsTestClass), '00_core.ipynb')
test_eq(source_nb(DocsTestClass.test), '00_core.ipynb')
assert source_nb(int) is None

## Reading the library

If someone decides to change a module instead of the notebooks, the following functions help update the notebooks accordingly.

In [130]:
# export
def _split(code):
    lines = code.split('\n')
    default_nb = _re_default_nb.search(lines[0]).groups()[0]
    s,res = 1,[]
    while _re_cell.search(lines[s]) is None: s += 1
    e = s+1
    while e < len(lines):
        while e < len(lines) and _re_cell.search(lines[e]) is None: e += 1
        grps = _re_cell.search(lines[s]).groups()
        nb = grps[0] or default_nb
        content = lines[s+1:e]
        while len(content) > 1 and content[-1] == '': content = content[:-1]
        res.append((nb, '\n'.join(content)))
        s,e = e,e+1
    return res

In [131]:
#export
def _relimport2name(name, mod_name):
    if mod_name.endswith('.py'): mod_name = mod_name[:-3]
    mods = mod_name.split(os.path.sep)
    i = last_index(Config().lib_name, mods)
    mods = mods[i:]
    if name=='.': return '.'.join(mods[:-1])
    i = 0
    while name[i] == '.': i += 1
    return '.'.join(mods[:-i] + [name[i:]])

In [132]:
# export
#Catches any from .bla import something and catches .bla in group 1, the imported thing(s) in group 2.
_re_loc_import = re.compile(r'(^\s*)from (\.\S*) import (.*)$')

In [133]:
test_eq(_relimport2name('.core', 'nbdev/data.py'), 'nbdev.core')
test_eq(_relimport2name('.core', 'home/sgugger/fastai_dev/nbdev/nbdev/data.py'), 'nbdev.core')
test_eq(_relimport2name('..core', 'nbdev/vision/data.py'), 'nbdev.core')
test_eq(_relimport2name('.transform', 'nbdev/vision/data.py'), 'nbdev.vision.transform')
test_eq(_relimport2name('..notebook.core', 'nbdev/data/external.py'), 'nbdev.notebook.core')

In [138]:
#export
def _deal_loc_import(code, fname):
    def _replace(m):
        sp,mod,obj = m.groups()
        return f"{sp}from {_relimport2name(mod, fname)} import {obj}"
    return '\n'.join([_re_loc_import.sub(_replace,line) for line in code.split('\n')])

In [139]:
#hide
code = "from .core import *\nnothing to see\n  from .vision import bla1, bla2"
test_eq(_deal_loc_import(code, 'nbdev/data.py'), "from nbdev.core import *\nnothing to see\n  from nbdev.vision import bla1, bla2")

In [72]:
#export
def _script2notebook(fname, dic, silent=False):
    "Put the content of `fname` back in the notebooks it came from."
    if os.environ.get('IN_TEST',0): return  # don't export if running tests
    fname = Path(fname)
    with open(fname, encoding='utf8') as f: code = f.read()
    splits = _split(code)
    assert len(splits)==len(dic[fname]), f"Exported file from notebooks should have {len(dic[fname])} cells but has {len(splits)}."
    assert np.all([c1[0]==c2[1]] for c1,c2 in zip(splits, dic[fname]))
    splits = [(c2[0],c1[0],c1[1]) for c1,c2 in zip(splits, dic[fname])]
    nb_fnames = {Config().nbs_path/s[1] for s in splits}
    for nb_fname in nb_fnames:
        nb = read_nb(nb_fname)
        for i,f,c in splits:
            c = _deal_loc_import(c, str(fname))
            if f == nb_fname:
                l = nb['cells'][i]['source'].split('\n')[0]
                nb['cells'][i]['source'] = l + '\n' + c
        NotebookNotary().sign(nb)
        nbformat.write(nb, str(nb_fname), version=4)
    
    if not silent: print(f"Converted {fname.relative_to(Config().lib_path)}.")

In [73]:
dic = notebook2script(silent=True, to_dict=True)
_script2notebook(Config().lib_path/'core.py', dic)

Converted core.py.


In [63]:
#export
def script2notebook(fname=None, silent=False):
    if os.environ.get('IN_TEST',0): return
    dic = notebook2script(silent=True, to_dict=True)
    exported = get_nbdev_module().modules
    
    if fname is None: 
        files = [f for f in Config().lib_path.glob('**/*.py') if str(f.relative_to(Config().lib_path)) in exported]
    else: files = glob.glob(fname)
    [ _script2notebook(f, dic, silent=silent) for f in files]

In [64]:
script2notebook()

Converted showdoc.py.
Converted test.py.
Converted core.py.
Converted export.py.
Converted export2html.py.


## Diff notebook - library

In [71]:
#export
import subprocess
from distutils.dir_util import copy_tree

In [107]:
#export
def diff_nb_script():
    "Print the diff between the notebooks and the library in `lib_folder`"
    lib_folder = Config().lib_path
    with tempfile.TemporaryDirectory() as d1, tempfile.TemporaryDirectory() as d2:
        copy_tree(Config().lib_path, d1)
        notebook2script(silent=True)
        copy_tree(Config().lib_path, d2)
        shutil.rmtree(Config().lib_path)
        shutil.copytree(d1, str(Config().lib_path))
        for d in [d1, d2]:
            if (Path(d)/'__pycache__').exists(): shutil.rmtree(Path(d)/'__pycache__')
        res = subprocess.run(['diff', '-ru', d1, d2], stdout=subprocess.PIPE)
        print(res.stdout.decode('utf-8'))

In [110]:
diff_nb_script()




## Export

In [140]:
#hide
notebook2script()

Converted 00_core.ipynb.
Converted 03_export2html.ipynb.
Converted 04_test.ipynb.
Converted 01_export.ipynb.
Converted 02_showdoc.ipynb.
