In [None]:
# export
from local.imports import *
from local.notebook.core import *
from local.notebook.export import *
import nbformat,inspect
from nbformat.sign import NotebookNotary
from nbconvert.preprocessors import ExecutePreprocessor
from local.test import *
from local.core import *

In [None]:
# default_exp notebook.test

# Extracting tests from notebooks

> The functions that grab the cells containing tests (filtering with potential flags) and execute them

In [None]:
_re_all_flag = re.compile("""
# Matches any line with #all_something and catches that something in a group:
^         # beginning of line (since re.MULTILINE is passed)
\s*       # any number of whitespace
\#\s*     # # then any number of whitespace
all_(\S+) # all_ followed by a group with any non-whitespace chars
\s*       # any number of whitespace
$         # end of line (since re.MULTILINE is passed)
""", re.IGNORECASE | re.MULTILINE | re.VERBOSE)

In [None]:
# export
def check_all_flag(cells):
    for cell in cells:
        if check_re(cell, _re_all_flag): return check_re(cell, _re_all_flag).groups()[0]

In [None]:
nb = read_nb("35_tutorial_wikitext.ipynb")
test_eq(check_all_flag(nb['cells']), 'slow')
nb = read_nb("91_notebook_export.ipynb")
assert check_all_flag(nb['cells']) is None

In [None]:
_re_flags = re.compile("""
# Matches any line with a test flad and catches it in a group:
^               # beginning of line (since re.MULTILINE is passed)
\s*             # any number of whitespace
\#\s*           # # then any number of whitespace
(slow|cuda|cpp) # all test flags
\s*             # any number of whitespace
$               # end of line (since re.MULTILINE is passed)
""", re.IGNORECASE | re.MULTILINE | re.VERBOSE)

In [None]:
def get_cell_flags(cell):
    if cell['cell_type'] != 'code': return []
    return _re_flags.findall(cell['source'])

In [None]:
test_eq(get_cell_flags({'cell_type': 'code', 'source': "#hide\n# slow\n"}), ['slow'])
test_eq(get_cell_flags({'cell_type': 'code', 'source': "#hide\n# slow\n # cuda"}), ['slow', 'cuda'])
test_eq(get_cell_flags({'cell_type': 'markdown', 'source': "#hide\n# slow\n # cuda"}), [])
test_eq(get_cell_flags({'cell_type': 'code', 'source': "#hide\n"}), [])

In [None]:
# export
def _add_import_cell(mod):
    "Return an import cell for `mod`"
    return {'cell_type': 'code',
            'execution_count': None,
            'metadata': {'hide_input': True},
            'outputs': [],
            'source': f"\nfrom local.{mod} import *"}

In [None]:
_re_is_export = re.compile(r"""
# Matches any text with #export or #exports flag:
^         # beginning of line (since re.MULTILINE is passed)
\s*       # any number of whitespace
\#\s*     # # then any number of whitespace
exports?  # export or exports
\s*       # any number of whitespace
""", re.IGNORECASE | re.MULTILINE | re.VERBOSE)

In [None]:
_re_has_import = re.compile(r"""
# Matches any text with import statement:
^         # beginning of line (since re.MULTILINE is passed)
\s*       # any number of whitespace
import    # # then any number of whitespace
\s+  
|
\s*
from
\s+\S+\s+
import
\s+
""", re.IGNORECASE | re.MULTILINE | re.VERBOSE)

In [None]:
# export
class NoExportPreprocessor(ExecutePreprocessor):
    "An `ExecutePreprocessor` that executes not exported cells"
    @delegates(ExecutePreprocessor.__init__)
    def __init__(self, flags, **kwargs):
        self.flags = flags
        super().__init__(**kwargs)
        
    def preprocess_cell(self, cell, resources, index):
        if 'source' not in cell or cell['cell_type'] != "code": return cell, resources
        if _re_is_export.search(cell['source']) and not _re_has_import.search(cell['source']): 
            return cell, resources
        for f in get_cell_flags(cell):
            if f not in self.flags:  return cell, resources
        print(cell["source"])
        return super().preprocess_cell(cell, resources, index)

In [None]:
def _test_nb(nb, flags=None, mod=None, name=None):
    "Execute `nb` (or only the `show_doc` cells) with `metadata`"
    mod = find_default_export(nb['cells'])
    if mod is not None: nb['cells'].insert(0, _add_import_cell(mod))
    ep = NoExportPreprocessor(L(flags), timeout=600, kernel_name='python3')
    pnb = nbformat.from_dict(nb)
    ep.preprocess(pnb)

In [None]:
nb = read_nb("18_callback_fp16.ipynb")
_test_nb(nb, flags="cuda")


from local.callback.fp16 import *
#export
from local.torch_basics import *
from local.test import *
from local.layers import *
from local.data.all import *
from local.notebook.showdoc import show_doc
from local.optimizer import *
from local.learner import *
from local.callback.progress import *
#default_exp callback.fp16
#hide
from local.utils.test import *
# export 
from local.utils.fp16_utils import convert_network, model_grads_to_master_grads, master_params_to_model_params
model = nn.Sequential(nn.Linear(10,30), nn.BatchNorm1d(30), nn.Linear(30,2)).cuda()
model = convert_network(model, torch.float16)

for i,t in enumerate([torch.float16, torch.float32, torch.float16]):
    test_eq(model[i].weight.dtype, t)
    test_eq(model[i].bias.dtype,   t)
    
model = nn.Sequential(nn.Linear(10,30), BatchNorm(30, ndim=1), nn.Linear(30,2)).cuda()
model = convert_network(model, torch.float16)

for i,t in enumerate([torch.float16, torch.float32, torch.float16]):
    test_eq(model[i].weight.dtype,

In [None]:
os.environ.