Skip to content

Commit

Permalink
[x64] make fft functionality compatible with strict dtype promotion
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Jun 15, 2022
1 parent acc7dc0 commit 297a296
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 25 deletions.
15 changes: 4 additions & 11 deletions jax/_src/lax/fft.py
Expand Up @@ -30,20 +30,13 @@
from jax._src.lib.mlir.dialects import mhlo
from jax._src.lib import xla_client
from jax._src.lib import pocketfft
from jax._src.numpy.util import _promote_dtypes_complex, _promote_dtypes_inexact

__all__ = [
"fft",
"fft_p",
]

def _promote_to_complex(arg):
dtype = dtypes.result_type(arg, np.complex64)
return lax.convert_element_type(arg, dtype)

def _promote_to_real(arg):
dtype = dtypes.result_type(arg, np.float32)
return lax.convert_element_type(arg, dtype)

def _str_to_fft_type(s: str) -> xla_client.FftType:
if s == "FFT":
return xla_client.FftType.FFT
Expand All @@ -68,9 +61,9 @@ def fft(x, fft_type: Union[xla_client.FftType, str], fft_lengths: Sequence[int])
if typ == xla_client.FftType.RFFT:
if np.iscomplexobj(x):
raise ValueError("only real valued inputs supported for rfft")
x = _promote_to_real(x)
x, = _promote_dtypes_inexact(x)
else:
x = _promote_to_complex(x)
x, = _promote_dtypes_complex(x)
if len(fft_lengths) == 0:
# XLA FFT doesn't support 0-rank.
return x
Expand Down Expand Up @@ -137,7 +130,7 @@ def _irfft_transpose(t, fft_lengths):
x = fft(t, xla_client.FftType.RFFT, fft_lengths)
n = x.shape[-1]
is_odd = fft_lengths[-1] % 2
full = partial(lax.full_like, t, dtype=t.dtype)
full = partial(lax.full_like, t, dtype=x.dtype)
mask = lax.concatenate(
[full(1.0, shape=(1,)),
full(2.0, shape=(n - 2 + is_odd,)),
Expand Down
19 changes: 11 additions & 8 deletions jax/_src/numpy/fft.py
Expand Up @@ -16,6 +16,7 @@
import operator
import numpy as np

from jax import dtypes
from jax import lax
from jax._src.lib import xla_client
from jax._src.util import safe_zip
Expand Down Expand Up @@ -87,7 +88,7 @@ def _fft_core(func_name, fft_type, a, s, axes, norm):
else:
s = [a.shape[axis] for axis in axes]
transformed = lax.fft(a, fft_type, tuple(s))
transformed *= _fft_norm(jnp.array(s, dtype=transformed.real.dtype), func_name, norm)
transformed *= _fft_norm(jnp.array(s, dtype=transformed.dtype), func_name, norm)

if orig_axes is not None:
transformed = jnp.moveaxis(transformed, axes, orig_axes)
Expand Down Expand Up @@ -199,6 +200,7 @@ def irfft2(a, s=None, axes=(-2,-1), norm=None):

@_wraps(np.fft.fftfreq)
def fftfreq(n, d=1.0):
dtype = dtypes.canonicalize_dtype(jnp.float_)
if isinstance(n, (list, tuple)):
raise ValueError(
"The n argument of jax.numpy.fft.fftfreq only takes an int. "
Expand All @@ -209,26 +211,27 @@ def fftfreq(n, d=1.0):
"The d argument of jax.numpy.fft.fftfreq only takes a single value. "
"Got d = %s." % list(d))

k = jnp.zeros(n)
k = jnp.zeros(n, dtype=dtype)
if n % 2 == 0:
# k[0: n // 2 - 1] = jnp.arange(0, n // 2 - 1)
k = k.at[0: n // 2].set( jnp.arange(0, n // 2))
k = k.at[0: n // 2].set( jnp.arange(0, n // 2, dtype=dtype))

# k[n // 2:] = jnp.arange(-n // 2, -1)
k = k.at[n // 2:].set( jnp.arange(-n // 2, 0))
k = k.at[n // 2:].set( jnp.arange(-n // 2, 0, dtype=dtype))

else:
# k[0: (n - 1) // 2] = jnp.arange(0, (n - 1) // 2)
k = k.at[0: (n - 1) // 2 + 1].set(jnp.arange(0, (n - 1) // 2 + 1))
k = k.at[0: (n - 1) // 2 + 1].set(jnp.arange(0, (n - 1) // 2 + 1, dtype=dtype))

# k[(n - 1) // 2 + 1:] = jnp.arange(-(n - 1) // 2, -1)
k = k.at[(n - 1) // 2 + 1:].set(jnp.arange(-(n - 1) // 2, 0))
k = k.at[(n - 1) // 2 + 1:].set(jnp.arange(-(n - 1) // 2, 0, dtype=dtype))

return k / (d * n)


@_wraps(np.fft.rfftfreq)
def rfftfreq(n, d=1.0):
dtype = dtypes.canonicalize_dtype(jnp.float_)
if isinstance(n, (list, tuple)):
raise ValueError(
"The n argument of jax.numpy.fft.rfftfreq only takes an int. "
Expand All @@ -240,10 +243,10 @@ def rfftfreq(n, d=1.0):
"Got d = %s." % list(d))

if n % 2 == 0:
k = jnp.arange(0, n // 2 + 1)
k = jnp.arange(0, n // 2 + 1, dtype=dtype)

else:
k = jnp.arange(0, (n - 1) // 2 + 1)
k = jnp.arange(0, (n - 1) // 2 + 1, dtype=dtype)

return k / (d * n)

Expand Down
11 changes: 11 additions & 0 deletions jax/_src/numpy/util.py
Expand Up @@ -275,6 +275,17 @@ def _promote_dtypes_inexact(*args):
for x in args]


def _promote_dtypes_complex(*args):
"""Convenience function to apply Numpy argument dtype promotion.
Promotes arguments to a complex type."""
to_dtype, weak_type = dtypes._lattice_result_type(*args)
to_dtype = dtypes.canonicalize_dtype(to_dtype)
to_dtype_complex = dtypes._to_complex_dtype(to_dtype)
return [lax_internal._convert_element_type(x, to_dtype_complex, weak_type)
for x in args]


def _complex_elem_type(dtype):
"""Returns the float type of the real/imaginary parts of a complex dtype."""
return np.abs(np.zeros((), dtype)).dtype
Expand Down
11 changes: 7 additions & 4 deletions jax/_src/scipy/fft.py
Expand Up @@ -17,9 +17,10 @@
import scipy.fftpack as osp_fft # TODO use scipy.fft once scipy>=1.4.0 is used
from jax import lax, numpy as jnp
from jax._src.util import canonicalize_axis
from jax._src.numpy.util import _wraps
from jax._src.numpy.util import _wraps, _promote_dtypes_complex

def _W4(N, k):
N, k = _promote_dtypes_complex(N, k)
return jnp.exp(-.5j * jnp.pi * k / N)

def _dct_interleave(x, axis):
Expand Down Expand Up @@ -49,7 +50,7 @@ def dct(x, type=2, n=None, axis=-1, norm=None):
N = x.shape[axis]
v = _dct_interleave(x, axis)
V = jnp.fft.fft(v, axis=axis)
k = lax.expand_dims(jnp.arange(N), [a for a in range(x.ndim) if a != axis])
k = lax.expand_dims(jnp.arange(N, dtype=V.real.dtype), [a for a in range(x.ndim) if a != axis])
out = V * _W4(N, k)
out = 2 * out.real
if norm == 'ortho':
Expand All @@ -62,8 +63,10 @@ def _dct2(x, axes, norm):
N1, N2 = x.shape[axis1], x.shape[axis2]
v = _dct_interleave(_dct_interleave(x, axis1), axis2)
V = jnp.fft.fftn(v, axes=axes)
k1 = lax.expand_dims(jnp.arange(N1), [a for a in range(x.ndim) if a != axis1])
k2 = lax.expand_dims(jnp.arange(N2), [a for a in range(x.ndim) if a != axis2])
k1 = lax.expand_dims(jnp.arange(N1, dtype=V.dtype),
[a for a in range(x.ndim) if a != axis1])
k2 = lax.expand_dims(jnp.arange(N2, dtype=V.dtype),
[a for a in range(x.ndim) if a != axis2])
out = _W4(N1, k1) * (_W4(N2, k2) * V + _W4(N2, -k2) * jnp.roll(jnp.flip(V, axis=axis2), shift=1, axis=axis2))
out = 2 * out.real
if norm == 'ortho':
Expand Down
8 changes: 6 additions & 2 deletions tests/fft_test.py
Expand Up @@ -23,7 +23,9 @@
import jax
from jax import lax
from jax import numpy as jnp
from jax._src import dtypes
from jax._src import test_util as jtu
from jax._src.numpy.util import _promote_dtypes_complex

from jax.config import config
config.parse_flags_with_absl()
Expand Down Expand Up @@ -111,7 +113,7 @@ def testLaxFftAcceptsStringTypes(self):
def testLaxIrfftDoesNotMutateInputs(self, dtype):
if dtype == np.float64 and not config.x64_enabled:
raise self.skipTest("float64 requires jax_enable_x64=true")
x = jnp.array([[1.0, 2.0], [3.0, 4.0]], dtype=dtype) * (1+1j)
x = (1 + 1j) * jnp.array([[1.0, 2.0], [3.0, 4.0]], dtype=dtypes._to_complex_dtype(dtype))
y = np.asarray(jnp.fft.irfft2(x))
z = np.asarray(jnp.fft.irfft2(x))
self.assertAllClose(y, z)
Expand Down Expand Up @@ -157,7 +159,9 @@ def build_matrix(linear_func, size):
return jax.vmap(linear_func)(jnp.eye(size, size))

def func(x):
return jnp.fft.irfft(jnp.concatenate([jnp.zeros(1), x[:2] + 1j*x[2:]]))
x, = _promote_dtypes_complex(x)
return jnp.fft.irfft(jnp.concatenate([jnp.zeros_like(x, shape=1),
x[:2] + 1j*x[2:]]))

def func_transpose(x):
return jax.linear_transpose(func, x)(x)[0]
Expand Down

0 comments on commit 297a296

Please sign in to comment.