# Argument Parsing [Module]

In [None]:
# +default_exp -to argument_parsing -use_scope

## Logging

In [None]:
# +export -internal
arg_parse_REPORT_ERROR  :bool = True
arg_parse_REPORT_WARNING:bool = True
arg_parse_RAISE_ERROR  :bool  = False
arg_parse_RAISE_WARNING:bool  = False
arg_parse_SILENT:bool = False

In [None]:
# +export
def set_arg_parse_report_options(report_error:bool=True, report_warning:bool=True,
                                 raise_error:bool=False, raise_warning:bool=False,
                                 silent=False):
    "Set options for how the Argument Parsing Module will behave on encountering errors or warnings.\n"\
    "Raise causes an exception to be raised, and it supersedes report.\n"\
    "Report prints the information and then continues. If raise is set, then this setting is ignored."\
    "Silent overwrites all other settings and causes all errors and warnings to be ignored."\
    "The priority is thus: silent > raise > report"
    global arg_parse_REPORT_ERROR, arg_parse_REPORT_WARNING
    global arg_parse_RAISE_ERROR, arg_parse_RAISE_WARNING
    global arg_parse_SILENT
    arg_parse_REPORT_ERROR, arg_parse_REPORT_WARNING = report_error, report_warning
    arg_parse_RAISE_ERROR , arg_parse_RAISE_WARNING  = raise_error , raise_warning
    arg_parse_SILENT = (silent or not (report_error and report_warning and raise_error and raise_warning))

In [None]:
# +export -internal
def report_error(err:Exception):
    if   arg_parse_SILENT: pass
    elif arg_parse_RAISE_ERROR : raise err
    elif arg_parse_REPORT_ERROR: print(f'[{err.__class__.__name__}]: {err}')

In [None]:
# +export -internal
def report_warning(warn:str):
    if   arg_parse_SILENT: pass
    elif arg_parse_RAISE_WARNING : raise Warning(warn)
    elif arg_parse_REPORT_WARNING: print(f'[Warning]: {warn}')

## Next Argument

This is just a fancy way of advancing the cursor and checking for out of bounds.

In [None]:
# +export -internal
def get_next_argument(args:list, name:str, cursor:int, suppress_error:bool=False) -> (bool, int, str):
    "Gets the next argument from the list.\nReturns success, the cursor, and the next argument"
    cursor_1 = cursor + 1
    try: return True, cursor_1, args[cursor_1]
    except IndexError:
        if not suppress_error:
            report_error(SyntaxError(f"End of arguments reached. Missing a value for argument '{name}' at position {cursor_1}"))
        return False, cursor, ''

### Examples

In [None]:
get_next_argument(['a', 'b', 'c'], 'b', 1)

(True, 2, 'c')

In [None]:
get_next_argument(['a', 'b', 'c'], 'c', 2)

[SyntaxError]: End of arguments reached. Missing a value for argument 'c' at position 3


(False, 2, '')

In [None]:
get_next_argument(['a', 'b', 'c'], 'c', 2, suppress_error=True)

(False, 2, '')

## Type conversion

The input to Argument Parsing is just a string, so values have to be converted based on the information provided by the caller.  These function help to do that in a safe way.

In [None]:
# +export -internal
def to_integer(value:str) -> (bool, int, float):
    "Try converting a str to int.\nReturn success, the value, and possibly a float remainder."
    try:
        f_value = float(value)
        int_value = int(f_value)
        remainder = f_value - int_value
    except: return False, value, None
    return True, int_value, remainder

In [None]:
to_integer('-2.1'), to_integer('nice')

((True, -2, -0.10000000000000009), (False, 'nice', None))

In [None]:
# +export -internal
def to_float(value:str) -> (bool, float):
    "Try converting a str to float.\nReturn success, and the value."
    # TODO: check if 'inf', 'nan', ...?
    try   : return True , float(value)
    except: return False, value

In [None]:
to_float('-1e-3'), to_float('nan'), to_float('nice')

((True, -0.001), (True, nan), (False, 'nice'))

In [None]:
# +export -internal
def to_bool(value:str) -> (bool, bool):
    """Try converting a str to bool.
    'True' and 'False' are recognized, otherwise the value is cast to float, and then to bool.
    Return success, and the value."""
    if value == 'True' : return True, True
    if value == 'False': return True, False
    try   : return True , bool(float(value))
    except: return False, value

In [None]:
to_bool('1'), to_bool('0'), to_bool('True'), to_bool('False'), to_bool('abc')

((True, True), (True, False), (True, True), (True, False), (False, 'abc'))

In [None]:
# +export -internal
def to_unbounded_array(args:list, cursor:int) -> (bool, int, list):
    """Consume any number of values until either reaching the end of args,
    or until finding a value starting with '-', denoting the beginning of a new argument.
    Return success, the cursor, and the list of values.
    Currently this can't actually fail... don't use unbounded lists kids."""
    values = []
    while True:
        string_success, cursor, value = get_next_argument(args, None, cursor, suppress_error=True)
        if string_success:
            if value[0] != '-': values.append(value)
            else: # value starting with '-' means it's the next command
                cursor -= 1
                break
        else: break
    return True, cursor, values

In [None]:
to_unbounded_array(['-list', '1', '2', '-3'], 0)

(True, 2, ['1', '2'])

In [None]:
# +export -internal
def typify(type_or_value:object) -> (type, object):
    """Takes a type or a value.
    Returns a tuple of the type (or type of the value) and value (or None)"""
    return (type_or_value, None) if isinstance(type_or_value, type) else (type(type_or_value), type_or_value)

In [None]:
typify((int, int)*2)

(tuple, (int, int, int, int))

## Parsing

In [None]:
# +export
def parse_arguments(command:dict, comment:str) -> (bool, dict, dict):
    "Finds, casts, and returns values from command, in the given comment."    
    members = command.keys()
    result  = command.copy() # copy needed?
    args    = comment.split()
    # TODO: check that the type of all commands is supported ahead of time?
    # TODO: handle quoted arguments?
    
    is_set = {member : False for member in members}
    
    state = {'args': args, 'name': '', 'cursor': 0,
             'inside_array': False,}
    
    success = True
    while state['cursor'] < len(args): # for arg in args:
        arg = args[state['cursor']]
        if arg[0] != '-':
            report_error(SyntaxError(f"Argument {state['cursor']} does not start with a '-'."))
            return False, result, is_set
        arg = arg[1:] # remove '-'
        state['name'] = arg # TODO: check that len(arg) > 0?
        
        for key in members: # loop over keys of command (the things we're supposed to find)
            if key != arg: continue    
            if is_set[key]: # TODO: improve error msg. maybe: "this is the second time this argument was given"?
                report_error(SyntaxError(f"Argument {state['cursor']} ('{arg}') was given multiple times."))
                success = False
            else:
                arg_type, arg_default = typify(command[key])
                member_success = handle_one_argument(result, state, arg_type, arg_default)
                if member_success: is_set[key] = True
                else: success = False
            break # once we have found the correct struct member, stop!
        else: # TODO: improve this msg. maybe: "is not part of the command"?
            report_error(SyntaxError(f"Argument {state['cursor']} ('{arg}') is not valid."))
            success = False
        if not success: break # stop at first error
        state['cursor'] += 1
        
    if success: success = check_is_set(result, is_set)
    return success, result, is_set

