Skip to content

Commit

Permalink
Use typevars rather than specific unions
Browse files Browse the repository at this point in the history
Also from the chars of pydata#8208 -- this uses the TypeVars we define, which hopefully sets a standard and removes ambiguity for how to type functions. It also allows subclasses.

Where possible, it uses `T_Xarray`, otherwise `T_DataArrayOrSet` where it's not possible for `mypy` to narrow the concrete type down.
  • Loading branch information
max-sixty committed Sep 19, 2023
1 parent 2b444af commit 504c8f1
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 23 deletions.
6 changes: 2 additions & 4 deletions xarray/core/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,11 @@
indexes_all_equal,
safe_cast_to_index,
)
from xarray.core.types import T_Alignable
from xarray.core.types import T_Alignable, T_Xarray
from xarray.core.utils import is_dict_like, is_full_slice
from xarray.core.variable import Variable, as_compatible_data, calculate_dimensions

if TYPE_CHECKING:
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
from xarray.core.types import JoinOptions, T_DataArray, T_Dataset, T_DuckArray


Expand Down Expand Up @@ -903,7 +901,7 @@ def reindex(

def reindex_like(
obj: T_Alignable,
other: Dataset | DataArray,
other: T_Xarray,
method: str | None = None,
tolerance: int | float | Iterable[int | float] | None = None,
copy: bool = True,
Expand Down
6 changes: 2 additions & 4 deletions xarray/core/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
create_default_index_implicit,
)
from xarray.core.merge import merge_coordinates_without_align, merge_coords
from xarray.core.types import Self, T_DataArray
from xarray.core.types import Self, T_DataArray, T_Xarray
from xarray.core.utils import (
Frozen,
ReprObject,
Expand Down Expand Up @@ -915,9 +915,7 @@ def drop_indexed_coords(
return Coordinates._construct_direct(coords=new_variables, indexes=new_indexes)


def assert_coordinate_consistent(
obj: T_DataArray | Dataset, coords: Mapping[Any, Variable]
) -> None:
def assert_coordinate_consistent(obj: T_Xarray, coords: Mapping[Any, Variable]) -> None:
"""Make sure the dimension coordinate of obj is consistent with coords.
obj: DataArray or Dataset
Expand Down
7 changes: 4 additions & 3 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from xarray.core.indexing import is_fancy_indexer, map_index_queries
from xarray.core.merge import PANDAS_TYPES, MergeError
from xarray.core.options import OPTIONS, _get_keep_attrs
from xarray.core.types import T_DataArrayOrSet
from xarray.core.utils import (
Default,
HybridMappingProxy,
Expand Down Expand Up @@ -1844,7 +1845,7 @@ def _reindex_callback(

def reindex_like(
self: T_DataArray,
other: DataArray | Dataset,
other: T_Xarray,
method: ReindexMethodOptions = None,
tolerance: int | float | Iterable[int | float] | None = None,
copy: bool = True,
Expand Down Expand Up @@ -2248,7 +2249,7 @@ def interp(

def interp_like(
self: T_DataArray,
other: DataArray | Dataset,
other: T_DataArrayOrSet,
method: InterpOptions = "linear",
assume_sorted: bool = False,
kwargs: Mapping[str, Any] | None = None,
Expand Down Expand Up @@ -5375,7 +5376,7 @@ def map_blocks(
func: Callable[..., T_Xarray],
args: Sequence[Any] = (),
kwargs: Mapping[str, Any] | None = None,
template: DataArray | Dataset | None = None,
template: T_Xarray | None = None,
) -> T_Xarray:
"""
Apply a function to each block of this DataArray.
Expand Down
4 changes: 2 additions & 2 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3398,7 +3398,7 @@ def _reindex_callback(

def reindex_like(
self: T_Dataset,
other: Dataset | DataArray,
other: T_Xarray,
method: ReindexMethodOptions = None,
tolerance: int | float | Iterable[int | float] | None = None,
copy: bool = True,
Expand Down Expand Up @@ -8545,7 +8545,7 @@ def map_blocks(
func: Callable[..., T_Xarray],
args: Sequence[Any] = (),
kwargs: Mapping[str, Any] | None = None,
template: DataArray | Dataset | None = None,
template: T_Xarray | None = None,
) -> T_Xarray:
"""
Apply a function to each block of this Dataset.
Expand Down
6 changes: 2 additions & 4 deletions xarray/core/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,7 @@ def assert_chunks_compatible(a: Dataset, b: Dataset):
raise ValueError(f"Chunk sizes along dimension {dim!r} are not equal.")


def check_result_variables(
result: DataArray | Dataset, expected: Mapping[str, Any], kind: str
):
def check_result_variables(result: T_Xarray, expected: Mapping[str, Any], kind: str):
if kind == "coords":
nice_str = "coordinate"
elif kind == "data_vars":
Expand Down Expand Up @@ -105,7 +103,7 @@ def make_meta(obj):


def infer_template(
func: Callable[..., T_Xarray], obj: DataArray | Dataset, *args, **kwargs
func: Callable[..., T_Xarray], obj: T_Xarray, *args, **kwargs
) -> T_Xarray:
"""Infer return object by running the function on meta objects."""
meta_args = [make_meta(arg) for arg in (obj,) + args]
Expand Down
10 changes: 4 additions & 6 deletions xarray/plot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from xarray.core.indexes import PandasMultiIndex
from xarray.core.options import OPTIONS
from xarray.core.pycompat import DuckArrayModule
from xarray.core.types import T_Xarray
from xarray.core.utils import is_scalar, module_available

nc_time_axis_available = module_available("nc_time_axis")
Expand All @@ -32,7 +33,6 @@
from numpy.typing import ArrayLike

from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
from xarray.core.types import AspectOptions, ScaleOptions

try:
Expand Down Expand Up @@ -312,7 +312,7 @@ def _determine_cmap_params(


def _infer_xy_labels_3d(
darray: DataArray | Dataset,
darray: T_Xarray,
x: Hashable | None,
y: Hashable | None,
rgb: Hashable | None,
Expand Down Expand Up @@ -374,7 +374,7 @@ def _infer_xy_labels_3d(


def _infer_xy_labels(
darray: DataArray | Dataset,
darray: T_Xarray,
x: Hashable | None,
y: Hashable | None,
imshow: bool = False,
Expand Down Expand Up @@ -413,9 +413,7 @@ def _infer_xy_labels(


# TODO: Can by used to more than x or y, rename?
def _assert_valid_xy(
darray: DataArray | Dataset, xy: Hashable | None, name: str
) -> None:
def _assert_valid_xy(darray: T_Xarray, xy: Hashable | None, name: str) -> None:
"""
make sure x and y passed to plotting functions are valid
"""
Expand Down

0 comments on commit 504c8f1

Please sign in to comment.