Skip to content

Commit

Permalink
Basic migration of treenode module
Browse files Browse the repository at this point in the history
Moves treenode.py and test_treenode.py.
Updates some typing.
Updates imports from treenode.
  • Loading branch information
flamingbear committed Feb 14, 2024
1 parent 3bd8858 commit 8121b81
Show file tree
Hide file tree
Showing 8 changed files with 31 additions and 31 deletions.
4 changes: 2 additions & 2 deletions xarray/backends/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,8 @@ def _open_datatree_netcdf(
**kwargs,
) -> DataTree:
from xarray.backends.api import open_dataset
from xarray.core.treenode import NodePath
from xarray.datatree_.datatree import DataTree
from xarray.datatree_.datatree.treenode import NodePath

ds = open_dataset(filename_or_obj, **kwargs)
tree_root = DataTree.from_dict({"/": ds})
Expand All @@ -159,7 +159,7 @@ def _open_datatree_netcdf(


def _iter_nc_groups(root, parent="/"):
from xarray.datatree_.datatree.treenode import NodePath
from xarray.core.treenode import NodePath

parent = NodePath(parent)
for path, group in root.groups.items():
Expand Down
4 changes: 2 additions & 2 deletions xarray/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1048,8 +1048,8 @@ def open_datatree(
import zarr

from xarray.backends.api import open_dataset
from xarray.core.treenode import NodePath
from xarray.datatree_.datatree import DataTree
from xarray.datatree_.datatree.treenode import NodePath

zds = zarr.open_group(filename_or_obj, mode="r")
ds = open_dataset(filename_or_obj, engine="zarr", **kwargs)
Expand All @@ -1075,7 +1075,7 @@ def open_datatree(


def _iter_zarr_groups(root, parent="/"):
from xarray.datatree_.datatree.treenode import NodePath
from xarray.core.treenode import NodePath

parent = NodePath(parent)
for path, group in root.groups():
Expand Down
40 changes: 19 additions & 21 deletions xarray/datatree_/datatree/treenode.py → xarray/core/treenode.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,12 @@

import sys
from collections import OrderedDict
from collections.abc import Iterator, Mapping
from pathlib import PurePosixPath
from typing import (
TYPE_CHECKING,
Generic,
Iterator,
Mapping,
Optional,
Tuple,
TypeVar,
Union,
)

from xarray.core.utils import Frozen, is_dict_like
Expand All @@ -24,6 +20,8 @@ class InvalidTreeError(Exception):
"""Raised when user attempts to create an invalid tree in some way."""


# TODO [MHS, 02/13/2024] I don't like the description here. It doesn't make
# sense on first glance.
class NotFoundInTreeError(ValueError):
"""Raised when operation can't be completed because one node is part of the expected tree."""

Expand Down Expand Up @@ -75,10 +73,10 @@ class TreeNode(Generic[Tree]):
(This class is heavily inspired by the anytree library's NodeMixin class.)
"""

_parent: Optional[Tree]
_parent: Tree | None
_children: OrderedDict[str, Tree]

def __init__(self, children: Optional[Mapping[str, Tree]] = None):
def __init__(self, children: Mapping[str, Tree] | None = None):
"""Create a parentless node."""
self._parent = None
self._children = OrderedDict()
Expand All @@ -91,7 +89,7 @@ def parent(self) -> Tree | None:
return self._parent

def _set_parent(
self, new_parent: Tree | None, child_name: Optional[str] = None
self, new_parent: Tree | None, child_name: str | None = None
) -> None:
# TODO is it possible to refactor in a way that removes this private method?

Expand Down Expand Up @@ -137,7 +135,7 @@ def _detach(self, parent: Tree | None) -> None:
self._parent = None
self._post_detach(parent)

def _attach(self, parent: Tree | None, child_name: Optional[str] = None) -> None:
def _attach(self, parent: Tree | None, child_name: str | None = None) -> None:
if parent is not None:
if child_name is None:
raise ValueError(
Expand Down Expand Up @@ -242,7 +240,7 @@ def _iter_parents(self: Tree) -> Iterator[Tree]:
yield node
node = node.parent

def iter_lineage(self: Tree) -> Tuple[Tree, ...]:
def iter_lineage(self: Tree) -> tuple[Tree, ...]:
"""Iterate up the tree, starting from the current node."""
from warnings import warn

Expand All @@ -254,7 +252,7 @@ def iter_lineage(self: Tree) -> Tuple[Tree, ...]:
return tuple((self, *self.parents))

@property
def lineage(self: Tree) -> Tuple[Tree, ...]:
def lineage(self: Tree) -> tuple[Tree, ...]:
"""All parent nodes and their parent nodes, starting with the closest."""
from warnings import warn

Expand All @@ -266,12 +264,12 @@ def lineage(self: Tree) -> Tuple[Tree, ...]:
return self.iter_lineage()

@property
def parents(self: Tree) -> Tuple[Tree, ...]:
def parents(self: Tree) -> tuple[Tree, ...]:
"""All parent nodes and their parent nodes, starting with the closest."""
return tuple(self._iter_parents())

@property
def ancestors(self: Tree) -> Tuple[Tree, ...]:
def ancestors(self: Tree) -> tuple[Tree, ...]:
"""All parent nodes and their parent nodes, starting with the most distant."""

from warnings import warn
Expand Down Expand Up @@ -306,7 +304,7 @@ def is_leaf(self) -> bool:
return self.children == {}

@property
def leaves(self: Tree) -> Tuple[Tree, ...]:
def leaves(self: Tree) -> tuple[Tree, ...]:
"""
All leaf nodes.
Expand Down Expand Up @@ -341,12 +339,12 @@ def subtree(self: Tree) -> Iterator[Tree]:
--------
DataTree.descendants
"""
from . import iterators
from xarray.datatree_.datatree import iterators

return iterators.PreOrderIter(self)

@property
def descendants(self: Tree) -> Tuple[Tree, ...]:
def descendants(self: Tree) -> tuple[Tree, ...]:
"""
Child nodes and all their child nodes.
Expand Down Expand Up @@ -431,7 +429,7 @@ def _post_attach(self: Tree, parent: Tree) -> None:
"""Method call after attaching to `parent`."""
pass

def get(self: Tree, key: str, default: Optional[Tree] = None) -> Optional[Tree]:
def get(self: Tree, key: str, default: Tree | None = None) -> Tree | None:
"""
Return the child node with the specified key.
Expand All @@ -445,7 +443,7 @@ def get(self: Tree, key: str, default: Optional[Tree] = None) -> Optional[Tree]:

# TODO `._walk` method to be called by both `_get_item` and `_set_item`

def _get_item(self: Tree, path: str | NodePath) -> Union[Tree, T_DataArray]:
def _get_item(self: Tree, path: str | NodePath) -> Tree | T_DataArray:
"""
Returns the object lying at the given path.
Expand Down Expand Up @@ -488,7 +486,7 @@ def _set(self: Tree, key: str, val: Tree) -> None:
def _set_item(
self: Tree,
path: str | NodePath,
item: Union[Tree, T_DataArray],
item: Tree | T_DataArray,
new_nodes_along_path: bool = False,
allow_overwrite: bool = True,
) -> None:
Expand Down Expand Up @@ -580,8 +578,8 @@ class NamedNode(TreeNode, Generic[Tree]):
Implements path-like relationships to other nodes in its tree.
"""

_name: Optional[str]
_parent: Optional[Tree]
_name: str | None
_parent: Tree | None
_children: OrderedDict[str, Tree]

def __init__(self, name=None, children=None):
Expand Down
2 changes: 1 addition & 1 deletion xarray/datatree_/datatree/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .datatree import DataTree
from .extensions import register_datatree_accessor
from .mapping import TreeIsomorphismError, map_over_subtree
from .treenode import InvalidTreeError, NotFoundInTreeError
from xarray.core.treenode import InvalidTreeError, NotFoundInTreeError


__all__ = (
Expand Down
2 changes: 1 addition & 1 deletion xarray/datatree_/datatree/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
MappedDataWithCoords,
)
from .render import RenderTree
from .treenode import NamedNode, NodePath, Tree
from xarray.core.treenode import NamedNode, NodePath, Tree

try:
from xarray.core.variable import calculate_dimensions
Expand Down
2 changes: 1 addition & 1 deletion xarray/datatree_/datatree/iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections import abc
from typing import Callable, Iterator, List, Optional

from .treenode import Tree
from xarray.core.treenode import Tree

"""These iterators are copied from anytree.iterators, with minor modifications."""

Expand Down
4 changes: 2 additions & 2 deletions xarray/datatree_/datatree/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
from xarray import DataArray, Dataset

from .iterators import LevelOrderIter
from .treenode import NodePath, TreeNode
from xarray.core.treenode import NodePath, TreeNode

if TYPE_CHECKING:
from .datatree import DataTree
from xarray.core.datatree import DataTree


class TreeIsomorphismError(ValueError):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

import pytest

from xarray.core.treenode import InvalidTreeError, NamedNode, NodePath, TreeNode
from xarray.datatree_.datatree.iterators import LevelOrderIter, PreOrderIter
from xarray.datatree_.datatree.treenode import InvalidTreeError, NamedNode, NodePath, TreeNode


class TestFamilyTree:
Expand Down

0 comments on commit 8121b81

Please sign in to comment.