In [None]:
with __import__("importnb").Notebook():
    try:
        from . import schema
    except:
        import schema
import contextlib
import dataclasses
import enum
import io
import itertools
import json
import pathlib
import textwrap
import typing

import black
import click
import IPython
import nbformat
import pydantic
import ujson

import __main__


In [None]:
@click.command()
@click.argument("filenames", nargs=-1)
def main(filenames):
    for file in filenames:
        if pathlib.Path(file).suffix.endswith(".ipynb"):
            Notebook(file).fix()
    return 0


In [None]:
translate = IPython.core.inputsplitter.IPythonInputSplitter()


In [None]:
@dataclasses.dataclass
class Notebook:
    filename: pathlib.Path
    node: nbformat.NotebookNode = dataclasses.field(
        default_factory=nbformat.NotebookNode, repr=False
    )

    def __post_init__(self):
        self.filename = pathlib.Path(self.filename)
        with self.filename.open() as file:
            self.node.update(ujson.load(file))

    def strip_output(self, cell):
        if "outputs" in cell and cell.outputs:
            cell.outputs = []
        return cell

    def normalize_execution(self, cell):
        if "execution_count" in cell and cell.execution_count:
            cell.execution_count = None
        return cell

    def normalize(self, cell):
        before = cell.source
        cell.source = translate.transform_cell(("".join(cell.source)))
        if isinstance(before, list):
            cell.source = cell.source.splitlines(True)
        return cell

    def blacken(self, cell):
        if cell.cell_type == "code":
            blacken = lambda x: black.format_str(x, mode=black.FileMode(line_length=88))
            before = cell.source
            cell.source = blacken(("".join(cell.source)))
            if isinstance(before, list):
                cell["source"] = cell["source"].splitlines(True)
        return cell

    def isort(self, cell):
        if cell["cell_type"] == "code":
            import isort

            before = cell["source"]
            try:
                cell["source"] = isort.SortImports(file_contents="".join(before)).output
                if isinstance(before, list):
                    cell["source"] = cell["source"].splitlines(True)
                if cell["source"] != before:
                    changed = True
            except:
                ...
        return cell

    def fix(self, *args):
        changed = False
        for i, cell in enumerate(self.node.cells):
            cell = nbformat.from_dict(cell)
            object = cell.copy()
            for callable in (
                self.strip_output,
                self.normalize_execution,
                self.normalize,
                self.blacken,
                self.isort,
            ):
                cell = callable(cell)
            if object != cell:
                changed = True
            self.node.cells[i] = cell

        self.node.cells = [
            cell
            for cell in self.node.cells
            if "".join(cell.source).strip() or getattr(cell, "outputs", None)
        ]
        with self.filename.open("w") as file:
            ujson.dump(self.node, file)
