    import pre_commit, prenotebook

`prenotebook` is designed primarily to be a `pre_commit` hook.

There will likely be some tasty bits that can be reused.

In [None]:
import contextlib
import dataclasses
import enum
import io
import itertools
import json
import pathlib
import textwrap
import typing

import black
import click
import IPython
import nbformat
import ujson


In [None]:
    class OutOfOrder(BaseException): ...


In [None]:
    @click.command()
    @click.option("--sort-imports/--no-sort", default=True)
    @click.option("--line-width", default=100)
    @click.option("--keep-output/--output", default=False)
    @click.option("--keep-execution-count/--execution-count", default=False)
    @click.option("--keep-metadata/--metadata", default=False)
    @click.option("--keep-empty/--empty", default=False)
    @click.option("--verify-order/--dont-verify-order", default=True)
    @click.argument("filenames", nargs=-1)
    def main(filenames, sort_imports, line_width, keep_output, keep_execution_count, keep_metadata, keep_empty, verify_order, *, ret=0):
        for file in filenames:
            if pathlib.Path(file).suffix.endswith(".ipynb"):
                nb = Notebook(filename=file, line_width=line_width, sort_imports=sort_imports,
                        keep_execution_count=keep_execution_count, keep_metadata=keep_metadata, keep_empty=keep_empty,
                        keep_output=keep_output
                        )
                verify_order and nb.verify_order()
                nb.fix()
        return ret


In [None]:
    @dataclasses.dataclass
    class Notebook:
        filename: pathlib.Path
        line_width: int = 100
        sort_imports: bool = True
        keep_output: bool = False
        keep_execution_count: bool = False
        keep_metadata: bool = False
        keep_empty: bool= False
        black: callable = None
        node: nbformat.NotebookNode = dataclasses.field(
            default_factory=nbformat.NotebookNode, repr=False
        )

        def __post_init__(self):
            self.filename = pathlib.Path(self.filename)
            self.black = blacken = lambda x: black.format_str(x, mode=black.FileMode(line_length=self.line_width))
            with self.filename.open() as file:
                self.node.update(ujson.load(file))

        def strip_output(self, cell):
            if not self.keep_output and "outputs" in cell and cell.outputs:
                cell.outputs = []
            return cell

        def normalize_execution(self, cell):
            if not self.keep_execution_count and "execution_count" in cell and cell.execution_count:
                cell.execution_count = None
            return cell

        def blacken(self, cell):
            if not self.line_width and cell.cell_type == "code":            
                before = cell.source
                cell.source = self.black(("".join(cell.source)))
                if isinstance(before, list):
                    cell["source"] = cell["source"].splitlines(True)
            return cell

        def isort(self, cell):
            if self.sort_imports and cell["cell_type"] == "code":
                import isort

                before = cell["source"]
                try:
                    cell["source"] = isort.SortImports(file_contents="".join(before)).output
                    if isinstance(before, list):
                        cell["source"] = cell["source"].splitlines(True)
                    if cell["source"] != before:
                        changed = True
                except:
                    ...
            return cell

        def verify_order(self, min=-1):
            for cell in self.node.cells:
                if 'execution_count' in cell:
                    if cell['execution_count']: 
                        if cell['execution_count'] < min:
                            raise OutOfOrder(self.filename)
                        min = cell['execution_count']

        def fix(self, *args):
            changed = False
            for i, cell in enumerate(self.node.cells):
                cell = nbformat.from_dict(cell)
                object = cell.copy()
                for callable in (self.strip_output, self.normalize_execution, self.blacken, self.isort):
                    cell = callable(cell)
                if object != cell:
                    changed = True
                self.node.cells[i] = cell

            if not self.keep_empty:
                self.node.cells = [
                    cell
                    for cell in self.node.cells
                    if "".join(cell.source).strip() or getattr(cell, "outputs", None)
                ]
            with self.filename.open("w") as file:
                nbformat.write(self.node, file)
