Skip to content

Commit

Permalink
feat(axes): Enable np.reshape on AxesArrays
Browse files Browse the repository at this point in the history
Only a limited subset of reshapes with obvious relabeling semantics are allowed:
For this version, it's just an outer product of some axes

Also clean up typing and documentation and add reshape tests
  • Loading branch information
Jacob-Stevens-Haas committed Jan 13, 2024
1 parent c11c0d6 commit bb1c73d
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 16 deletions.
130 changes: 115 additions & 15 deletions pysindy/utils/axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class _AxisMapping:

def __init__(
self,
axes: dict[str, Union[int, Sequence[int]]] = None,
axes: Optional[dict[str, Union[int, Sequence[int]]]] = None,
in_ndim: int = 0,
):
if axes is None:
Expand Down Expand Up @@ -75,9 +75,7 @@ def coerce_sequence(obj):
)

@staticmethod
def _compat_axes(
in_dict: dict[str, Sequence[int]]
) -> dict[str, Union[Sequence[int], int]]:
def _compat_axes(in_dict: dict[str, list[int]]) -> dict[str, Union[list[int], int]]:
"""Like fwd_map, but unpack single-element axis lists"""
axes = {}
for k, v in in_dict.items():
Expand Down Expand Up @@ -156,20 +154,35 @@ def ndim(self):
class AxesArray(np.lib.mixins.NDArrayOperatorsMixin, np.ndarray):
"""A numpy-like array that keeps track of the meaning of its axes.
Limitations:
* Not all numpy functions, such as ``np.flatten()``, does not have an
implementation for AxesArray, a regular numpy array is returned.
* For functions that are implemented for `AxesArray`, such as
``np.reshape()``, use the numpy function rather than the bound
method (e.g. arr.reshape)
* Such such functions may raise ValueErrors where numpy would not, when
it is impossible to determine the output axis labels.
Bound methods, such as arr.reshape, are not implemented. Use the functions.
While the functions in the numpy namespace will work on ``AxesArray``
objects, the documentation must be found in their equivalent names here.
Parameters:
input_array (array-like): the data to create the array.
axes (dict): A dictionary of axis labels to shape indices.
Allowed keys:
- ax_time: int
- ax_coord: int
- ax_sample: int
- ax_spatial: List[int]
input_array: the data to create the array.
axes: A dictionary of axis labels to shape indices. Axes labels must
be of the format "ax_name". indices can be either an int or a
list of ints.
Raises:
AxesWarning if axes does not match shape of input_array
* AxesWarning if axes does not match shape of input_array.
* ValueError if assigning the same axis index to multiple meanings or
assigning an axis beyond ndim.
"""

def __new__(cls, input_array, axes):
_ax_map: _AxisMapping

def __new__(cls, input_array: NDArray, axes: dict[str, int | list[int]]):
obj = np.asarray(input_array).view(cls)
if axes is None:
axes = {}
Expand Down Expand Up @@ -226,10 +239,10 @@ def __getitem__(self, key: Indexer | Sequence[Indexer], /):
new_map = _AxisMapping(
self._ax_map.remove_axis(remove_axes), len(in_dim) - len(remove_axes)
)
for new_ax_ind, new_ax_name in new_axes:
for insert_counter, (new_ax_ind, new_ax_name) in enumerate(new_axes):
new_map = _AxisMapping(
new_map.insert_axis(new_ax_ind, new_ax_name),
len(in_dim) - len(remove_axes) + len(new_axes),
in_ndim=len(in_dim) - len(remove_axes) + (insert_counter + 1),
)
output._ax_map = new_map
return output
Expand Down Expand Up @@ -342,6 +355,72 @@ def concatenate(arrays, axis=0):
return AxesArray(np.concatenate(parents, axis), axes=ax_list[0])


@implements(np.reshape)
def reshape(a: AxesArray, newshape: int | tuple[int], order="C"):
"""Gives a new shape to an array without changing its data.
Args:
a: Array to be reshaped
newshape: int or tuple of ints
The new shape should be compatible with the original shape. In
addition, the axis labels must make sense when the data is
translated to a new shape. Currently, the only use case supported
is to flatten an outer product of two or more axes with the same
label and size.
order: Must be "C"
"""
if order != "C":
raise ValueError("AxesArray only supports reshaping in 'C' order currently.")
out = np.reshape(np.asarray(a), newshape, order) # handle any regular errors

new_axes = {}
if isinstance(newshape, int):
newshape = [newshape]
newshape = list(newshape)
explicit_new_size = np.multiply.reduce(np.array(newshape))
if explicit_new_size < 0:
replace_ind = newshape.index(-1)
newshape[replace_ind] = a.size // (-1 * explicit_new_size)

