Skip to content

Commit

Permalink
Check for blockarray inputs in from_operator (#229)
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael-T-McCann committed Feb 23, 2022
1 parent 54f9b39 commit f7395cf
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 1 deletion.
9 changes: 9 additions & 0 deletions scico/linop/_circconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import scico.numpy as snp
from scico._generic_operators import Operator
from scico.array import is_nested
from scico.typing import DType, JaxArray, Shape

from ._linop import LinearOperator, _wrap_add_sub, _wrap_mul_div_scalar
Expand Down Expand Up @@ -234,6 +235,7 @@ def __truediv__(self, scalar):
h_is_dft=True,
)

@staticmethod
def from_operator(
H: Operator, ndims: Optional[int] = None, center: Optional[Shape] = None, jit: bool = True
):
Expand All @@ -251,6 +253,13 @@ def from_operator(
jit: If ``True``, jit the resulting `CircularConvolve`.
"""

if is_nested(H.input_shape):
raise ValueError(
f"H.input_shape ({H.input_shape}) suggests that H "
"takes a BlockArray as input, which is not supported "
"by this function."
)

if ndims is None:
ndims = len(H.input_shape)
else:
Expand Down
11 changes: 10 additions & 1 deletion scico/test/linop/test_circconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest

import scico.numpy as snp
from scico.linop import CircularConvolve, Convolve
from scico.linop import CircularConvolve, Convolve, Diagonal
from scico.random import randint, randn, uniform
from scico.test.linop.test_linop import adjoint_test

Expand Down Expand Up @@ -138,3 +138,12 @@ def test_from_operator(self, axes_shape_spec, input_dtype, jit_old_op, jit_new_o
B = CircularConvolve.from_operator(A, ndims, jit=jit_new_op)

np.testing.assert_allclose(A @ x, B @ x, atol=1e-5)

def test_from_operator_block_array(self):
"""`from_operator` should throw an exception if asked to work
on an operator with blockarray inputs."""

H = Diagonal(snp.zeros(((1, 2), (3,))))

with pytest.raises(ValueError):
CircularConvolve.from_operator(H)

0 comments on commit f7395cf

Please sign in to comment.