In [None]:
# +export -internal
def handle_one_argument(result:dict, state:dict, arg_type:type, arg_default:object) -> bool:
    "Parse the input args based on arg_type, and set arg_name in result to that value."
    # NOTE: state and result are modified from here and essentially treated as pointers
    args     = state['args']
    arg_name = state['name']
    success  = True
    if arg_type == str:
        # get the next argument, advance cursor, set success
        string_success, state['cursor'], value = get_next_argument(args, arg_name, state['cursor'])
        # TODO: how to handle strings that start with a '-'
        if string_success: result[arg_name] = value
        else: success = False

    elif arg_type == bool:
        if state['inside_array']:
            string_success, state['cursor'], value = get_next_argument(args, arg_name, state['cursor'])
            if string_success:
                bool_success, value = to_bool(value)
                if bool_success: result[arg_name] = value
                else:
                    report_error(ValueError(f"Value of argument {state['cursor']-1} ('{arg_name}') \
                    was not convertable to bool. Please use 'True', 'False', '0', or '1'. (It was '{value}')"))
                    success = False
            else: success = False
        # special case where supplying the argument means True and not supplying it means use the default (False)
        else: result[arg_name] = True

    elif arg_type == int:
        # get the next argument, cast to int, check for remainder, advance cursor, set success
        string_success, state['cursor'], value = get_next_argument(args, arg_name, state['cursor'])
        if not string_success: return False
        int_success, value, remainder = to_integer(value)
        if int_success:
            result[arg_name] = value
            if remainder:
                report_warning(f"Junk on the end of the value for int argument \
                               {state['cursor']-1} ('{arg_name}'): {remainder}")
        else:
            report_error(ValueError(f"Value of argument {state['cursor']-1} ('{arg_name}') \
                                    was not an int. (It was '{value}')"))
            success = False

    elif arg_type == float:
        # get the next argument, cast to float, advance cursor, set success
        string_success, state['cursor'], value = get_next_argument(args, arg_name, state['cursor'])
        if not string_success: return False
        float_success, value = to_float(value)
        if float_success: result[arg_name] = value
        else:
            report_error(ValueError(f"Value of argument {state['cursor']-1} ('{arg_name}') \
                                    was not a float. (It was '{value}')"))
            success = False

    elif arg_type == list or arg_type == tuple:
        if arg_default is None: # unbounded list / tuple
            if state['inside_array']:
                report_error(SyntaxError(f"Using an unbounded list or tuple inside an array is not supported."))
                return False
            
            array_success, state['cursor'], value = to_unbounded_array(args, state['cursor'])
            if array_success: # NOTE: currently this can't actually fail... don't use unbounded lists kids.
                result[arg_name] = arg_type(value)
            else: success = False
            
        else: # predefined list
            s = {'args': args, 'name': 'v', 'cursor': state['cursor'],
                 'inside_array': True}
            value = []
            for i, x in enumerate(arg_default):
                t, d = typify(x)
                n = f'{arg_name}[{i}]'
                s['name'] = n
                r = {n:d}
                member_success = handle_one_argument(r, s, t, d)
                if member_success: value.append(r[n])
                else: # TODO: Improve error message
                    # report_error(SyntaxError(f"Array argument {state['cursor']} ('{arg_name}') was not passed correctly."))
                    return False
            state['cursor'] = s['cursor']
            result[arg_name] = arg_type(value)

    else:
        report_error(TypeError(f"Argument {state['cursor']} ('{arg_name}') is of unsupported type {arg_type}."))
        success = False
        
    return success

In [None]:
# +export -internal
def check_is_set(result:dict, is_set:dict) -> bool:
    "Check if any required values (those without defaults), haven't been set yet"
    success = True
    for member, v_is_set in is_set.items():
        if v_is_set: continue
        arg_type, arg_default = typify(result[member])
        if arg_default is None: 
            if arg_type == bool: # NOTE: Special case, not setting a boolean means it's False.
                result[member] = False # TODO: set is_set as well? what's the use-case here?
                continue
            report_error(ValueError(f"Argument '{member}' has not been set, and no default value was given."))
            success = False
        elif (arg_type == list) or (arg_type == tuple): # this is a bounded list
            name = [f'{member}[{i}]' for i in range(len(arg_default))]
            r = {n:x for n, x in zip(name, arg_default)}
            s = {n:False for n in r}
            is_set_success = check_is_set(r, s)
            if is_set_success: # re-set result
                result[member] = arg_type([r[n] for n in name])
                continue
            else: success = False
    return success

## Documentation

