Skip to content

Commit

Permalink
Add diagonal operators constructed by replication of a base operator
Browse files Browse the repository at this point in the history
  • Loading branch information
bwohlberg committed May 15, 2024
1 parent 71b15ff commit 7c734cb
Show file tree
Hide file tree
Showing 6 changed files with 160 additions and 7 deletions.
3 changes: 2 additions & 1 deletion scico/linop/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ._func import Crop, Pad, Reshape, Slice, Sum, Transpose, linop_from_function
from ._linop import ComposedLinearOperator, LinearOperator
from ._matrix import MatrixOperator
from ._stack import DiagonalStack, VerticalStack
from ._stack import DiagonalReplicated, DiagonalStack, VerticalStack
from ._util import jacobian, operator_norm, power_iteration, valid_adjoint
from .xray import Parallel2dProjector, XRayTransform

Expand All @@ -29,6 +29,7 @@
"FiniteDifference",
"SingleAxisFiniteDifference",
"Identity",
"DiagonalReplicated",
"VerticalStack",
"DiagonalStack",
"MatrixOperator",
Expand Down
31 changes: 30 additions & 1 deletion scico/linop/_stack.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2022-2023 by SCICO Developers
# Copyright (C) 2022-2024 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
Expand All @@ -13,6 +13,7 @@

import scico.numpy as snp
from scico.numpy import Array, BlockArray
from scico.operator._stack import DiagonalReplicated as DReplicated
from scico.operator._stack import DiagonalStack as DStack
from scico.operator._stack import VerticalStack as VStack

Expand Down Expand Up @@ -142,3 +143,31 @@ def _adj(self, y: Union[Array, BlockArray]) -> Union[Array, BlockArray]: # type
if self.collapse_input:
return snp.stack(result)
return snp.blockarray(result)


class DiagonalReplicated(DReplicated, LinearOperator):
""" """

def __init__(
self,
op: LinearOperator,
replicates: int,
input_axis: int = 0,
output_axis: Optional[int] = None,
map_type: str = "auto",
**kwargs,
):

if not isinstance(op, LinearOperator):
raise TypeError("Argument op must be of type LinearOperator.")

super().__init__(
op,
replicates,
input_axis=input_axis,
output_axis=output_axis,
map_type=map_type,
**kwargs,
)

self._adj = self.jaxmap(op.adj, in_axes=self.input_axis, out_axes=self.output_axis)
5 changes: 3 additions & 2 deletions scico/operator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2021-2023 by SCICO Developers
# Copyright (C) 2021-2024 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
Expand All @@ -13,11 +13,12 @@
from ._operator import Operator
from .biconvolve import BiConvolve
from ._func import operator_from_function, Abs, Angle, Exp
from ._stack import DiagonalStack, VerticalStack
from ._stack import DiagonalStack, VerticalStack, DiagonalReplicated

