From 647e6ec6924b80928e7737a19b202e721da468cc Mon Sep 17 00:00:00 2001 From: Jake Stevens-Haas <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Sun, 30 Apr 2023 14:15:01 -0700 Subject: [PATCH] WIP: Offload AxesArray construction logic to _AxisMapping --- pysindy/utils/axes.py | 251 +++++++++++++++++++++++++++------------- test/utils/test_axes.py | 60 ++++++++-- 2 files changed, 217 insertions(+), 94 deletions(-) diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index 87bcdadb8..f57b6553e 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -1,8 +1,9 @@ import copy import warnings -from typing import Collection from typing import List +from typing import MutableMapping from typing import Sequence +from typing import Union import numpy as np from sklearn.base import TransformerMixin @@ -12,6 +13,87 @@ AxesWarning = type("AxesWarning", (SyntaxWarning,), {}) +class _AxisMapping: + """Convenience wrapper for a two-way map between axis names and + indexes. + """ + + def __init__( + self, + axes: MutableMapping[str, Union[int, Sequence[int]]] = None, + in_ndim: int = 0, + ): + if axes is None: + axes = {} + axes = copy.deepcopy(axes) + self.fwd_map = {} + self.reverse_map = {} + null = object() + + def coerce_sequence(obj): + if isinstance(obj, Sequence): + return sorted(obj) + return [obj] + + for ax_name, ax_ids in axes.items(): + ax_ids = coerce_sequence(ax_ids) + self.fwd_map[ax_name] = ax_ids + for ax_id in ax_ids: + old_name = self.reverse_map.get(ax_id, null) + if old_name is not null: + raise ValueError(f"Assigned multiple definitions to axis {ax_id}") + if ax_id >= in_ndim: + raise ValueError( + f"Assigned definition to axis {ax_id}, but array only has" + f" {in_ndim} axes" + ) + self.reverse_map[ax_id] = ax_name + if len(self.reverse_map) != in_ndim: + warnings.warn( + f"{len(self.reverse_map)} axes labeled for array with {in_ndim} axes", + AxesWarning, + ) + + @staticmethod + def _compat_axes(in_dict: dict[str, Sequence]) -> dict[str, Union[Sequence, int]]: + """Turn single-element axis index lists into ints""" + axes = {} + for k, v in in_dict.items(): + if len(v) == 1: + axes[k] = v[0] + else: + axes[k] = v + return axes + + @property + def compat_axes(self): + return self._compat_axes(self.fwd_map) + + def reduce(self, axis: Union[int, None] = None): + """Create an axes dict from self with specified axis + removed and all greater axes decremented. + + Arguments: + axis: the axis index to remove. By numpy ufunc convention, + axis=None (default) removes _all_ axes. + """ + if axis is None: + return {} + new_axes = copy.deepcopy(self.fwd_map) + in_ndim = len(self.reverse_map) + remove_ax_name = self.reverse_map[axis] + if len(new_axes[remove_ax_name]) == 1: + new_axes.pop(remove_ax_name) + else: + new_axes[remove_ax_name].remove(axis) + decrement_names = set() + for ax_id in range(axis + 1, in_ndim): + decrement_names.add(self.reverse_map[ax_id]) + for dec_name in decrement_names: + new_axes[dec_name] = [ax_id - 1 for ax_id in new_axes[dec_name]] + return self._compat_axes(new_axes) + + class AxesArray(np.lib.mixins.NDArrayOperatorsMixin, np.ndarray): """A numpy-like array that keeps track of the meaning of its axes. @@ -30,93 +112,85 @@ class AxesArray(np.lib.mixins.NDArrayOperatorsMixin, np.ndarray): def __new__(cls, input_array, axes): obj = np.asarray(input_array).view(cls) - defaults = { - "ax_time": None, - "ax_coord": None, - "ax_sample": None, - "ax_spatial": [], - } - n_axes = sum(1 for k, v in axes.items() if v) if axes is None: - return obj + axes = {} in_ndim = len(input_array.shape) - if n_axes != in_ndim: - warnings.warn( - f"{n_axes} axes labeled for array with {in_ndim} axes", AxesWarning - ) - axes = {**defaults, **axes} - listed_axes = [ - el for k, v in axes.items() if isinstance(v, Collection) for el in v - ] - listed_axes += [ - v - for k, v in axes.items() - if not isinstance(v, Collection) and v is not None - ] - _reverse_map = {} - for axis in listed_axes: - if axis >= in_ndim: - raise ValueError( - f"Assigned definition to axis {axis}, but array only has" - f" {in_ndim} axes" - ) - ax_names = [ax_name for ax_name in axes if axes[ax_name] == axis] - if len(ax_names) > 1: - raise ValueError(f"Assigned multiple definitions to axis {axis}") - _reverse_map[axis] = ax_names[0] - obj.__dict__.update({**axes}) - obj.__dict__["_reverse_map"] = _reverse_map + obj.__ax_map = _AxisMapping(axes, in_ndim) return obj - def __getitem__(self, key, /): - remove_axes = [] - if isinstance(key, int): - remove_axes.append(key) - if isinstance(key, Sequence): - for axis, k in enumerate(key): - if isinstance(k, int): - remove_axes.append(axis) - new_item = super().__getitem__(key) - if not isinstance(new_item, AxesArray): - return new_item - for axis in remove_axes: - ax_name = self._reverse_map[axis] - if isinstance(new_item.__dict__[ax_name], int): - new_item.__dict__[ax_name] = None - else: - new_item.__dict__[ax_name].remove(axis) - new_item._reverse_map.pop(axis) - return new_item - - def __array_finalize__(self, obj) -> None: - if obj is None: - return - self._reverse_map = copy.deepcopy(getattr(obj, "_reverse_map", {})) - self.ax_time = getattr(obj, "ax_time", None) - self.ax_coord = getattr(obj, "ax_coord", None) - self.ax_sample = getattr(obj, "ax_sample", None) - self.ax_spatial = getattr(obj, "ax_spatial", []) - - @property - def n_spatial(self): - return tuple(self.shape[ax] for ax in self.ax_spatial) - - @property - def n_time(self): - return self.shape[self.ax_time] if self.ax_time is not None else 1 - @property - def n_sample(self): - return self.shape[self.ax_sample] if self.ax_sample is not None else 1 + def axes(self): + return self.__ax_map.compat_axes @property - def n_coord(self): - return self.shape[self.ax_coord] if self.ax_coord is not None else 1 + def _reverse_map(self): + return self.__ax_map.reverse_map @property def shape(self): return super().shape + def __getattr__(self, name): + parts = name.split("_", 1) + if parts[0] == "ax": + return self.axes[name] + if parts[0] == "n": + fwd_map = self.__ax_map.fwd_map + shape = tuple(self.shape[ax_id] for ax_id in fwd_map["ax_" + parts[1]]) + if len(shape) == 1: + return shape[0] + return shape + raise AttributeError(f"'{type(self)}' object has no attribute '{name}'") + + # def __getitem__(self, key, /): + # pass + # return super().__getitem__(self, key) + # def __getitem__(self, key, /): + # remove_axes = [] + # if isinstance(key, int): + # remove_axes.append(key) + # if isinstance(key, Sequence): + # for axis, k in enumerate(key): + # if isinstance(k, int): + # remove_axes.append(axis) + # new_item = super().__getitem__(key) + # if not isinstance(new_item, AxesArray): + # return new_item + # for axis in remove_axes: + # ax_name = self._reverse_map[axis] + # if isinstance(new_item.__dict__[ax_name], int): + # new_item.__dict__[ax_name] = None + # else: + # new_item.__dict__[ax_name].remove(axis) + # new_item._reverse_map.pop(axis) + # return new_item + + def __array_wrap__(self, out_arr, context=None): + return super().__array_wrap__(self, out_arr, context) + + def __array_finalize__(self, obj) -> None: + if obj is None: # explicit construction via super().__new__().. not called? + return + # view from numpy array, called in constructor but also tests + if all( + ( + not isinstance(obj, AxesArray), + self.shape == (), + not hasattr(self, "__ax_map"), + ) + ): + self.__ax_map = _AxisMapping({}) + # required by ravel() and view() used in numpy testing. Also for zeros_like... + elif all( + ( + isinstance(obj, AxesArray), + not hasattr(self, "__ax_map"), + self.shape == obj.shape, + ) + ): + self.__ax_map = _AxisMapping(obj.axes, len(obj.shape)) + # maybe add errors for incompatible views? + def __array_ufunc__( self, ufunc, method, *inputs, out=None, **kwargs ): # this method is called whenever you use a ufunc @@ -145,17 +219,30 @@ def __array_ufunc__( return if ufunc.nout == 1: results = (results,) - results = tuple( - (AxesArray(np.asarray(result), self.__dict__) if output is None else output) - for result, output in zip(results, outputs) - ) + if method == "reduce" and ( + "keepdims" not in kwargs.keys() or kwargs["keepdims"] is False + ): + axes = None + if kwargs["axis"] is not None: + axes = self.__ax_map.reduce(axis=kwargs["axis"]) + else: + axes = self.axes + final_results = [] + for result, output in zip(results, outputs): + if output is not None: + final_results.append(output) + elif axes is None: + final_results.append(result) + else: + final_results.append(AxesArray(np.asarray(result), axes)) + results = tuple(final_results) return results[0] if len(results) == 1 else results def __array_function__(self, func, types, args, kwargs): if func not in HANDLED_FUNCTIONS: arr = super(AxesArray, self).__array_function__(func, types, args, kwargs) if isinstance(arr, np.ndarray): - return AxesArray(arr, axes=self.__dict__) + return AxesArray(arr, axes=self.axes) elif arr is not None: return arr return @@ -177,7 +264,7 @@ def decorator(func): @implements(np.concatenate) def concatenate(arrays, axis=0): parents = [np.asarray(obj) for obj in arrays] - ax_list = [obj.__dict__ for obj in arrays if isinstance(obj, AxesArray)] + ax_list = [obj.axes for obj in arrays if isinstance(obj, AxesArray)] for ax1, ax2 in zip(ax_list[:-1], ax_list[1:]): if ax1 != ax2: raise TypeError("Concatenating >1 AxesArray with incompatible axes") diff --git a/test/utils/test_axes.py b/test/utils/test_axes.py index e0b89d876..65bb4c63d 100644 --- a/test/utils/test_axes.py +++ b/test/utils/test_axes.py @@ -5,11 +5,12 @@ from numpy.testing import assert_raises from pysindy import AxesArray +from pysindy.utils.axes import _AxisMapping from pysindy.utils.axes import AxesWarning def test_reduce_mean_noinf_recursion(): - arr = AxesArray(np.array([[1]]), {}) + arr = AxesArray(np.array([[1]]), {"ax_a": [0, 1]}) np.mean(arr, axis=0) @@ -26,31 +27,31 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): d = np.arange(5.0) # 1 input, 1 output - a = AxesArray(d, {}) + a = AxesArray(d, {"ax_time": 0}) b = np.sin(a) check = np.sin(d) assert_(np.all(check == b)) b = np.sin(d, out=(a,)) assert_(np.all(check == b)) assert_(b is a) - a = AxesArray(np.arange(5.0), {}) + a = AxesArray(np.arange(5.0), {"ax_time": 0}) b = np.sin(a, out=a) assert_(np.all(check == b)) # 1 input, 2 outputs - a = AxesArray(np.arange(5.0), {}) + a = AxesArray(np.arange(5.0), {"ax_time": 0}) b1, b2 = np.modf(a) b1, b2 = np.modf(d, out=(None, a)) assert_(b2 is a) - a = AxesArray(np.arange(5.0), {}) - b = AxesArray(np.arange(5.0), {}) + a = AxesArray(np.arange(5.0), {"ax_time": 0}) + b = AxesArray(np.arange(5.0), {"ax_time": 0}) c1, c2 = np.modf(a, out=(a, b)) assert_(c1 is a) assert_(c2 is b) # 2 input, 1 output - a = AxesArray(np.arange(5.0), {}) - b = AxesArray(np.arange(5.0), {}) + a = AxesArray(np.arange(5.0), {"ax_time": 0}) + b = AxesArray(np.arange(5.0), {"ax_time": 0}) c = np.add(a, b, out=a) assert_(c is a) # some tests with a non-ndarray subclass @@ -59,13 +60,13 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): assert_(a.__array_ufunc__(np.add, "__call__", a, b) is NotImplemented) assert_(b.__array_ufunc__(np.add, "__call__", a, b) is NotImplemented) assert_raises(TypeError, np.add, a, b) - a = AxesArray(a, {}) + a = AxesArray(a, {"ax_time": 0}) assert_(a.__array_ufunc__(np.add, "__call__", a, b) is NotImplemented) assert_(b.__array_ufunc__(np.add, "__call__", a, b) == "A!") assert_(np.add(a, b) == "A!") # regression check for gh-9102 -- tests ufunc.reduce implicitly. d = np.array([[1, 2, 3], [1, 2, 3]]) - a = AxesArray(d, {}) + a = AxesArray(d, {"ax_time": [0, 1]}) c = a.any() check = d.any() assert_equal(c, check) @@ -89,6 +90,11 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): c = np.add.reduce(a, 1, None, b) assert_equal(c, check) assert_(c is b) + + +def test_ufunc_override_accumulate(): + d = np.array([[1, 2, 3], [1, 2, 3]]) + a = AxesArray(d, {"ax_time": [0, 1]}) check = np.add.accumulate(d, axis=0) c = np.add.accumulate(a, axis=0) assert_equal(c, check) @@ -123,14 +129,16 @@ def test_n_elements(): assert arr.n_spatial == (1, 2) assert arr.n_time == 3 assert arr.n_coord == 4 - assert arr.n_sample == 1 arr2 = np.concatenate((arr, arr), axis=arr.ax_time) assert arr2.n_spatial == (1, 2) assert arr2.n_time == 6 assert arr2.n_coord == 4 - assert arr2.n_sample == 1 + +def test_limited_slice(): + arr = np.empty(np.arange(1, 5)) + arr = AxesArray(arr, {"ax_spatial": [0, 1], "ax_time": 2, "ax_coord": 3}) arr3 = arr[..., :2, 0] assert arr3.n_spatial == (1, 2) assert arr3.n_time == 2 @@ -152,6 +160,13 @@ def test_toomany_axes(): AxesArray(np.ones(4).reshape((2, 2)), axes) +def test_conflicting_axes_defn(): + axes = {"ax_time": 0, "ax_coord": 0} + with pytest.raises(ValueError): + AxesArray(np.ones(4), axes) + + +# @pytest.mark.skip("giving error") def test_fancy_indexing_modifies_axes(): axes = {"ax_time": 0, "ax_coord": 1} arr = AxesArray(np.ones(4).reshape((2, 2)), axes) @@ -161,3 +176,24 @@ def test_fancy_indexing_modifies_axes(): assert slim.ax_coord == 1 assert fat.ax_time == [0, 1] assert fat.ax_coord == 2 + + +def test_reduce_AxisMapping(): + ax_map = _AxisMapping( + { + "ax_a": [0, 1], + "ax_b": 2, + "ax_c": 3, + "ax_d": 4, + "ax_e": [5, 6], + }, + 7, + ) + result = ax_map.reduce(3) + expected = { + "ax_a": [0, 1], + "ax_b": 2, + "ax_d": 3, + "ax_e": [4, 5], + } + assert result == expected