From 4637be00379ff2b73abf75b6928fa7d828b03293 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Wed, 8 May 2024 08:54:46 -0600 Subject: [PATCH] Add support for np.newaxis to scico.numpy.util.indexed_shape --- scico/numpy/util.py | 20 ++++++++++++++------ scico/test/numpy/test_numpy_util.py | 4 ++++ 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/scico/numpy/util.py b/scico/numpy/util.py index 54a1c497d..a31a7fe4b 100644 --- a/scico/numpy/util.py +++ b/scico/numpy/util.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2022-2023 by SCICO Developers +# Copyright (C) 2022-2024 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SPORCO package. Details of the copyright # and user license can be found in the 'LICENSE.txt' file distributed @@ -87,7 +87,8 @@ def slice_length(length: int, idx: AxisIndex) -> Optional[int]: Raises: ValueError: If `idx` is an integer index that is out bounds for - the axis length. + the axis length or if the type of `idx` is not one of + `Ellipsis`, `int`, or `slice`. """ if idx is Ellipsis: return length @@ -95,6 +96,8 @@ def slice_length(length: int, idx: AxisIndex) -> Optional[int]: if idx < -length or idx > length - 1: raise ValueError(f"Index {idx} out of bounds for axis of length {length}.") return None + if not isinstance(idx, slice): + raise ValueError(f"Index expression {idx} is of an unrecognized type.") start, stop, stride = idx.indices(length) if start > stop: start = stop @@ -112,19 +115,24 @@ def indexed_shape(shape: Shape, idx: ArrayIndex) -> Tuple[int, ...]: Shape of indexed/sliced array. Raises: - ValueError: If `idx` is longer than `shape`. + ValueError: If any element of `idx` is not one of `Ellipsis`, + `int`, `slice`, or ``None`` (`np.newaxis`), or if an integer + index is out bounds for the corresponding axis length. """ if not isinstance(idx, tuple): idx = (idx,) - if len(idx) > len(shape): - raise ValueError(f"Slice {idx} has more dimensions than shape {shape}.") idx_shape: List[Optional[int]] = list(shape) offset = 0 + newaxis = 0 for axis, ax_idx in enumerate(idx): + if ax_idx is None: + idx_shape.insert(axis, 1) + newaxis += 1 + continue if ax_idx is Ellipsis: offset = len(shape) - len(idx) continue - idx_shape[axis + offset] = slice_length(shape[axis + offset], ax_idx) + idx_shape[axis + offset + newaxis] = slice_length(shape[axis + offset], ax_idx) return tuple(filter(lambda x: x is not None, idx_shape)) # type: ignore diff --git a/scico/test/numpy/test_numpy_util.py b/scico/test/numpy/test_numpy_util.py index faab01dc5..c781b56e1 100644 --- a/scico/test/numpy/test_numpy_util.py +++ b/scico/test/numpy/test_numpy_util.py @@ -99,6 +99,10 @@ def test_slice_length_other(length, slc): np.s_[..., 2:], np.s_[..., 2:, :], np.s_[1:, ..., 2:], + np.s_[np.newaxis], + np.s_[:, np.newaxis], + np.s_[np.newaxis, :, np.newaxis], + np.s_[np.newaxis, ..., 0:2, :], ), ) def test_indexed_shape(shape, slc):