Skip to content

Commit

Permalink
Merge pull request #69 from khaeru/types-2022-W33
Browse files Browse the repository at this point in the history
Improve typing for Quantity.assign_coords, .expand_dims, .name, and .sum
  • Loading branch information
khaeru committed Aug 16, 2022
2 parents 7fbf2b3 + 3a5920e commit f175969
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 19 deletions.
2 changes: 1 addition & 1 deletion genno/computations.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,7 @@ def map_labels(mapper, values):

def rename_dims(
qty: Quantity,
new_name_or_name_dict: Union[Hashable, Mapping[Hashable, Hashable]] = None,
new_name_or_name_dict: Union[Hashable, Mapping[Any, Hashable]] = None,
**names: Hashable,
) -> Quantity:
"""Rename the dimensions of `qty`.
Expand Down
42 changes: 27 additions & 15 deletions genno/core/attrseries.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
import logging
import warnings
from functools import partial
from typing import Any, Hashable, Iterable, List, Mapping, Tuple, Union
from typing import (
Any,
Hashable,
Iterable,
List,
Mapping,
Optional,
Sequence,
Tuple,
Union,
cast,
)

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -161,12 +172,7 @@ def drop_vars(

return self.droplevel(names)

def expand_dims(
self,
dim: Union[None, Mapping[Hashable, Any]] = None,
axis=None,
**dim_kwargs: Any,
):
def expand_dims(self, dim=None, axis=None, **dim_kwargs: Any) -> "AttrSeries":
"""Like :meth:`xarray.DataArray.expand_dims`."""
dim = either_dict_or_kwargs(dim, dim_kwargs, "expand_dims")
if axis is not None:
Expand Down Expand Up @@ -419,16 +425,22 @@ def shift(
attrs=self.attrs,
)

def sum(self, *args, **kwargs):
def sum(
self,
dim: Optional[Union[Hashable, Sequence[Hashable]]] = None,
# Signature from xarray.DataArray
# *,
# skipna: bool | None = None,
# min_count: int | None = None,
keep_attrs: Optional[bool] = None,
**kwargs: Any,
) -> "AttrSeries":
"""Like :meth:`xarray.DataArray.sum`."""
obj = super()
obj = cast(pd.Series, super())
attrs = None

try:
dim = kwargs.pop("dim")
except KeyError:
dim = list(args)
args = tuple()
if not isinstance(dim, Sequence):
dim = () if dim is None else (dim,)

if len(dim) in (0, len(self.index.names)):
bad_dims = set(dim) - set(self.index.names)
Expand All @@ -445,7 +457,7 @@ def sum(self, *args, **kwargs):
# Result will be DataFrame; re-attach attrs when converted to AttrSeries
attrs = self.attrs

return AttrSeries(obj.sum(*args, **kwargs), attrs=attrs)
return AttrSeries(obj.sum(**kwargs), attrs=attrs)

def squeeze(self, dim=None, *args, **kwargs):
"""Like :meth:`xarray.DataArray.squeeze`."""
Expand Down
40 changes: 37 additions & 3 deletions genno/core/quantity.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import update_wrapper
from typing import Any, Dict, Hashable, Mapping, Tuple, Union
from typing import Any, Dict, Hashable, Mapping, Optional, Sequence, Tuple, Union

import numpy as np
import pandas as pd
Expand All @@ -21,6 +21,8 @@ class Quantity:
# To silence a warning in xarray
__slots__: Tuple[str, ...] = tuple()

_name: Optional[Hashable]

def __new__(cls, *args, **kwargs):
# Use _get_class() to retrieve either AttrSeries or SparseDataArray
return object.__new__(Quantity._get_class(cls))
Expand All @@ -38,8 +40,13 @@ def from_series(cls, series, sparse=True):
return cls._get_class().from_series(series, sparse)

@property
def name(self) -> Hashable:
... # pragma: no cover
def name(self) -> Optional[Hashable]:
"""The name of this quantity."""
return self._name # pragma: no cover

@name.setter
def name(self, value: Optional[Hashable]) -> None:
self._name = value # pragma: no cover

@property
def units(self):
Expand Down Expand Up @@ -100,6 +107,21 @@ def coords(self) -> xarray.core.coordinates.DataArrayCoordinates:
def dims(self) -> Tuple[Hashable, ...]:
... # pragma: no cover

def assign_coords(
self,
coords: Optional[Mapping[Any, Any]] = None,
**coords_kwargs: Any,
) -> "Quantity":
... # pragma: no cover

def expand_dims(
self,
dim=None,
axis=None,
**dim_kwargs: Any,
): # NB "Quantity" here offends mypy
... # pragma: no cover

def interp(
self,
coords: Mapping[Hashable, Any] = None,
Expand Down Expand Up @@ -138,6 +160,18 @@ def shift(
): # NB "Quantity" here offends mypy
... # pragma: no cover

def sum(
self,
dim: Optional[Union[Hashable, Sequence[Hashable]]] = None,
# Signature from xarray.DataArray
# *,
# skipna: bool | None = None,
# min_count: int | None = None,
keep_attrs: Optional[bool] = None,
**kwargs: Any,
) -> "Quantity":
... # pragma: no cover

def to_numpy(self) -> np.ndarray:
... # pragma: no cover

Expand Down

0 comments on commit f175969

Please sign in to comment.