Skip to content

Commit

Permalink
better typing
Browse files Browse the repository at this point in the history
  • Loading branch information
matthiasdiener committed Jan 13, 2023
1 parent 5fe0a91 commit c53a9af
Showing 1 changed file with 31 additions and 17 deletions.
48 changes: 31 additions & 17 deletions loopy/schedule/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,31 @@
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""

from immutables import Map
from typing import Generic, Hashable, Tuple, TypeVar, Iterator, Optional, List
from dataclasses import dataclass

# {{{ tree data structure

T = TypeVar("T")
NodeT = TypeVar("NodeT", bound=Hashable)


@dataclass(frozen=True)
class Tree(Generic[T]):
class Tree(Generic[NodeT]):
"""
An immutable tree implementation.
An immutable n-ary tree containing nodes of type :class:`NodeT`.
.. automethod:: ancestors
.. automethod:: parent
.. automethod:: children
.. automethod:: create_node
.. automethod:: add_node
.. automethod:: depth
.. automethod:: rename_node
.. automethod:: replace_node
.. automethod:: move_node
.. note::
Almost all the operations are implemented recursively. NOT suitable for
deep trees. At the very least if the Python implementation is CPython
this allocates a new stack frame for each iteration of the operation.
Expand All @@ -49,7 +57,7 @@ def from_root(root: T):
Map({root: None}))

@property
def root(self) -> T:
def root(self) -> NodeT:
guess = set(self._child_to_parent).pop()
parent_of_guess = self.parent(guess)
while parent_of_guess is not None:
Expand All @@ -74,7 +82,10 @@ def ancestors(self, node: T) -> FrozenSet[T]:

return frozenset([parent]) | self.ancestors(parent)

def parent(self, node: T) -> OptionalT[T]:
def parent(self, node: NodeT) -> Optional[NodeT]:
"""
Returns the parent of *node*.
"""
if not self.is_a_node(node):
raise ValueError(f"'{node}' not in tree.")

Expand All @@ -86,7 +97,10 @@ def children(self, node: T) -> FrozenSet[T]:

return self._parent_to_children[node]

def depth(self, node: T) -> int:
def depth(self, node: NodeT) -> int:
"""
Returns the depth of *node*.
"""
if not self.is_a_node(node):
raise ValueError(f"'{node}' not in tree.")

Expand All @@ -99,22 +113,22 @@ def depth(self, node: T) -> int:

return 1 + self.depth(parent_of_node)

def is_root(self, node: T) -> bool:
def is_root(self, node: NodeT) -> bool:
if not self.is_a_node(node):
raise ValueError(f"'{node}' not in tree.")

return self.parent(node) is None

def is_leaf(self, node: T) -> bool:
def is_leaf(self, node: NodeT) -> bool:
if not self.is_a_node(node):
raise ValueError(f"'{node}' not in tree.")

return len(self.children(node)) == 0

def is_a_node(self, node: T) -> bool:
def is_a_node(self, node: NodeT) -> bool:
return node in self._child_to_parent

def add_node(self, node: T, parent: T) -> "Tree[T]":
def add_node(self, node: NodeT, parent: NodeT) -> "Tree[NodeT]":
"""
Returns a :class:`Tree` with added node *node* having a parent
*parent*.
Expand All @@ -129,9 +143,9 @@ def add_node(self, node: T, parent: T) -> "Tree[T]":
.set(node, frozenset())),
self._child_to_parent.set(node, parent))

def rename_node(self, node: T, new_id: T) -> "Tree[T]":
def replace_node(self, node: NodeT, new_id: NodeT) -> "Tree[NodeT]":
"""
Returns a copy of *self* with *node* renamed to *new_id*.
Returns a copy of *self* with *node* replaced with *new_id*.
"""
if not self.is_a_node(node):
raise ValueError(f"'{node}' not present in tree.")
Expand Down Expand Up @@ -173,7 +187,7 @@ def rename_node(self, node: T, new_id: T) -> "Tree[T]":
return Tree(new_parent_to_children,
new_child_to_parent)

def move_node(self, node: T, new_parent: OptionalT[T]) -> "Tree[T]":
def move_node(self, node: NodeT, new_parent: Optional[NodeT]) -> "Tree[NodeT]":
"""
Returns a copy of *self* with node *node* as a child of *new_parent*.
"""
Expand Down Expand Up @@ -228,7 +242,7 @@ def __str__(self) -> str:
├── D
└── E
"""
def rec(node):
def rec(node: NodeT) -> List[str]:
children_result = [rec(c) for c in self.children(node)]

def post_process_non_last_child(child):
Expand All @@ -245,7 +259,7 @@ def post_process_last_child(child):

return "\n".join(rec(self.root))

def nodes(self) -> Iterator[T]:
def nodes(self) -> Iterator[NodeT]:
return iter(self._child_to_parent.keys())

# }}}

0 comments on commit c53a9af

Please sign in to comment.