From 7c734cb72410f3ac3f16a196468230bd4a49da65 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Wed, 15 May 2024 15:43:53 -0600 Subject: [PATCH] Add diagonal operators constructed by replication of a base operator --- scico/linop/__init__.py | 3 +- scico/linop/_stack.py | 31 +++++++++++++- scico/operator/__init__.py | 5 ++- scico/operator/_stack.py | 63 +++++++++++++++++++++++++++- scico/test/linop/test_linop_stack.py | 23 +++++++++- scico/test/operator/test_op_stack.py | 42 ++++++++++++++++++- 6 files changed, 160 insertions(+), 7 deletions(-) diff --git a/scico/linop/__init__.py b/scico/linop/__init__.py index ee6c78e36..5e5ba00df 100644 --- a/scico/linop/__init__.py +++ b/scico/linop/__init__.py @@ -17,7 +17,7 @@ from ._func import Crop, Pad, Reshape, Slice, Sum, Transpose, linop_from_function from ._linop import ComposedLinearOperator, LinearOperator from ._matrix import MatrixOperator -from ._stack import DiagonalStack, VerticalStack +from ._stack import DiagonalReplicated, DiagonalStack, VerticalStack from ._util import jacobian, operator_norm, power_iteration, valid_adjoint from .xray import Parallel2dProjector, XRayTransform @@ -29,6 +29,7 @@ "FiniteDifference", "SingleAxisFiniteDifference", "Identity", + "DiagonalReplicated", "VerticalStack", "DiagonalStack", "MatrixOperator", diff --git a/scico/linop/_stack.py b/scico/linop/_stack.py index db8b1b7b2..663aa7b64 100644 --- a/scico/linop/_stack.py +++ b/scico/linop/_stack.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2022-2023 by SCICO Developers +# Copyright (C) 2022-2024 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the @@ -13,6 +13,7 @@ import scico.numpy as snp from scico.numpy import Array, BlockArray +from scico.operator._stack import DiagonalReplicated as DReplicated from scico.operator._stack import DiagonalStack as DStack from scico.operator._stack import VerticalStack as VStack @@ -142,3 +143,31 @@ def _adj(self, y: Union[Array, BlockArray]) -> Union[Array, BlockArray]: # type if self.collapse_input: return snp.stack(result) return snp.blockarray(result) + + +class DiagonalReplicated(DReplicated, LinearOperator): + """ """ + + def __init__( + self, + op: LinearOperator, + replicates: int, + input_axis: int = 0, + output_axis: Optional[int] = None, + map_type: str = "auto", + **kwargs, + ): + + if not isinstance(op, LinearOperator): + raise TypeError("Argument op must be of type LinearOperator.") + + super().__init__( + op, + replicates, + input_axis=input_axis, + output_axis=output_axis, + map_type=map_type, + **kwargs, + ) + + self._adj = self.jaxmap(op.adj, in_axes=self.input_axis, out_axes=self.output_axis) diff --git a/scico/operator/__init__.py b/scico/operator/__init__.py index fee512369..8d3b01928 100644 --- a/scico/operator/__init__.py +++ b/scico/operator/__init__.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2021-2023 by SCICO Developers +# Copyright (C) 2021-2024 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the @@ -13,11 +13,12 @@ from ._operator import Operator from .biconvolve import BiConvolve from ._func import operator_from_function, Abs, Angle, Exp -from ._stack import DiagonalStack, VerticalStack +from ._stack import DiagonalStack, VerticalStack, DiagonalReplicated __all__ = [ "Operator", "BiConvolve", + "DiagonalReplicated", "DiagonalStack", "VerticalStack", "operator_from_function", diff --git a/scico/operator/_stack.py b/scico/operator/_stack.py index 9e16f05ea..7ecfbf8ae 100644 --- a/scico/operator/_stack.py +++ b/scico/operator/_stack.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright (C) 2023 by SCICO Developers +# Copyright (C) 2023-2024 by SCICO Developers # All rights reserved. BSD 3-clause License. # This file is part of the SCICO package. Details of the copyright and # user license can be found in the 'LICENSE' file distributed with the @@ -13,6 +13,8 @@ import numpy as np +import jax + from typing_extensions import TypeGuard import scico.numpy as snp @@ -234,3 +236,62 @@ def _eval(self, x: Union[Array, BlockArray]) -> Union[Array, BlockArray]: if self.collapse_output: return snp.stack(result) return snp.blockarray(result) + + +class DiagonalReplicated(Operator): + """ """ + + def __init__( + self, + op: Operator, + replicates: int, + input_axis: int = 0, + output_axis: Optional[int] = None, + map_type: str = "auto", + **kwargs, + ): + """ """ + if map_type not in ["auto", "pmap", "vmap"]: + raise ValueError("Argument map_type must be one of 'auto', 'pmap, or 'vmap'.") + if input_axis < 0 or input_axis >= len(op.input_shape): + raise ValueError( + "Argument input_axis must be positive and less than the number of axes " + "in the input shape of op." + ) + if is_nested(op.input_shape): + raise ValueError("Argument op may not be an Operator taking BlockArray input.") + self.op = op + self.replicates = replicates + self.input_axis = input_axis + self.output_axis = self.input_axis if output_axis is None else output_axis + + if map_type == "auto": + self.jaxmap = jax.pmap if replicates <= jax.device_count() else jax.vmap + else: + if map_type == "pmap" and replicates > jax.device_count(): + raise ValueError( + "Requested pmap mapping but number of replicates exceeds device count." + ) + else: + self.jaxmap = jax.pmap if map_type == "pmap" else jax.vmap + + eval_fn = self.jaxmap(op.__call__, in_axes=self.input_axis, out_axes=self.output_axis) + + input_shape = ( + op.input_shape[0 : self.input_axis] + (replicates,) + op.input_shape[self.input_axis :] + ) + output_shape = ( + op.output_shape[0 : self.output_axis] + + (replicates,) + + op.output_shape[self.output_axis :] + ) + + super().__init__( + input_shape=input_shape, # type: ignore + output_shape=output_shape, # type: ignore + eval_fn=eval_fn, + input_dtype=op.input_dtype, + output_dtype=op.output_dtype, + jit=False, + **kwargs, + ) diff --git a/scico/test/linop/test_linop_stack.py b/scico/test/linop/test_linop_stack.py index 37f77d4c8..0a2589d85 100644 --- a/scico/test/linop/test_linop_stack.py +++ b/scico/test/linop/test_linop_stack.py @@ -5,8 +5,16 @@ import pytest import scico.numpy as snp -from scico.linop import Convolve, DiagonalStack, Identity, Sum, VerticalStack +from scico.linop import ( + Convolve, + DiagonalReplicated, + DiagonalStack, + Identity, + Sum, + VerticalStack, +) from scico.operator import Abs +from scico.random import randn from scico.test.linop.test_linop import adjoint_test @@ -166,3 +174,16 @@ def test_output_collapse(self): H = DiagonalStack((A1, A2), collapse_output=False) assert H.output_shape == (S1, S1) + + +class TestDiagonalReplicated: + def setup_method(self, method): + self.key = jax.random.PRNGKey(12345) + + def test_adjoint(self): + x, key = randn((2, 3, 4), key=self.key) + A = Sum(x.shape[1:], axis=-1) + D = DiagonalReplicated(A, x.shape[0]) + y = D.T(D(x)) + np.testing.assert_allclose(y[0], A.T(A(x[0]))) + np.testing.assert_allclose(y[1], A.T(A(x[1]))) diff --git a/scico/test/operator/test_op_stack.py b/scico/test/operator/test_op_stack.py index c981cdf26..7e3ed4747 100644 --- a/scico/test/operator/test_op_stack.py +++ b/scico/test/operator/test_op_stack.py @@ -5,7 +5,14 @@ import pytest import scico.numpy as snp -from scico.operator import Abs, DiagonalStack, Operator, VerticalStack +from scico.operator import ( + Abs, + DiagonalReplicated, + DiagonalStack, + Operator, + VerticalStack, +) +from scico.random import randn TestOpA = Operator(input_shape=(3, 4), output_shape=(2, 3, 4), eval_fn=lambda x: snp.stack((x, x))) TestOpB = Operator( @@ -140,3 +147,36 @@ def test_output_collapse(self): H = DiagonalStack((A1, A2), collapse_output=False) assert H.output_shape == (A1.output_shape, A1.output_shape) + + +class TestDiagonalReplicated: + def setup_method(self, method): + self.key = jax.random.PRNGKey(12345) + + @pytest.mark.parametrize("map_type", ["auto", "vmap"]) + @pytest.mark.parametrize("input_axis", [0, 1]) + def test_map_auto_vmap(self, input_axis, map_type): + x, key = randn((2, 3, 4), key=self.key) + mapshape = (3, 4) if input_axis == 0 else (2, 4) + replicates = x.shape[input_axis] + A = Abs(mapshape) + D = DiagonalReplicated(A, replicates, input_axis=input_axis, map_type=map_type) + y = D(x) + assert y.shape[input_axis] == replicates + + @pytest.mark.skipif(jax.device_count() < 2, reason="multiple devices required for test") + def test_map_auto_pmap(self): + x, key = randn((2, 3, 4), key=self.key) + A = Abs(x.shape[1:]) + replicates = x.shape[0] + D = DiagonalReplicated(A, replicates, map_type="pmap") + y = D(x) + assert y.shape[0] == replicates + + def test_output_axis(self): + x, key = randn((2, 3, 4), key=self.key) + A = Abs(x.shape[1:]) + replicates = x.shape[0] + D = DiagonalReplicated(A, replicates, output_axis=1) + y = D(x) + assert y.shape == (3, 2, 4)