Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve handling of h_center parameter in linop.CircularConvolve #299

Merged
merged 3 commits into from
May 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 22 additions & 8 deletions scico/linop/_circconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@
import math
import operator
from functools import partial
from typing import Optional, Tuple
from typing import Optional, Sequence, Tuple, Union

import numpy as np

from jax.dtypes import result_type

from jaxlib.xla_extension import DeviceArray

import scico.numpy as snp
from scico._generic_operators import Operator
from scico.numpy.util import is_nested
Expand All @@ -27,10 +29,10 @@
class CircularConvolve(LinearOperator):
r"""A circular convolution linear operator.

This linear operator implements circular, n-dimensional convolution
via pointwise multiplication in the DFT domain. In its simplest form,
it implements a single convolution and can be represented by linear
operator :math:`H` such that
This linear operator implements circular, multi-dimensional
convolution via pointwise multiplication in the DFT domain. In its
simplest form, it implements a single convolution and can be
represented by linear operator :math:`H` such that

.. math::
H \mb{x} = \mb{h} \ast \mb{x} \;,
Expand Down Expand Up @@ -83,7 +85,7 @@ def __init__(
ndims: Optional[int] = None,
input_dtype: DType = snp.float32,
h_is_dft: bool = False,
h_center: Optional[JaxArray] = None,
h_center: Optional[Union[JaxArray, Sequence, float, int]] = None,
jit: bool = True,
**kwargs,
):
Expand All @@ -99,7 +101,8 @@ def __init__(
h_is_dft: Flag indicating whether `h` is in the DFT domain.
h_center: Array of length `ndims` specifying the center of
the filter. Defaults to the upper left corner, i.e.,
`h_center = [0, 0, ..., 0]`, may be noninteger.
`h_center = [0, 0, ..., 0]`, may be noninteger. May be a
``float`` or ``int`` if `h` is one-dimensional.
jit: If ``True``, jit the evaluation, adjoint, and gram
functions of the LinearOperator.
"""
Expand All @@ -124,7 +127,18 @@ def __init__(
output_dtype = result_type(h.dtype, input_dtype)

if self.h_center is not None:
offset = -self.h_center
if isinstance(self.h_center, DeviceArray):
offset = -self.h_center
else:
# support float or int values for h_center
if isinstance(self.h_center, (float, int)):
offset = -snp.array(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

formatting looks slightly odd here but I assume that's black's doing

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean the relatively deeply nested if/else structure? That's just a natural consequence of the actual logic here.

[
self.h_center,
]
)
else: # support list/tuple values for h_center
offset = -snp.array(self.h_center)
shifts: Tuple[Array, ...] = np.ix_(
*tuple(
np.exp(-1j * k * 2 * np.pi * np.fft.fftfreq(s))
Expand Down
21 changes: 21 additions & 0 deletions scico/test/linop/test_circconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,27 @@ def test_matches_convolve(self, input_dtype, jit):
desired = A @ x
np.testing.assert_allclose(actual, desired, atol=1e-6)

@pytest.mark.parametrize(
"center",
[
1,
[
1,
],
snp.array([2]),
],
)
def test_center(self, center):
x, key = uniform(minval=-1, maxval=1, shape=(16,), key=self.key)
h = snp.array([0.5, 1.0, 0.25])
A = CircularConvolve(h=h, input_shape=x.shape, h_center=center)
B = CircularConvolve(h=h, input_shape=x.shape)
if isinstance(center, int):
shift = -center
else:
shift = -center[0]
np.testing.assert_allclose(A @ x, snp.roll(B @ x, shift), atol=1e-5)

@pytest.mark.parametrize("axes_shape_spec", SHAPE_SPECS)
@pytest.mark.parametrize("input_dtype", [np.float32, np.complex64])
@pytest.mark.parametrize("jit_old_op", [True, False])
Expand Down