__all__ = [
"Operator",
"BiConvolve",
"DiagonalReplicated",
"DiagonalStack",
"VerticalStack",
"operator_from_function",
Expand Down
63 changes: 62 additions & 1 deletion scico/operator/_stack.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2023 by SCICO Developers
# Copyright (C) 2023-2024 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
Expand All @@ -13,6 +13,8 @@

import numpy as np

import jax

from typing_extensions import TypeGuard

import scico.numpy as snp
Expand Down Expand Up @@ -234,3 +236,62 @@ def _eval(self, x: Union[Array, BlockArray]) -> Union[Array, BlockArray]:
if self.collapse_output:
return snp.stack(result)
return snp.blockarray(result)


class DiagonalReplicated(Operator):
""" """

def __init__(
self,
op: Operator,
replicates: int,
input_axis: int = 0,
output_axis: Optional[int] = None,
map_type: str = "auto",
**kwargs,
):
""" """
if map_type not in ["auto", "pmap", "vmap"]:
raise ValueError("Argument map_type must be one of 'auto', 'pmap, or 'vmap'.")
if input_axis < 0 or input_axis >= len(op.input_shape):
raise ValueError(
"Argument input_axis must be positive and less than the number of axes "
"in the input shape of op."
)
if is_nested(op.input_shape):
raise ValueError("Argument op may not be an Operator taking BlockArray input.")
self.op = op
self.replicates = replicates
self.input_axis = input_axis
self.output_axis = self.input_axis if output_axis is None else output_axis

if map_type == "auto":
self.jaxmap = jax.pmap if replicates <= jax.device_count() else jax.vmap
else:
if map_type == "pmap" and replicates > jax.device_count():
raise ValueError(
"Requested pmap mapping but number of replicates exceeds device count."
)
else:
self.jaxmap = jax.pmap if map_type == "pmap" else jax.vmap

eval_fn = self.jaxmap(op.__call__, in_axes=self.input_axis, out_axes=self.output_axis)

input_shape = (
op.input_shape[0 : self.input_axis] + (replicates,) + op.input_shape[self.input_axis :]
)
output_shape = (
op.output_shape[0 : self.output_axis]
+ (replicates,)
+ op.output_shape[self.output_axis :]
)

super().__init__(
input_shape=input_shape, # type: ignore
output_shape=output_shape, # type: ignore
eval_fn=eval_fn,
input_dtype=op.input_dtype,
output_dtype=op.output_dtype,
jit=False,
**kwargs,
)
23 changes: 22 additions & 1 deletion scico/test/linop/test_linop_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,16 @@
import pytest

import scico.numpy as snp
from scico.linop import Convolve, DiagonalStack, Identity, Sum, VerticalStack
from scico.linop import (
Convolve,
DiagonalReplicated,
DiagonalStack,
Identity,
Sum,
VerticalStack,
)
from scico.operator import Abs
from scico.random import randn
from scico.test.linop.test_linop import adjoint_test


Expand Down Expand Up @@ -166,3 +174,16 @@ def test_output_collapse(self):

H = DiagonalStack((A1, A2), collapse_output=False)
assert H.output_shape == (S1, S1)


class TestDiagonalReplicated:
def setup_method(self, method):
self.key = jax.random.PRNGKey(12345)

def test_adjoint(self):
x, key = randn((2, 3, 4), key=self.key)
A = Sum(x.shape[1:], axis=-1)
D = DiagonalReplicated(A, x.shape[0])
y = D.T(D(x))
np.testing.assert_allclose(y[0], A.T(A(x[0])))
np.testing.assert_allclose(y[1], A.T(A(x[1])))
42 changes: 41 additions & 1 deletion scico/test/operator/test_op_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,14 @@
import pytest

import scico.numpy as snp
from scico.operator import Abs, DiagonalStack, Operator, VerticalStack
from scico.operator import (
Abs,
DiagonalReplicated,
DiagonalStack,
Operator,
VerticalStack,
)
from scico.random import randn

TestOpA = Operator(input_shape=(3, 4), output_shape=(2, 3, 4), eval_fn=lambda x: snp.stack((x, x)))
TestOpB = Operator(
Expand Down Expand Up @@ -140,3 +147,36 @@ def test_output_collapse(self):

H = DiagonalStack((A1, A2), collapse_output=False)
assert H.output_shape == (A1.output_shape, A1.output_shape)


class TestDiagonalReplicated:
def setup_method(self, method):
self.key = jax.random.PRNGKey(12345)

@pytest.mark.parametrize("map_type", ["auto", "vmap"])
@pytest.mark.parametrize("input_axis", [0, 1])
def test_map_auto_vmap(self, input_axis, map_type):
x, key = randn((2, 3, 4), key=self.key)
mapshape = (3, 4) if input_axis == 0 else (2, 4)
replicates = x.shape[input_axis]
A = Abs(mapshape)
D = DiagonalReplicated(A, replicates, input_axis=input_axis, map_type=map_type)
y = D(x)
assert y.shape[input_axis] == replicates

@pytest.mark.skipif(jax.device_count() < 2, reason="multiple devices required for test")
def test_map_auto_pmap(self):
x, key = randn((2, 3, 4), key=self.key)
A = Abs(x.shape[1:])
replicates = x.shape[0]
D = DiagonalReplicated(A, replicates, map_type="pmap")
y = D(x)
assert y.shape[0] == replicates

def test_output_axis(self):
x, key = randn((2, 3, 4), key=self.key)
A = Abs(x.shape[1:])
replicates = x.shape[0]
D = DiagonalReplicated(A, replicates, output_axis=1)
y = D(x)
assert y.shape == (3, 2, 4)

0 comments on commit 7c734cb

Please sign in to comment.