This argument parser is largely inspired by these two videos by Jonathan Blow.
>[Part 1](https://youtu.be/TwqXTf7VfZk)  
>[Part 2](https://youtu.be/pgiVrhsGkKY)

This module besically provides only one function:  
```python
def parse_arguments(command:dict, comment:str) -> (bool, dict, dict)
```  

It takes one __"command" dictionary__, and a __"comment" string__.  

#### __The command__

is a simple key-value collection of expected flags, where a attribute name maps to either a type, or a default value, from which the type is infered.  
```python
command = {
    'arg1':bool,
    'arg2':str,
    'arg3':32,
    'arg4':3.14,
}
```

#### __The comment__
is just a list of space-separated arguments, with words starting with a minus (`'-'`) denoting a keyword, and anything without a minus as the first character being a value to the previous keyword.  
```python
'-name bob -age 99 -celsius 30.5 -thirsty'
```  
is a valid string for the command  
```python
{
    'name'   : str,
    'weather': 'sunny',
    'celsius': float,
    'age'    : int,
    'thirsty': bool,
    'tired'  : bool
}
```

#### __The primitive types:__
Currently the following primitive types are supported:  
- `str`
    - a `str` argument requires one value.
    - e.g.: `-weather sunny`
- `bool`
    - a `bool` argument requires no values. setting the flag automatically sets the value to `True`.
    - writing `bool` is the same as using the default value `False`.
    - e.g.: `-is_wet`
- `int`
    - a `int` argument requires one value.
    - the value will first be cast to `float`, and then to `int`, partly due to how python works, and also to check for a remainder in case the provided value was actually in a float format.
    - e.g.: `-age 99`, `-negative -1`
- `float`
    - a `float` argument requires one value.
    - the value has to be castable to `float`. what is and what isn't a float can be suprising, so you should check the [casting rules](https://stackoverflow.com/a/20929983/) beforehand.
    - e.g.: `-pi 3.14`, `-negative -1.0`, `-weird nan`, `-large inf`, `-small -inf`
  
Any of these types can be declared either by just using the `type` directly, or by giving a default value of the specific `type`. All arguments that use the `type` directly have to be passed in the comment. If a default value is specified, or if the `type` is `bool`, the argument does not have to be passed in the comment, and instead the `result` will simply contain the default value. This changes with composite types (see below). If an argument was passed in the comment or not, can be seen by looking at the `is_set` return value (see below).

  
##### __The composite types__
`list` and `tuple` (referred to as 'array' when it can be either one of them) are also supported, however due to pythons lack of strong typing, they have slightly different semantics.  

Specifying only the type `list` or `tuple`, will result in an 'unbounded array' of that type, meaning that all values following the keyword will be added to the array, until either the end of arguments is reached, or a value starts with a minus (`'-'`), which denotes the start of the next argument. All values or the array will be of type `str`. This kind of argument should be used with caution, because, for instance, negative values will be treated as the start of a new argument.  
```python
{
    'unbounded_list' : list,
    'unbounded_tuple': tuple,
}
```  

The other, better way to use arrays is to actually create an array containing the types, default values, and ordering you want the values to have. This can get arbitrarily complex, mixing and matching any supported primitive type you want. The only thing not allowed, is using an unbounded array (see above).  
All values will be cast to the corresponding type using all the same semantics as of they were single values (see above). The only exception to that is the `bool` type, where the value has to be either `'True'`, `'False'`, or interpretable as a `float`, which will then be cast to a `bool`. This means that e.g. `'0.0'` will result in `False`, and `'123'` will result in `'True'` (careful, check the [casting rules](https://docs.python.org/3.3/library/stdtypes.html?highlight=frozenset#truth-value-testing) first).
```python
{
    'arg1': [int]*5,
    'arg2': (3.14, 'pi', bool),
    'arg3': (bool, str, 123)*2,
    'arg4': [[0]*3, [1]*3, [str]*3],
    'arg5': [str, int, bool, True, [1, '2', 3, bool], (2.1, float)]
}
```

#### __The return value__
is a three-tuple of `(success, result, is_set)`.  
- `success` is a `bool`, saying whether or not parsing was successful. If it is `False`, the other two arguments are not guaranteed to be valid. There will be an error message with details on what happened to help debugging.  
- `result` is a `dict` with exactly the same keys as the input `command`, with the corresponding values set to whatever was extracted from the comment. In cases where `success` if `False`, this might only be partially filled out, so `success` should always be checked.
- `is_set` is a `dict`, which also contains exactly the same keys as the input `command`, this time mapping to a `bool`. It is `True` if `comment` contains a value for the particular argument, and `False` otherwise. In cases where a default value is given in `command`, the same rule applies. Meaning that only if the default was overwritten by an argument in `comment` will the `is_set` value be `True`. This holds even for `bool`s, which default to `False` even if no explicit default was given.

## Examples

In [None]:
command = {
    'test'  : bool,
    'sunny' : False,
    'toast' : str,
    'shots' : int,
    'scale' : float,
    'scoops': [str, int, bool, [1, 2, 3, bool], (float, float)],
    # 'valid' : (bool, bool),
    'valid' : (1, 1.23, bool, 'hi', [1, 2]),
    'nah'   : 'boi',
    'sweet' : bool,
    'nr'    : int,
    'list'  : list
}

comment = '-sunny -toast jelly -shots 25 -scale 69105.1234 -test -list 2 -scoops a 1 0 5 6 7 False 3.0 2.1 -nr 21'
# comment = '-sunny -toast jelly -shots 25 -scale 69105.1234 -test -nr 1'
parse_arguments(command, comment)

(True,
 {'test': True,
  'sunny': True,
  'toast': 'jelly',
  'shots': 25,
  'scale': 69105.1234,
  'scoops': ['a', 1, False, [5, 6, 7, False], (3.0, 2.1)],
  'valid': (1, 1.23, False, 'hi', [1, 2]),
  'nah': 'boi',
  'sweet': False,
  'nr': 21,
  'list': ['2']},
 {'test': True,
  'sunny': True,
  'toast': True,
  'shots': True,
  'scale': True,
  'scoops': True,
  'valid': False,
  'nah': False,
  'sweet': False,
  'nr': True,
  'list': True})

In [None]:
# %timeit parse_arguments(command, comment)

# Main [Module]

In [None]:
# +default_exp -to main

In [None]:
THIS_FILE = '00_export_v4.ipynb'

In [None]:
# +export
from collections import namedtuple, defaultdict
import os
import re
from nbdev_rewrite.imports import *

from inspect import signature, currentframe

import functools
from types import MethodType,FunctionType

import ast
from ast import iter_fields, AST
import _ast

In [None]:
# +export
# Only import if executing as a python file, because then argument_parsing is in a different file.
if (__name__ != '__main__') or ('parse_arguments' not in globals()):
    from nbdev_rewrite.argument_parsing import *
    assert 'parse_arguments' in globals(), "Missing the 'parse_arguments' function after import."

In [None]:
# _all_ = ['parse_arguments', 'set_arg_parse_report_options']

## Logging

This is a class for passing along contextual information during execution.  
The class is a linked list, which can be extended each time a new function is called.  
Everytime a function is called, create a new StackTrace instance, and pass the current instance to it.

In [None]:
# +export -internal
main_REPORT_OPTIONAL_ERROR:bool = False

In [None]:
# +export
def set_main_report_options(report_optional_error:bool=False):
    "Set options for how the Main Module will behave on encountering errors or warnings.\n"\
    "report_optional_error prints the information and then continues."
    global main_REPORT_OPTIONAL_ERROR
    main_REPORT_OPTIONAL_ERROR = report_optional_error

In [None]:
# +export
class StackTrace: pass # only for :StackTrace annotations to work
class StackTrace:
    up:StackTrace = None
    namespace:str = None
    lineno   :int = None
    extern:bool = False
    file:str    = None
    cellno:int  = None
    excerpt:str = None
    span:(int, int) = None
        
    def __init__(self, namespace:object, up:StackTrace=None):
        self.namespace = namespace.__qualname__ if namespace else None
        self.up = up
        self.lineno = currentframe().f_back.f_lineno
        
    @classmethod
    def ext(cls, file:str, cellno:int=None, lineno:int=None, excerpt:str=None, up:StackTrace=None):
        st = cls(None, up=up)
        st.extern = True
        st.file = file
        st.cellno = cellno
        st.lineno = lineno
        st.excerpt = excerpt
        return st
        
    def __repr__(self):
        ln = self.lineno
        if self.extern:
            s = f"{'' if self.up is None else self.up.__repr__()}"\
                f"\n<{self.file}>, cell {self.cellno}, line {ln}\n"
            if self.excerpt:
                x = f"--->{' ' if ((ln is None) or (0 <= ln <= 9)) else ''}{ln} "
                s += f"{x}{self.excerpt}\n"\
                     f"{(' ' * (len(x) + self.span[0]) + '^' * self.span[1]) if self.span else ''}"
            return s
        else: # the default
            return f"{'' if self.up is None else self.up.__repr__()}"\
                   f"<{__name__}>, line {ln} in <{self.namespace}>\n"
    
    def report_error(self, err:Exception, lineno=None, excerpt=None, span:(int, int)=None):
        if lineno: self.lineno = lineno
        if excerpt: self.excerpt = excerpt
        self.span = span
        err_type = err.__class__.__name__
        s = f"{'-'*75}\n"\
            f"{err_type}{' '*(41-len(err_type))}Stacktrace (most recent call last)\n"\
            f"{self.__repr__()}\n"\
            f"[{err_type}]: {err}"
        print(s)
    
    def report_optional_error(self, err:Exception, lineno=None, excerpt=None, span:(int, int)=None):
        if main_REPORT_OPTIONAL_ERROR:
            self.report_error(err=err, lineno=lineno, span=span)

### Example Code

In [None]:
def _part(st):
    success = True
    st.report_error(Exception('Failed doing the thing'))
    success = False
    return success

In [None]:
def _start():
    st = StackTrace(_start)
    success = True
    success_part = _part(StackTrace(_part, up=st))
    if not success_part:
        success = False
        return 0
    return success

In [None]:
_start()

---------------------------------------------------------------------------
Exception                                Stacktrace (most recent call last)
<__main__>, line 2 in <_start>
<__main__>, line 4 in <_part>

[Exception]: Failed doing the thing


0

In [None]:
_st=StackTrace(namespace=_start)
_st=StackTrace(_part, up=_st)
_st=StackTrace.ext(file='file.py', cellno=5, lineno=45, excerpt='# weird comment', up=_st)
_st.cellno=18
_st.excerpt = '# another weird comment'
_st.report_error(SyntaxError('Failed to parse'), lineno=None, span=(2, 7))

---------------------------------------------------------------------------
SyntaxError                              Stacktrace (most recent call last)
<__main__>, line 1 in <_start>
<__main__>, line 2 in <_part>

<file.py>, cell 18, line 45
--->45 # another weird comment
         ^^^^^^^
[SyntaxError]: Failed to parse


### Reference of Python Tracebacks

## Find and Parse Comments

### Finding comments in source code

In [None]:
# +export
# TODO: Only look for 0 indent comments?
def iter_comments(src:str, pure_comments_only:bool=True, line_limit:int=None) -> (str, (int, int)):
    "Detect all comments in a piece of code, excluding those that are a part of a string."
    in_lstr = in_sstr = False
    count, quote = 1, ''
    for i, line in enumerate(src.splitlines()[:line_limit]):
        is_pure, escape, prev_c = True, False, '\n'
        for j, c in enumerate(line):
            # we can't break as soon as not is_pure, because we have to detect if a multiline string beginns
            if is_pure and (not (c.isspace() or c == '#')): is_pure = False
            if (in_sstr or in_lstr):
                # assert in_sstr ^ in_lstr # XOR
                if escape: count = 0
                else:
                    if (c == quote):
                        count = ((count + 1) if (c == prev_c) else 1)
                        if in_sstr: in_sstr = False
                        elif (in_lstr and (count == 3)): count, in_lstr = 0, False
                escape = False if escape else (c == '\\')
            else:                    
                if (c == '#'):
                    if (pure_comments_only and is_pure): yield (line, (i, j))
                    elif (not pure_comments_only):       yield (line[j:], (i, j))
                    break
                elif c == "'" or c == '"':
                    count = ((count + 1) if (c == prev_c) else 1)
                    if count == 1: in_sstr = True
                    elif count == 3: count, in_lstr = 0, True
                    else: assert False, 'If this code path happens, then the code keeping track of quotes is broken.'
                    quote = c
            prev_c = c

In [None]:
list(iter_comments('# this is a zero indented comment'))

[('# this is a zero indented comment', (0, 0))]

### Parsing

This regex is used to remove whitespace and the '#' of python comments.  
The content of the comment will be added to a group, which can be extracted afterwards.

In [None]:
# +export -internal
# https://docs.python.org/3/library/re.html
re_match_comment = re.compile(r"""
        ^              # start of the string
        \s?            # 0 or 1 whitespace
        \#+\s?         # 1 or more literal "#", then 0 or 1 whitespace
        (.*)           # group of arbitrary symbols (except new line)
        $              # end of the string
        """,re.IGNORECASE | re.VERBOSE) # re.MULTILINE is not passed, since this regex is used on each line separately.

In [None]:
re_match_comment.search('# hi')

<re.Match object; span=(0, 4), match='# hi'>

In [None]:
re_match_comment.search('a\n# hi')

In [None]:
re_match_comment.search('# # hi').groups()

('# hi',)

This specifies what a valid nbdev comment has to look like, and filters out everything whose syntax does not fit with any of the registered commands.

In [None]:
# +export
def parse_comment(all_commands:dict, comment:str, st:StackTrace) -> (bool, str, dict, dict):
    "Finds command names and arguments in comments and parses them with parse_arguments()"
    res = re_match_comment.search(comment)
    if not res:
        st.report_optional_error(SyntaxError('Not a valid comment syntax.'))
        return False, None, None, None
    
    all_args = res.groups()[0].split()
    if len(all_args) == 0:
        st.report_optional_error(SyntaxError(f"Need at least one argument in comment. Reveived: '{comment}'"))
        return False, None, None, None
    
    cmd, *args = all_args
    if cmd[0] != '+':
        st.report_optional_error(SyntaxError("The first argument (the command to execute) does not start with a '+'."\
                                            f"It was: '{cmd}'"), span=(1, 3))
        return False, None, None, None
    
    cmd = cmd[1:] # remove the '+'
    if cmd not in all_commands:
        st.report_optional_error(KeyError(f"'{cmd}' is not a recognized command. See 'all_commands'."))
        return False, None, None, None
    
    success, result, is_set = parse_arguments(all_commands[cmd], ' '.join(args))
    if not success: return False, None, None, None
    
    return True, cmd, result, is_set

### Examples

In [None]:
kw_default_exp = {'scope': 'file' , 'to': str}
kw_export      = {'internal': bool, 'to': ''}

all_commands   = {'default_exp': kw_default_exp, 'export': kw_export}

In [None]:
parse_arguments(all_commands['export'], '-internal -to file.py')

(True, {'internal': True, 'to': 'file.py'}, {'internal': True, 'to': True})

In [None]:
parse_comment(all_commands, '# +export -internal -to file.py', st=StackTrace(None))

(True,
 'export',
 {'internal': True, 'to': 'file.py'},
 {'internal': True, 'to': True})

## Find function, class and variable Names in Source Code

https://docs.python.org/3/library/ast.html

This code is using pythons builtin `ast` module to parse source code into an abstract syntax tree, from which the set of all variable-, function-, and classnames is extracted.  
All names found, that are not private (prefixed with a single underscore), are added to a set to get rid of duplicate names.  
It also seperately parses the nbdev-reserved special variable name `_all_` and adds all assignments to it to the set.  

Some special cases (like fastai specific python extensions) are also handled here, although this will probably change in the future.

### debug help

In [None]:
# +export -internal
class Context:
    def __init__(self, cell_nr=None, export_nr=None):
        self.cell_nr   = cell_nr
        self.export_nr = export_nr
    def __repr__(self):
        return f'cell_nr: {self.cell_nr}, export_nr: {self.export_nr}'

In [None]:
# +export -internal
def lineno(node):
    "Format a string containing location information on ast nodes. Used for Debugging only."
    if hasattr(node, 'lineno') and hasattr(node, 'col_offset'):
        return f'line_nr: {node.lineno} col_offset: {node.col_offset}'
    else: return ''

In [None]:
# +export -internal
def info(context, node):
    "Format a string with available information on a ast node. Used for Debugging only."
    return f'\nLocation: {context} | {lineno(node)}'

### Parsing

In [None]:
# +export -internal
def unwrap_attr(node:_ast.Attribute) -> str:
    "Joins a sequance of Attribute accesses together in a single string. e.g. numpy.array"
    if isinstance(node.value, _ast.Attribute): return '.'.join((unwrap_attr(node.value), node.attr))
    else: return '.'.join((node.value.id, node.attr))

In [None]:
# +export -internal
def update_from_all_(node, names, c):
    "inplace, recursive update of set of names, by parsing the right side of a _all_ variable"
    if   isinstance(node, _ast.Str): names.add(node.s)
    elif isinstance(node, _ast.Name): names.add(node.id)
    elif isinstance(node, _ast.Attribute): names.add(unwrap_attr(node))
    elif isinstance(node, (_ast.List, _ast.Tuple, _ast.Set)):
        for x in node.elts: update_from_all_(x, names, c)
    elif isinstance(node, _ast.Subscript) :
        raise SyntaxError(f'Subscript expression not allowed in _all_. {info(c, node)}')
    elif isinstance(node, _ast.Starred):
        raise SyntaxError(f'Starred expression *{node.value.id} not allowed in _all_. {info(c, node)}')
    else: raise SyntaxError(f'Can\'t resolve {node} to name, unknown type. {info(c, node)}')

In [None]:
# +export -internal
def unwrap_assign(node, names, c):
    "inplace, recursive update of list of names"
    if   isinstance(node, _ast.Name)      : names.append(node.id)
    elif isinstance(node, _ast.Starred)   : names.append(node.value.id)
    elif isinstance(node, _ast.Attribute) : names.append(unwrap_attr(node))
    elif isinstance(node, _ast.Subscript) : pass # e.g. a[0] = 1
    elif isinstance(node, (_ast.List, _ast.Tuple)):
        for x in node.elts: unwrap_assign(x, names, c)
    elif isinstance(node, list):
        for x in node: unwrap_assign(x, names, c)
    else: raise SyntaxError(f'Can\'t resolve {node} to name, unknown type. {info(c, node)}')

In [None]:
# +export -internal
def not_private(name): return not (name.startswith('_') and (not name.startswith('__')))

In [None]:
# +export -internal
def add_names_A(node, names, c):
    "Handle Assignments to variables"
    tmp_names = list()
    if   isinstance(node, _ast.Assign):
        unwrap_assign(node.targets, tmp_names, c)
    elif isinstance(node, _ast.AnnAssign):
        unwrap_assign(node.target, tmp_names, c)
    else: assert False, 'add_names_A only accepts _ast.Assign or _ast.AnnAssign'
    for name in tmp_names:
        if not_private(name): names.add(name)
        # NOTE: special cases below can only use private variable names
        elif name == '_all_': # NOTE: _all_ is a keyword reserved by nbdev.
            if len(tmp_names) != 1:
                raise SyntaxError(f'Reserved keyword "_all_" can only be used in simple assignments. {info(c, node)}')
            update_from_all_(node.value, names, c)

In [None]:
# +export -internal
def decorators(node):
    yield from [(d.id if isinstance(d, _ast.Name) else d.func.id) for d in node.decorator_list]

def fastai_patch(cls, node, names, c):
    if   isinstance(cls, _ast.Name):
        if not_private(cls.id): names.add(f'{cls.id}.{node.name}')
    elif isinstance(cls, (_ast.List, _ast.Tuple, _ast.Set)):
            for x in cls.elts: fastai_patch(x, node, names, c)
    else: raise SyntaxError(f'Can\'t resolve {cls} to @patch annotation, unknown type. {info(c, node)}')

# ignoring `@typedispatch` might not even be neccesarry,
# since all names are added to a single set before being exported.
def add_names_FC(node, names, c, fastai_decorators=True):
    "Handle Function and Class Definitions"
    if fastai_decorators and ('patch' in decorators(node)):
        if not (len(node.args.args) >= 1): raise SyntaxError(f'fastai\'s @patch decorator requires at least one parameter. {info(c, node)}')
        cls = node.args.args[0].annotation
        if cls is None: raise SyntaxError(f'fastai\'s @patch decorator requires a type annotation on the first parameter. {info(c, node)}')
        fastai_patch(cls, node, names, c)
    elif fastai_decorators and ('typedispatch' in decorators(node)): return # ignore @typedispatch
    elif not_private(node.name): names.add(node.name)

In [None]:
# +export
def find_names(code:str, context:Context=None) -> list:
    "Find all function, class and variable names in the given source code."
    tree = ast.parse(code)
    names = set()
    for node in tree.body:
        if   isinstance(node, (_ast.Assign     , _ast.AnnAssign)): add_names_A (node, names, context)
        elif isinstance(node, (_ast.FunctionDef, _ast.ClassDef )): add_names_FC(node, names, context)
        else: pass
    return names

### Examples

In [None]:
find_names('x = 1')

{'x'}

## Relativify import statements in output file

This part is responsible for transforming import statements.  
It only affects 'from' imports of the library the project belongs to.  
So if the project library is called "my_library", then `from my_library import *` might be transformed into `from . import *` in the output file.  
The relative path is generated in such a way that it will be a valid import from the file the code is exported to.

The "normal" `import module` statement does not allow relative module names, so it can not be translated from an absolute version in the notebook to a relative one in the output file.  
Similarly, using a relative module name in the notebook in a `from .module import ...` statement does not work due to the interactive nature of the notebook environment.  
Those two cases are not supported for automatic translation since they would require a very hacky solution, which can not be guaranteed to be always correct.

In [None]:
# +export -internal
def make_import_relative(p_from:Path, m_to:str)->str:
    "Convert a module `m_to` to a name relative to `p_from`."
    mods = m_to.split('.')
    splits = str(p_from).split(os.path.sep)
    if mods[0] not in splits: return m_to
    i=len(splits)-1
    while i>0 and splits[i] != mods[0]: i-=1
    splits = splits[i:]
    while len(mods)>0 and splits[0] == mods[0]: splits,mods = splits[1:],mods[1:]
    return '.' * len(splits) + '.'.join(mods)

In [None]:
n1, n2, n3, n4, n5 = 'nbdev.core', 'nbdev.core', 'nbdev.vision.transform', 'nbdev.notebook.core', 'nbdev.vision'
p1, p2, p3 = Path('./nbdev/data.py').absolute(), Path('./nbdev/vision/data.py'), Path('./nbdev/vision/data.py')
p4, p5     = Path('./nbdev/data/external.py'), Path('./nbdev/vision/learner.py')

In [None]:
test_eq(make_import_relative(p1, n1),'.core')
test_eq(make_import_relative(p2, n2),'..core')
test_eq(make_import_relative(p3, n3),'.transform')
test_eq(make_import_relative(p4, n4),'..notebook.core')
test_eq(make_import_relative(p5, n5),'.')

In [None]:
# +export -internal
# https://docs.python.org/3/library/re.html
letter = 'a-zA-Z'
identifier = f'[{letter}_][{letter}0-9_]*'
re_import = ReLibName(fr"""
    ^                             # start of the string / line
    (\ *)                         # any amount of whitespace (indenting)
    from(\ +)                     # 'from', followed by at least one whitespace
    (LIB_NAME(?:\.{identifier})*) # Name of the library, possibly followed by dot separated submodules
    \ +import(.+)                 # whitespace, then 'import', followed by arbitrary symbols except new line
    $                             # end of the string / line
    """, re.VERBOSE | re.MULTILINE)

In [None]:
# +export
def relativify_imports(origin:Path, code:str)->str:
    "Transform an absolute 'from LIB_NAME import module' into a relative import of 'module' wrt the library."
    def repl(match):
        sp1,sp2,module,names = match.groups()
        return f'{sp1}from{sp2}{make_import_relative(origin, m_to=module)} import{names}'
    return re_import.re.sub(repl,code)

In [None]:
print(relativify_imports(Path('./nbdev_rewrite/submodule/data.py'),"""
import numpy as np, matplotlib.pyplot, moduleaaaabbb as mod
import nbdev_rewrite.vision
# Nothing to see here
from   nbdev_rewrite.abc   import array as arr, linalg.solve, module as mod
def function():
    "from nbdev_rewrite import *"
    pass
from     nbdev_rewrite import (abs, b as c, h,) # sure
from nbdev_rewrite import *
from nbdev_rewrite.core import* # ok
    from . import *
from nbdev_rewrite  import(
    abs
                  as a
    , # this is weird, but legal
                       absolute 
    as 
                  f
                  )"""))


import numpy as np, matplotlib.pyplot, moduleaaaabbb as mod
import nbdev_rewrite.vision
# Nothing to see here
from   ..abc import array as arr, linalg.solve, module as mod
def function():
    "from nbdev_rewrite import *"
    pass
from     .. import (abs, b as c, h,) # sure
from .. import *
from ..core import* # ok
    from . import *
from .. import(
    abs
                  as a
    , # this is weird, but legal
                       absolute 
    as 
                  f
                  )


## File I/O and Exporting

### Project Initialization

In [None]:
# +export
def init_config(lib_name='nbdev_rewrite', user='flpeters', nbs_path='.'):
    "create a config file, if it doesn't already exist"
    if not Config().config_file.exists(): create_config(lib_name=lib_name, user=user, nbs_path=nbs_path)
init_config()

In [None]:
# +export
def init_lib():
    "initialize the module folder, if it's not initialized already"
    C = Config()
    if (not C.lib_path.exists()) or (not (C.lib_path/'__init__.py').exists()):
        C.lib_path.mkdir(parents=True, exist_ok=True)
        with (C.lib_path/'__init__.py').open('w') as f:
            f.write(f'__version__ = "{C.version}"\n')
    else: pass # module *should* already exists
init_lib()

### File Loading

In [None]:
# +export
_reserved_dirs = (Config().lib_path, Config().nbs_path, Config().doc_path)
def crawl_directory(path:Path, recurse:bool=True) -> list:
    "finds a list of ipynb files to convert"
    # TODO: Handle symlinks?
    if isinstance(path, (list, tuple)):
        for p in path: yield from crawl_directory(p, recurse)
    elif path.is_file(): yield path
    else:
        for p in path.iterdir():
            f = p.name
            if f.startswith('.') or f.startswith('_'): continue
            if p.is_file():
                if f.endswith('.ipynb'): yield p
                else: continue
            elif p.is_dir() and recurse:
                if p in _reserved_dirs: continue
                else: yield from crawl_directory(p, recurse)
            else: continue
list(crawl_directory(Config().nbs_path))

[WindowsPath('//DESKTOP-MDPTPCT/Projects/GitHub/nbdev_rewrite/00_export_v4.ipynb'),
 WindowsPath('//DESKTOP-MDPTPCT/Projects/GitHub/nbdev_rewrite/99_index.ipynb')]

In [None]:
# +export
def read_nb(fname:Path) -> dict:
    "Read the notebook in `fname`."
    with open(Path(fname),'r', encoding='utf8') as f: return dict(nbformat.reads(f.read(), as_version=4))

In [None]:
len(read_nb(THIS_FILE)['cells'])

198

In [None]:
# +export
@prefetch(max_prefetch=4)
def file_generator(path:Path=Config().nbs_path) -> (Path, dict):
    for file_path in crawl_directory(path): yield (file_path, read_nb(file_path))

In [None]:
# [len(x[1]['cells']) for x in file_generator()]

### Export Path Parsing

#### identify modules

Here we use pattern matching to identify valid module names.

In [None]:
# +export -internal
# https://docs.python.org/3/library/re.html
letter = 'a-zA-Z'
identifier = f'[{letter}_][{letter}0-9_]*'
module = fr'(?:{identifier}\.)*{identifier}'
module

'(?:[a-zA-Z_][a-zA-Z0-9_]*\\.)*[a-zA-Z_][a-zA-Z0-9_]*'

In [None]:
# +export -internal
# https://docs.python.org/3/library/re.html
re_match_module = re.compile(fr"""
        ^              # start of the string
        {module}       # definition for matching a module 
        $              # end of the string
        """, re.VERBOSE)

In [None]:
re_match_module.search('module.main.test')

<re.Match object; span=(0, 16), match='module.main.test'>

In [None]:
# +export
def module_to_path(m:str)->Path:
    "Turn a module name into a path such that the exported file can be imported from the library "\
    "using the same expression."
    if re_match_module.search(m) is not None:
        if m.endswith('.py'):
            raise ValueError(f"The module name '{m}' is not valid, because ending on '.py' "\
                             f"would produce a file called 'py.py' in the folder '{m.split('.')[-2]}', "\
                              "which is most likely not what was intended.\nTo name a file 'py.py', use the "\
                              "'-to_path' argument instead of '-to'.")
        return Config().path_to('lib')/f"{os.path.sep.join(m.split('.'))}.py"
    else: raise ValueError(f"'{m}' is not a valid module name.")

In [None]:
module_to_path('module.sub.file')

WindowsPath('//DESKTOP-MDPTPCT/Projects/GitHub/nbdev_rewrite/nbdev_rewrite/module/sub/file.py')

In [None]:
module_to_path('main')
module_to_path('main.main')

WindowsPath('//DESKTOP-MDPTPCT/Projects/GitHub/nbdev_rewrite/nbdev_rewrite/main/main.py')

These functions might come in handy late on

In [None]:
# ??importlib.util._resolve_name

In [None]:
# importlib.util.resolve_name??

In [None]:
importlib.util.resolve_name('..export', 'module.test')

'module.export'

#### identify paths

When the user explicitly passes a path, then this code is tasked with checking it for correctness and converting it to an absolute path from the perspective of the library path.

In [None]:
# +export -internal
def commonpath(*paths)->Path:
    "Given a sequence of path names, returns the longest common sub-path."
    return Path(os.path.commonpath(paths))

In [None]:
commonpath(Path('c:/abc/fgh/a'), Path('c:/abc/fgh/b'))

WindowsPath('c:/abc/fgh')

In [None]:
# +export -internal
def in_directory(p:Path, d:Path)->bool:
    "Tests if `p` is pointing to something in the directory `d`.\n"\
    "Expects both `p` and `d` to be fully resolved and absolute paths."
    return p.as_posix().startswith(d.as_posix())

In [None]:
def in_directory_slow_1(p, d)->bool:
    try: p.relative_to(d)
    except: return False
    else: return True
def in_directory_slow_2(p, d)->bool:
    return len(commonpath(p, d).parts) >= len(d.parts)

In [None]:
in_directory(p=Path('C:/abc/fgh/abc.txt'), d=Path('C:/abc/fgh/'))

True

In [None]:
# +export
def make_valid_path(s:str)->Path:
    "Turn a export path argument into a valid path, resolving relative paths and checking for mistakes."
    p, lib = Path(s), Config().path_to('lib')
    is_abs = p.is_absolute()
    p = (p if is_abs else (lib/p)).absolute().resolve()
    if (not is_abs) and (not in_directory(p, lib)):
        raise ValueError("Relative export path beyond top level directory of library is not allowed by default. "\
                        f"Use an absolute path, or set <NOT IMPLEMENTED YET> flag on the command. ('{s}')")
    if not p.suffix: raise ValueError(f"The path '{s}' is missing a file type suffix like '.py'.")
    if p.suffix == '.py': return p
    else: raise ValueError(f"'{p.suffix}' is not a valid file ending. ('{s}')")

In [None]:
make_valid_path(Path('./module/../hi.py'))

WindowsPath('//DESKTOP-MDPTPCT/Projects/GitHub/nbdev_rewrite/nbdev_rewrite/hi.py')

In [None]:
make_valid_path('main.py')
make_valid_path('./main.py')
make_valid_path('../../nbdev_rewrite/nbdev_rewrite/main.py')
make_valid_path('d:/main.py')
make_valid_path('main/main.py')
make_valid_path('../nbdev_rewrite/main.py')

WindowsPath('//DESKTOP-MDPTPCT/Projects/GitHub/nbdev_rewrite/nbdev_rewrite/main.py')

## Notes

## Main

### Register Commands

`@register_command` stores argument information about the registered function in the global variables `all_commands`, and a reference to the function in `cmd2func`.

In [None]:
# +export
def register_command(cmd, args, active=True):
    "Store mapping from command name to args, and command name to reference to the decorated function in globals."
    if not active: return lambda f: f
    all_commands[cmd] = args
    def _reg(f):
        cmd2func[cmd] = f
        return f
    return _reg

In [None]:
# +export
all_commands = {}
cmd2func     = {}

In [None]:
# +export
@register_command(cmd='default_exp', # allow custom scope name that can be referenced in export?
                  args={'to': '', 'to_path': '', 'use_scope': False})
def kw_default_exp(file_info, cell_info, result, is_set):
    "Set the default file that cells of this notebook will be exported to."
    if not (is_set['to'] ^ is_set['to_path']): # NOTE: XOR
        raise ValueError("The `default_exp` command expects exactly one of the arguments "\
                         f"'-to' or '-to_path' to be set, but recieved was: {result}")
    # NOTE: use this cells indentation level, or the default tuple([0]) as key to identify scope
    scope:tuple     = cell_info['scope'] if result['use_scope'] else tuple([0])
    old_target:Path = file_info['export_scopes'].get(scope, None)
    new_target:Path = (module_to_path(result['to'])
                       if is_set['to'] else
                       make_valid_path(result['to_path']))
    if old_target is not None:
        raise ValueError(f"Overwriting an existing export target is not allowed. (cell nr. {cell_info['cell_nr']})"\
                        f"\n\t\t->(was: '{old_target}', new: '{new_target}')")
    file_info['export_scopes'][scope] = new_target

In [None]:
# +export
@register_command(cmd='export',
                  args={'internal': False, 'to': '', 'to_path':'', 'ignore_scope':False,
                        'cell_nr': 0, 'prepend': False, 'append': False})
def kw_export(file_info, cell_info, result, is_set):
    "This cell will be exported from the notebook to a .py file."
    if (is_set['to'] and is_set['to_path']):
        raise ValueError("The `export` command does not accept the '-to' and '-to_path' argument at the same time. "\
                         f"They are mutually exclusive. Recieved: {result}")
    cell_info['export_to_py'] = True # Using this command implicitly means to export this cell
    if is_set['cell_nr']: cell_info['cell_nr'] = result['cell_nr'] # overwrite the cell_nr of this cell
    is_internal = cell_info['is_internal'] = result['internal']
    if is_internal: pass # no contained names will be added to __all__ for importing
    else: cell_info['names'] = find_names(cell_info['original_source_code'])
    export_target:Path = None
    if is_set['to'     ]: export_target = module_to_path (result['to'])
    if is_set['to_path']: export_target = make_valid_path(result['to_path'])
    if export_target is not None:
        if is_set['ignore_scope']:
            raise ValueError("Setting 'ignore_scope' is not allowed when exporting to a custom target "\
                            f"using 'to' or 'to_path'. (cell nr. {cell_info['cell_nr']})")
        cell_info['export_to'].append(export_target) # Set a new export target just for this cell.
    else:
        if result['ignore_scope']: cell_info['export_to_default'] += 1
        else:                      cell_info['export_to_scope']   += 1
    
    # TODO: support setting append or prepend
#     append, prepend = result['append'], result['prepend']
#     if append : cls.to[targ].append(cell)
#     if prepend: cls.to[targ].prepend(cell)
#     if (append and prepend):
#         report_warning(f'Cell nr. {cell.cell_nr} is being appended AND prepended to the output file.')
#     else: cls.to[targ].add(cell)

In [None]:
@register_command(cmd='set',
                  args={'file': '', 'use_names': True},
                  active=False)
def kw_set(file_info, cell_info, result, is_set):
    "set some predefined variables that control execution behaviour"
    pass

#### Documentation

Command: `default_exp`  
Set the default file that cells of this notebook will be exported to.  
Args:
- `to`: The target file name. 
- `scope`: Set a scope for which this default value is valid. The default is 'file' level. Smaller scopes overwrite larger ones. Other options are: 'heading' (Not Implemented Yet)

Command: `export`  
This cell will be exported from the notebook to a .py file.  
Args:  
- `internal`: The variable, function and class names of this cell will not be added to `__all__` in the exported file, making them hidden from any `import *`.
- `to`: Instead of exporting to the notebook or scope wide default file, this cell is exported to the file specified in this argument.
- `cell_nr`: Overwrite the cell_nr of this cell. every cell has this number, based on it's position in the notebook file. Overwriting it has the effect of repositioning this cell in the output .py file, since cells are sorted by cell_nr.
- `prepend`: Every file has three "buckets" that cells can be added to. The 'before', 'normal', and 'after' Bucket. Setting `prepend` to `True`, means this cell will be added to the 'before' Bucket. Cells in the 'before' Bucket will appear before all cells in both the 'normal' and 'after' Bucket in the output .py file.
- `append`: Setting `append` to `True`, means this cell will be added to the 'after' Bucket. Cells in the 'after' Bucket will appear after all cells in both the 'before' and 'normal' Bucket in the output .py file. Setting neither `prepend` nor `append`, means this cell will be added to the 'normal' Bucket. Use cases for these two arguments might be sending all imports in a notebook to the top of the .py file, or helping with correctly ordering cells exported to a different file (e.g. using the `to` arg).

Command: `set`  
Set some predefined variables that control execution behaviour.  
Args:  
- `file`: If this is set, the variables will only be set on this specific file.
- `use_names`: Control whether or not a `__all__` with all (non internal) variable, function and class names should be inserted at the top of the file. Default is `True`.

### Do

changelog:  


TODO:
- Rethink the structure of data used to represent the programm state.
    - To export a cell multiple times a list of some form is necessary
    - To give better error messages a mapping from commands back to the cells that use those commands might be helpful. Such a mapping would also allow cells to be added multiple times in order to export them more than once.
- Use the StackTrace class for error reporting in registered commands, when interacting with other files.
- Support Automatic / Explicit Versioning
- Add better debugging information. e.g. default_exp should show the previous occurence in case it is defined multiple times.
- Implement code appending / prepending when exporting to non-default files
- improve error messages in name parsing

In [None]:
nb = read_nb(THIS_FILE)

### Parsing

In [None]:
# +export -internal
# https://docs.python.org/3/library/re.html
re_match_heading = re.compile(r"""
        ^              # start of the string
        (\#+)\s+       # 1 or more literal "#", then 1 or more whitespace
        (.*)           # group of arbitrary symbols (including new line)
        $              # end of the string
        """,re.IGNORECASE | re.VERBOSE | re.DOTALL)

In [None]:
res = re_match_heading.search('## test')
res.groups()

('##', 'test')

In [None]:
Config().config_file.parent

WindowsPath('//DESKTOP-MDPTPCT/Projects/GitHub/nbdev_rewrite')

In [None]:
# +export
def parse_file(file_path:Path, file:dict, st:StackTrace) -> (bool, dict):
    success = True
    pure_comments_only = True
    nb_version:(int, int) = (file['nbformat'], file['nbformat_minor'])
    metadata  :dict       =  file['metadata']
        
    file_info = {
        'origin_file': file_path,
        'relative_origin': os.path.relpath(file_path, Config().config_file.parent).replace('\\', '/'),
        'nb_version': nb_version,
        'export_scopes': {
            tuple([0]): None, # This is the default for an entire file.
        },
        'cells': list()
    }
    scope_count :[int] = [0]
    scope_level :int   = 0
    
    cells:list = file_info['cells']
    
    f_pc_st = StackTrace.ext(file=file_info['relative_origin'],
                             up=StackTrace(parse_comment, up=st))
    
    for i, cell in enumerate(file['cells']):
        cell_type   = cell['cell_type']
        cell_source = cell['source']
        cell_info = {
            'cell_nr' : i,
            'cell_type' : cell_type,
            'original_source_code' : cell_source,
            'processed_source_code': cell_source,
            'scope' : tuple(scope_count),
            'export_to_py' : False,
            'export_to_scope' : 0,
            'export_to_default' : 0,
            'is_internal' : None,
            'export_to' : [],
            'names' : None,
            'comments' : []
        }
        if cell_type == 'code':
            f_pc_st.cellno = i
            comments_to_remove = []
            for comment, (lineno, charno) in iter_comments(cell_source, pure_comments_only, line_limit=None):
                f_pc_st.lineno = lineno
                f_pc_st.excerpt = comment
                parsing_success, cmd, result, is_set = parse_comment(all_commands, comment, st=f_pc_st)
                if not parsing_success: continue
                print(f'Found: {cmd} @ ({i}, {lineno}, {charno}) with args: {result}')
                if cmd in cmd2func: cmd2func[cmd](file_info, cell_info, result, is_set)
                else: raise ValueError(f"The command '{cmd}' in cell number {i} is recognized, "\
                                        "but is missing a corresponding action mapping in cmd2func.")
                cell_info['comments'].append(comment)
                comments_to_remove.append((lineno, charno))
            if len(comments_to_remove) > 0:
                lines = cell_source.splitlines()
                if pure_comments_only:
                    for lineno, charno in comments_to_remove[::-1]: lines.pop(lineno)
                else:
                    for lineno, charno in comments_to_remove[::-1]: lines[lineno] = lines[lineno][:charno]
                cell_info['processed_source_code'] = '\n'.join(lines)
            
        elif cell_type == 'markdown':
            res = re_match_heading.search(cell_source)
            if not (res is None): # this cell contains a heading
                heading_level, heading_name = res.groups()
                new_scope_level = len(heading_level) # number of '#' in the heading
                if new_scope_level > scope_level:
                    scope_count += ([0] * (new_scope_level - (len(scope_count)))) # extend list if necessary
                elif new_scope_level < scope_level:
                    scope_count = scope_count[:new_scope_level] # reset lower values
                scope_count[new_scope_level - 1] += 1
                scope_level = new_scope_level
            else: pass # this cell is regular markdown
        elif cell_type == 'raw': pass
        else: raise ValueError(f"Unknown cell_type '{cell_type}' in cell number {i}."\
                                "Should be 'code', 'markdown', or 'raw'.")
        cells.append(cell_info)
    return success, file_info

In [None]:
# +export
def load_and_parse_all(origin_path:Path, output_path:Path, recurse:bool, st:StackTrace) -> (bool, dict):
    "Loads all .ipynb files in the origin_path directory, and passes them one at a time to parse_file."
    # TODO: replace these two lines with a call to file_generator() defined above.
    file_paths:list = crawl_directory(Config().nbs_path)
    
    # TODO: fine tune, or even pass an argument from the user on how many thread to use for prefetching files.
    #       num_cpus() from nbdev.imports can be used here
    file_generator = BackgroundGenerator(((file_path, read_nb(file_path)) for file_path in file_paths), max_prefetch=4)
    
    parsed_files = {
        # Add flags and settings variables above this line
        'files': list()
    }
    
    # TODO: use multithreading / multiprocessing per file / per bunch of cells
    for file_path, file in file_generator:
        # if file_path.name != THIS_FILE: continue # For Debugging
        success, file = parse_file(file_path, file, st=StackTrace(parse_file, up=st))
        # TODO: try parsing all the files, even if one fails?
        if not success:
            st.report_error(Exception(f'Error while parsing {file_path}'))
            return 0, None
        # TODO: before returning, give any meta programm a chance to run.
        # maybe have parse_file return some additional information about any meta programm
        parsed_files['files'].append(file)
        
    return True, parsed_files

### Writing out

In [None]:
def stringify_names(names:set, sep='\n\n\n')->str:
    start, part = "__all__ = [", ''
    for name in sorted(names):
        if len(part) + len(name) < 80:
            part = f"{part}'{name}', "
        else:
            start += (part + '\n')
            part = f"           '{name}', "
    return f'{sep}{start}{part[:-2]}]'

In [None]:
def stringify_names_2(names):
    return f"\n\n\n__all__ = ['{', '.join(sorted(names))}']"

In [None]:
test_data = ['abc'] * 1000

In [None]:
# %timeit stringify_names(test_data)

In [None]:
# %timeit stringify_names_2(test_data)

In [None]:
# +export
def write_file(to:Path, orig:str, names:set, code:list, st:StackTrace) -> bool:
    sep:str = '\n\n\n'
    if orig is None:
        warning = f'# AUTOGENERATED! DO NOT EDIT! View info comment on each cell for file to edit.'
    else:
        warning = f'# AUTOGENERATED! DO NOT EDIT! File to edit: {orig} (unless otherwise specified).'
    if len(names) > 0:
        # TODO: add line breaks at regular intervals
        comma = "', '"
        names:str = f"{sep}__all__ = ['{comma.join(sorted(names))}']"
    else: names:str = f'{sep}__all__ = []'
    code :str = sep + sep.join(code)
    file_content:str = f'{warning}{names}{code}'
    # print('-'*70)
    # print(to)
    # print(file_content)
    to.parent.mkdir(parents=True, exist_ok=True)
    with open(to, 'w', encoding='utf8') as f: f.write(file_content)

In [None]:
# +export
def write_out_all(parsed_files, st:StackTrace) -> bool:
    # TODO: write one file at a time to disk, to the correct directory,
    # initialize a python module, if it doesn't already exists,
    # Handle mergers between multiple parsed_files. <-----------------
    config    = Config()
    lib_path  = config.lib_path
    nbs_path  = config.nbs_path
    proj_path = config.config_file.parent
    zero_tuple = tuple([0])
    
    export_files = defaultdict(lambda: {'names': set(), 'code': [], 'orig': None})
    
    for file_info in parsed_files['files']:
        rel_orig:str = file_info['relative_origin']
        scopes:dict  = file_info['export_scopes']
        assert zero_tuple in scopes, 'No default in export Scopes.'
        scopes_available:bool = (len(scopes) > 1)
        default_export:Path = scopes[zero_tuple]
        # NOTE: Having no default is ok, as long as all cells still have a valid export target
        none_default  :bool = (default_export is None)
            
        if not none_default:
            default_state = export_files[default_export]
            if (default_state['orig'] is None): default_state['orig'] = rel_orig
            else: raise ValueError(f'Multiple files have {default_export} as the default export target. '\
                                   f'(old: {default_state["orig"]} | new: {rel_orig})')
                
        for cell in file_info['cells']:
            if not cell['export_to_py']: continue
            info_string = f"# {'Internal ' if cell['is_internal'] else ''}Cell nr. {cell['cell_nr']}"
            info_string_src = f"{info_string}; Comes from '{rel_orig}'"
            
            if len(cell['export_to']) > 0:
                for to in cell['export_to']:
                    state:dict = export_files[to]
                    if not cell['is_internal']: state['names'].update(cell['names'])
                    # TODO: implement code appending / prepending here
                    state['code'].append(f"{info_string_src}\n{relativify_imports(to, cell['processed_source_code'])}")
            
            if scopes_available:
                if cell['export_to_scope'] > 0:
                    # Do scope matching
                    cell_scope:tuple = cell['scope']
                    best_fit = zero_tuple
                    best_fit_len = 0
                    # NOTE: The number of scopes should usually be relatively small
                    for k in scopes.keys(): # TODO: can this go faster with sorting, binary search, quit early?
                        if ((len(k) > best_fit_len) # Trying to find the tightest fit
                            and (k == cell_scope[:len(k)])): # iff cell is part of this scope
                            best_fit, best_fit_len = k, len(k)
                    to:Path = scopes[best_fit]
                    if (best_fit == zero_tuple) or (to == default_export):
                        cell['export_to_default'] += cell['export_to_scope']
                        pass
                    else:
                        state:dict = export_files[to]
                        if not cell['is_internal']: state['names'].update(cell['names'])
                        for _ in range(cell['export_to_scope']):
                            # TODO: implement code appending / prepending here
                            state['code'].append(f"{info_string_src}\n{relativify_imports(to, cell['processed_source_code'])}")
            else: cell['export_to_default'] += cell['export_to_scope']
                
            if cell['export_to_default'] > 0:
                if none_default:
                    raise ValueError(f'Export Target of cell {cell["cell_nr"]} is None. '\
                                     'Did you forget to add a default target using `default_exp`?')
                to = default_export
                state:dict = export_files[to]
                if not cell['is_internal']: state['names'].update(cell['names'])
                for _ in range(cell['export_to_default']):
                    # TODO: implement code appending / prepending here
                    state['code'].append(f"{info_string}\n{relativify_imports(to, cell['processed_source_code'])}")
        # NOTE: Files can't be written at this point, since there might be other notebooks exporting to the same file.
    
    # print(dict(export_files))
    for to, state in export_files.items():
        write_file(to=to, orig=state['orig'], names=state['names'], code=state['code'],
                   st=StackTrace(write_file, up=st))
    return 1

### Main()

In [None]:
# +export
def main(origin_path:str=None, output_path:str=None, recurse:bool=True) -> bool:
    st = StackTrace(main)
    origin_path:Path = Config().nbs_path if origin_path is None else Path(origin_path).resolve()
    output_path:Path = Config().lib_path if output_path is None else Path(output_path).resolve()
    
    success, parsed_files = load_and_parse_all(origin_path, output_path, recurse,
                                               st=StackTrace(load_and_parse_all, up=st))
    if not success:
        return 0, None
    # NOTE: At this point all files are completely parsed, and any meta programm has run.
    
    success = write_out_all(parsed_files, st=StackTrace(write_out_all, st))
    return success, parsed_files

## Run

In [None]:
# +export
set_arg_parse_report_options(report_error=False)
set_main_report_options()

In [None]:
success, parsed_files = main();

Found: default_exp @ (1, 0, 0) with args: {'to': 'argument_parsing', 'to_path': '', 'use_scope': True}
Found: export @ (3, 0, 0) with args: {'internal': True, 'to': '', 'to_path': '', 'ignore_scope': False, 'cell_nr': 0, 'prepend': False, 'append': False}
Found: export @ (4, 0, 0) with args: {'internal': False, 'to': '', 'to_path': '', 'ignore_scope': False, 'cell_nr': 0, 'prepend': False, 'append': False}
Found: export @ (5, 0, 0) with args: {'internal': True, 'to': '', 'to_path': '', 'ignore_scope': False, 'cell_nr': 0, 'prepend': False, 'append': False}
Found: export @ (6, 0, 0) with args: {'internal': True, 'to': '', 'to_path': '', 'ignore_scope': False, 'cell_nr': 0, 'prepend': False, 'append': False}
Found: export @ (9, 0, 0) with args: {'internal': True, 'to': '', 'to_path': '', 'ignore_scope': False, 'cell_nr': 0, 'prepend': False, 'append': False}
Found: export @ (16, 0, 0) with args: {'internal': True, 'to': '', 'to_path': '', 'ignore_scope': False, 'cell_nr': 0, 'prepend': F

In [None]:
from nbdev_rewrite.main import *

In [None]:
success, parsed_files = main();

In [None]:
parsed_files['files'][0]['export_scopes']

{(0,): WindowsPath('//DESKTOP-MDPTPCT/Projects/GitHub/nbdev_rewrite/nbdev_rewrite/main.py'),
 (1,): WindowsPath('//DESKTOP-MDPTPCT/Projects/GitHub/nbdev_rewrite/nbdev_rewrite/argument_parsing.py')}

In [None]:
[c for c in parsed_files['files'][0]['cells'] if c['export_to_py']];

## Develop new Stuff

In [None]:
Config().lib_path == Config().path_to('lib_path') == Config().path_to('lib')

True

In [None]:
lib = Config().path_to('lib'); lib

WindowsPath('//DESKTOP-MDPTPCT/Projects/GitHub/nbdev_rewrite/nbdev_rewrite')

### regex for matching import statements

In [None]:
# https://docs.python.org/3/library/re.html
letter = 'a-zA-Z'
identifier = f'[{letter}_][{letter}0-9_]*'
module = fr'(?:{identifier}\.)*{identifier}'
module

'(?:[a-zA-Z_][a-zA-Z0-9_]*\\.)*[a-zA-Z_][a-zA-Z0-9_]*'

In [None]:
relative_module = fr'(?:\.*{module}|\.+)'
name = identifier

In [None]:
as_name  = fr'(?:\ +as\ +{name})'
as_name  = fr'{as_name}?(?:\ *,\ *{module}{as_name}?)*'

import_1 = fr'import\ +({module})({as_name})'

import_2 = fr'from\ +({relative_module})\ +import\ +({identifier}{as_name})'

as_name_s  = fr'(?:\s+as\s+{name})'
as_name_s  = fr'{as_name_s}?(?:\s*,\s*{module}{as_name_s}?)*'
import_3   = fr'from\ +({relative_module})\ +import\ *(\(\s*{identifier}{as_name_s}\s*,?\s*\))'

# NOTE: The docs say 'module', but in reality relative imports work as well.
import_4 = fr'from\ +({relative_module})\ +import\ *\*'

# NOTE: import_1 is not included, because it doesn't allow relative imports.
import_stmt = fr'(?:{import_2}|{import_3}|{import_4})'
import_stmt

'(?:from\\ +((?:\\.*(?:[a-zA-Z_][a-zA-Z0-9_]*\\.)*[a-zA-Z_][a-zA-Z0-9_]*|\\.+))\\ +import\\ +([a-zA-Z_][a-zA-Z0-9_]*(?:\\ +as\\ +[a-zA-Z_][a-zA-Z0-9_]*)?(?:\\ *,\\ *(?:[a-zA-Z_][a-zA-Z0-9_]*\\.)*[a-zA-Z_][a-zA-Z0-9_]*(?:\\ +as\\ +[a-zA-Z_][a-zA-Z0-9_]*)?)*)|from\\ +((?:\\.*(?:[a-zA-Z_][a-zA-Z0-9_]*\\.)*[a-zA-Z_][a-zA-Z0-9_]*|\\.+))\\ +import\\ *(\\(\\s*[a-zA-Z_][a-zA-Z0-9_]*(?:\\s+as\\s+[a-zA-Z_][a-zA-Z0-9_]*)?(?:\\s*,\\s*(?:[a-zA-Z_][a-zA-Z0-9_]*\\.)*[a-zA-Z_][a-zA-Z0-9_]*(?:\\s+as\\s+[a-zA-Z_][a-zA-Z0-9_]*)?)*\\s*,?\\s*\\))|from\\ +((?:\\.*(?:[a-zA-Z_][a-zA-Z0-9_]*\\.)*[a-zA-Z_][a-zA-Z0-9_]*|\\.+))\\ +import\\ *\\*)'

In [None]:
# https://docs.python.org/3/library/re.html
re_test = re.compile(fr"""
        ^              # start of the string
        (\ *)          # capturing group of any amount of whitespace (indenting)
        {import_stmt}  # definition for matching a module 
        \ *            # non-capturing whitespace
                       # TODO: match any remaining character in case of e.g. comments
        $              # end of the string
        """, re.VERBOSE | re.MULTILINE)

In [None]:
re_test.search('import numpy as np, matplotlib.pyplot, moduleaaaabbb as mod') # import_1

In [None]:
re_test.search('from numpy import array as arr, linalg.solve, module as mod').group() # import_2

'from numpy import array as arr, linalg.solve, module as mod'

In [None]:
re_test.search('from numpy import (abs, b as c, h,)').group() # import_3

'from numpy import (abs, b as c, h,)'

In [None]:
re_test.search('from numpy import *').group() # import_4
re_test.search('from . import *').group() # import_4

'from . import *'

In [None]:
_The_Name = 'numpy'
# import_stmt
def repl(match):
    print(match.groups())
    sp, n2, a2, n3, a3, n4 = match.groups()
    if n2:
        if n2 == _The_Name: return f'{sp}from <REL>{n2} import {a2}'
        else: return f'{sp}from {n2} import {a2}'
    elif n3:
        if n3 == _The_Name: f'{sp}from <REL>{n3} import {a3}'
        else: return f'{sp}from {n3} import {a3}'
    elif n4:
        if n4 == _The_Name: f'{sp}from <REL>{n4} import *'
        else: return f'{sp}from {n4} import *'

res = re_test.sub(repl, """
import numpy as np, matplotlib.pyplot, moduleaaaabbb as mod
# Nothing to see here
from numpy import array as arr, linalg.solve, module as mod
def function():
    pass
from numpy import (abs, b as c, h,)
from numpy import *
    from . import *
from numpy  import(
    abs
                  as a
    ,
                       absolute 
    as 
                  f
                  )""")
print(res)

('', 'numpy', 'array as arr, linalg.solve, module as mod', None, None, None)
('', None, None, 'numpy', '(abs, b as c, h,)', None)
('', None, None, None, None, 'numpy')
('    ', None, None, None, None, '.')
('', None, None, 'numpy', '(\n    abs\n                  as a\n    ,\n                       absolute \n    as \n                  f\n                  )', None)

import numpy as np, matplotlib.pyplot, moduleaaaabbb as mod
# Nothing to see here
from <REL>numpy import array as arr, linalg.solve, module as mod
def function():
    pass


    from . import *

