Skip to content

Commit

Permalink
Create general function for constructing VerticalStack linop by itera…
Browse files Browse the repository at this point in the history
…ting over axes
  • Loading branch information
bwohlberg committed May 10, 2024
1 parent 19d37c9 commit a75f5be
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 20 deletions.
29 changes: 11 additions & 18 deletions scico/linop/_diff.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2020-2023 by SCICO Developers
# Copyright (C) 2020-2024 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
Expand All @@ -17,11 +17,10 @@
import numpy as np

import scico.numpy as snp
from scico.numpy.util import parse_axes
from scico.typing import Axes, DType, Shape

from ._linop import LinearOperator
from ._stack import VerticalStack
from ._stack import VerticalStack, linop_over_axes


class FiniteDifference(VerticalStack):
Expand Down Expand Up @@ -81,22 +80,16 @@ def __init__(
jit: If ``True``, jit the evaluation, adjoint, and gram
functions of the :class:`LinearOperator`.
"""

if axes is None:
axes_list = tuple(range(len(input_shape)))
elif isinstance(axes, (list, tuple)):
axes_list = axes # type: ignore
else:
axes_list = (axes,)
self.axes = parse_axes(axes_list, input_shape)
single_kwargs = dict(
input_dtype=input_dtype, prepend=prepend, append=append, circular=circular, jit=False
self.axes, ops = linop_over_axes(
SingleAxisFiniteDifference,
input_shape,
axes=axes,
input_dtype=input_dtype,
prepend=prepend,
append=append,
circular=circular,
jit=False,
)
ops = [
SingleAxisFiniteDifference(input_shape, axis=axis, **single_kwargs) # type: ignore
for axis in axes_list
]

super().__init__(
ops, # type: ignore
jit=jit,
Expand Down
49 changes: 47 additions & 2 deletions scico/linop/_stack.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2022-2023 by SCICO Developers
# Copyright (C) 2022-2024 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
Expand All @@ -9,12 +9,14 @@

from __future__ import annotations

from typing import Optional, Sequence, Union
from typing import Any, List, Optional, Sequence, Union

import scico.numpy as snp
from scico.numpy import Array, BlockArray
from scico.numpy.util import parse_axes
from scico.operator._stack import DiagonalStack as DStack
from scico.operator._stack import VerticalStack as VStack
from scico.typing import Axes, Shape

from ._linop import LinearOperator

Expand Down Expand Up @@ -142,3 +144,46 @@ def _adj(self, y: Union[Array, BlockArray]) -> Union[Array, BlockArray]: # type
if self.collapse_input:
return snp.stack(result)
return snp.blockarray(result)


def linop_over_axes(
linop: type[LinearOperator],
input_shape: Shape,
*args: Any,
axes: Optional[Axes] = None,
**kwargs: Any,
) -> List[LinearOperator]:
"""Construct a list of :class:`LinearOperator` by iterating over axes.
Construct a list of :class:`LinearOperator` by iterating over a
specified sequence of axes, passing each value in sequence to the
`axis` keyword argument of the :class:`LinearOperator` initializer.
Args:
linop: Type of :class:`LinearOperator` to construct for each axis.
input_shape: Shape of input array.
*args: Positional arguments for the :class:`LinearOperator`
initializer.
axes: Axis or axes over which to construct the list. If not
specified, or ``None``, use all axes corresponding to
`input_shape`.
**kwargs: Keyword arguments for the :class:`LinearOperator`
initializer.
Returns:
A tuple (`axes`, `ops`) where `axes` is a tuple of the axes used
to construct that list of list of :class:`LinearOperator`, and
`ops` is the list itself.
"""
if axes is None:
axes = tuple(range(len(input_shape)))
elif not isinstance(axes, (list, tuple)):
axes = (axes,)
if axes is None:
axis_list = tuple(range(len(input_shape)))
elif isinstance(axes, (list, tuple)):
axis_list = axes # type: ignore
else:
axis_list = (axes,)
axes = parse_axes(axes, input_shape)
return axes, [linop(input_shape, *args, axis=axis, **kwargs) for axis in axes] # type: ignore

0 comments on commit a75f5be

Please sign in to comment.