diff --git a/pysindy/utils/axes.py b/pysindy/utils/axes.py index 731c5080b..dcfd6d8a1 100644 --- a/pysindy/utils/axes.py +++ b/pysindy/utils/axes.py @@ -42,7 +42,7 @@ class _AxisMapping: def __init__( self, - axes: dict[str, Union[int, Sequence[int]]] = None, + axes: Optional[dict[str, Union[int, Sequence[int]]]] = None, in_ndim: int = 0, ): if axes is None: @@ -75,9 +75,7 @@ def coerce_sequence(obj): ) @staticmethod - def _compat_axes( - in_dict: dict[str, Sequence[int]] - ) -> dict[str, Union[Sequence[int], int]]: + def _compat_axes(in_dict: dict[str, list[int]]) -> dict[str, Union[list[int], int]]: """Like fwd_map, but unpack single-element axis lists""" axes = {} for k, v in in_dict.items(): @@ -156,20 +154,35 @@ def ndim(self): class AxesArray(np.lib.mixins.NDArrayOperatorsMixin, np.ndarray): """A numpy-like array that keeps track of the meaning of its axes. + Limitations: + * Not all numpy functions, such as ``np.flatten()``, does not have an + implementation for AxesArray, a regular numpy array is returned. + * For functions that are implemented for `AxesArray`, such as + ``np.reshape()``, use the numpy function rather than the bound + method (e.g. arr.reshape) + * Such such functions may raise ValueErrors where numpy would not, when + it is impossible to determine the output axis labels. + + Bound methods, such as arr.reshape, are not implemented. Use the functions. + While the functions in the numpy namespace will work on ``AxesArray`` + objects, the documentation must be found in their equivalent names here. + Parameters: - input_array (array-like): the data to create the array. - axes (dict): A dictionary of axis labels to shape indices. - Allowed keys: - - ax_time: int - - ax_coord: int - - ax_sample: int - - ax_spatial: List[int] + input_array: the data to create the array. + axes: A dictionary of axis labels to shape indices. Axes labels must + be of the format "ax_name". indices can be either an int or a + list of ints. Raises: - AxesWarning if axes does not match shape of input_array + * AxesWarning if axes does not match shape of input_array. + * ValueError if assigning the same axis index to multiple meanings or + assigning an axis beyond ndim. + """ - def __new__(cls, input_array, axes): + _ax_map: _AxisMapping + + def __new__(cls, input_array: NDArray, axes: dict[str, int | list[int]]): obj = np.asarray(input_array).view(cls) if axes is None: axes = {} @@ -226,10 +239,10 @@ def __getitem__(self, key: Indexer | Sequence[Indexer], /): new_map = _AxisMapping( self._ax_map.remove_axis(remove_axes), len(in_dim) - len(remove_axes) ) - for new_ax_ind, new_ax_name in new_axes: + for insert_counter, (new_ax_ind, new_ax_name) in enumerate(new_axes): new_map = _AxisMapping( new_map.insert_axis(new_ax_ind, new_ax_name), - len(in_dim) - len(remove_axes) + len(new_axes), + in_ndim=len(in_dim) - len(remove_axes) + (insert_counter + 1), ) output._ax_map = new_map return output @@ -342,6 +355,72 @@ def concatenate(arrays, axis=0): return AxesArray(np.concatenate(parents, axis), axes=ax_list[0]) +@implements(np.reshape) +def reshape(a: AxesArray, newshape: int | tuple[int], order="C"): + """Gives a new shape to an array without changing its data. + + Args: + a: Array to be reshaped + newshape: int or tuple of ints + The new shape should be compatible with the original shape. In + addition, the axis labels must make sense when the data is + translated to a new shape. Currently, the only use case supported + is to flatten an outer product of two or more axes with the same + label and size. + order: Must be "C" + """ + if order != "C": + raise ValueError("AxesArray only supports reshaping in 'C' order currently.") + out = np.reshape(np.asarray(a), newshape, order) # handle any regular errors + + new_axes = {} + if isinstance(newshape, int): + newshape = [newshape] + newshape = list(newshape) + explicit_new_size = np.multiply.reduce(np.array(newshape)) + if explicit_new_size < 0: + replace_ind = newshape.index(-1) + newshape[replace_ind] = a.size // (-1 * explicit_new_size) + + curr_base = 0 + for curr_new in range(len(newshape)): + if curr_base >= a.ndim: + raise ValueError( + "Cannot reshape an AxesArray this way. Adding a length-1 axis at" + f" dimension {curr_new} not understood." + ) + base_name = a._ax_map.reverse_map[curr_base] + if a.shape[curr_base] == newshape[curr_new]: + _compat_axes_append(new_axes, base_name, curr_new) + curr_base += 1 + elif newshape[curr_new] == 1: + raise ValueError( + f"Cannot reshape an AxesArray this way. Inserting a new axis at" + f" dimension {curr_new} of new shape is not supported" + ) + else: # outer product + remaining = newshape[curr_new] + while remaining > 1: + if a._ax_map.reverse_map[curr_base] != base_name: + raise ValueError( + "Cannot reshape an AxesArray this way. It would combine" + f" {base_name} with {a._ax_map.reverse_map[curr_base]}" + ) + remaining, error = divmod(remaining, a.shape[curr_base]) + if error: + raise ValueError( + f"Cannot reshape an AxesArray this way. Array dimension" + f" {curr_base} has size {a.shape[curr_base]}, must divide into" + f" newshape dimension {curr_new} with size" + f" {newshape[curr_new]}." + ) + curr_base += 1 + + _compat_axes_append(new_axes, base_name, curr_new) + + return AxesArray(out, axes=new_axes) + + def standardize_indexer( arr: np.ndarray, key: Indexer | Sequence[Indexer] ) -> tuple[Sequence[StandardIndexer], tuple[KeyIndex, ...]]: @@ -524,3 +603,24 @@ def wrap_axes(axes: dict, obj): except KeyError: pass return obj + + +def _compat_axes_append( + axes_dict: dict[str, Union[int, list[int]]], + ax_name: str, + newaxis: Union[int, list[int]], +) -> None: + if isinstance(newaxis, int): + try: + axes_dict[ax_name].append(newaxis) + except KeyError: + axes_dict[ax_name] = newaxis + except AttributeError: + axes_dict[ax_name] = [axes_dict[ax_name], newaxis] + else: + try: + axes_dict[ax_name] += newaxis + except KeyError: + axes_dict[ax_name] = newaxis + except AttributeError: + axes_dict[ax_name] = [axes_dict[ax_name], *newaxis] diff --git a/test/utils/test_axes.py b/test/utils/test_axes.py index e3910e29e..7f19596c2 100644 --- a/test/utils/test_axes.py +++ b/test/utils/test_axes.py @@ -140,6 +140,41 @@ def test_n_elements(): assert arr2.n_coord == 4 +def test_reshape_outer_product(): + arr = AxesArray(np.arange(4).reshape((2, 2)), {"ax_a": [0, 1]}) + merge = np.reshape(arr, (4,)) + assert merge.axes == {"ax_a": 0} + + +def test_reshape_fill_outer_product(): + arr = AxesArray(np.arange(4).reshape((2, 2)), {"ax_a": [0, 1]}) + merge = np.reshape(arr, (-1,)) + assert merge.axes == {"ax_a": 0} + + +def test_reshape_fill_regular(): + arr = AxesArray(np.arange(8).reshape((2, 2, 2)), {"ax_a": [0, 1], "ax_b": 2}) + merge = np.reshape(arr, (4, -1)) + assert merge.axes == {"ax_a": 0, "ax_b": 1} + + +def test_illegal_reshape(): + arr = AxesArray(np.arange(4).reshape((2, 2)), {"ax_a": [0, 1]}) + # melding across axes + with pytest.raises(ValueError, match="Cannot reshape an AxesArray"): + np.reshape(arr, (4, 1)) + + # Add a hidden 1 in the middle! maybe a matching 1 + + # different name outer product + arr = AxesArray(np.arange(4).reshape((2, 2)), {"ax_a": 0, "ax_b": 1}) + with pytest.raises(ValueError, match="Cannot reshape an AxesArray"): + np.reshape(arr, (4,)) + # newaxes + with pytest.raises(ValueError, match="Cannot reshape an AxesArray"): + np.reshape(arr, (2, 1, 2)) + + def test_warn_toofew_axes(): axes = {"ax_time": 0, "ax_coord": 1} with pytest.warns(AxesWarning): @@ -334,7 +369,7 @@ def test_reduce_twisted_AxisMapping(): def test_reduce_misordered_AxisMapping(): - ax_map = _AxisMapping({"ax_a": [0, 1], "ax_b": 2, "ax_c": 3}, 7) + ax_map = _AxisMapping({"ax_a": [0, 1], "ax_b": 2, "ax_c": 3}, 4) result = ax_map.remove_axis([2, 1]) expected = {"ax_a": 0, "ax_c": 1} assert result == expected