Skip to content

Commit

Permalink
bug(axes): enable 0-degree arrays
Browse files Browse the repository at this point in the history
If arr[key] returns an element of an array, arr[key, ...] returns a
0-degree array.
  • Loading branch information
Jacob-Stevens-Haas committed Jan 12, 2024
1 parent 23817f0 commit c11c0d6
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 10 deletions.
13 changes: 3 additions & 10 deletions pysindy/utils/axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def __getitem__(self, key: Indexer | Sequence[Indexer], /):
base_indexer = key
output = super().__getitem__(base_indexer)
if not isinstance(output, AxesArray):
return output # why?
return output # return an element from the array
in_dim = self.shape
key, adv_inds = standardize_indexer(self, key)
bcast_nd, bcast_start_ax = _determine_adv_broadcasting(key, adv_inds)
Expand Down Expand Up @@ -386,17 +386,10 @@ def _expand_indexer_ellipsis(key: list[Indexer], ndim: int) -> list[Indexer]:
"""Replace ellipsis in indexers with the appropriate amount of slice(None)"""
# [...].index errors if list contains numpy array
ellind = [ind for ind, val in enumerate(key) if val is ...][0]
new_key = []
n_new_dims = sum(ax_key is None or isinstance(ax_key, str) for ax_key in key)
n_ellipsis_dims = ndim - (len(key) - n_new_dims - 1)
new_key = (
key[:ellind]
+ n_ellipsis_dims
* [
slice(None),
]
+ key[ellind + 1 + n_ellipsis_dims :]
)
new_key = key[:ellind] + key[ellind + 1 :]
new_key = new_key[:ellind] + (n_ellipsis_dims * [slice(None)]) + new_key[ellind:]
return new_key


Expand Down
28 changes: 28 additions & 0 deletions test/utils/test_axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,15 @@ def test_simple_slice():
assert arr[0] == 1


# @pytest.mark.skip # TODO: make this pass
def test_0d_indexer():
arr = AxesArray(np.ones(2), {"ax_coord": 0})
arr_out = arr[1, ...]
assert arr_out.ndim == 0
assert arr_out.axes == {}
assert arr_out[()] == 1


def test_basic_indexing_modifies_axes():
axes = {"ax_time": 0, "ax_coord": 1}
arr = AxesArray(np.ones(4).reshape((2, 2)), axes)
Expand Down Expand Up @@ -428,3 +437,22 @@ def test_determine_adv_broadcasting():
res_nd, res_start = axes._determine_adv_broadcasting(indexers, [])
assert res_nd == 0
assert res_start is None


def test_replace_ellipsis():
key = [..., 0]
result = axes._expand_indexer_ellipsis(key, 2)
expected = [slice(None), 0]
assert result == expected


def test_strip_ellipsis():
key = [1, ...]
result = axes._expand_indexer_ellipsis(key, 1)
expected = [1]
assert result == expected

key = [..., 1]
result = axes._expand_indexer_ellipsis(key, 1)
expected = [1]
assert result == expected

0 comments on commit c11c0d6

Please sign in to comment.