diff --git a/xarray/tests/datatree/test_treenode.py b/xarray/tests/datatree/test_treenode.py index 245e8fab4fc..4789a9be6e3 100644 --- a/xarray/tests/datatree/test_treenode.py +++ b/xarray/tests/datatree/test_treenode.py @@ -1,5 +1,8 @@ from __future__ import annotations +from collections.abc import Iterator +from typing import cast + import pytest from xarray.core.treenode import InvalidTreeError, NamedNode, NodePath, TreeNode @@ -8,20 +11,20 @@ class TestFamilyTree: def test_lonely(self): - root = TreeNode() + root: TreeNode = TreeNode() assert root.parent is None assert root.children == {} def test_parenting(self): - john = TreeNode() - mary = TreeNode() + john: TreeNode = TreeNode() + mary: TreeNode = TreeNode() mary._set_parent(john, "Mary") assert mary.parent == john assert john.children["Mary"] is mary def test_no_time_traveller_loops(self): - john = TreeNode() + john: TreeNode = TreeNode() with pytest.raises(InvalidTreeError, match="cannot be a parent of itself"): john._set_parent(john, "John") @@ -29,8 +32,8 @@ def test_no_time_traveller_loops(self): with pytest.raises(InvalidTreeError, match="cannot be a parent of itself"): john.children = {"John": john} - mary = TreeNode() - rose = TreeNode() + mary: TreeNode = TreeNode() + rose: TreeNode = TreeNode() mary._set_parent(john, "Mary") rose._set_parent(mary, "Rose") @@ -41,11 +44,11 @@ def test_no_time_traveller_loops(self): rose.children = {"John": john} def test_parent_swap(self): - john = TreeNode() - mary = TreeNode() + john: TreeNode = TreeNode() + mary: TreeNode = TreeNode() mary._set_parent(john, "Mary") - steve = TreeNode() + steve: TreeNode = TreeNode() mary._set_parent(steve, "Mary") assert mary.parent == steve @@ -53,24 +56,24 @@ def test_parent_swap(self): assert "Mary" not in john.children def test_multi_child_family(self): - mary = TreeNode() - kate = TreeNode() - john = TreeNode(children={"Mary": mary, "Kate": kate}) + mary: TreeNode = TreeNode() + kate: TreeNode = TreeNode() + john: TreeNode = TreeNode(children={"Mary": mary, "Kate": kate}) assert john.children["Mary"] is mary assert john.children["Kate"] is kate assert mary.parent is john assert kate.parent is john def test_disown_child(self): - mary = TreeNode() - john = TreeNode(children={"Mary": mary}) + mary: TreeNode = TreeNode() + john: TreeNode = TreeNode(children={"Mary": mary}) mary.orphan() assert mary.parent is None assert "Mary" not in john.children def test_doppelganger_child(self): - kate = TreeNode() - john = TreeNode() + kate: TreeNode = TreeNode() + john: TreeNode = TreeNode() with pytest.raises(TypeError): john.children = {"Kate": 666} @@ -79,22 +82,22 @@ def test_doppelganger_child(self): john.children = {"Kate": kate, "Evil_Kate": kate} john = TreeNode(children={"Kate": kate}) - evil_kate = TreeNode() + evil_kate: TreeNode = TreeNode() evil_kate._set_parent(john, "Kate") assert john.children["Kate"] is evil_kate def test_sibling_relationships(self): - mary = TreeNode() - kate = TreeNode() - ashley = TreeNode() + mary: TreeNode = TreeNode() + kate: TreeNode = TreeNode() + ashley: TreeNode = TreeNode() TreeNode(children={"Mary": mary, "Kate": kate, "Ashley": ashley}) assert kate.siblings["Mary"] is mary assert kate.siblings["Ashley"] is ashley assert "Kate" not in kate.siblings def test_ancestors(self): - tony = TreeNode() - michael = TreeNode(children={"Tony": tony}) + tony: TreeNode = TreeNode() + michael: TreeNode = TreeNode(children={"Tony": tony}) vito = TreeNode(children={"Michael": michael}) assert tony.root is vito assert tony.parents == (michael, vito) @@ -103,7 +106,7 @@ def test_ancestors(self): class TestGetNodes: def test_get_child(self): - steven = TreeNode() + steven: TreeNode = TreeNode() sue = TreeNode(children={"Steven": steven}) mary = TreeNode(children={"Sue": sue}) john = TreeNode(children={"Mary": mary}) @@ -126,8 +129,8 @@ def test_get_child(self): assert mary._get_item("Sue/Steven") is steven def test_get_upwards(self): - sue = TreeNode() - kate = TreeNode() + sue: TreeNode = TreeNode() + kate: TreeNode = TreeNode() mary = TreeNode(children={"Sue": sue, "Kate": kate}) john = TreeNode(children={"Mary": mary}) @@ -138,7 +141,7 @@ def test_get_upwards(self): assert sue._get_item("../Kate") is kate def test_get_from_root(self): - sue = TreeNode() + sue: TreeNode = TreeNode() mary = TreeNode(children={"Sue": sue}) john = TreeNode(children={"Mary": mary}) # noqa @@ -147,8 +150,8 @@ def test_get_from_root(self): class TestSetNodes: def test_set_child_node(self): - john = TreeNode() - mary = TreeNode() + john: TreeNode = TreeNode() + mary: TreeNode = TreeNode() john._set_item("Mary", mary) assert john.children["Mary"] is mary @@ -157,16 +160,16 @@ def test_set_child_node(self): assert mary.parent is john def test_child_already_exists(self): - mary = TreeNode() - john = TreeNode(children={"Mary": mary}) - mary_2 = TreeNode() + mary: TreeNode = TreeNode() + john: TreeNode = TreeNode(children={"Mary": mary}) + mary_2: TreeNode = TreeNode() with pytest.raises(KeyError): john._set_item("Mary", mary_2, allow_overwrite=False) def test_set_grandchild(self): - rose = TreeNode() - mary = TreeNode() - john = TreeNode() + rose: TreeNode = TreeNode() + mary: TreeNode = TreeNode() + john: TreeNode = TreeNode() john._set_item("Mary", mary) john._set_item("Mary/Rose", rose) @@ -177,8 +180,8 @@ def test_set_grandchild(self): assert rose.parent is mary def test_create_intermediate_child(self): - john = TreeNode() - rose = TreeNode() + john: TreeNode = TreeNode() + rose: TreeNode = TreeNode() # test intermediate children not allowed with pytest.raises(KeyError, match="Could not reach"): @@ -194,12 +197,12 @@ def test_create_intermediate_child(self): assert rose.parent == mary def test_overwrite_child(self): - john = TreeNode() - mary = TreeNode() + john: TreeNode = TreeNode() + mary: TreeNode = TreeNode() john._set_item("Mary", mary) # test overwriting not allowed - marys_evil_twin = TreeNode() + marys_evil_twin: TreeNode = TreeNode() with pytest.raises(KeyError, match="Already a node object"): john._set_item("Mary", marys_evil_twin, allow_overwrite=False) assert john.children["Mary"] is mary @@ -214,8 +217,8 @@ def test_overwrite_child(self): class TestPruning: def test_del_child(self): - john = TreeNode() - mary = TreeNode() + john: TreeNode = TreeNode() + mary: TreeNode = TreeNode() john._set_item("Mary", mary) del john["Mary"] @@ -226,16 +229,16 @@ def test_del_child(self): del john["Mary"] -def create_test_tree(): - a = NamedNode(name="a") - b = NamedNode() - c = NamedNode() - d = NamedNode() - e = NamedNode() - f = NamedNode() - g = NamedNode() - h = NamedNode() - i = NamedNode() +def create_test_tree() -> tuple[NamedNode, NamedNode]: + a: NamedNode = NamedNode(name="a") + b: NamedNode = NamedNode() + c: NamedNode = NamedNode() + d: NamedNode = NamedNode() + e: NamedNode = NamedNode() + f: NamedNode = NamedNode() + g: NamedNode = NamedNode() + h: NamedNode = NamedNode() + i: NamedNode = NamedNode() a.children = {"b": b, "c": c} b.children = {"d": d, "e": e} @@ -249,7 +252,9 @@ def create_test_tree(): class TestIterators: def test_preorderiter(self): root, _ = create_test_tree() - result = [node.name for node in PreOrderIter(root)] + result: list[str | None] = [ + node.name for node in cast(Iterator[NamedNode], PreOrderIter(root)) + ] expected = [ "a", "b", @@ -265,7 +270,9 @@ def test_preorderiter(self): def test_levelorderiter(self): root, _ = create_test_tree() - result = [node.name for node in LevelOrderIter(root)] + result: list[str | None] = [ + node.name for node in cast(Iterator[NamedNode], LevelOrderIter(root)) + ] expected = [ "a", # root Node is unnamed "b", @@ -358,11 +365,11 @@ def test_levels(self): class TestRenderTree: def test_render_nodetree(self): - sam = NamedNode() - ben = NamedNode() - mary = NamedNode(children={"Sam": sam, "Ben": ben}) - kate = NamedNode() - john = NamedNode(children={"Mary": mary, "Kate": kate}) + sam: NamedNode = NamedNode() + ben: NamedNode = NamedNode() + mary: NamedNode = NamedNode(children={"Sam": sam, "Ben": ben}) + kate: NamedNode = NamedNode() + john: NamedNode = NamedNode(children={"Mary": mary, "Kate": kate}) printout = john.__str__() expected_nodes = [