diff --git a/mdformat/_cli.py b/mdformat/_cli.py index db9a8dcc..89b63b40 100644 --- a/mdformat/_cli.py +++ b/mdformat/_cli.py @@ -53,7 +53,8 @@ def run(cli_args: Sequence[str]) -> int: # noqa: C901 if not is_md_equal( original_str, formatted_str, - ignore_codeclasses=enabled_codeformatter_langs, + enabled_extensions=enabled_parserplugins, + enabled_codeformatters=enabled_codeformatter_langs, ): sys.stderr.write( f'Error: Could not format "{path_str}"\n' diff --git a/mdformat/_util.py b/mdformat/_util.py index cb900075..893ad683 100644 --- a/mdformat/_util.py +++ b/mdformat/_util.py @@ -1,21 +1,80 @@ -import re -from typing import Iterable +from html.parser import HTMLParser +from typing import Iterable, List, Optional, Set from markdown_it import MarkdownIt +import mdformat.plugins -def is_md_equal(md1: str, md2: str, *, ignore_codeclasses: Iterable[str] = ()) -> bool: + +def is_md_equal( + md1: str, + md2: str, + *, + enabled_extensions: Iterable[str] = (), + enabled_codeformatters: Iterable[str] = (), +) -> bool: """Check if two Markdown produce the same HTML. - Renders HTML from both Markdown strings, strips whitespace and - checks equality. Note that this is not a perfect solution, as there - can be meaningful whitespace in HTML, e.g. in a block. + Renders HTML from both Markdown strings, strip content of tags with + specified classes, and checks equality of the generated ASTs. """ - html1 = MarkdownIt().render(md1) - html2 = MarkdownIt().render(md2) - html1 = re.sub(r"\s+", "", html1) - html2 = re.sub(r"\s+", "", html2) - for codeclass in ignore_codeclasses: - html1 = re.sub(rf'.*', "", html1) - html2 = re.sub(rf'.*', "", html2) + ignore_classes = [f"language-{lang}" for lang in enabled_codeformatters] + mdit = MarkdownIt() + for name in enabled_extensions: + plugin = mdformat.plugins.PARSER_EXTENSIONS[name] + plugin.update_mdit(mdit) + ignore_classes.extend(getattr(plugin, "ignore_classes", [])) + html1 = HTML2AST().parse(mdit.render(md1), ignore_classes) + html2 = HTML2AST().parse(mdit.render(md2), ignore_classes) + return html1 == html2 + + +class HTML2AST(HTMLParser): + """Parser HTML to AST.""" + + def parse(self, text: str, strip_classes: Iterable[str] = ()) -> List[dict]: + self.tree: List[dict] = [] + self.current: Optional[dict] = None + self.feed(text) + self.strip_classes(self.tree, set(strip_classes)) + return self.tree + + def strip_classes(self, tree: List[dict], classes: Set[str]) -> List[dict]: + """Strip content from tags with certain classes.""" + items = [] + for item in tree: + if set(item["attrs"].get("class", "").split()).intersection(classes): + items.append({"tag": item["tag"], "attrs": item["attrs"]}) + continue + items.append(item) + item["children"] = self.strip_classes(item.get("children", []), classes) + if not item["children"]: + item.pop("children") + + return items + + def handle_starttag(self, tag: str, attrs: list) -> None: + tag_item = {"tag": tag, "attrs": dict(attrs), "parent": self.current} + if self.current is None: + self.tree.append(tag_item) + else: + children = self.current.setdefault("children", []) + children.append(tag_item) + self.current = tag_item + + def handle_endtag(self, tag: str) -> None: + # walk up the tree to the tag's parent + while self.current is not None: + if self.current["tag"] == tag: + self.current = self.current.pop("parent") + break + self.current = self.current.pop("parent") + + def handle_data(self, data: str) -> None: + # ignore data outside tabs + if self.current is not None: + # ignore empty lines and trailing whitespace + self.current["data"] = [ + li.rstrip() for li in data.splitlines() if li.strip() + ] diff --git a/mdformat/plugins.py b/mdformat/plugins.py index 3e5092b1..1fb56e49 100644 --- a/mdformat/plugins.py +++ b/mdformat/plugins.py @@ -46,6 +46,11 @@ def render_token( """ return None + def ignore_classes(self) -> List[str]: + """Return CSS classes to ignore when comparing the input/output HTML + equality.""" + return [] + def _load_parser_extensions() -> Dict[str, ParserExtensionInterface]: parser_extension_entrypoints = importlib_metadata.entry_points().get( diff --git a/mdformat/renderer/_token_renderers.py b/mdformat/renderer/_token_renderers.py index fe0d2e2f..36b0d12b 100644 --- a/mdformat/renderer/_token_renderers.py +++ b/mdformat/renderer/_token_renderers.py @@ -1,5 +1,6 @@ """A namespace for functions that render the Markdown of tokens from markdown- it-py.""" +import logging from typing import List, Optional from markdown_it.token import Token @@ -12,6 +13,8 @@ longest_consecutive_sequence, ) +LOGGER = logging.getLogger(__name__) + def default(tokens: List[Token], idx: int, options: dict, env: dict) -> str: """Default formatter for tokens that don't have one implemented.""" @@ -101,10 +104,10 @@ def fence(tokens: List[Token], idx: int, options: dict, env: dict) -> str: fmt_func = options["codeformatters"][lang] try: code_block = fmt_func(code_block, info_str) - except Exception: + except Exception as err: # Swallow exceptions so that formatter errors (e.g. due to # invalid code) do not crash mdformat. - pass + LOGGER.warning(f"Code formatting of '{lang}' failed: {err}") # The code block must not include as long or longer sequence of `fence_char`s # as the fence string itself diff --git a/tests/test_cli.py b/tests/test_cli.py index ea8c0842..9cdede3e 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -2,6 +2,7 @@ import sys from mdformat._cli import run +from mdformat.plugins import CODEFORMATTERS UNFORMATTED_MARKDOWN = "\n\n# A header\n\n" FORMATTED_MARKDOWN = "# A header\n" @@ -63,6 +64,18 @@ def test_check__multi_fail(capsys, tmp_path): assert str(file_path2) in captured.err +def example_formatter(code, info): + return "dummy\n" + + +def test_formatter_plugin(tmp_path, monkeypatch): + monkeypatch.setitem(CODEFORMATTERS, "lang", example_formatter) + file_path = tmp_path / "test_markdown.md" + file_path.write_text("```lang\nother\n```\n") + assert run((str(file_path),)) == 0 + assert file_path.read_text() == "```lang\ndummy\n```\n" + + def test_dash_stdin(capsys, monkeypatch): monkeypatch.setattr(sys, "stdin", StringIO(UNFORMATTED_MARKDOWN)) run(("-",)) diff --git a/tests/test_plugins.py b/tests/test_plugins.py index f17572d2..0226b180 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -7,10 +7,35 @@ import yaml import mdformat -from mdformat.plugins import PARSER_EXTENSIONS +from mdformat.plugins import CODEFORMATTERS, PARSER_EXTENSIONS from mdformat.renderer import MARKERS, MDRenderer +def example_formatter(code, info): + return "dummy\n" + + +def test_code_formatter(monkeypatch): + monkeypatch.setitem(CODEFORMATTERS, "lang", example_formatter) + text = mdformat.text( + dedent( + """\ + ```lang + a + ``` + """ + ), + codeformatters={"lang"}, + ) + assert text == dedent( + """\ + ```lang + dummy + ``` + """ + ) + + class ExampleFrontMatterPlugin: """A class for extending the base parser.""" diff --git a/tests/test_util.py b/tests/test_util.py new file mode 100644 index 00000000..ed058a39 --- /dev/null +++ b/tests/test_util.py @@ -0,0 +1,56 @@ +from mdformat._util import HTML2AST + + +def test_html2ast(): + data = HTML2AST().parse('

aj

b') + assert data == [ + { + "tag": "div", + "attrs": {}, + "children": [ + { + "tag": "p", + "attrs": {"class": "x"}, + "data": ["a"], + "children": [{"tag": "s", "attrs": {}, "data": ["j"]}], + } + ], + }, + {"tag": "a", "attrs": {}, "data": ["b"]}, + ] + + +def test_html2ast_multiline(): + data = HTML2AST().parse("
a\nb \nc \n\n
") + assert data == [{"tag": "div", "attrs": {}, "data": ["a", "b", "c"]}] + + +def test_html2ast_nested(): + data = HTML2AST().parse("bce") + assert data == [ + { + "tag": "a", + "attrs": {"d": "1"}, + "data": ["b"], + "children": [ + { + "tag": "a", + "attrs": {"d": "2"}, + "data": ["c"], + "children": [{"tag": "a", "attrs": {"d": "3"}, "data": ["e"]}], + } + ], + } + ] + + +def test_html2ast_strip(): + data = HTML2AST().parse('

aj

b', {"x"}) + assert data == [ + { + "tag": "div", + "attrs": {}, + "children": [{"tag": "p", "attrs": {"class": "x y"}}], + }, + {"tag": "a", "attrs": {}, "data": ["b"]}, + ]