Skip to content

Commit

Permalink
WIP: begin __getitem__ work to id axes
Browse files Browse the repository at this point in the history
  • Loading branch information
Jacob-Stevens-Haas committed May 1, 2023
1 parent 06c5b9a commit 218e1f4
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 4 deletions.
30 changes: 27 additions & 3 deletions pysindy/utils/axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,33 @@ def __getattr__(self, name):
return shape
raise AttributeError(f"'{type(self)}' object has no attribute '{name}'")

# def __getitem__(self, key, /):
# pass
# return super().__getitem__(self, key)
def __getitem__(self, key, /):
output = super().__getitem__(key)
# determine axes of output
in_dim = self.shape # noqa
out_dim = output.shape # noqa
remove_dims = [] # noqa
basic_indexer = Union[slice, int, type(Ellipsis), np.newaxis, type(None)]
if any( # basic indexing
isinstance(key, basic_indexer),
isinstance(key, tuple) and all(isinstance(k, basic_indexer) for k in key),
):
pass
if any( # fancy indexing
isinstance(key, Sequence) and not isinstance(key, tuple),
isinstance(key, np.ndarray),
isinstance(key, tuple) and any(isinstance(k, Sequence) for k in key),
isinstance(key, tuple) and any(isinstance(k, np.ndarray) for k in key), # ?
):
# check if integer or boolean indexing
# if integer, check which dimensions get broadcast where
# if multiple, axes are merged. If adjacent, merged inplace, otherwise moved to beginning
pass
else:
raise TypeError(f"AxisArray {self} does not know how to slice with {key}")
# mulligan structured arrays, etc.
return output

# def __getitem__(self, key, /):
# remove_axes = []
# if isinstance(key, int):
Expand Down
2 changes: 1 addition & 1 deletion test/utils/test_axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
assert_(c is b)


@pytest.mark.skip("Expected error")
# @pytest.mark.skip("Expected error")
def test_ufunc_override_accumulate():
d = np.array([[1, 2, 3], [1, 2, 3]])
a = AxesArray(d, {"ax_time": [0, 1]})
Expand Down

0 comments on commit 218e1f4

Please sign in to comment.