Skip to content

Commit

Permalink
Add support for np.newaxis to scico.numpy.util.indexed_shape
Browse files Browse the repository at this point in the history
  • Loading branch information
bwohlberg committed May 8, 2024
1 parent 19d37c9 commit 4637be0
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 6 deletions.
20 changes: 14 additions & 6 deletions scico/numpy/util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2022-2023 by SCICO Developers
# Copyright (C) 2022-2024 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SPORCO package. Details of the copyright
# and user license can be found in the 'LICENSE.txt' file distributed
Expand Down Expand Up @@ -87,14 +87,17 @@ def slice_length(length: int, idx: AxisIndex) -> Optional[int]:
Raises:
ValueError: If `idx` is an integer index that is out bounds for
the axis length.
the axis length or if the type of `idx` is not one of
`Ellipsis`, `int`, or `slice`.
"""
if idx is Ellipsis:
return length
if isinstance(idx, int):
if idx < -length or idx > length - 1:
raise ValueError(f"Index {idx} out of bounds for axis of length {length}.")
return None
if not isinstance(idx, slice):
raise ValueError(f"Index expression {idx} is of an unrecognized type.")
start, stop, stride = idx.indices(length)
if start > stop:
start = stop
Expand All @@ -112,19 +115,24 @@ def indexed_shape(shape: Shape, idx: ArrayIndex) -> Tuple[int, ...]:
Shape of indexed/sliced array.
Raises:
ValueError: If `idx` is longer than `shape`.
ValueError: If any element of `idx` is not one of `Ellipsis`,
`int`, `slice`, or ``None`` (`np.newaxis`), or if an integer
index is out bounds for the corresponding axis length.
"""
if not isinstance(idx, tuple):
idx = (idx,)
if len(idx) > len(shape):
raise ValueError(f"Slice {idx} has more dimensions than shape {shape}.")
idx_shape: List[Optional[int]] = list(shape)
offset = 0
newaxis = 0
for axis, ax_idx in enumerate(idx):
if ax_idx is None:
idx_shape.insert(axis, 1)
newaxis += 1
continue
if ax_idx is Ellipsis:
offset = len(shape) - len(idx)
continue
idx_shape[axis + offset] = slice_length(shape[axis + offset], ax_idx)
idx_shape[axis + offset + newaxis] = slice_length(shape[axis + offset], ax_idx)
return tuple(filter(lambda x: x is not None, idx_shape)) # type: ignore


Expand Down
4 changes: 4 additions & 0 deletions scico/test/numpy/test_numpy_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ def test_slice_length_other(length, slc):
np.s_[..., 2:],
np.s_[..., 2:, :],
np.s_[1:, ..., 2:],
np.s_[np.newaxis],
np.s_[:, np.newaxis],
np.s_[np.newaxis, :, np.newaxis],
np.s_[np.newaxis, ..., 0:2, :],
),
)
def test_indexed_shape(shape, slc):
Expand Down

0 comments on commit 4637be0

Please sign in to comment.