diff --git a/docs/using.md b/docs/using.md index ec21b6bc..c07f32e7 100644 --- a/docs/using.md +++ b/docs/using.md @@ -265,7 +265,7 @@ Here's some text and an image ![title](image.png) > a *quote* """) -node = SyntaxTreeNode.from_tokens(tokens) +node = SyntaxTreeNode(tokens) print(node.pretty(indent=2, show_text=True)) ``` diff --git a/markdown_it/tree.py b/markdown_it/tree.py index 58398d9a..fac28dd5 100644 --- a/markdown_it/tree.py +++ b/markdown_it/tree.py @@ -12,7 +12,6 @@ Optional, Any, TypeVar, - Type, overload, Union, ) @@ -33,8 +32,7 @@ class SyntaxTreeNode: """A Markdown syntax tree node. A class that can be used to construct a tree representation of a linear - `markdown-it-py` token stream. Use `SyntaxTreeNode.from_tokens` to - initialize instead of the `__init__` method. + `markdown-it-py` token stream. Each node in the tree represents either: - root of the Markdown document @@ -43,10 +41,12 @@ class SyntaxTreeNode: between """ - def __init__(self) -> None: - """Initialize a root node with no children. + def __init__( + self, tokens: Sequence[Token] = (), *, create_root: bool = True + ) -> None: + """Initialize a `SyntaxTreeNode` from a token stream. - You probably need `SyntaxTreeNode.from_tokens` instead. + If `create_root` is True, create a root node for the document. """ # Only nodes representing an unnested token have self.token self.token: Optional[Token] = None @@ -61,6 +61,28 @@ def __init__(self) -> None: # children (i.e. inline or img) self._children: list = [] + if create_root: + self._set_children_from_tokens(tokens) + return + + if not tokens: + raise ValueError( + "Can only create root from empty token sequence." + " Set `create_root=True`." + ) + elif len(tokens) == 1: + inline_token = tokens[0] + if inline_token.nesting: + raise ValueError( + "Unequal nesting level at the start and end of token stream." + ) + self.token = inline_token + if inline_token.children: + self._set_children_from_tokens(inline_token.children) + else: + self.nester_tokens = _NesterTokens(tokens[0], tokens[-1]) + self._set_children_from_tokens(tokens[1:-1]) + def __repr__(self) -> str: return f"{type(self).__name__}({self.type})" @@ -77,16 +99,6 @@ def __getitem__( ) -> Union[_NodeType, List[_NodeType]]: return self.children[item] - @classmethod - def from_tokens(cls: Type[_NodeType], tokens: Sequence[Token]) -> _NodeType: - """Instantiate a `SyntaxTreeNode` from a token stream. - - This is the standard method for instantiating `SyntaxTreeNode`. - """ - root = cls() - root._set_children_from_tokens(tokens) - return root - def to_tokens(self: _NodeType) -> List[Token]: """Recover the linear token stream.""" @@ -186,23 +198,14 @@ def previous_sibling(self: _NodeType) -> Optional[_NodeType]: return self.siblings[self_index - 1] return None - def _make_child( - self: _NodeType, - *, - token: Optional[Token] = None, - nester_tokens: Optional[_NesterTokens] = None, - ) -> _NodeType: - """Make and return a child node for `self`.""" - if token and nester_tokens or not token and not nester_tokens: - raise ValueError("must specify either `token` or `nester_tokens`") - child = type(self)() - if token: - child.token = token - else: - child.nester_tokens = nester_tokens + def _add_child( + self, + tokens: Sequence[Token], + ) -> None: + """Make a child node for `self`.""" + child = type(self)(tokens, create_root=False) child.parent = self self.children.append(child) - return child def _set_children_from_tokens(self, tokens: Sequence[Token]) -> None: """Convert the token stream to a tree structure and set the resulting @@ -211,27 +214,22 @@ def _set_children_from_tokens(self, tokens: Sequence[Token]) -> None: while reversed_tokens: token = reversed_tokens.pop() - if token.nesting == 0: - child = self._make_child(token=token) - if token.children: - child._set_children_from_tokens(token.children) + if not token.nesting: + self._add_child([token]) continue - - assert token.nesting == 1 + if token.nesting != 1: + raise ValueError("Invalid token nesting") nested_tokens = [token] nesting = 1 - while reversed_tokens and nesting != 0: + while reversed_tokens and nesting: token = reversed_tokens.pop() nested_tokens.append(token) nesting += token.nesting - if nesting != 0: + if nesting: raise ValueError(f"unclosed tokens starting {nested_tokens[0]}") - child = self._make_child( - nester_tokens=_NesterTokens(nested_tokens[0], nested_tokens[-1]) - ) - child._set_children_from_tokens(nested_tokens[1:-1]) + self._add_child(nested_tokens) def pretty( self, *, indent: int = 2, show_text: bool = False, _current: int = 0 diff --git a/tests/test_tree.py b/tests/test_tree.py index bd04527e..6e0fd1ed 100644 --- a/tests/test_tree.py +++ b/tests/test_tree.py @@ -10,14 +10,14 @@ def test_tree_to_tokens_conversion(): tokens = MarkdownIt().parse(EXAMPLE_MARKDOWN) - tokens_after_roundtrip = SyntaxTreeNode.from_tokens(tokens).to_tokens() + tokens_after_roundtrip = SyntaxTreeNode(tokens).to_tokens() assert tokens == tokens_after_roundtrip def test_property_passthrough(): tokens = MarkdownIt().parse(EXAMPLE_MARKDOWN) heading_open = tokens[0] - tree = SyntaxTreeNode.from_tokens(tokens) + tree = SyntaxTreeNode(tokens) heading_node = tree.children[0] assert heading_open.tag == heading_node.tag assert tuple(heading_open.map) == heading_node.map @@ -32,7 +32,7 @@ def test_property_passthrough(): def test_type(): tokens = MarkdownIt().parse(EXAMPLE_MARKDOWN) - tree = SyntaxTreeNode.from_tokens(tokens) + tree = SyntaxTreeNode(tokens) # Root type is "root" assert tree.type == "root" # "_open" suffix must be stripped from nested token type @@ -44,7 +44,7 @@ def test_type(): def test_sibling_traverse(): tokens = MarkdownIt().parse(EXAMPLE_MARKDOWN) - tree = SyntaxTreeNode.from_tokens(tokens) + tree = SyntaxTreeNode(tokens) paragraph_inline_node = tree.children[1].children[0] text_node = paragraph_inline_node.children[0] assert text_node.type == "text" @@ -70,5 +70,5 @@ def test_pretty(file_regression): > a *quote* """ ) - node = SyntaxTreeNode.from_tokens(tokens) + node = SyntaxTreeNode(tokens) file_regression.check(node.pretty(indent=2, show_text=True), extension=".xml")