diff --git a/markdown_it/tree.py b/markdown_it/tree.py index dcc46132..58398d9a 100644 --- a/markdown_it/tree.py +++ b/markdown_it/tree.py @@ -3,12 +3,32 @@ This module is not part of upstream JavaScript markdown-it. """ import textwrap -from typing import NamedTuple, Sequence, Tuple, Dict, List, Optional, Any +from typing import ( + NamedTuple, + Sequence, + Tuple, + Dict, + List, + Optional, + Any, + TypeVar, + Type, + overload, + Union, +) from .token import Token from .utils import _removesuffix +class _NesterTokens(NamedTuple): + opening: Token + closing: Token + + +_NodeType = TypeVar("_NodeType", bound="SyntaxTreeNode") + + class SyntaxTreeNode: """A Markdown syntax tree node. @@ -23,10 +43,6 @@ class SyntaxTreeNode: between """ - class _NesterTokens(NamedTuple): - opening: Token - closing: Token - def __init__(self) -> None: """Initialize a root node with no children. @@ -36,23 +52,33 @@ def __init__(self) -> None: self.token: Optional[Token] = None # Only containers have nester tokens - self.nester_tokens: Optional[SyntaxTreeNode._NesterTokens] = None + self.nester_tokens: Optional[_NesterTokens] = None # Root node does not have self.parent - self.parent: Optional["SyntaxTreeNode"] = None + self._parent: Any = None # Empty list unless a non-empty container, or unnested token that has # children (i.e. inline or img) - self.children: List["SyntaxTreeNode"] = [] + self._children: list = [] def __repr__(self) -> str: return f"{type(self).__name__}({self.type})" - def __getitem__(self, item: int) -> "SyntaxTreeNode": + @overload + def __getitem__(self: _NodeType, item: int) -> _NodeType: + ... + + @overload + def __getitem__(self: _NodeType, item: slice) -> List[_NodeType]: + ... + + def __getitem__( + self: _NodeType, item: Union[int, slice] + ) -> Union[_NodeType, List[_NodeType]]: return self.children[item] @classmethod - def from_tokens(cls, tokens: Sequence[Token]) -> "SyntaxTreeNode": + def from_tokens(cls: Type[_NodeType], tokens: Sequence[Token]) -> _NodeType: """Instantiate a `SyntaxTreeNode` from a token stream. This is the standard method for instantiating `SyntaxTreeNode`. @@ -61,12 +87,10 @@ def from_tokens(cls, tokens: Sequence[Token]) -> "SyntaxTreeNode": root._set_children_from_tokens(tokens) return root - def to_tokens(self) -> List[Token]: + def to_tokens(self: _NodeType) -> List[Token]: """Recover the linear token stream.""" - def recursive_collect_tokens( - node: "SyntaxTreeNode", token_list: List[Token] - ) -> None: + def recursive_collect_tokens(node: _NodeType, token_list: List[Token]) -> None: if node.type == "root": for child in node.children: recursive_collect_tokens(child, token_list) @@ -83,6 +107,22 @@ def recursive_collect_tokens( recursive_collect_tokens(self, tokens) return tokens + @property + def children(self: _NodeType) -> List[_NodeType]: + return self._children + + @children.setter + def children(self: _NodeType, value: List[_NodeType]) -> None: + self._children = value + + @property + def parent(self: _NodeType) -> Optional[_NodeType]: + return self._parent + + @parent.setter + def parent(self: _NodeType, value: Optional[_NodeType]) -> None: + self._parent = value + @property def is_root(self) -> bool: """Is the node a special root node?""" @@ -99,7 +139,7 @@ def is_nested(self) -> bool: return bool(self.nester_tokens) @property - def siblings(self) -> Sequence["SyntaxTreeNode"]: + def siblings(self: _NodeType) -> Sequence[_NodeType]: """Get siblings of the node. Gets the whole group of siblings, including self. @@ -125,7 +165,7 @@ def type(self) -> str: return _removesuffix(self.nester_tokens.opening.type, "_open") @property - def next_sibling(self) -> Optional["SyntaxTreeNode"]: + def next_sibling(self: _NodeType) -> Optional[_NodeType]: """Get the next node in the sequence of siblings. Returns `None` if this is the last sibling. @@ -136,7 +176,7 @@ def next_sibling(self) -> Optional["SyntaxTreeNode"]: return None @property - def previous_sibling(self) -> Optional["SyntaxTreeNode"]: + def previous_sibling(self: _NodeType) -> Optional[_NodeType]: """Get the previous node in the sequence of siblings. Returns `None` if this is the first sibling. @@ -147,11 +187,11 @@ def previous_sibling(self) -> Optional["SyntaxTreeNode"]: return None def _make_child( - self, + self: _NodeType, *, token: Optional[Token] = None, nester_tokens: Optional[_NesterTokens] = None, - ) -> "SyntaxTreeNode": + ) -> _NodeType: """Make and return a child node for `self`.""" if token and nester_tokens or not token and not nester_tokens: raise ValueError("must specify either `token` or `nester_tokens`") @@ -189,9 +229,7 @@ def _set_children_from_tokens(self, tokens: Sequence[Token]) -> None: raise ValueError(f"unclosed tokens starting {nested_tokens[0]}") child = self._make_child( - nester_tokens=SyntaxTreeNode._NesterTokens( - nested_tokens[0], nested_tokens[-1] - ) + nester_tokens=_NesterTokens(nested_tokens[0], nested_tokens[-1]) ) child._set_children_from_tokens(nested_tokens[1:-1])