Skip to content

Commit

Permalink
Add axes= keyword to fft dtype sample computation
Browse files Browse the repository at this point in the history
This is important for irfft2 dtype analysis
  • Loading branch information
mrocklin committed Aug 21, 2018
1 parent 9636454 commit 5b25a09
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion dask/array/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,10 @@ def func(a, s=None, axes=None):
_dtype = dtype
if _dtype is None:
sample = np.ones(a.ndim * (8,), dtype=a.dtype)
_dtype = fft_func(sample).dtype
try:
_dtype = fft_func(sample, axes=axes).dtype
except TypeError:
_dtype = fft_func(sample).dtype

for each_axis in axes:
if len(a.chunks[each_axis]) != 1:
Expand Down

0 comments on commit 5b25a09

Please sign in to comment.