Skip to content

Commit

Permalink
Support complex numbers in jax.scipy.signal.convolve/correlate
Browse files Browse the repository at this point in the history
  • Loading branch information
lukepfister committed Jun 18, 2021
1 parent 6053890 commit c33388b
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 16 deletions.
35 changes: 24 additions & 11 deletions jax/_src/scipy/signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,6 @@ def convolve(in1, in2, mode='full', method='auto',
precision=None):
if method != 'auto':
warnings.warn("convolve() ignores method argument")
if jnp.issubdtype(in1.dtype, jnp.complexfloating) or jnp.issubdtype(in2.dtype, jnp.complexfloating):
raise NotImplementedError("convolve() does not support complex inputs")
return _convolve_nd(in1, in2, mode, precision=precision)


Expand All @@ -76,8 +74,6 @@ def convolve2d(in1, in2, mode='full', boundary='fill', fillvalue=0,
precision=None):
if boundary != 'fill' or fillvalue != 0:
raise NotImplementedError("convolve2d() only supports boundary='fill', fillvalue=0")
if jnp.issubdtype(in1.dtype, jnp.complexfloating) or jnp.issubdtype(in2.dtype, jnp.complexfloating):
raise NotImplementedError("convolve2d() does not support complex inputs")
if jnp.ndim(in1) != 2 or jnp.ndim(in2) != 2:
raise ValueError("convolve2d() only supports 2-dimensional inputs.")
return _convolve_nd(in1, in2, mode, precision=precision)
Expand All @@ -88,21 +84,38 @@ def correlate(in1, in2, mode='full', method='auto',
precision=None):
if method != 'auto':
warnings.warn("correlate() ignores method argument")
if jnp.issubdtype(in1.dtype, jnp.complexfloating) or jnp.issubdtype(in2.dtype, jnp.complexfloating):
raise NotImplementedError("correlate() does not support complex inputs")
return _convolve_nd(in1, jnp.flip(in2), mode, precision=precision)
return _convolve_nd(in1, jnp.flip(in2.conj()), mode, precision=precision)


@_wraps(osp_signal.correlate2d)
def correlate2d(in1, in2, mode='full', boundary='fill', fillvalue=0,
precision=None):
if boundary != 'fill' or fillvalue != 0:
raise NotImplementedError("correlate2d() only supports boundary='fill', fillvalue=0")
if jnp.issubdtype(in1.dtype, jnp.complexfloating) or jnp.issubdtype(in2.dtype, jnp.complexfloating):
raise NotImplementedError("correlate2d() does not support complex inputs")
if jnp.ndim(in1) != 2 or jnp.ndim(in2) != 2:
raise ValueError("correlate2d() only supports {ndim}-dimensional inputs.")
return _convolve_nd(in1[::-1, ::-1], in2, mode, precision=precision)[::-1, ::-1]
raise ValueError("correlate2d() only supports 2-dimensional inputs.")

swap = all(s1 <= s2 for s1, s2 in zip(in1.shape, in2.shape))
same_shape = all(s1 == s2 for s1, s2 in zip(in1.shape, in2.shape))

if mode == "same":
in1, in2 = in1[::-1, ::-1], in2.conj()
result = _convolve_nd(in1, in2, mode, precision=precision)[::-1, ::-1]
elif mode == "valid":
if swap and not same_shape:
in1, in2 = in2[::-1, ::-1], in1.conj()
result = _convolve_nd(in1, in2, mode, precision=precision)
else:
in1, in2 = in1[::-1, ::-1], in2.conj()
result = _convolve_nd(in1, in2, mode, precision=precision)[::-1, ::-1]
else:
if swap:
in1, in2 = in2[::-1, ::-1], in1.conj()
result = _convolve_nd(in1, in2, mode, precision=precision).conj()
else:
in1, in2 = in1[::-1, ::-1], in2.conj()
result = _convolve_nd(in1, in2, mode, precision=precision)[::-1, ::-1]
return result


@_wraps(osp_signal.detrend)
Expand Down
10 changes: 5 additions & 5 deletions tests/scipy_signal_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
threedim_shapes = [(2, 2, 2), (3, 3, 2), (4, 4, 2), (5, 5, 2)]


default_dtypes = jtu.dtypes.floating + jtu.dtypes.integer
default_dtypes = jtu.dtypes.floating + jtu.dtypes.integer + jtu.dtypes.complex


class LaxBackedScipySignalTests(jtu.JaxTestCase):
Expand All @@ -58,9 +58,9 @@ def testConvolutions(self, xshape, yshape, dtype, mode, jsp_op, osp_op):
args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype)]
osp_fun = partial(osp_op, mode=mode)
jsp_fun = partial(jsp_op, mode=mode, precision=lax.Precision.HIGHEST)
tol = {np.float16: 1e-2, np.float32: 1e-2, np.float64: 1e-8}
tol = {np.float16: 1e-2, np.float32: 1e-2, np.float64: 1e-12, np.complex64: 1e-2, np.complex128: 1e-12}
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, check_dtypes=False, tol=tol)
self._CompileAndCheck(jsp_fun, args_maker)
self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "op={}_xshape={}_yshape={}_mode={}".format(
Expand All @@ -81,7 +81,7 @@ def testConvolutions2D(self, xshape, yshape, dtype, mode, jsp_op, osp_op):
args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype)]
osp_fun = partial(osp_op, mode=mode)
jsp_fun = partial(jsp_op, mode=mode, precision=lax.Precision.HIGHEST)
tol = {np.float16: 1e-2, np.float32: 1e-2, np.float64: 1e-14}
tol = {np.float16: 1e-2, np.float32: 1e-2, np.float64: 1e-12, np.complex64: 1e-2, np.complex128: 1e-12}
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, check_dtypes=False,
tol=tol)
self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)
Expand All @@ -91,7 +91,7 @@ def testConvolutions2D(self, xshape, yshape, dtype, mode, jsp_op, osp_op):
jtu.format_shape_dtype_string(shape, dtype), axis, type, bp),
"shape": shape, "dtype": dtype, "axis": axis, "type": type, "bp": bp}
for shape in [(5,), (4, 5), (3, 4, 5)]
for dtype in default_dtypes
for dtype in jtu.dtypes.floating + jtu.dtypes.integer
for axis in [0, -1]
for type in ['constant', 'linear']
for bp in [0, [0, 2]]))
Expand Down

0 comments on commit c33388b

Please sign in to comment.