Skip to content

Commit

Permalink
ENH: add function to standardize basic indexing keys
Browse files Browse the repository at this point in the history
  • Loading branch information
Jacob-Stevens-Haas committed May 18, 2023
1 parent 218e1f4 commit 0d358de
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 4 deletions.
26 changes: 23 additions & 3 deletions pysindy/utils/axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,9 @@ def __getitem__(self, key, /):
isinstance(key, basic_indexer),
isinstance(key, tuple) and all(isinstance(k, basic_indexer) for k in key),
):
pass
key = _standardize_basic_indexer(self, key)

return output
if any( # fancy indexing
isinstance(key, Sequence) and not isinstance(key, tuple),
isinstance(key, np.ndarray),
Expand All @@ -162,8 +164,9 @@ def __getitem__(self, key, /):
):
# check if integer or boolean indexing
# if integer, check which dimensions get broadcast where
# if multiple, axes are merged. If adjacent, merged inplace, otherwise moved to beginning
pass
# if multiple, axes are merged. If adjacent, merged inplace,
# otherwise moved to beginning
return output
else:
raise TypeError(f"AxisArray {self} does not know how to slice with {key}")
# mulligan structured arrays, etc.
Expand Down Expand Up @@ -297,6 +300,23 @@ def concatenate(arrays, axis=0):
return AxesArray(np.concatenate(parents, axis), axes=ax_list[0])


def _standardize_basic_indexer(arr: np.ndarray, key):
"""Convert to a tuple of slices, ints, and None."""
if isinstance(key, tuple):
if not any(ax_key is Ellipsis for ax_key in key):
key = (*key, Ellipsis)
slicedim = sum(isinstance(ax_key, slice | int) for ax_key in key)
final_key = []
for ax_key in key:
inner_iterator = (ax_key,)
if ax_key is Ellipsis:
inner_iterator = (arr.ndim - slicedim) * (slice(None),)
for el in inner_iterator:
final_key.append(el)
return tuple(final_key)
return _standardize_basic_indexer(arr, (key,))


def comprehend_axes(x):
axes = {}
axes["ax_coord"] = len(x.shape) - 1
Expand Down
12 changes: 11 additions & 1 deletion test/utils/test_axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from numpy.testing import assert_raises

from pysindy import AxesArray
from pysindy.utils import axes
from pysindy.utils.axes import _AxisMapping
from pysindy.utils.axes import AxesWarning

Expand Down Expand Up @@ -176,7 +177,7 @@ def test_conflicting_axes_defn():


@pytest.mark.skip("giving error")
def test_fancy_indexing_modifies_axes():
def test_fancy_getitem_modifies_axes():
axes = {"ax_time": 0, "ax_coord": 1}
arr = AxesArray(np.ones(4).reshape((2, 2)), axes)
slim = arr[1, :]
Expand All @@ -187,6 +188,15 @@ def test_fancy_indexing_modifies_axes():
assert fat.ax_coord == 2


def test_standardize_basic_indexer():
arr = np.arange(6).reshape(2, 3)
result = axes._standardize_basic_indexer(arr, Ellipsis)
assert result == (slice(None), slice(None))

result = axes._standardize_basic_indexer(arr, (np.newaxis, 1, 1, Ellipsis))
assert result == (None, 1, 1)


def test_reduce_AxisMapping():
ax_map = _AxisMapping(
{
Expand Down

0 comments on commit 0d358de

Please sign in to comment.