Skip to content
82 changes: 60 additions & 22 deletions markdown_it/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,32 @@
This module is not part of upstream JavaScript markdown-it.
"""
import textwrap
from typing import NamedTuple, Sequence, Tuple, Dict, List, Optional, Any
from typing import (
NamedTuple,
Sequence,
Tuple,
Dict,
List,
Optional,
Any,
TypeVar,
Type,
overload,
Union,
)

from .token import Token
from .utils import _removesuffix


class _NesterTokens(NamedTuple):
opening: Token
closing: Token


_NodeType = TypeVar("_NodeType", bound="SyntaxTreeNode")


class SyntaxTreeNode:
"""A Markdown syntax tree node.

Expand All @@ -23,10 +43,6 @@ class SyntaxTreeNode:
between
"""

class _NesterTokens(NamedTuple):
opening: Token
closing: Token

def __init__(self) -> None:
"""Initialize a root node with no children.

Expand All @@ -36,23 +52,33 @@ def __init__(self) -> None:
self.token: Optional[Token] = None

# Only containers have nester tokens
self.nester_tokens: Optional[SyntaxTreeNode._NesterTokens] = None
self.nester_tokens: Optional[_NesterTokens] = None

# Root node does not have self.parent
self.parent: Optional["SyntaxTreeNode"] = None
self._parent: Any = None

# Empty list unless a non-empty container, or unnested token that has
# children (i.e. inline or img)
self.children: List["SyntaxTreeNode"] = []
self._children: list = []

def __repr__(self) -> str:
return f"{type(self).__name__}({self.type})"

def __getitem__(self, item: int) -> "SyntaxTreeNode":
@overload
def __getitem__(self: _NodeType, item: int) -> _NodeType:
...

@overload
def __getitem__(self: _NodeType, item: slice) -> List[_NodeType]:
...

def __getitem__(
self: _NodeType, item: Union[int, slice]
) -> Union[_NodeType, List[_NodeType]]:
return self.children[item]

@classmethod
def from_tokens(cls, tokens: Sequence[Token]) -> "SyntaxTreeNode":
def from_tokens(cls: Type[_NodeType], tokens: Sequence[Token]) -> _NodeType:
"""Instantiate a `SyntaxTreeNode` from a token stream.

This is the standard method for instantiating `SyntaxTreeNode`.
Expand All @@ -61,12 +87,10 @@ def from_tokens(cls, tokens: Sequence[Token]) -> "SyntaxTreeNode":
root._set_children_from_tokens(tokens)
return root

def to_tokens(self) -> List[Token]:
def to_tokens(self: _NodeType) -> List[Token]:
"""Recover the linear token stream."""

def recursive_collect_tokens(
node: "SyntaxTreeNode", token_list: List[Token]
) -> None:
def recursive_collect_tokens(node: _NodeType, token_list: List[Token]) -> None:
if node.type == "root":
for child in node.children:
recursive_collect_tokens(child, token_list)
Expand All @@ -83,6 +107,22 @@ def recursive_collect_tokens(
recursive_collect_tokens(self, tokens)
return tokens

@property
def children(self: _NodeType) -> List[_NodeType]:
return self._children

@children.setter
def children(self: _NodeType, value: List[_NodeType]) -> None:
self._children = value

@property
def parent(self: _NodeType) -> Optional[_NodeType]:
return self._parent

@parent.setter
def parent(self: _NodeType, value: Optional[_NodeType]) -> None:
self._parent = value

@property
def is_root(self) -> bool:
"""Is the node a special root node?"""
Expand All @@ -99,7 +139,7 @@ def is_nested(self) -> bool:
return bool(self.nester_tokens)

@property
def siblings(self) -> Sequence["SyntaxTreeNode"]:
def siblings(self: _NodeType) -> Sequence[_NodeType]:
"""Get siblings of the node.

Gets the whole group of siblings, including self.
Expand All @@ -125,7 +165,7 @@ def type(self) -> str:
return _removesuffix(self.nester_tokens.opening.type, "_open")

@property
def next_sibling(self) -> Optional["SyntaxTreeNode"]:
def next_sibling(self: _NodeType) -> Optional[_NodeType]:
"""Get the next node in the sequence of siblings.

Returns `None` if this is the last sibling.
Expand All @@ -136,7 +176,7 @@ def next_sibling(self) -> Optional["SyntaxTreeNode"]:
return None

@property
def previous_sibling(self) -> Optional["SyntaxTreeNode"]:
def previous_sibling(self: _NodeType) -> Optional[_NodeType]:
"""Get the previous node in the sequence of siblings.

Returns `None` if this is the first sibling.
Expand All @@ -147,11 +187,11 @@ def previous_sibling(self) -> Optional["SyntaxTreeNode"]:
return None

def _make_child(
self,
self: _NodeType,
*,
token: Optional[Token] = None,
nester_tokens: Optional[_NesterTokens] = None,
) -> "SyntaxTreeNode":
) -> _NodeType:
"""Make and return a child node for `self`."""
if token and nester_tokens or not token and not nester_tokens:
raise ValueError("must specify either `token` or `nester_tokens`")
Expand Down Expand Up @@ -189,9 +229,7 @@ def _set_children_from_tokens(self, tokens: Sequence[Token]) -> None:
raise ValueError(f"unclosed tokens starting {nested_tokens[0]}")

child = self._make_child(
nester_tokens=SyntaxTreeNode._NesterTokens(
nested_tokens[0], nested_tokens[-1]
)
nester_tokens=_NesterTokens(nested_tokens[0], nested_tokens[-1])
)
child._set_children_from_tokens(nested_tokens[1:-1])

Expand Down