From a75f5be6a1bca5164343c1f11308ee9da631120c Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Thu, 9 May 2024 18:04:49 -0600 Subject: [PATCH] Create general function for constructing VerticalStack linop by iterating over axes --- scico/linop/_diff.py | 29 ++++++++++--------------- scico/linop/_stack.py | 49 +++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 58 insertions(+), 20 deletions(-) diff --git a/scico/linop/_diff.py b/scico/linop/_diff.py index 47c72e3c..5dd019d3 100644 --- a/scico/linop/_diff.py +++ b/scico/linop/_diff.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2020-2023 by SCICO Developers +# Copyright (C) 2020-2024 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the @@ -17,11 +17,10 @@ import numpy as np import scico.numpy as snp -from scico.numpy.util import parse_axes from scico.typing import Axes, DType, Shape from ._linop import LinearOperator -from ._stack import VerticalStack +from ._stack import VerticalStack, linop_over_axes class FiniteDifference(VerticalStack): @@ -81,22 +80,16 @@ def __init__( jit: If ``True``, jit the evaluation, adjoint, and gram functions of the :class:`LinearOperator`. """ - - if axes is None: - axes_list = tuple(range(len(input_shape))) - elif isinstance(axes, (list, tuple)): - axes_list = axes # type: ignore - else: - axes_list = (axes,) - self.axes = parse_axes(axes_list, input_shape) - single_kwargs = dict( - input_dtype=input_dtype, prepend=prepend, append=append, circular=circular, jit=False + self.axes, ops = linop_over_axes( + SingleAxisFiniteDifference, + input_shape, + axes=axes, + input_dtype=input_dtype, + prepend=prepend, + append=append, + circular=circular, + jit=False, ) - ops = [ - SingleAxisFiniteDifference(input_shape, axis=axis, **single_kwargs) # type: ignore - for axis in axes_list - ] - super().__init__( ops, # type: ignore jit=jit, diff --git a/scico/linop/_stack.py b/scico/linop/_stack.py index db8b1b7b..340eeafc 100644 --- a/scico/linop/_stack.py +++ b/scico/linop/_stack.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2022-2023 by SCICO Developers +# Copyright (C) 2022-2024 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the @@ -9,12 +9,14 @@ from __future__ import annotations -from typing import Optional, Sequence, Union +from typing import Any, List, Optional, Sequence, Union import scico.numpy as snp from scico.numpy import Array, BlockArray +from scico.numpy.util import parse_axes from scico.operator._stack import DiagonalStack as DStack from scico.operator._stack import VerticalStack as VStack +from scico.typing import Axes, Shape from ._linop import LinearOperator @@ -142,3 +144,46 @@ def _adj(self, y: Union[Array, BlockArray]) -> Union[Array, BlockArray]: # type if self.collapse_input: return snp.stack(result) return snp.blockarray(result) + + +def linop_over_axes( + linop: type[LinearOperator], + input_shape: Shape, + *args: Any, + axes: Optional[Axes] = None, + **kwargs: Any, +) -> List[LinearOperator]: + """Construct a list of :class:`LinearOperator` by iterating over axes. + + Construct a list of :class:`LinearOperator` by iterating over a + specified sequence of axes, passing each value in sequence to the + `axis` keyword argument of the :class:`LinearOperator` initializer. + + Args: + linop: Type of :class:`LinearOperator` to construct for each axis. + input_shape: Shape of input array. + *args: Positional arguments for the :class:`LinearOperator` + initializer. + axes: Axis or axes over which to construct the list. If not + specified, or ``None``, use all axes corresponding to + `input_shape`. + **kwargs: Keyword arguments for the :class:`LinearOperator` + initializer. + + Returns: + A tuple (`axes`, `ops`) where `axes` is a tuple of the axes used + to construct that list of list of :class:`LinearOperator`, and + `ops` is the list itself. + """ + if axes is None: + axes = tuple(range(len(input_shape))) + elif not isinstance(axes, (list, tuple)): + axes = (axes,) + if axes is None: + axis_list = tuple(range(len(input_shape))) + elif isinstance(axes, (list, tuple)): + axis_list = axes # type: ignore + else: + axis_list = (axes,) + axes = parse_axes(axes, input_shape) + return axes, [linop(input_shape, *args, axis=axis, **kwargs) for axis in axes] # type: ignore