curr_base = 0
for curr_new in range(len(newshape)):
if curr_base >= a.ndim:
raise ValueError(
"Cannot reshape an AxesArray this way. Adding a length-1 axis at"
f" dimension {curr_new} not understood."
)
base_name = a._ax_map.reverse_map[curr_base]
if a.shape[curr_base] == newshape[curr_new]:
_compat_axes_append(new_axes, base_name, curr_new)
curr_base += 1
elif newshape[curr_new] == 1:
raise ValueError(
f"Cannot reshape an AxesArray this way. Inserting a new axis at"
f" dimension {curr_new} of new shape is not supported"
)
else: # outer product
remaining = newshape[curr_new]
while remaining > 1:
if a._ax_map.reverse_map[curr_base] != base_name:
raise ValueError(
"Cannot reshape an AxesArray this way. It would combine"
f" {base_name} with {a._ax_map.reverse_map[curr_base]}"
)
remaining, error = divmod(remaining, a.shape[curr_base])
if error:
raise ValueError(
f"Cannot reshape an AxesArray this way. Array dimension"
f" {curr_base} has size {a.shape[curr_base]}, must divide into"
f" newshape dimension {curr_new} with size"
f" {newshape[curr_new]}."
)
curr_base += 1

_compat_axes_append(new_axes, base_name, curr_new)

return AxesArray(out, axes=new_axes)


def standardize_indexer(
arr: np.ndarray, key: Indexer | Sequence[Indexer]
) -> tuple[Sequence[StandardIndexer], tuple[KeyIndex, ...]]:
Expand Down Expand Up @@ -524,3 +603,24 @@ def wrap_axes(axes: dict, obj):
except KeyError:
pass
return obj


def _compat_axes_append(
axes_dict: dict[str, Union[int, list[int]]],
ax_name: str,
newaxis: Union[int, list[int]],
) -> None:
if isinstance(newaxis, int):
try:
axes_dict[ax_name].append(newaxis)
except KeyError:
axes_dict[ax_name] = newaxis
except AttributeError:
axes_dict[ax_name] = [axes_dict[ax_name], newaxis]
else:
try:
axes_dict[ax_name] += newaxis
except KeyError:
axes_dict[ax_name] = newaxis
except AttributeError:
axes_dict[ax_name] = [axes_dict[ax_name], *newaxis]
37 changes: 36 additions & 1 deletion test/utils/test_axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,41 @@ def test_n_elements():
assert arr2.n_coord == 4


def test_reshape_outer_product():
arr = AxesArray(np.arange(4).reshape((2, 2)), {"ax_a": [0, 1]})
merge = np.reshape(arr, (4,))
assert merge.axes == {"ax_a": 0}


def test_reshape_fill_outer_product():
arr = AxesArray(np.arange(4).reshape((2, 2)), {"ax_a": [0, 1]})
merge = np.reshape(arr, (-1,))
assert merge.axes == {"ax_a": 0}


def test_reshape_fill_regular():
arr = AxesArray(np.arange(8).reshape((2, 2, 2)), {"ax_a": [0, 1], "ax_b": 2})
merge = np.reshape(arr, (4, -1))
assert merge.axes == {"ax_a": 0, "ax_b": 1}


def test_illegal_reshape():
arr = AxesArray(np.arange(4).reshape((2, 2)), {"ax_a": [0, 1]})
# melding across axes
with pytest.raises(ValueError, match="Cannot reshape an AxesArray"):
np.reshape(arr, (4, 1))

# Add a hidden 1 in the middle! maybe a matching 1

# different name outer product
arr = AxesArray(np.arange(4).reshape((2, 2)), {"ax_a": 0, "ax_b": 1})
with pytest.raises(ValueError, match="Cannot reshape an AxesArray"):
np.reshape(arr, (4,))
# newaxes
with pytest.raises(ValueError, match="Cannot reshape an AxesArray"):
np.reshape(arr, (2, 1, 2))


def test_warn_toofew_axes():
axes = {"ax_time": 0, "ax_coord": 1}
with pytest.warns(AxesWarning):
Expand Down Expand Up @@ -334,7 +369,7 @@ def test_reduce_twisted_AxisMapping():


def test_reduce_misordered_AxisMapping():
ax_map = _AxisMapping({"ax_a": [0, 1], "ax_b": 2, "ax_c": 3}, 7)
ax_map = _AxisMapping({"ax_a": [0, 1], "ax_b": 2, "ax_c": 3}, 4)
result = ax_map.remove_axis([2, 1])
expected = {"ax_a": 0, "ax_c": 1}
assert result == expected
Expand Down

0 comments on commit bb1c73d

Please sign in to comment.