From 3bc7037c72e31bb7888aefe1b3e3c92e931bcaeb Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Wed, 13 Sep 2023 12:56:16 +0200 Subject: [PATCH 01/50] implement fftn, first draft --- heat/__init__.py | 1 + heat/fft/__init__.py | 5 ++ heat/fft/fft.py | 115 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 121 insertions(+) create mode 100644 heat/fft/__init__.py create mode 100644 heat/fft/fft.py diff --git a/heat/__init__.py b/heat/__init__.py index 6f6d7959dd..84c4afc11b 100644 --- a/heat/__init__.py +++ b/heat/__init__.py @@ -9,6 +9,7 @@ from . import core from . import classification from . import cluster +from . import fft from . import graph from . import naive_bayes from . import nn diff --git a/heat/fft/__init__.py b/heat/fft/__init__.py new file mode 100644 index 0000000000..da3b4f307e --- /dev/null +++ b/heat/fft/__init__.py @@ -0,0 +1,5 @@ +""" +import the graph functions into the graph namespace +""" + +from .fft import * diff --git a/heat/fft/fft.py b/heat/fft/fft.py new file mode 100644 index 0000000000..27d8fa9eb8 --- /dev/null +++ b/heat/fft/fft.py @@ -0,0 +1,115 @@ +"""Provides a collection of Discrete Fast Fourier Transforms (DFFT) and their inverses.""" + +import torch + +from ..core.communication import MPI +from ..core.dndarray import DNDarray +from ..core.stride_tricks import sanitize_axis +from ..core.types import promote_types, heat_type_of +from ..core.factories import array, zeros + +from typing import Type, Union, Tuple, Any, Iterable, Optional + +__all__ = ["fftn"] + +# TODO: implement __fft_op to deal with the different operations + + +def fftn( + x: DNDarray, s: Tuple[int, ...] = None, axes: Tuple[int, ...] = None, norm: str = None +) -> DNDarray: + """ + Compute the N-dimensional discrete Fourier Transform. + + This function computes the N-dimensional discrete Fourier Transform over any number of axes in an M-dimensional + array by means of the Fast Fourier Transform (FFT). By default, all axes are transformed, with the real transform + performed over the last axis, while the remaining transforms are complex. + + Parameters + ---------- + x : DNDarray + Input array, can be complex + s : Tuple[int, ...], optional + Shape of the output along the transformed axes. (default is x.shape) + axes : Tuple[int, ...], optional + Axes over which to compute the FFT. If not given, the last `len(s)` axes are used, or all axes if `s` is also + not specified. Repeated indices in `axes` means that the transform over that axis is performed multiple times. + (default is None) + norm : str, optional + Normalization mode (see `numpy.fft` for details). (default is None) + + Notes + ----- + This function requires MPI communication if the input array is distributed and the split axis is transformed. + """ + try: + local_x = x.larray + except AttributeError: + raise TypeError("x must be a DNDarray, is {}".format(type(x))) + + # check if axes are valid + axes = sanitize_axis(x.gshape, axes) + split = x.split + + # non-distributed DNDarray + if not x.is_distributed(): + result = torch.fft.fftn(local_x, s=s, dim=axes, norm=norm) + return array(result, split=x.split, device=x.device, comm=x.comm) + + # distributed DNDarray: + # calculate output shape + output_shape = list(x.shape) + if s is not None: + if axes is None: + axes = tuple(range(x.ndim)[-len(s) :]) + for i, axis in enumerate(axes): + output_shape[axis] = s[i] + else: + s = tuple(output_shape[axis] for axis in axes) + output_shape = tuple(output_shape) + + fft_along_split = x.split in axes + + # FFT along non-split axes only + if not fft_along_split: + result = torch.fft.fftn(local_x, s=s, dim=axes, norm=norm) + return DNDarray( + result, + gshape=output_shape, + dtype=heat_type_of(result), + split=x.split, + device=x.device, + comm=x.comm, + balanced=x.balanced, + ) + + # FFT along split axis + if split != 0: + # transpose x so redistribution starts from axis 0 + transpose_axes = list(range(x.ndim)) + transpose_axes[0], transpose_axes[split] = transpose_axes[split], transpose_axes[0] + x = x.transpose(transpose_axes) + + # redistribute x from axis 0 to 1 + _ = x.resplit(axis=1) + # FFT along axis 0 (now non-split) + split_index = axes.index(split) + partial_result = fftn(_, s=(s[split_index],), axes=(0,), norm=norm) + del _ + # redistribute partial result from axis 1 to 0 + partial_result.resplit_(axis=0) + if split != 0: + # transpose x, partial_result back to original shape + x = x.transpose(transpose_axes) + partial_result = partial_result.transpose(transpose_axes) + + # now apply FFT along leftover (non-split) axes + axes = list(axes) + axes.remove(split) + axes = tuple(axes) + s = list(s) + s = s[:split_index] + s[split_index + 1 :] + s = tuple(s) + result = fftn(partial_result, s=s, axes=axes, norm=norm) + del partial_result + return array(result, is_split=split, device=x.device, comm=x.comm) From 5190e0074d9a6c8ca0f21c54c117e1566da0443d Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Thu, 14 Sep 2023 05:46:12 +0200 Subject: [PATCH 02/50] implement general , add --- heat/fft/fft.py | 128 ++++++++++++++++++++++++++++++++---------------- 1 file changed, 85 insertions(+), 43 deletions(-) diff --git a/heat/fft/fft.py b/heat/fft/fft.py index 27d8fa9eb8..61ddefc1e7 100644 --- a/heat/fft/fft.py +++ b/heat/fft/fft.py @@ -10,51 +10,30 @@ from typing import Type, Union, Tuple, Any, Iterable, Optional -__all__ = ["fftn"] +__all__ = ["fft2", "fftn"] -# TODO: implement __fft_op to deal with the different operations - -def fftn( - x: DNDarray, s: Tuple[int, ...] = None, axes: Tuple[int, ...] = None, norm: str = None -) -> DNDarray: +# TODO: implement __fft_op, __fftn_op to deal with the different operations +def __fftn_op(x: DNDarray, fftn_op: callable, **kwargs) -> DNDarray: """ - Compute the N-dimensional discrete Fourier Transform. - - This function computes the N-dimensional discrete Fourier Transform over any number of axes in an M-dimensional - array by means of the Fast Fourier Transform (FFT). By default, all axes are transformed, with the real transform - performed over the last axis, while the remaining transforms are complex. - - Parameters - ---------- - x : DNDarray - Input array, can be complex - s : Tuple[int, ...], optional - Shape of the output along the transformed axes. (default is x.shape) - axes : Tuple[int, ...], optional - Axes over which to compute the FFT. If not given, the last `len(s)` axes are used, or all axes if `s` is also - not specified. Repeated indices in `axes` means that the transform over that axis is performed multiple times. - (default is None) - norm : str, optional - Normalization mode (see `numpy.fft` for details). (default is None) - - Notes - ----- - This function requires MPI communication if the input array is distributed and the split axis is transformed. + Helper function for fftn """ try: local_x = x.larray except AttributeError: raise TypeError("x must be a DNDarray, is {}".format(type(x))) + original_split = x.split - # check if axes are valid + # sanitize kwargs + axes = kwargs.get("axes", None) axes = sanitize_axis(x.gshape, axes) - split = x.split + s = kwargs.get("s", None) + norm = kwargs.get("norm", None) # non-distributed DNDarray if not x.is_distributed(): - result = torch.fft.fftn(local_x, s=s, dim=axes, norm=norm) - return array(result, split=x.split, device=x.device, comm=x.comm) + result = fftn_op(local_x, s=s, dim=axes, norm=norm) + return array(result, split=original_split, device=x.device, comm=x.comm) # distributed DNDarray: # calculate output shape @@ -68,48 +47,111 @@ def fftn( s = tuple(output_shape[axis] for axis in axes) output_shape = tuple(output_shape) - fft_along_split = x.split in axes + fft_along_split = original_split in axes # FFT along non-split axes only if not fft_along_split: - result = torch.fft.fftn(local_x, s=s, dim=axes, norm=norm) + result = fftn_op(local_x, s=s, dim=axes, norm=norm) return DNDarray( result, gshape=output_shape, dtype=heat_type_of(result), - split=x.split, + split=original_split, device=x.device, comm=x.comm, balanced=x.balanced, ) # FFT along split axis - if split != 0: + if original_split != 0: # transpose x so redistribution starts from axis 0 transpose_axes = list(range(x.ndim)) - transpose_axes[0], transpose_axes[split] = transpose_axes[split], transpose_axes[0] + transpose_axes[0], transpose_axes[original_split] = ( + transpose_axes[original_split], + transpose_axes[0], + ) x = x.transpose(transpose_axes) # redistribute x from axis 0 to 1 _ = x.resplit(axis=1) # FFT along axis 0 (now non-split) - split_index = axes.index(split) - partial_result = fftn(_, s=(s[split_index],), axes=(0,), norm=norm) + split_index = axes.index(original_split) + partial_result = __fftn_op(_, fftn_op, s=(s[split_index],), axes=(0,), norm=norm) del _ # redistribute partial result from axis 1 to 0 partial_result.resplit_(axis=0) - if split != 0: + if original_split != 0: # transpose x, partial_result back to original shape x = x.transpose(transpose_axes) partial_result = partial_result.transpose(transpose_axes) # now apply FFT along leftover (non-split) axes axes = list(axes) - axes.remove(split) + axes.remove(original_split) axes = tuple(axes) s = list(s) s = s[:split_index] + s[split_index + 1 :] s = tuple(s) - result = fftn(partial_result, s=s, axes=axes, norm=norm) + result = __fftn_op(partial_result, fftn_op, s=s, axes=axes, norm=norm) del partial_result - return array(result, is_split=split, device=x.device, comm=x.comm) + return array(result.larray, is_split=original_split, device=x.device, comm=x.comm) + + +def fft2( + x: DNDarray, s: Tuple[int, int] = None, axes: Tuple[int, int] = (-2, -1), norm: str = None +) -> DNDarray: + """ + Compute the 2-dimensional discrete Fourier Transform. + + This function computes the 2-dimensional discrete Fourier Transform over the specified axes in an M-dimensional + array by means of the Fast Fourier Transform (FFT). By default, the last two axes are transformed, while the + remaining axes are left unchanged. + + Parameters + ---------- + x : DNDarray + Input array, can be complex + s : Tuple[int, int], optional + Shape of the output along the transformed axes. (default is x.shape) + axes : Tuple[int, int], optional + Axes over which to compute the FFT. If not given, the last `len(s)` axes are used, or all axes if `s` is also + not specified. Repeated indices in `axes` means that the transform over that axis is performed multiple times. + (default is (-2, -1)) + norm : str, optional + Normalization mode: 'forward', 'backward', or 'ortho' (see `numpy.fft` for details). Default is "backward". + + Notes + ----- + This function requires MPI communication if the input array is distributed and the split axis is transformed. + """ + return __fftn_op(x, torch.fft.fft2, s=s, axes=axes, norm=norm) + + +def fftn( + x: DNDarray, s: Tuple[int, ...] = None, axes: Tuple[int, ...] = None, norm: str = None +) -> DNDarray: + """ + Compute the N-dimensional discrete Fourier Transform. + + This function computes the N-dimensional discrete Fourier Transform over any number of axes in an M-dimensional + array by means of the Fast Fourier Transform (FFT). By default, all axes are transformed, with the real transform + performed over the last axis, while the remaining transforms are complex. + + Parameters + ---------- + x : DNDarray + Input array, can be complex + s : Tuple[int, ...], optional + Shape of the output along the transformed axes. (default is x.shape) + axes : Tuple[int, ...], optional + Axes over which to compute the FFT. If not given, the last `len(s)` axes are used, or all axes if `s` is also + not specified. Repeated indices in `axes` means that the transform over that axis is performed multiple times. + (default is None) + norm : str, optional + Normalization mode: 'forward', 'backward', or 'ortho' (see `numpy.fft` for details). Default is "backward". + + Notes + ----- + This function requires MPI communication if the input array is distributed and the split axis is transformed. + """ + return __fftn_op(x, torch.fft.fftn, s=s, axes=axes, norm=norm) From e0e48de56b58e04d5ddd7e5baa9ced81b4db2c0a Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Thu, 14 Sep 2023 10:35:17 +0200 Subject: [PATCH 03/50] split fft_op and fftn_op, implement inverse and real fft --- heat/fft/fft.py | 357 +++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 355 insertions(+), 2 deletions(-) diff --git a/heat/fft/fft.py b/heat/fft/fft.py index 61ddefc1e7..773cd58b26 100644 --- a/heat/fft/fft.py +++ b/heat/fft/fft.py @@ -10,10 +10,95 @@ from typing import Type, Union, Tuple, Any, Iterable, Optional -__all__ = ["fft2", "fftn"] +__all__ = [ + "fft", + "fft2", + "fftn", + "ifft", + "ifft2", + "ifftn", + "rfft", + "rfft2", + "rfftn", + "irfft", + "irfft2", + "irfftn", +] + + +def __fft_op(x: DNDarray, fft_op: callable, **kwargs) -> DNDarray: + """ + Helper function for fft + """ + try: + local_x = x.larray + except AttributeError: + raise TypeError("x must be a DNDarray, is {}".format(type(x))) + original_split = x.split + + # sanitize kwargs + axis = kwargs.get("axis", None) + axis = sanitize_axis(x.gshape, axis) + n = kwargs.get("n", None) + norm = kwargs.get("norm", None) + + # non-distributed DNDarray + if not x.is_distributed(): + result = fft_op(local_x, n=n, dim=axis, norm=norm) + return array(result, split=original_split, device=x.device, comm=x.comm) + + # distributed DNDarray: + # calculate output shape + output_shape = list(x.shape) + if n is not None: + if axis is None: + output_shape[-1] = n + else: + output_shape[axis] = n + + fft_along_split = original_split == axis + + # FFT along non-split axis + if not fft_along_split: + result = fft_op(local_x, n=n, dim=axis, norm=norm) + return DNDarray( + result, + gshape=output_shape, + dtype=heat_type_of(result), + split=original_split, + device=x.device, + comm=x.comm, + balanced=x.balanced, + ) + + # FFT along split axis + if original_split != 0: + # transpose x so redistribution starts from axis 0 + transpose_axes = list(range(x.ndim)) + transpose_axes[0], transpose_axes[original_split] = ( + transpose_axes[original_split], + transpose_axes[0], + ) + x = x.transpose(transpose_axes) + + # redistribute x + if x.ndim > 1: + _ = x.resplit(axis=1) + else: + _ = x.resplit(axis=None) + # FFT along axis 0 (now non-split) + result = __fft_op(_, fft_op, n=n, axis=0, norm=norm) + del _ + # redistribute partial result back to axis 0 + result.resplit_(axis=0) + if original_split != 0: + # transpose x, partial_result back to original shape + x = x.transpose(transpose_axes) + result = result.transpose(transpose_axes) + + return result -# TODO: implement __fft_op, __fftn_op to deal with the different operations def __fftn_op(x: DNDarray, fftn_op: callable, **kwargs) -> DNDarray: """ Helper function for fftn @@ -27,6 +112,9 @@ def __fftn_op(x: DNDarray, fftn_op: callable, **kwargs) -> DNDarray: # sanitize kwargs axes = kwargs.get("axes", None) axes = sanitize_axis(x.gshape, axes) + repeated_axes = axes is not None and len(axes) != len(set(axes)) + if repeated_axes: + raise NotImplementedError("Multiple transforms over the same axis not implemented yet.") s = kwargs.get("s", None) norm = kwargs.get("norm", None) @@ -97,6 +185,37 @@ def __fftn_op(x: DNDarray, fftn_op: callable, **kwargs) -> DNDarray: return array(result.larray, is_split=original_split, device=x.device, comm=x.comm) +def fft(x: DNDarray, n: int = None, axis: int = -1, norm: str = None) -> DNDarray: + """ + Compute the one-dimensional discrete Fourier Transform. + + This function computes the one-dimensional discrete Fourier Transform over the specified axis in an M-dimensional + array by means of the Fast Fourier Transform (FFT). By default, the last axis is transformed, while the remaining + axes are left unchanged. + + Parameters + ---------- + x : DNDarray + Input array, can be complex. WARNING: If x is 1-D and distributed, the entire array is copied on each MPI process. + n : int, optional + Length of the transformed axis of the output. If not given, the length is taken to be the length of the input + along the axis specified by axis. If `n` is smaller than the length of the input, the input is cropped. If `n` is + larger, the input is padded with zeros. Default: None. + axis : int, optional + Axis over which to compute the FFT. If not given, the last axis is used, or the only axis if x has only one + dimension. Default: -1. + norm : str, optional + Normalization mode: 'forward', 'backward', or 'ortho' (see `numpy.fft` for details). Default is "backward". + + Notes + ----- + This function requires MPI communication if the input array is transformed along the distribution axis. + If the input array is 1-D and distributed, this function copies the entire array on each MPI process! i.e. if the array is very large, you might run out of memory. + Hint: if you are looping through a batch of 1-D arrays to transform them, consider stacking them into a 2-D DNDarray and transforming them all at once (see :func:`fft2`). + """ + return __fft_op(x, torch.fft.fft, n=n, axis=axis, norm=norm) + + def fft2( x: DNDarray, s: Tuple[int, int] = None, axes: Tuple[int, int] = (-2, -1), norm: str = None ) -> DNDarray: @@ -155,3 +274,237 @@ def fftn( This function requires MPI communication if the input array is distributed and the split axis is transformed. """ return __fftn_op(x, torch.fft.fftn, s=s, axes=axes, norm=norm) + + +def ifft(x: DNDarray, n: int = None, axis: int = -1, norm: str = None) -> DNDarray: + """ + Compute the one-dimensional inverse discrete Fourier Transform. + + Parameters + ---------- + x : DNDarray + Input array, can be complex + n : int, optional + Length of the transformed axis of the output. If not given, the length is taken to be the length of the input + along the axis specified by `axis`. If `n` is smaller than the length of the input, the input is cropped. If `n` is + larger, the input is padded with zeros. Default: None. + axis : int, optional + Axis over which to compute the inverse FFT. If not given, the last axis is used, or the only axis if x has only one dimension. Default: -1. + norm : str, optional + Normalization mode: 'forward', 'backward', or 'ortho' (see `numpy.fft` for details). Default is "backward". + + Notes + ----- + This function requires MPI communication if the input array is transformed along the distribution axis. + If the input array is 1-D and distributed, this function copies the entire array on each MPI process! i.e. if the array is very large, you might run out of memory. + Hint: if you are looping through a batch of 1-D arrays to transform them, consider stacking them into a 2-D DNDarray and transforming them all at once (see :func:`ifft2`). + """ + return __fft_op(x, torch.fft.ifft, n=n, axis=axis, norm=norm) + + +def ifft2( + x: DNDarray, s: Tuple[int, int] = None, axes: Tuple[int, int] = (-2, -1), norm: str = None +) -> DNDarray: + """ + Compute the 2-dimensional inverse discrete Fourier Transform. + + Parameters + ---------- + x : DNDarray + Input array, can be complex + s : Tuple[int, int], optional + Shape of the output along the transformed axes. (default is x.shape) + axes : Tuple[int, int], optional + Axes over which to compute the inverse FFT. If not given, the last `len(s)` axes are used, or all axes if `s` is + also not specified. Repeated indices in `axes` means that the transform over that axis is performed multiple + times. (default is (-2, -1)) + norm : str, optional + Normalization mode: 'forward', 'backward', or 'ortho' (see `numpy.fft` for details). Default is "backward". + + Notes + ----- + This function requires MPI communication if the input array is distributed and the split axis is transformed. + """ + return __fftn_op(x, torch.fft.ifft2, s=s, axes=axes, norm=norm) + + +def ifftn( + x: DNDarray, s: Tuple[int, int] = None, axes: Tuple[int, ...] = None, norm: str = None +) -> DNDarray: + """ + Compute the N-dimensional inverse discrete Fourier Transform. + + Parameters + ---------- + x : DNDarray + Input array, can be complex + s : Tuple[int, ...], optional + Shape of the output along the transformed axes. (default is x.shape) + axes : Tuple[int, ...], optional + Axes over which to compute the inverse FFT. If not given, the last `len(s)` axes are used, or all axes if `s` is + also not specified. Repeated indices in `axes` means that the transform over that axis is performed multiple + times. (default is None) + norm : str, optional + Normalization mode: 'forward', 'backward', or 'ortho' (see `numpy.fft` for details). Default is "backward". + + Notes + ----- + This function requires MPI communication if the input array is distributed and the split axis is transformed. + """ + return __fftn_op(x, torch.fft.ifftn, s=s, axes=axes, norm=norm) + + +def irfft(x: DNDarray, n: int = None, axis: int = -1, norm: str = None) -> DNDarray: + """ + Compute the one-dimensional inverse discrete Fourier Transform for real input. + + Parameters + ---------- + x : DNDarray + Input array, can be complex + n : int, optional + Length of the transformed axis of the output. If not given, the length is taken to be the length of the input + along the axis specified by `axis`. If `n` is smaller than the length of the input, the input is cropped. If `n` is + larger, the input is padded with zeros. Default: None. + axis : int, optional + Axis over which to compute the inverse FFT. If not given, the last axis is used, or the only axis if x has only one dimension. Default: -1. + norm : str, optional + Normalization mode: 'forward', 'backward', or 'ortho' (see `numpy.fft` for details). Default is "backward". + + Notes + ----- + This function requires MPI communication if the input array is transformed along the distribution axis. + If the input array is 1-D and distributed, this function copies the entire array on each MPI process! i.e. if the array is very large, you might run out of memory. + Hint: if you are looping through a batch of 1-D arrays to transform them, consider stacking them into a 2-D DNDarray and transforming them all at once (see :func:`irfft2`). + """ + return __fft_op(x, torch.fft.irfft, n=n, axis=axis, norm=norm) + + +def irfft2( + x: DNDarray, s: Tuple[int, int] = None, axes: Tuple[int, int] = (-2, -1), norm: str = None +) -> DNDarray: + """ + Compute the 2-dimensional inverse discrete Fourier Transform for real input. + + Parameters + ---------- + x : DNDarray + Input array, can be complex + s : Tuple[int, int], optional + Shape of the output along the transformed axes. (default is x.shape) + axes : Tuple[int, int], optional + Axes over which to compute the inverse FFT. If not given, the last `len(s)` axes are used, or all axes if `s` is + also not specified. Repeated indices in `axes` means that the transform over that axis is performed multiple + times. (default is (-2, -1)) + norm : str, optional + Normalization mode: 'forward', 'backward', or 'ortho' (see `numpy.fft` for details). Default is "backward". + + Notes + ----- + This function requires MPI communication if the input array is distributed and the split axis is transformed. + """ + return __fftn_op(x, torch.fft.irfft2, s=s, axes=axes, norm=norm) + + +def irfftn( + x: DNDarray, s: Tuple[int, int] = None, axes: Tuple[int, ...] = None, norm: str = None +) -> DNDarray: + """ + Compute the N-dimensional inverse discrete Fourier Transform for real input. + + Parameters + ---------- + x : DNDarray + Input array, can be complex + s : Tuple[int, ...], optional + Shape of the output along the transformed axes. (default is x.shape) + axes : Tuple[int, ...], optional + Axes over which to compute the inverse FFT. If not given, the last `len(s)` axes are used, or all axes if `s` is + also not specified. Repeated indices in `axes` means that the transform over that axis is performed multiple + times. (default is None) + norm : str, optional + Normalization mode: 'forward', 'backward', or 'ortho' (see `numpy.fft` for details). Default is "backward". + + Notes + ----- + This function requires MPI communication if the input array is distributed and the split axis is transformed. + """ + return __fftn_op(x, torch.fft.irfftn, s=s, axes=axes, norm=norm) + + +def rfft(x: DNDarray, n: int = None, axis: int = -1, norm: str = None) -> DNDarray: + """ + Compute the one-dimensional discrete Fourier Transform for real input. + + Parameters + ---------- + x : DNDarray + Input array, must be float. + n : int, optional + Length of the transformed axis of the output. If not given, the length is taken to be the length of the input + along the axis specified by `axis`. If `n` is smaller than the length of the input, the input is cropped. If `n` is + larger, the input is padded with zeros. Default: None. + axis : int, optional + Axis over which to compute the FFT. If not given, the last axis is used, or the only axis if x has only one dimension. Default: -1. + norm : str, optional + Normalization mode: 'forward', 'backward', or 'ortho' (see `numpy.fft` for details). Default is "backward". + + Notes + ----- + This function requires MPI communication if the input array is transformed along the distribution axis. + If the input array is 1-D and distributed, this function copies the entire array on each MPI process! i.e. if the array is very large, you might run out of memory. + Hint: if you are looping through a batch of 1-D arrays to transform them, consider stacking them into a 2-D DNDarray and transforming them all at once (see :func:`rfft2`). + """ + return __fft_op(x, torch.fft.rfft, n=n, axis=axis, norm=norm) + + +def rfft2( + x: DNDarray, s: Tuple[int, int] = None, axes: Tuple[int, int] = (-2, -1), norm: str = None +) -> DNDarray: + """ + Compute the 2-dimensional discrete Fourier Transform for real input. + + Parameters + ---------- + x : DNDarray + Input array, must be float. + s : Tuple[int, int], optional + Shape of the output along the transformed axes. (default is x.shape) + axes : Tuple[int, int], optional + Axes over which to compute the FFT. If not given, the last `len(s)` axes are used, or all axes if `s` is + also not specified. Repeated indices in `axes` means that the transform over that axis is performed multiple + times. (default is (-2, -1)) + norm : str, optional + Normalization mode: 'forward', 'backward', or 'ortho' (see `numpy.fft` for details). Default is "backward". + + Notes + ----- + This function requires MPI communication if the input array is distributed and the split axis is transformed. + """ + return __fftn_op(x, torch.fft.rfft2, s=s, axes=axes, norm=norm) + + +def rfftn( + x: DNDarray, s: Tuple[int, int] = None, axes: Tuple[int, ...] = None, norm: str = None +) -> DNDarray: + """ + Compute the N-dimensional discrete Fourier Transform for real input. + + Parameters + ---------- + x : DNDarray + Input array, must be float. + s : Tuple[int, ...], optional + Shape of the output along the transformed axes. (default is x.shape) + axes : Tuple[int, ...], optional + Axes over which to compute the FFT. If not given, the last `len(s)` axes are used, or all axes if `s` is + also not specified. Repeated indices in `axes` means that the transform over that axis is performed multiple + times. (default is None) + norm : str, optional + Normalization mode: 'forward', 'backward', or 'ortho' (see `numpy.fft` for details). Default is "backward". + + Notes + ----- + This function requires MPI communication if the input array is distributed and the split axis is transformed. + """ + return __fftn_op(x, torch.fft.rfftn, s=s, axes=axes, norm=norm) From 9a1cb989d84c718c96657f46bac0a8ea42456fbf Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Thu, 14 Sep 2023 10:43:59 +0200 Subject: [PATCH 04/50] add TODO hermitian fft --- heat/fft/fft.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/heat/fft/fft.py b/heat/fft/fft.py index 773cd58b26..8b24748348 100644 --- a/heat/fft/fft.py +++ b/heat/fft/fft.py @@ -23,6 +23,16 @@ "irfft", "irfft2", "irfftn", + # "hfft", + # "hfft2", + # "hfftn", + # "ihfft", + # "ihfft2", + # "ihfftn", + # "fftfreq", + # "rfftfreq", + # "fftshift", + # "ifftshift", ] From 05af2441dd285efe6635fdd56cde92cb1daf3b56 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Thu, 14 Sep 2023 10:44:56 +0200 Subject: [PATCH 05/50] add fft tests first draft --- heat/fft/tests/__init__.py | 0 heat/fft/tests/test_fft.py | 43 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+) create mode 100644 heat/fft/tests/__init__.py create mode 100644 heat/fft/tests/test_fft.py diff --git a/heat/fft/tests/__init__.py b/heat/fft/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/heat/fft/tests/test_fft.py b/heat/fft/tests/test_fft.py new file mode 100644 index 0000000000..202aec6750 --- /dev/null +++ b/heat/fft/tests/test_fft.py @@ -0,0 +1,43 @@ +import numpy as np +import torch + +import heat as ht +from heat.core.tests.test_suites.basic_test import TestCase + + +class TestFFT(TestCase): + def test_fft(self): + pass + + def test_ifft(self): + pass + + def test_rfft(self): + pass + + def test_irfft(self): + pass + + def test_fft2(self): + pass + + def test_ifft2(self): + pass + + def test_rfft2(self): + pass + + def test_irfft2(self): + pass + + def test_fftn(self): + pass + + def test_ifftn(self): + pass + + def test_rfftn(self): + pass + + def test_irfftn(self): + pass From 2ce5a09449cabcb3b86a8a1657d32cf7c75ea79b Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Mon, 18 Sep 2023 10:42:08 +0200 Subject: [PATCH 06/50] implement tests first draft --- heat/fft/fft.py | 4 ++-- heat/fft/tests/test_fft.py | 45 ++++++++++++++++++++++++++++++++++++-- 2 files changed, 45 insertions(+), 4 deletions(-) diff --git a/heat/fft/fft.py b/heat/fft/fft.py index 8b24748348..05dd2d831d 100644 --- a/heat/fft/fft.py +++ b/heat/fft/fft.py @@ -73,7 +73,7 @@ def __fft_op(x: DNDarray, fft_op: callable, **kwargs) -> DNDarray: result = fft_op(local_x, n=n, dim=axis, norm=norm) return DNDarray( result, - gshape=output_shape, + gshape=tuple(output_shape), dtype=heat_type_of(result), split=original_split, device=x.device, @@ -152,7 +152,7 @@ def __fftn_op(x: DNDarray, fftn_op: callable, **kwargs) -> DNDarray: result = fftn_op(local_x, s=s, dim=axes, norm=norm) return DNDarray( result, - gshape=output_shape, + gshape=tuple(output_shape), dtype=heat_type_of(result), split=original_split, device=x.device, diff --git a/heat/fft/tests/test_fft.py b/heat/fft/tests/test_fft.py index 202aec6750..c7d1da476e 100644 --- a/heat/fft/tests/test_fft.py +++ b/heat/fft/tests/test_fft.py @@ -7,10 +7,51 @@ class TestFFT(TestCase): def test_fft(self): - pass + # 1D non-distributed + x = ht.random.randn(6) + y = ht.fft.fft(x) + np_y = np.fft.fft(x.numpy()) + self.assertIsInstance(y, ht.DNDarray) + self.assertEqual(y.shape, x.shape) + self.assert_array_equal(y, np_y) + + # n-D distributed + x = ht.random.randn(10, 8, 6, dtype=ht.float64, split=0) + # FFT along last axis + y = ht.fft.fft(x) + np_y = np.fft.fft(x.numpy()) + self.assertIsInstance(y, ht.DNDarray) + self.assertEqual(y.shape, x.shape) + self.assertTrue(y.split == 0) + self.assert_array_equal(y, np_y) + + # FFT along distributed axis + y = ht.fft.fft(x, axis=0) + np_y = np.fft.fft(x.numpy(), axis=0) + self.assertIsInstance(y, ht.DNDarray) + self.assertEqual(y.shape, x.shape) + self.assertTrue(y.split == 0) + self.assert_array_equal(y, np_y) + + # complex input + x = x + 1j * ht.random.randn(10, 8, 6, dtype=ht.float64, split=0) + # FFT along last axis (distributed) + x.resplit_(axis=2) + y = ht.fft.fft(x) + np_y = np.fft.fft(x.numpy()) + self.assertIsInstance(y, ht.DNDarray) + self.assertEqual(y.shape, x.shape) + self.assertTrue(y.split == 2) + self.assert_array_equal(y, np_y) def test_ifft(self): - pass + # 1D non-distributed + x = ht.random.randn(6) + x_fft = ht.fft.fft(x) + y = ht.fft.ifft(x_fft) + self.assertIsInstance(y, ht.DNDarray) + self.assertEqual(y.shape, x.shape) + self.assert_array_equal(y, x.numpy()) def test_rfft(self): pass From a3cdcc3991b216953da59310d18c7fc5a6171b05 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Mon, 18 Sep 2023 10:55:08 +0200 Subject: [PATCH 07/50] update fft/__init__.py --- heat/fft/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/heat/fft/__init__.py b/heat/fft/__init__.py index da3b4f307e..bb7f51f486 100644 --- a/heat/fft/__init__.py +++ b/heat/fft/__init__.py @@ -1,5 +1,5 @@ """ -import the graph functions into the graph namespace +import the fft functions into the fft namespace """ from .fft import * From e5e015cea364e6cce02739c485eea83788f45e35 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Wed, 11 Oct 2023 11:55:38 +0200 Subject: [PATCH 08/50] expand tests fft --- heat/fft/tests/test_fft.py | 30 +++++++++++++++++++++++------- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/heat/fft/tests/test_fft.py b/heat/fft/tests/test_fft.py index c7d1da476e..9f7de3de4e 100644 --- a/heat/fft/tests/test_fft.py +++ b/heat/fft/tests/test_fft.py @@ -15,6 +15,15 @@ def test_fft(self): self.assertEqual(y.shape, x.shape) self.assert_array_equal(y, np_y) + # 1D distributed + x = ht.random.randn(6, split=0) + y = ht.fft.fft(x) + np_y = np.fft.fft(x.numpy()) + self.assertIsInstance(y, ht.DNDarray) + self.assertEqual(y.shape, x.shape) + self.assertTrue(y.split == 0) + self.assert_array_equal(y, np_y) + # n-D distributed x = ht.random.randn(10, 8, 6, dtype=ht.float64, split=0) # FFT along last axis @@ -25,11 +34,12 @@ def test_fft(self): self.assertTrue(y.split == 0) self.assert_array_equal(y, np_y) - # FFT along distributed axis - y = ht.fft.fft(x, axis=0) - np_y = np.fft.fft(x.numpy(), axis=0) + # FFT along distributed axis, n not None + n = 8 + y = ht.fft.fft(x, axis=0, n=n) + np_y = np.fft.fft(x.numpy(), axis=0, n=n) self.assertIsInstance(y, ht.DNDarray) - self.assertEqual(y.shape, x.shape) + self.assertEqual(y.shape, np_y.shape) self.assertTrue(y.split == 0) self.assert_array_equal(y, np_y) @@ -37,13 +47,19 @@ def test_fft(self): x = x + 1j * ht.random.randn(10, 8, 6, dtype=ht.float64, split=0) # FFT along last axis (distributed) x.resplit_(axis=2) - y = ht.fft.fft(x) - np_y = np.fft.fft(x.numpy()) + y = ht.fft.fft(x, n=n) + np_y = np.fft.fft(x.numpy(), n=n) self.assertIsInstance(y, ht.DNDarray) - self.assertEqual(y.shape, x.shape) + self.assertEqual(y.shape, np_y.shape) self.assertTrue(y.split == 2) self.assert_array_equal(y, np_y) + # exceptions + # wrong input type + x = np.random.randn(6, 3, 3) + with self.assertRaises(TypeError): + ht.fft.fft(x) + def test_ifft(self): # 1D non-distributed x = ht.random.randn(6) From c76d61c9081edad502d80f31d2b2db3c6a90a2ed Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Wed, 11 Oct 2023 12:33:30 +0200 Subject: [PATCH 09/50] expand tests ffftn --- heat/fft/fft.py | 11 +++++++++++ heat/fft/tests/test_fft.py | 19 +++++++++++++++++-- 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/heat/fft/fft.py b/heat/fft/fft.py index 05dd2d831d..e827c5c0de 100644 --- a/heat/fft/fft.py +++ b/heat/fft/fft.py @@ -117,6 +117,7 @@ def __fftn_op(x: DNDarray, fftn_op: callable, **kwargs) -> DNDarray: local_x = x.larray except AttributeError: raise TypeError("x must be a DNDarray, is {}".format(type(x))) + original_split = x.split # sanitize kwargs @@ -126,6 +127,7 @@ def __fftn_op(x: DNDarray, fftn_op: callable, **kwargs) -> DNDarray: if repeated_axes: raise NotImplementedError("Multiple transforms over the same axis not implemented yet.") s = kwargs.get("s", None) + s = sanitize_axis(x.gshape, s) norm = kwargs.get("norm", None) # non-distributed DNDarray @@ -142,6 +144,7 @@ def __fftn_op(x: DNDarray, fftn_op: callable, **kwargs) -> DNDarray: for i, axis in enumerate(axes): output_shape[axis] = s[i] else: + axes = tuple(range(x.ndim)) s = tuple(output_shape[axis] for axis in axes) output_shape = tuple(output_shape) @@ -170,6 +173,14 @@ def __fftn_op(x: DNDarray, fftn_op: callable, **kwargs) -> DNDarray: ) x = x.transpose(transpose_axes) + # original split is 0 and fft is along axis 0 + if x.ndim == 1: + _ = x.resplit(axis=None) + result = __fftn_op(_, fftn_op, **kwargs) + del _ + result.resplit_(axis=0) + return result + # redistribute x from axis 0 to 1 _ = x.resplit(axis=1) # FFT along axis 0 (now non-split) diff --git a/heat/fft/tests/test_fft.py b/heat/fft/tests/test_fft.py index 9f7de3de4e..afb00c402a 100644 --- a/heat/fft/tests/test_fft.py +++ b/heat/fft/tests/test_fft.py @@ -62,7 +62,7 @@ def test_fft(self): def test_ifft(self): # 1D non-distributed - x = ht.random.randn(6) + x = ht.random.randn(6, dtype=ht.float64) x_fft = ht.fft.fft(x) y = ht.fft.ifft(x_fft) self.assertIsInstance(y, ht.DNDarray) @@ -88,7 +88,22 @@ def test_irfft2(self): pass def test_fftn(self): - pass + # 1D non-distributed + x = ht.random.randn(6) + y = ht.fft.fftn(x) + np_y = np.fft.fftn(x.numpy()) + self.assertIsInstance(y, ht.DNDarray) + self.assertEqual(y.shape, x.shape) + self.assert_array_equal(y, np_y) + + # 1D distributed + x = ht.random.randn(6, split=0) + y = ht.fft.fftn(x) + np_y = np.fft.fftn(x.numpy()) + self.assertIsInstance(y, ht.DNDarray) + self.assertEqual(y.shape, x.shape) + self.assertTrue(y.split == 0) + self.assert_array_equal(y, np_y) def test_ifftn(self): pass From 2be8a24aee5fcb8272b5c5b07d9dba9b58d0b5e4 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Fri, 13 Oct 2023 17:10:39 +0200 Subject: [PATCH 10/50] expand tests and fix errors --- heat/fft/fft.py | 5 ++++- heat/fft/tests/test_fft.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/heat/fft/fft.py b/heat/fft/fft.py index e827c5c0de..76dbf69033 100644 --- a/heat/fft/fft.py +++ b/heat/fft/fft.py @@ -127,7 +127,10 @@ def __fftn_op(x: DNDarray, fftn_op: callable, **kwargs) -> DNDarray: if repeated_axes: raise NotImplementedError("Multiple transforms over the same axis not implemented yet.") s = kwargs.get("s", None) - s = sanitize_axis(x.gshape, s) + if s is not None and len(s) > x.ndim: + raise ValueError( + f"Input is {x.ndim}-dimensional, so s can be at most {x.ndim} elements long. Got {len(s)} elements instead." + ) norm = kwargs.get("norm", None) # non-distributed DNDarray diff --git a/heat/fft/tests/test_fft.py b/heat/fft/tests/test_fft.py index afb00c402a..6bd45bbc8f 100644 --- a/heat/fft/tests/test_fft.py +++ b/heat/fft/tests/test_fft.py @@ -105,6 +105,34 @@ def test_fftn(self): self.assertTrue(y.split == 0) self.assert_array_equal(y, np_y) + # n-D distributed + x = ht.random.randn(10, 8, 6, dtype=ht.float64, split=0) + # FFT along last 2 axes + y = ht.fft.fftn(x, s=(6, 6)) + np_y = np.fft.fftn(x.numpy(), s=(6, 6)) + self.assertIsInstance(y, ht.DNDarray) + self.assertEqual(y.shape, np_y.shape) + self.assertTrue(y.split == 0) + self.assert_array_equal(y, np_y) + + # FFT along distributed axis + y = ht.fft.fftn(x, axes=(0, 1), s=(10, 8)) + np_y = np.fft.fftn(x.numpy(), axes=(0, 1), s=(10, 8)) + self.assertIsInstance(y, ht.DNDarray) + self.assertEqual(y.shape, np_y.shape) + self.assertTrue(y.split == 0) + self.assert_array_equal(y, np_y) + + # exceptions + # wrong input type + x = torch.randn(6, 3, 3) + with self.assertRaises(TypeError): + ht.fft.fftn(x) + # s larger than dimensions + x = ht.random.randn(6, 3, 3, split=0) + with self.assertRaises(ValueError): + ht.fft.fftn(x, s=(10, 10, 10, 10)) + def test_ifftn(self): pass From 74ae80edaa03a60b2ae51541ed54b313e3be2f87 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Fri, 13 Oct 2023 18:05:01 +0200 Subject: [PATCH 11/50] add Hermitian FFTs --- heat/fft/fft.py | 116 ++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 107 insertions(+), 9 deletions(-) diff --git a/heat/fft/fft.py b/heat/fft/fft.py index 76dbf69033..e7ccfa5169 100644 --- a/heat/fft/fft.py +++ b/heat/fft/fft.py @@ -14,21 +14,21 @@ "fft", "fft2", "fftn", + "hfft", + "hfft2", + "hfftn", "ifft", "ifft2", "ifftn", - "rfft", - "rfft2", - "rfftn", - "irfft", - "irfft2", - "irfftn", - # "hfft", - # "hfft2", - # "hfftn", # "ihfft", # "ihfft2", # "ihfftn", + "irfft", + "irfft2", + "irfftn", + "rfft", + "rfft2", + "rfftn", # "fftfreq", # "rfftfreq", # "fftshift", @@ -300,6 +300,104 @@ def fftn( return __fftn_op(x, torch.fft.fftn, s=s, axes=axes, norm=norm) +def hfft(x: DNDarray, n: int = None, axis: int = -1, norm: str = None) -> DNDarray: + """ + Compute the one-dimensional discrete Fourier Transform of a Hermitian symmetric signal. + + This function computes the one-dimensional discrete Fourier Transform over the specified axis in an M-dimensional + array by means of the Fast Fourier Transform (FFT). By default, the last axis is transformed, while the remaining + axes are left unchanged. The input signal is assumed to be Hermitian-symmetric, i.e. `x[..., i] = x[..., -i].conj()`. + + Parameters + ---------- + x : DNDarray + Input array + n : int, optional + Length of the transformed axis of the output. + If `n` is not None, the input array is either zero-padded or trimmed to length `n` before the transform. + Default: `2 * (x.shape[axis] - 1)`. + axis : int, optional + Axis over which to compute the FFT. If not given, the last axis is used, or the only axis if x has only one + dimension. Default: -1. + norm : str, optional + Normalization mode: 'forward', 'backward', or 'ortho' (see `numpy.fft` for details). Default is "backward". + + Notes + ----- + This function requires MPI communication if the input array is transformed along the distribution axis. + """ + if n is None: + n = 2 * (x.shape[axis] - 1) + return __fft_op(x, torch.fft.hfft, n=n, axis=axis, norm=norm) + + +def hfft2( + x: DNDarray, s: Tuple[int, int] = None, axes: Tuple[int, int] = (-2, -1), norm: str = None +) -> DNDarray: + """ + Compute the 2-dimensional discrete Fourier Transform of a Hermitian symmetric signal. + + This function computes the 2-dimensional discrete Fourier Transform over the specified axes in an M-dimensional + array by means of the Fast Fourier Transform (FFT). By default, the last two axes are transformed, while the + remaining axes are left unchanged. The input signal is assumed to be Hermitian-symmetric, i.e. `x[..., i] = x[..., -i].conj()`. + + Parameters + ---------- + x : DNDarray + Input array + s : Tuple[int, int], optional + Shape of the signal along the transformed axes. If `s` is specified, the input array is either zero-padded or trimmed to length `s` before the transform. + If `s` is not given, the last dimension defaults to even output: `s[-1] = 2 * (x.shape[-1] - 1)`. + axes : Tuple[int, int], optional + Axes over which to compute the FFT. If not given, the last two dimensions are transformed. Repeated indices in `axes` means that the transform over that axis is performed multiple times. + norm : str, optional + Normalization mode: 'forward', 'backward', or 'ortho' (see `numpy.fft` for details). Default is "backward". + + Notes + ----- + This function requires MPI communication if the input array is distributed and the split axis is transformed. + """ + if s is None: + s = (x.shape[axes[0]], 2 * (x.shape[axes[1]] - 1)) + return __fftn_op(x, torch.fft.hfft2, s=s, axes=axes, norm=norm) + + +def hfftn( + x: DNDarray, s: Tuple[int, ...] = None, axes: Tuple[int, ...] = None, norm: str = None +) -> DNDarray: + """ + Compute the N-dimensional discrete Fourier Transform of a Hermitian symmetric signal. + + This function computes the N-dimensional discrete Fourier Transform over any number of axes in an M-dimensional + array by means of the Fast Fourier Transform (FFT). By default, all axes are transformed. + + Parameters + ---------- + x : DNDarray + Input array + s : Tuple[int, ...], optional + Shape of the signal along the transformed axes. If `s` is specified, the input array is either zero-padded or trimmed to length `s` before the transform. + If `s` is not given, the last dimension defaults to even output: `s[-1] = 2 * (x.shape[-1] - 1)`. + axes : Tuple[int, ...], optional + Axes over which to compute the FFT. If not given, all dimensions are transformed. Repeated indices in `axes` means that the transform over that axis is performed multiple times. + norm : str, optional + Normalization mode: 'forward', 'backward', or 'ortho' (see `numpy.fft` for details). Default is "backward". + + Notes + ----- + This function requires MPI communication if the input array is distributed and the split axis is transformed. + """ + if s is None: + if axes is not None: + s = list(x.shape[axis] for axis in axes) + else: + s = list(x.shape) + s[-1] = 2 * (s[-1] - 1) + s = tuple(s) + + return __fftn_op(x, torch.fft.hfftn, s=s, axes=axes, norm=norm) + + def ifft(x: DNDarray, n: int = None, axis: int = -1, norm: str = None) -> DNDarray: """ Compute the one-dimensional inverse discrete Fourier Transform. From 35b0d43f5979230543ff5b88389fad4648084427 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Fri, 13 Oct 2023 18:07:35 +0200 Subject: [PATCH 12/50] heat/fft/tests/test_fft.py --- heat/fft/tests/test_fft.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/heat/fft/tests/test_fft.py b/heat/fft/tests/test_fft.py index 6bd45bbc8f..b203239797 100644 --- a/heat/fft/tests/test_fft.py +++ b/heat/fft/tests/test_fft.py @@ -116,11 +116,12 @@ def test_fftn(self): self.assert_array_equal(y, np_y) # FFT along distributed axis + x.resplit_(axis=1) y = ht.fft.fftn(x, axes=(0, 1), s=(10, 8)) np_y = np.fft.fftn(x.numpy(), axes=(0, 1), s=(10, 8)) self.assertIsInstance(y, ht.DNDarray) self.assertEqual(y.shape, np_y.shape) - self.assertTrue(y.split == 0) + self.assertTrue(y.split == 1) self.assert_array_equal(y, np_y) # exceptions From f578522d3fa76a4c2fe9eb665184e9bf32245aa8 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Mon, 16 Oct 2023 11:26:41 +0200 Subject: [PATCH 13/50] raise IndexError, not ValueError, when axes don't match dimensions --- heat/fft/fft.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/heat/fft/fft.py b/heat/fft/fft.py index e7ccfa5169..3087af6980 100644 --- a/heat/fft/fft.py +++ b/heat/fft/fft.py @@ -48,7 +48,10 @@ def __fft_op(x: DNDarray, fft_op: callable, **kwargs) -> DNDarray: # sanitize kwargs axis = kwargs.get("axis", None) - axis = sanitize_axis(x.gshape, axis) + try: + axis = sanitize_axis(x.gshape, axis) + except ValueError as e: + raise IndexError(e) n = kwargs.get("n", None) norm = kwargs.get("norm", None) @@ -122,7 +125,10 @@ def __fftn_op(x: DNDarray, fftn_op: callable, **kwargs) -> DNDarray: # sanitize kwargs axes = kwargs.get("axes", None) - axes = sanitize_axis(x.gshape, axes) + try: + axes = sanitize_axis(x.gshape, axes) + except ValueError as e: + raise IndexError(e) repeated_axes = axes is not None and len(axes) != len(set(axes)) if repeated_axes: raise NotImplementedError("Multiple transforms over the same axis not implemented yet.") From 23fea9e73ef393034db7b98ec01875e1043a2724 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Mon, 16 Oct 2023 11:27:31 +0200 Subject: [PATCH 14/50] expand tests --- heat/fft/tests/test_fft.py | 41 +++++++++++++++++++++++++++++--------- 1 file changed, 32 insertions(+), 9 deletions(-) diff --git a/heat/fft/tests/test_fft.py b/heat/fft/tests/test_fft.py index b203239797..dab70e5200 100644 --- a/heat/fft/tests/test_fft.py +++ b/heat/fft/tests/test_fft.py @@ -17,20 +17,22 @@ def test_fft(self): # 1D distributed x = ht.random.randn(6, split=0) - y = ht.fft.fft(x) - np_y = np.fft.fft(x.numpy()) + n = 8 + y = ht.fft.fft(x, n=n) + np_y = np.fft.fft(x.numpy(), n=n) self.assertIsInstance(y, ht.DNDarray) - self.assertEqual(y.shape, x.shape) + self.assertEqual(y.shape, np_y.shape) self.assertTrue(y.split == 0) self.assert_array_equal(y, np_y) # n-D distributed x = ht.random.randn(10, 8, 6, dtype=ht.float64, split=0) # FFT along last axis - y = ht.fft.fft(x) - np_y = np.fft.fft(x.numpy()) + n = 5 + y = ht.fft.fft(x, n=n) + np_y = np.fft.fft(x.numpy(), n=n) self.assertIsInstance(y, ht.DNDarray) - self.assertEqual(y.shape, x.shape) + self.assertEqual(y.shape, np_y.shape) self.assertTrue(y.split == 0) self.assert_array_equal(y, np_y) @@ -60,6 +62,30 @@ def test_fft(self): with self.assertRaises(TypeError): ht.fft.fft(x) + def test_fft2(self): + # 2D FFT along non-split axes + x = ht.random.randn(10, 6, 6, split=0) + y = ht.fft.fft2(x) + np_y = np.fft.fft2(x.numpy()) + self.assertIsInstance(y, ht.DNDarray) + self.assertEqual(y.shape, np_y.shape) + self.assertTrue(y.split == 0) + self.assert_array_equal(y, np_y) + + # 2D FFT along split axes + x = ht.random.randn(10, 6, 6, split=0) + axes = (0, 1) + y = ht.fft.fft2(x, axes=axes) + np_y = np.fft.fft2(x.numpy(), axes=axes) + self.assertEqual(y.shape, np_y.shape) + self.assertTrue(y.split == 0) + self.assert_array_equal(y, np_y) + + # exceptions + x = ht.arange(10, split=0) + with self.assertRaises(IndexError): + ht.fft.fft2(x) + def test_ifft(self): # 1D non-distributed x = ht.random.randn(6, dtype=ht.float64) @@ -75,9 +101,6 @@ def test_rfft(self): def test_irfft(self): pass - def test_fft2(self): - pass - def test_ifft2(self): pass From 75e289f0c821831de0982de9ef85cdc9cb57ceae Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Mon, 16 Oct 2023 11:28:29 +0200 Subject: [PATCH 15/50] edit error message for better understanding --- heat/core/stride_tricks.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/heat/core/stride_tricks.py b/heat/core/stride_tricks.py index 2424ad23c9..10010d6f15 100644 --- a/heat/core/stride_tricks.py +++ b/heat/core/stride_tricks.py @@ -151,7 +151,10 @@ def sanitize_axis( """ # scalars are handled like unsplit matrices - if len(shape) == 0: + original_axis = axis.copy() if isinstance(axis, list) else axis + ndim = len(shape) + + if ndim == 0: axis = None if axis is not None and not isinstance(axis, int) and not isinstance(axis, tuple): @@ -160,7 +163,9 @@ def sanitize_axis( axis = tuple(dim + len(shape) if dim < 0 else dim for dim in axis) for dim in axis: if dim < 0 or dim >= len(shape): - raise ValueError(f"axis {axis} is out of bounds for shape {shape}") + raise ValueError( + f"axis {original_axis} is out of bounds for {ndim}-dimensional array" + ) return axis if axis is None or 0 <= axis < len(shape): @@ -169,7 +174,7 @@ def sanitize_axis( axis += len(shape) if axis < 0 or axis >= len(shape): - raise ValueError(f"axis {axis} is out of bounds for shape {shape}") + raise ValueError(f"axis {original_axis} is out of bounds for {ndim}-dimensional array") return axis From 02b0a296d4c21f760e063b7575337050a4e8a090 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Wed, 18 Oct 2023 17:24:47 +0200 Subject: [PATCH 16/50] replace == with allclose for 2D FFTs --- heat/fft/tests/test_fft.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/heat/fft/tests/test_fft.py b/heat/fft/tests/test_fft.py index dab70e5200..37337acad2 100644 --- a/heat/fft/tests/test_fft.py +++ b/heat/fft/tests/test_fft.py @@ -79,7 +79,7 @@ def test_fft2(self): np_y = np.fft.fft2(x.numpy(), axes=axes) self.assertEqual(y.shape, np_y.shape) self.assertTrue(y.split == 0) - self.assert_array_equal(y, np_y) + self.assertTrue(ht.allclose(y, np_y)) # exceptions x = ht.arange(10, split=0) From 47a2fe856a70718cd7d5d08caaed85a51087317f Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Wed, 18 Oct 2023 17:46:26 +0200 Subject: [PATCH 17/50] fix error --- heat/fft/tests/test_fft.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/heat/fft/tests/test_fft.py b/heat/fft/tests/test_fft.py index 37337acad2..e21812c0f4 100644 --- a/heat/fft/tests/test_fft.py +++ b/heat/fft/tests/test_fft.py @@ -79,7 +79,7 @@ def test_fft2(self): np_y = np.fft.fft2(x.numpy(), axes=axes) self.assertEqual(y.shape, np_y.shape) self.assertTrue(y.split == 0) - self.assertTrue(ht.allclose(y, np_y)) + self.assertTrue(ht.allclose(y, ht.array(np_y, split=y.split))) # exceptions x = ht.arange(10, split=0) From e2c10ddefa17f7c98ec5abf09425986b451ef279 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Thu, 19 Oct 2023 11:26:33 +0200 Subject: [PATCH 18/50] remove redundant communication --- heat/core/tests/test_suites/basic_test.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/heat/core/tests/test_suites/basic_test.py b/heat/core/tests/test_suites/basic_test.py index 4ef3871419..965eda75f9 100644 --- a/heat/core/tests/test_suites/basic_test.py +++ b/heat/core/tests/test_suites/basic_test.py @@ -129,11 +129,11 @@ def assert_array_equal(self, heat_array, expected_array): f"Local shapes do not match. Got {heat_array.lshape} expected {expected_array[slices].shape}", ) # compare local tensors to corresponding slice of expected_array - is_allclose = np.allclose(heat_array.larray.cpu(), expected_array[slices]) - ht_is_allclose = ht.array( - [is_allclose], dtype=ht.bool, is_split=0, device=heat_array.device + is_allclose = torch.tensor( + np.allclose(heat_array.larray.cpu(), expected_array[slices]), dtype=torch.int32 ) - self.assertTrue(ht.all(ht_is_allclose)) + heat_array.comm.Allreduce(MPI.IN_PLACE, is_allclose, MPI.SUM) + self.assertTrue(is_allclose == heat_array.comm.size) def assert_func_equal( self, From 3f0eac6de206158051ec5c89cf584340239089ea Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Thu, 19 Oct 2023 11:27:37 +0200 Subject: [PATCH 19/50] remove redundant tests --- heat/fft/tests/test_fft.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/heat/fft/tests/test_fft.py b/heat/fft/tests/test_fft.py index e21812c0f4..b294037753 100644 --- a/heat/fft/tests/test_fft.py +++ b/heat/fft/tests/test_fft.py @@ -67,8 +67,6 @@ def test_fft2(self): x = ht.random.randn(10, 6, 6, split=0) y = ht.fft.fft2(x) np_y = np.fft.fft2(x.numpy()) - self.assertIsInstance(y, ht.DNDarray) - self.assertEqual(y.shape, np_y.shape) self.assertTrue(y.split == 0) self.assert_array_equal(y, np_y) @@ -77,9 +75,8 @@ def test_fft2(self): axes = (0, 1) y = ht.fft.fft2(x, axes=axes) np_y = np.fft.fft2(x.numpy(), axes=axes) - self.assertEqual(y.shape, np_y.shape) self.assertTrue(y.split == 0) - self.assertTrue(ht.allclose(y, ht.array(np_y, split=y.split))) + self.assert_array_equal(y, np_y) # exceptions x = ht.arange(10, split=0) From 585bac3589b725b3f4a545908370432c7f84cf0f Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Thu, 19 Oct 2023 11:28:11 +0200 Subject: [PATCH 20/50] fix bug in axes handling --- heat/fft/fft.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/heat/fft/fft.py b/heat/fft/fft.py index 3087af6980..c6342c7ff8 100644 --- a/heat/fft/fft.py +++ b/heat/fft/fft.py @@ -153,7 +153,8 @@ def __fftn_op(x: DNDarray, fftn_op: callable, **kwargs) -> DNDarray: for i, axis in enumerate(axes): output_shape[axis] = s[i] else: - axes = tuple(range(x.ndim)) + if axes is None: + axes = tuple(range(x.ndim)) s = tuple(output_shape[axis] for axis in axes) output_shape = tuple(output_shape) From 432d6ff3a644db4ee18e2456e790b94a7a54b1c5 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Thu, 19 Oct 2023 13:09:48 +0200 Subject: [PATCH 21/50] test hermitian FFT --- heat/fft/tests/test_fft.py | 64 ++++++++++++++++++++++++-------------- 1 file changed, 40 insertions(+), 24 deletions(-) diff --git a/heat/fft/tests/test_fft.py b/heat/fft/tests/test_fft.py index b294037753..d5183a1cd6 100644 --- a/heat/fft/tests/test_fft.py +++ b/heat/fft/tests/test_fft.py @@ -83,30 +83,6 @@ def test_fft2(self): with self.assertRaises(IndexError): ht.fft.fft2(x) - def test_ifft(self): - # 1D non-distributed - x = ht.random.randn(6, dtype=ht.float64) - x_fft = ht.fft.fft(x) - y = ht.fft.ifft(x_fft) - self.assertIsInstance(y, ht.DNDarray) - self.assertEqual(y.shape, x.shape) - self.assert_array_equal(y, x.numpy()) - - def test_rfft(self): - pass - - def test_irfft(self): - pass - - def test_ifft2(self): - pass - - def test_rfft2(self): - pass - - def test_irfft2(self): - pass - def test_fftn(self): # 1D non-distributed x = ht.random.randn(6) @@ -154,6 +130,46 @@ def test_fftn(self): with self.assertRaises(ValueError): ht.fft.fftn(x, s=(10, 10, 10, 10)) + def test_hfft(self): + # follows example in torch.fft.hfft docs + x = ht.zeros((3, 5), split=0) + edges = [1, 3, 7] + for i, n in enumerate(edges): + x[i] = ht.linspace(0, n, 5) + + inv_fft = ht.fft.ifft(x) + # inv_fft is hermitian symmetric along the rows + # we can reconstruct the original signal by transforming the first half of the rows only + reconstructed_x = ht.fft.hfft(inv_fft[:3], n=5) + self.assertTrue(ht.allclose(reconstructed_x, x)) + n = 2 * (x.shape[-1] - 1) + reconstructed_x = ht.fft.hfft(inv_fft[:3]) + self.assertEqual(reconstructed_x.shape, (3, n)) + + def test_ifft(self): + # 1D non-distributed + x = ht.random.randn(6, dtype=ht.float64) + x_fft = ht.fft.fft(x) + y = ht.fft.ifft(x_fft) + self.assertIsInstance(y, ht.DNDarray) + self.assertEqual(y.shape, x.shape) + self.assert_array_equal(y, x.numpy()) + + def test_rfft(self): + pass + + def test_irfft(self): + pass + + def test_ifft2(self): + pass + + def test_rfft2(self): + pass + + def test_irfft2(self): + pass + def test_ifftn(self): pass From 51e7b3319b20e14373b33e1e476b79d2a1b14573 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Thu, 19 Oct 2023 13:11:16 +0200 Subject: [PATCH 22/50] cast numpy fft2 to complex64 --- heat/fft/tests/test_fft.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/heat/fft/tests/test_fft.py b/heat/fft/tests/test_fft.py index d5183a1cd6..effc818525 100644 --- a/heat/fft/tests/test_fft.py +++ b/heat/fft/tests/test_fft.py @@ -66,7 +66,7 @@ def test_fft2(self): # 2D FFT along non-split axes x = ht.random.randn(10, 6, 6, split=0) y = ht.fft.fft2(x) - np_y = np.fft.fft2(x.numpy()) + np_y = np.fft.fft2(x.numpy()).astype(np.complex64) self.assertTrue(y.split == 0) self.assert_array_equal(y, np_y) @@ -74,7 +74,7 @@ def test_fft2(self): x = ht.random.randn(10, 6, 6, split=0) axes = (0, 1) y = ht.fft.fft2(x, axes=axes) - np_y = np.fft.fft2(x.numpy(), axes=axes) + np_y = np.fft.fft2(x.numpy(), axes=axes).astype(np.complex64) self.assertTrue(y.split == 0) self.assert_array_equal(y, np_y) From f241144f3183b275716e4fd79b88cd3b93d0969e Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Thu, 19 Oct 2023 13:28:55 +0200 Subject: [PATCH 23/50] expand tests --- heat/fft/tests/test_fft.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/heat/fft/tests/test_fft.py b/heat/fft/tests/test_fft.py index effc818525..0b6c36ef78 100644 --- a/heat/fft/tests/test_fft.py +++ b/heat/fft/tests/test_fft.py @@ -61,20 +61,27 @@ def test_fft(self): x = np.random.randn(6, 3, 3) with self.assertRaises(TypeError): ht.fft.fft(x) + # axis out of range + x = ht.random.randn(6, 3, 3) + with self.assertRaises(IndexError): + ht.fft.fft(x, axis=3) + # n-D axes + with self.assertRaises(TypeError): + ht.fft.fft(x, axis=(0, 1)) def test_fft2(self): # 2D FFT along non-split axes - x = ht.random.randn(10, 6, 6, split=0) + x = ht.random.randn(10, 6, 6, split=0, dtype=ht.float64) y = ht.fft.fft2(x) - np_y = np.fft.fft2(x.numpy()).astype(np.complex64) + np_y = np.fft.fft2(x.numpy()) self.assertTrue(y.split == 0) self.assert_array_equal(y, np_y) # 2D FFT along split axes - x = ht.random.randn(10, 6, 6, split=0) + x = ht.random.randn(10, 6, 6, split=0, dtype=ht.float64) axes = (0, 1) y = ht.fft.fft2(x, axes=axes) - np_y = np.fft.fft2(x.numpy(), axes=axes).astype(np.complex64) + np_y = np.fft.fft2(x.numpy(), axes=axes) self.assertTrue(y.split == 0) self.assert_array_equal(y, np_y) From eee000131d1c028eefc77e5e0df8cffd16212ec2 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Thu, 19 Oct 2023 13:29:23 +0200 Subject: [PATCH 24/50] edit error messages --- heat/fft/fft.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/heat/fft/fft.py b/heat/fft/fft.py index c6342c7ff8..34e6147473 100644 --- a/heat/fft/fft.py +++ b/heat/fft/fft.py @@ -43,7 +43,7 @@ def __fft_op(x: DNDarray, fft_op: callable, **kwargs) -> DNDarray: try: local_x = x.larray except AttributeError: - raise TypeError("x must be a DNDarray, is {}".format(type(x))) + raise TypeError(f"x must be a DNDarray, is {type(x)}") original_split = x.split # sanitize kwargs @@ -52,6 +52,8 @@ def __fft_op(x: DNDarray, fft_op: callable, **kwargs) -> DNDarray: axis = sanitize_axis(x.gshape, axis) except ValueError as e: raise IndexError(e) + if isinstance(axis, tuple) and len(axis) > 1: + raise TypeError(f"axis must be an integer, got {axis}") n = kwargs.get("n", None) norm = kwargs.get("norm", None) From acfbea3dbfe1955144e0217c2961d65dd8f7a611 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Thu, 19 Oct 2023 14:01:48 +0200 Subject: [PATCH 25/50] remove unnecessary axis check --- heat/fft/fft.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/heat/fft/fft.py b/heat/fft/fft.py index 34e6147473..f62067107a 100644 --- a/heat/fft/fft.py +++ b/heat/fft/fft.py @@ -66,10 +66,7 @@ def __fft_op(x: DNDarray, fft_op: callable, **kwargs) -> DNDarray: # calculate output shape output_shape = list(x.shape) if n is not None: - if axis is None: - output_shape[-1] = n - else: - output_shape[axis] = n + output_shape[axis] = n fft_along_split = original_split == axis From 6d6b0fd9b9753ac061b71e66dc5aa0b3a5ec2624 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Thu, 19 Oct 2023 14:02:42 +0200 Subject: [PATCH 26/50] test inverse ffts as well --- heat/fft/tests/test_fft.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/heat/fft/tests/test_fft.py b/heat/fft/tests/test_fft.py index 0b6c36ef78..97cfa8cf59 100644 --- a/heat/fft/tests/test_fft.py +++ b/heat/fft/tests/test_fft.py @@ -6,7 +6,7 @@ class TestFFT(TestCase): - def test_fft(self): + def test_fft_ifft(self): # 1D non-distributed x = ht.random.randn(6) y = ht.fft.fft(x) @@ -14,6 +14,8 @@ def test_fft(self): self.assertIsInstance(y, ht.DNDarray) self.assertEqual(y.shape, x.shape) self.assert_array_equal(y, np_y) + backwards = ht.fft.ifft(y) + self.assertTrue(ht.allclose(backwards, x)) # 1D distributed x = ht.random.randn(6, split=0) @@ -69,13 +71,15 @@ def test_fft(self): with self.assertRaises(TypeError): ht.fft.fft(x, axis=(0, 1)) - def test_fft2(self): + def test_fft2_ifft2(self): # 2D FFT along non-split axes x = ht.random.randn(10, 6, 6, split=0, dtype=ht.float64) y = ht.fft.fft2(x) np_y = np.fft.fft2(x.numpy()) self.assertTrue(y.split == 0) self.assert_array_equal(y, np_y) + backwards = ht.fft.ifft2(y) + self.assertTrue(ht.allclose(backwards, x)) # 2D FFT along split axes x = ht.random.randn(10, 6, 6, split=0, dtype=ht.float64) @@ -84,13 +88,15 @@ def test_fft2(self): np_y = np.fft.fft2(x.numpy(), axes=axes) self.assertTrue(y.split == 0) self.assert_array_equal(y, np_y) + backwards = ht.fft.ifft2(y, axes=axes) + self.assertTrue(ht.allclose(backwards, x)) # exceptions x = ht.arange(10, split=0) with self.assertRaises(IndexError): ht.fft.fft2(x) - def test_fftn(self): + def test_fftn_ifftn(self): # 1D non-distributed x = ht.random.randn(6) y = ht.fft.fftn(x) @@ -98,6 +104,8 @@ def test_fftn(self): self.assertIsInstance(y, ht.DNDarray) self.assertEqual(y.shape, x.shape) self.assert_array_equal(y, np_y) + backwards = ht.fft.ifftn(y) + self.assertTrue(ht.allclose(backwards, x)) # 1D distributed x = ht.random.randn(6, split=0) @@ -137,9 +145,9 @@ def test_fftn(self): with self.assertRaises(ValueError): ht.fft.fftn(x, s=(10, 10, 10, 10)) - def test_hfft(self): + def test_hfft_ihfft(self): # follows example in torch.fft.hfft docs - x = ht.zeros((3, 5), split=0) + x = ht.zeros((3, 5), split=0, dtype=ht.float64) edges = [1, 3, 7] for i, n in enumerate(edges): x[i] = ht.linspace(0, n, 5) From 1cd62f7bfb7c89884ec067089b047f7369c2df7a Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Thu, 19 Oct 2023 15:25:43 +0200 Subject: [PATCH 27/50] skip comm-intensive tests on gpu --- heat/fft/tests/test_fft.py | 122 +++++++++++++++++++------------------ 1 file changed, 62 insertions(+), 60 deletions(-) diff --git a/heat/fft/tests/test_fft.py b/heat/fft/tests/test_fft.py index 97cfa8cf59..48ab091bfa 100644 --- a/heat/fft/tests/test_fft.py +++ b/heat/fft/tests/test_fft.py @@ -38,25 +38,27 @@ def test_fft_ifft(self): self.assertTrue(y.split == 0) self.assert_array_equal(y, np_y) - # FFT along distributed axis, n not None - n = 8 - y = ht.fft.fft(x, axis=0, n=n) - np_y = np.fft.fft(x.numpy(), axis=0, n=n) - self.assertIsInstance(y, ht.DNDarray) - self.assertEqual(y.shape, np_y.shape) - self.assertTrue(y.split == 0) - self.assert_array_equal(y, np_y) - - # complex input - x = x + 1j * ht.random.randn(10, 8, 6, dtype=ht.float64, split=0) - # FFT along last axis (distributed) - x.resplit_(axis=2) - y = ht.fft.fft(x, n=n) - np_y = np.fft.fft(x.numpy(), n=n) - self.assertIsInstance(y, ht.DNDarray) - self.assertEqual(y.shape, np_y.shape) - self.assertTrue(y.split == 2) - self.assert_array_equal(y, np_y) + # on GPU, test only on less than 4 processes + if x.device == ht.cpu or x.comm.size < 4: + # FFT along distributed axis, n not None + n = 8 + y = ht.fft.fft(x, axis=0, n=n) + np_y = np.fft.fft(x.numpy(), axis=0, n=n) + self.assertIsInstance(y, ht.DNDarray) + self.assertEqual(y.shape, np_y.shape) + self.assertTrue(y.split == 0) + self.assert_array_equal(y, np_y) + + # complex input + x = x + 1j * ht.random.randn(10, 8, 6, dtype=ht.float64, split=0) + # FFT along last axis (distributed) + x.resplit_(axis=2) + y = ht.fft.fft(x, n=n) + np_y = np.fft.fft(x.numpy(), n=n) + self.assertIsInstance(y, ht.DNDarray) + self.assertEqual(y.shape, np_y.shape) + self.assertTrue(y.split == 2) + self.assert_array_equal(y, np_y) # exceptions # wrong input type @@ -81,15 +83,16 @@ def test_fft2_ifft2(self): backwards = ht.fft.ifft2(y) self.assertTrue(ht.allclose(backwards, x)) - # 2D FFT along split axes - x = ht.random.randn(10, 6, 6, split=0, dtype=ht.float64) - axes = (0, 1) - y = ht.fft.fft2(x, axes=axes) - np_y = np.fft.fft2(x.numpy(), axes=axes) - self.assertTrue(y.split == 0) - self.assert_array_equal(y, np_y) - backwards = ht.fft.ifft2(y, axes=axes) - self.assertTrue(ht.allclose(backwards, x)) + if x.device == ht.cpu or x.comm.size < 4: + # 2D FFT along split axes + x = ht.random.randn(10, 6, 6, split=0, dtype=ht.float64) + axes = (0, 1) + y = ht.fft.fft2(x, axes=axes) + np_y = np.fft.fft2(x.numpy(), axes=axes) + self.assertTrue(y.split == 0) + self.assert_array_equal(y, np_y) + backwards = ht.fft.ifft2(y, axes=axes) + self.assertTrue(ht.allclose(backwards, x)) # exceptions x = ht.arange(10, split=0) @@ -126,14 +129,16 @@ def test_fftn_ifftn(self): self.assertTrue(y.split == 0) self.assert_array_equal(y, np_y) - # FFT along distributed axis - x.resplit_(axis=1) - y = ht.fft.fftn(x, axes=(0, 1), s=(10, 8)) - np_y = np.fft.fftn(x.numpy(), axes=(0, 1), s=(10, 8)) - self.assertIsInstance(y, ht.DNDarray) - self.assertEqual(y.shape, np_y.shape) - self.assertTrue(y.split == 1) - self.assert_array_equal(y, np_y) + # on GPU, test only on less than 4 processes + if x.device == ht.cpu or x.comm.size < 4: + # FFT along distributed axis + x.resplit_(axis=1) + y = ht.fft.fftn(x, axes=(0, 1), s=(10, 8)) + np_y = np.fft.fftn(x.numpy(), axes=(0, 1), s=(10, 8)) + self.assertIsInstance(y, ht.DNDarray) + self.assertEqual(y.shape, np_y.shape) + self.assertTrue(y.split == 1) + self.assert_array_equal(y, np_y) # exceptions # wrong input type @@ -161,32 +166,29 @@ def test_hfft_ihfft(self): reconstructed_x = ht.fft.hfft(inv_fft[:3]) self.assertEqual(reconstructed_x.shape, (3, n)) - def test_ifft(self): - # 1D non-distributed - x = ht.random.randn(6, dtype=ht.float64) - x_fft = ht.fft.fft(x) - y = ht.fft.ifft(x_fft) - self.assertIsInstance(y, ht.DNDarray) - self.assertEqual(y.shape, x.shape) - self.assert_array_equal(y, x.numpy()) - - def test_rfft(self): - pass - - def test_irfft(self): - pass - - def test_ifft2(self): - pass - - def test_rfft2(self): - pass + def test_hfftn_ihfftn(self): + # follows example in torch.fft.hfftn docs + x = ht.random.randn(10, 6, 6, dtype=ht.float64) + inv_fft = ht.fft.ifftn(x) + reconstructed_x = ht.fft.hfftn(inv_fft, s=x.shape) + self.assertTrue(ht.allclose(reconstructed_x, x)) - def test_irfft2(self): - pass + def test_rfft_irfft(self): + # n-D distributed + x = ht.random.randn(10, 8, 6, dtype=ht.float64, split=0) + # FFT along last axis + y = ht.fft.fft(x) + np_y = np.fft.fft(x.numpy()) + self.assertTrue(y.split == 0) + self.assert_array_equal(y, np_y) + backwards = ht.fft.irfft(y, n=x.shape[-1]) + self.assertTrue(ht.allclose(backwards, x)) - def test_ifftn(self): - pass + # exceptions + # complex input + x = x + 1j * ht.random.randn(10, 8, 6, dtype=ht.float64, split=0) + with self.assertRaises(TypeError): + ht.fft.rfft(x) def test_rfftn(self): pass From ba73d6280131dd620c01d0e4687d2c3ecccecfdd Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Thu, 19 Oct 2023 15:26:20 +0200 Subject: [PATCH 28/50] introduce helper functions for real fft operations --- heat/fft/fft.py | 50 ++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 47 insertions(+), 3 deletions(-) diff --git a/heat/fft/fft.py b/heat/fft/fft.py index f62067107a..ccc7ec0bb6 100644 --- a/heat/fft/fft.py +++ b/heat/fft/fft.py @@ -215,6 +215,28 @@ def __fftn_op(x: DNDarray, fftn_op: callable, **kwargs) -> DNDarray: return array(result.larray, is_split=original_split, device=x.device, comm=x.comm) +def __real_fft_op(x: DNDarray, fft_op: callable, **kwargs) -> DNDarray: + try: + result = __fft_op(x, fft_op, **kwargs) + except RuntimeError as e: + if "real input tensor" in str(e): + raise TypeError(f"Input array must be real, is {x.dtype}.") + else: + raise e + return result + + +def __real_fftn_op(x: DNDarray, fftn_op: callable, **kwargs) -> DNDarray: + try: + result = __fftn_op(x, fftn_op, **kwargs) + except RuntimeError as e: + if "real input tensor" in str(e): + raise TypeError(f"Input array must be real, is {x.dtype}.") + else: + raise e + return result + + def fft(x: DNDarray, n: int = None, axis: int = -1, norm: str = None) -> DNDarray: """ Compute the one-dimensional discrete Fourier Transform. @@ -482,6 +504,28 @@ def ifftn( return __fftn_op(x, torch.fft.ifftn, s=s, axes=axes, norm=norm) +def ihfftn( + x: DNDarray, s: Tuple[int, ...] = None, axes: Tuple[int, ...] = None, norm: str = None +) -> DNDarray: + """ + Compute the N-dimensional inverse discrete Fourier Transform of a real signal. The output is Hermitian-symmetric. + + Parameters + ---------- + x : DNDarray + Input array, must be real + s : Tuple[int, ...], optional + Shape of the output along the transformed axes. (default is x.shape) + axes : Tuple[int, ...], optional + Axes over which to compute the inverse FFT. If not given, the last `len(s)` axes are used, or all axes if `s` is + also not specified. Repeated indices in `axes` means that the transform over that axis is performed multiple + times. (default is None) + norm : str, optional + Normalization mode: 'forward', 'backward', or 'ortho' (see `numpy.fft` for details). Default is "backward". + """ + return __real_fftn_op(x, torch.fft.ihfftn, s=s, axes=axes, norm=norm) + + def irfft(x: DNDarray, n: int = None, axis: int = -1, norm: str = None) -> DNDarray: """ Compute the one-dimensional inverse discrete Fourier Transform for real input. @@ -583,7 +627,7 @@ def rfft(x: DNDarray, n: int = None, axis: int = -1, norm: str = None) -> DNDarr If the input array is 1-D and distributed, this function copies the entire array on each MPI process! i.e. if the array is very large, you might run out of memory. Hint: if you are looping through a batch of 1-D arrays to transform them, consider stacking them into a 2-D DNDarray and transforming them all at once (see :func:`rfft2`). """ - return __fft_op(x, torch.fft.rfft, n=n, axis=axis, norm=norm) + return __real_fft_op(x, torch.fft.rfft, n=n, axis=axis, norm=norm) def rfft2( @@ -609,7 +653,7 @@ def rfft2( ----- This function requires MPI communication if the input array is distributed and the split axis is transformed. """ - return __fftn_op(x, torch.fft.rfft2, s=s, axes=axes, norm=norm) + return __real_fftn_op(x, torch.fft.rfft2, s=s, axes=axes, norm=norm) def rfftn( @@ -635,4 +679,4 @@ def rfftn( ----- This function requires MPI communication if the input array is distributed and the split axis is transformed. """ - return __fftn_op(x, torch.fft.rfftn, s=s, axes=axes, norm=norm) + return __real_fftn_op(x, torch.fft.rfftn, s=s, axes=axes, norm=norm) From d5267006d79616c3fcf21113a050b0d6fbf8f470 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Tue, 24 Oct 2023 14:17:09 +0200 Subject: [PATCH 29/50] fix output shape wrt Nyquist frequency --- heat/fft/fft.py | 158 ++++++++++++++++++++++-------------------------- 1 file changed, 72 insertions(+), 86 deletions(-) diff --git a/heat/fft/fft.py b/heat/fft/fft.py index ccc7ec0bb6..f765ca26e3 100644 --- a/heat/fft/fft.py +++ b/heat/fft/fft.py @@ -52,31 +52,23 @@ def __fft_op(x: DNDarray, fft_op: callable, **kwargs) -> DNDarray: axis = sanitize_axis(x.gshape, axis) except ValueError as e: raise IndexError(e) - if isinstance(axis, tuple) and len(axis) > 1: + if axis is None: + axis = x.ndim - 1 + elif isinstance(axis, tuple) and len(axis) > 1: raise TypeError(f"axis must be an integer, got {axis}") n = kwargs.get("n", None) norm = kwargs.get("norm", None) - # non-distributed DNDarray - if not x.is_distributed(): - result = fft_op(local_x, n=n, dim=axis, norm=norm) - return array(result, split=original_split, device=x.device, comm=x.comm) - - # distributed DNDarray: - # calculate output shape - output_shape = list(x.shape) - if n is not None: - output_shape[axis] = n - fft_along_split = original_split == axis - + output_shape = list(x.shape) # FFT along non-split axis - if not fft_along_split: - result = fft_op(local_x, n=n, dim=axis, norm=norm) + if not x.is_distributed() or not fft_along_split: + torch_result = fft_op(local_x, n=n, dim=axis, norm=norm) + output_shape[axis] = torch_result.shape[axis] return DNDarray( - result, + torch_result, gshape=tuple(output_shape), - dtype=heat_type_of(result), + dtype=heat_type_of(torch_result), split=original_split, device=x.device, comm=x.comm, @@ -99,16 +91,16 @@ def __fft_op(x: DNDarray, fft_op: callable, **kwargs) -> DNDarray: else: _ = x.resplit(axis=None) # FFT along axis 0 (now non-split) - result = __fft_op(_, fft_op, n=n, axis=0, norm=norm) + ht_result = __fft_op(_, fft_op, n=n, axis=0, norm=norm) del _ # redistribute partial result back to axis 0 - result.resplit_(axis=0) + ht_result.resplit_(axis=0) if original_split != 0: # transpose x, partial_result back to original shape x = x.transpose(transpose_axes) - result = result.transpose(transpose_axes) + ht_result = ht_result.transpose(transpose_axes) - return result + return ht_result def __fftn_op(x: DNDarray, fftn_op: callable, **kwargs) -> DNDarray: @@ -138,34 +130,23 @@ def __fftn_op(x: DNDarray, fftn_op: callable, **kwargs) -> DNDarray: ) norm = kwargs.get("norm", None) - # non-distributed DNDarray - if not x.is_distributed(): - result = fftn_op(local_x, s=s, dim=axes, norm=norm) - return array(result, split=original_split, device=x.device, comm=x.comm) - - # distributed DNDarray: - # calculate output shape - output_shape = list(x.shape) - if s is not None: - if axes is None: + if axes is None: + if s is not None: axes = tuple(range(x.ndim)[-len(s) :]) - for i, axis in enumerate(axes): - output_shape[axis] = s[i] - else: - if axes is None: + else: axes = tuple(range(x.ndim)) - s = tuple(output_shape[axis] for axis in axes) - output_shape = tuple(output_shape) fft_along_split = original_split in axes - + output_shape = list(x.shape) # FFT along non-split axes only - if not fft_along_split: - result = fftn_op(local_x, s=s, dim=axes, norm=norm) + if not x.is_distributed() or not fft_along_split: + torch_result = fftn_op(local_x, s=s, dim=axes, norm=norm) + for axis in axes: + output_shape[axis] = torch_result.shape[axis] return DNDarray( - result, + torch_result, gshape=tuple(output_shape), - dtype=heat_type_of(result), + dtype=heat_type_of(torch_result), split=original_split, device=x.device, comm=x.comm, @@ -185,56 +166,73 @@ def __fftn_op(x: DNDarray, fftn_op: callable, **kwargs) -> DNDarray: # original split is 0 and fft is along axis 0 if x.ndim == 1: _ = x.resplit(axis=None) - result = __fftn_op(_, fftn_op, **kwargs) + ht_result = __fftn_op(_, fftn_op, **kwargs).resplit_(axis=0) del _ - result.resplit_(axis=0) - return result + return ht_result + + # transform decomposition: split axis first, then the rest + # if fft operation requires real input, switch to generic operation: + real_to_generic_fftn_ops = { + torch.fft.rfftn: torch.fft.fftn, + torch.fft.rfft2: torch.fft.fft2, + torch.fft.ihfftn: torch.fft.ifftn, + torch.fft.ihfft2: torch.fft.ifft2, + } + real_op = fftn_op in real_to_generic_fftn_ops + if real_op: + fftn_op = real_to_generic_fftn_ops[fftn_op] + nyquist_axis = axes[-1] + if s is not None: + nyquist_freq = s[-1] // 2 + 1 + else: + nyquist_freq = x.shape[-1] // 2 + 1 # redistribute x from axis 0 to 1 _ = x.resplit(axis=1) # FFT along axis 0 (now non-split) split_index = axes.index(original_split) - partial_result = __fftn_op(_, fftn_op, s=(s[split_index],), axes=(0,), norm=norm) + if s is not None: + partial_s = (s[split_index],) + else: + partial_s = None + partial_ht_result = __fftn_op(_, fftn_op, s=partial_s, axes=(0,), norm=norm) + output_shape[original_split] = partial_ht_result.shape[0] del _ # redistribute partial result from axis 1 to 0 - partial_result.resplit_(axis=0) + partial_ht_result.resplit_(axis=0) if original_split != 0: - # transpose x, partial_result back to original shape + # transpose x, partial_ht_result back to original shape x = x.transpose(transpose_axes) - partial_result = partial_result.transpose(transpose_axes) + partial_ht_result = partial_ht_result.transpose(transpose_axes) # now apply FFT along leftover (non-split) axes axes = list(axes) axes.remove(original_split) axes = tuple(axes) - s = list(s) - s = s[:split_index] + s[split_index + 1 :] - s = tuple(s) - result = __fftn_op(partial_result, fftn_op, s=s, axes=axes, norm=norm) - del partial_result - return array(result.larray, is_split=original_split, device=x.device, comm=x.comm) + if s is not None: + s = list(s) + s = s[:split_index] + s[split_index + 1 :] + s = tuple(s) + ht_result = __fftn_op(partial_ht_result, fftn_op, s=s, axes=axes, norm=norm) + del partial_ht_result + if real_op: + # discard elements beyond Nyquist frequency on last transformed axis + nyquist_slice = [slice(None)] * ht_result.ndim + nyquist_slice[nyquist_axis] = slice(0, nyquist_freq) + ht_result = ht_result[(nyquist_slice)].balance_() + return ht_result def __real_fft_op(x: DNDarray, fft_op: callable, **kwargs) -> DNDarray: - try: - result = __fft_op(x, fft_op, **kwargs) - except RuntimeError as e: - if "real input tensor" in str(e): - raise TypeError(f"Input array must be real, is {x.dtype}.") - else: - raise e - return result + if x.larray.is_complex(): + raise TypeError(f"Input array must be real, is {x.dtype}.") + return __fft_op(x, fft_op, **kwargs) def __real_fftn_op(x: DNDarray, fftn_op: callable, **kwargs) -> DNDarray: - try: - result = __fftn_op(x, fftn_op, **kwargs) - except RuntimeError as e: - if "real input tensor" in str(e): - raise TypeError(f"Input array must be real, is {x.dtype}.") - else: - raise e - return result + if x.larray.is_complex(): + raise TypeError(f"Input array must be real, is {x.dtype}.") + return __fftn_op(x, fftn_op, **kwargs) def fft(x: DNDarray, n: int = None, axis: int = -1, norm: str = None) -> DNDarray: @@ -354,8 +352,6 @@ def hfft(x: DNDarray, n: int = None, axis: int = -1, norm: str = None) -> DNDarr ----- This function requires MPI communication if the input array is transformed along the distribution axis. """ - if n is None: - n = 2 * (x.shape[axis] - 1) return __fft_op(x, torch.fft.hfft, n=n, axis=axis, norm=norm) @@ -385,8 +381,6 @@ def hfft2( ----- This function requires MPI communication if the input array is distributed and the split axis is transformed. """ - if s is None: - s = (x.shape[axes[0]], 2 * (x.shape[axes[1]] - 1)) return __fftn_op(x, torch.fft.hfft2, s=s, axes=axes, norm=norm) @@ -415,14 +409,6 @@ def hfftn( ----- This function requires MPI communication if the input array is distributed and the split axis is transformed. """ - if s is None: - if axes is not None: - s = list(x.shape[axis] for axis in axes) - else: - s = list(x.shape) - s[-1] = 2 * (s[-1] - 1) - s = tuple(s) - return __fftn_op(x, torch.fft.hfftn, s=s, axes=axes, norm=norm) @@ -611,7 +597,7 @@ def rfft(x: DNDarray, n: int = None, axis: int = -1, norm: str = None) -> DNDarr Parameters ---------- x : DNDarray - Input array, must be float. + Input array, must be real. n : int, optional Length of the transformed axis of the output. If not given, the length is taken to be the length of the input along the axis specified by `axis`. If `n` is smaller than the length of the input, the input is cropped. If `n` is @@ -639,7 +625,7 @@ def rfft2( Parameters ---------- x : DNDarray - Input array, must be float. + Input array, must be real. s : Tuple[int, int], optional Shape of the output along the transformed axes. (default is x.shape) axes : Tuple[int, int], optional @@ -665,7 +651,7 @@ def rfftn( Parameters ---------- x : DNDarray - Input array, must be float. + Input array, must be real. s : Tuple[int, ...], optional Shape of the output along the transformed axes. (default is x.shape) axes : Tuple[int, ...], optional From e27243e4376e483609d0a7e2ffb8336c00ffd9aa Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Tue, 24 Oct 2023 14:17:44 +0200 Subject: [PATCH 30/50] double precision tests --- heat/fft/tests/test_fft.py | 46 +++++++++++++++++++++++++------------- 1 file changed, 31 insertions(+), 15 deletions(-) diff --git a/heat/fft/tests/test_fft.py b/heat/fft/tests/test_fft.py index 48ab091bfa..ef7b3b4e4e 100644 --- a/heat/fft/tests/test_fft.py +++ b/heat/fft/tests/test_fft.py @@ -8,7 +8,7 @@ class TestFFT(TestCase): def test_fft_ifft(self): # 1D non-distributed - x = ht.random.randn(6) + x = ht.random.randn(6, dtype=ht.float64) y = ht.fft.fft(x) np_y = np.fft.fft(x.numpy()) self.assertIsInstance(y, ht.DNDarray) @@ -83,16 +83,15 @@ def test_fft2_ifft2(self): backwards = ht.fft.ifft2(y) self.assertTrue(ht.allclose(backwards, x)) - if x.device == ht.cpu or x.comm.size < 4: - # 2D FFT along split axes - x = ht.random.randn(10, 6, 6, split=0, dtype=ht.float64) - axes = (0, 1) - y = ht.fft.fft2(x, axes=axes) - np_y = np.fft.fft2(x.numpy(), axes=axes) - self.assertTrue(y.split == 0) - self.assert_array_equal(y, np_y) - backwards = ht.fft.ifft2(y, axes=axes) - self.assertTrue(ht.allclose(backwards, x)) + # 2D FFT along split axes + x = ht.random.randn(10, 6, 6, split=0, dtype=ht.float64) + axes = (0, 1) + y = ht.fft.fft2(x, axes=axes) + np_y = np.fft.fft2(x.numpy(), axes=axes) + self.assertTrue(y.split == 0) + self.assert_array_equal(y, np_y) + backwards = ht.fft.ifft2(y, axes=axes) + self.assertTrue(ht.allclose(backwards, x)) # exceptions x = ht.arange(10, split=0) @@ -190,8 +189,25 @@ def test_rfft_irfft(self): with self.assertRaises(TypeError): ht.fft.rfft(x) - def test_rfftn(self): - pass + def test_rfftn_irfftn(self): + # n-D distributed + x = ht.random.randn(10, 8, 6, dtype=ht.float64, split=0) + # FFT along last 2 axes + y = ht.fft.rfftn(x, axes=(1, 2)) + np_y = np.fft.rfftn(x.numpy(), axes=(1, 2)) + self.assertIsInstance(y, ht.DNDarray) + self.assertEqual(y.shape, np_y.shape) + self.assertTrue(y.split == 0) + self.assert_array_equal(y, np_y) + + # FFT along all axes + # TODO: comment this out after merging indexing PR + # y = ht.fft.rfftn(x) + # backwards = ht.fft.irfftn(y, s=x.shape) + # self.assertTrue(ht.allclose(backwards, x)) - def test_irfftn(self): - pass + # exceptions + # complex input + x = x + 1j * ht.random.randn(10, 8, 6, dtype=ht.float64, split=0) + with self.assertRaises(TypeError): + ht.fft.rfftn(x) From 60eb72e724c036219bb2c8132a25af71e8e8afe9 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Tue, 24 Oct 2023 14:53:13 +0200 Subject: [PATCH 31/50] debugging: introduce synchronization --- heat/fft/fft.py | 38 ++++++++++++++++++++------------------ 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/heat/fft/fft.py b/heat/fft/fft.py index f765ca26e3..0da8eb77e6 100644 --- a/heat/fft/fft.py +++ b/heat/fft/fft.py @@ -65,15 +65,16 @@ def __fft_op(x: DNDarray, fft_op: callable, **kwargs) -> DNDarray: if not x.is_distributed() or not fft_along_split: torch_result = fft_op(local_x, n=n, dim=axis, norm=norm) output_shape[axis] = torch_result.shape[axis] - return DNDarray( - torch_result, - gshape=tuple(output_shape), - dtype=heat_type_of(torch_result), - split=original_split, - device=x.device, - comm=x.comm, - balanced=x.balanced, - ) + # return DNDarray( + # torch_result, + # gshape=tuple(output_shape), + # dtype=heat_type_of(torch_result), + # split=original_split, + # device=x.device, + # comm=x.comm, + # balanced=x.balanced, + # ) + return array(torch_result, is_split=original_split, comm=x.comm) # FFT along split axis if original_split != 0: @@ -143,15 +144,16 @@ def __fftn_op(x: DNDarray, fftn_op: callable, **kwargs) -> DNDarray: torch_result = fftn_op(local_x, s=s, dim=axes, norm=norm) for axis in axes: output_shape[axis] = torch_result.shape[axis] - return DNDarray( - torch_result, - gshape=tuple(output_shape), - dtype=heat_type_of(torch_result), - split=original_split, - device=x.device, - comm=x.comm, - balanced=x.balanced, - ) + # return DNDarray( + # torch_result, + # gshape=tuple(output_shape), + # dtype=heat_type_of(torch_result), + # split=original_split, + # device=x.device, + # comm=x.comm, + # balanced=x.balanced, + # ) + return array(torch_result, is_split=original_split, comm=x.comm) # FFT along split axis if original_split != 0: From c56169a2587bc191521a8666bd4fb28fffaf4588 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Tue, 24 Oct 2023 15:19:52 +0200 Subject: [PATCH 32/50] debugging --- heat/fft/tests/test_fft.py | 44 +++++++++++++++++++------------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/heat/fft/tests/test_fft.py b/heat/fft/tests/test_fft.py index ef7b3b4e4e..2d9eff7982 100644 --- a/heat/fft/tests/test_fft.py +++ b/heat/fft/tests/test_fft.py @@ -149,28 +149,28 @@ def test_fftn_ifftn(self): with self.assertRaises(ValueError): ht.fft.fftn(x, s=(10, 10, 10, 10)) - def test_hfft_ihfft(self): - # follows example in torch.fft.hfft docs - x = ht.zeros((3, 5), split=0, dtype=ht.float64) - edges = [1, 3, 7] - for i, n in enumerate(edges): - x[i] = ht.linspace(0, n, 5) - - inv_fft = ht.fft.ifft(x) - # inv_fft is hermitian symmetric along the rows - # we can reconstruct the original signal by transforming the first half of the rows only - reconstructed_x = ht.fft.hfft(inv_fft[:3], n=5) - self.assertTrue(ht.allclose(reconstructed_x, x)) - n = 2 * (x.shape[-1] - 1) - reconstructed_x = ht.fft.hfft(inv_fft[:3]) - self.assertEqual(reconstructed_x.shape, (3, n)) - - def test_hfftn_ihfftn(self): - # follows example in torch.fft.hfftn docs - x = ht.random.randn(10, 6, 6, dtype=ht.float64) - inv_fft = ht.fft.ifftn(x) - reconstructed_x = ht.fft.hfftn(inv_fft, s=x.shape) - self.assertTrue(ht.allclose(reconstructed_x, x)) + # def test_hfft_ihfft(self): + # # follows example in torch.fft.hfft docs + # x = ht.zeros((3, 5), split=0, dtype=ht.float64) + # edges = [1, 3, 7] + # for i, n in enumerate(edges): + # x[i] = ht.linspace(0, n, 5) + + # inv_fft = ht.fft.ifft(x) + # # inv_fft is hermitian symmetric along the rows + # # we can reconstruct the original signal by transforming the first half of the rows only + # reconstructed_x = ht.fft.hfft(inv_fft[:3], n=5) + # self.assertTrue(ht.allclose(reconstructed_x, x)) + # n = 2 * (x.shape[-1] - 1) + # reconstructed_x = ht.fft.hfft(inv_fft[:3]) + # self.assertEqual(reconstructed_x.shape, (3, n)) + + # def test_hfftn_ihfftn(self): + # # follows example in torch.fft.hfftn docs + # x = ht.random.randn(10, 6, 6, dtype=ht.float64) + # inv_fft = ht.fft.ifftn(x) + # reconstructed_x = ht.fft.hfftn(inv_fft, s=x.shape) + # self.assertTrue(ht.allclose(reconstructed_x, x)) def test_rfft_irfft(self): # n-D distributed From 5b7d134268a918339cf0815a8ad202138e858dc1 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Tue, 24 Oct 2023 15:24:35 +0200 Subject: [PATCH 33/50] debugging --- heat/fft/tests/test_fft.py | 32 +++++++++++++++++--------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/heat/fft/tests/test_fft.py b/heat/fft/tests/test_fft.py index 2d9eff7982..6e05fe719f 100644 --- a/heat/fft/tests/test_fft.py +++ b/heat/fft/tests/test_fft.py @@ -149,21 +149,23 @@ def test_fftn_ifftn(self): with self.assertRaises(ValueError): ht.fft.fftn(x, s=(10, 10, 10, 10)) - # def test_hfft_ihfft(self): - # # follows example in torch.fft.hfft docs - # x = ht.zeros((3, 5), split=0, dtype=ht.float64) - # edges = [1, 3, 7] - # for i, n in enumerate(edges): - # x[i] = ht.linspace(0, n, 5) - - # inv_fft = ht.fft.ifft(x) - # # inv_fft is hermitian symmetric along the rows - # # we can reconstruct the original signal by transforming the first half of the rows only - # reconstructed_x = ht.fft.hfft(inv_fft[:3], n=5) - # self.assertTrue(ht.allclose(reconstructed_x, x)) - # n = 2 * (x.shape[-1] - 1) - # reconstructed_x = ht.fft.hfft(inv_fft[:3]) - # self.assertEqual(reconstructed_x.shape, (3, n)) + def test_hfft_ihfft(self): + # follows example in torch.fft.hfft docs + x = ht.zeros((3, 5), split=0, dtype=ht.float64) + edges = [1, 3, 7] + for i, n in enumerate(edges): + x[i] = ht.linspace(0, n, 5) + + print(f"DEBUGGING: ON RANK {x.comm.rank}, lshape is {x.lshape}, gshape is {x.gshape}") + inv_fft = ht.fft.ifft(x) + + # inv_fft is hermitian symmetric along the rows + # we can reconstruct the original signal by transforming the first half of the rows only + reconstructed_x = ht.fft.hfft(inv_fft[:3], n=5) + self.assertTrue(ht.allclose(reconstructed_x, x)) + n = 2 * (x.shape[-1] - 1) + reconstructed_x = ht.fft.hfft(inv_fft[:3]) + self.assertEqual(reconstructed_x.shape, (3, n)) # def test_hfftn_ihfftn(self): # # follows example in torch.fft.hfftn docs From 4005afa2fbfdebd99b90a9204e3832ea48734255 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Tue, 24 Oct 2023 15:59:01 +0200 Subject: [PATCH 34/50] debugging --- heat/fft/fft.py | 60 +++++++++++++++++++++++++++++++++----- heat/fft/tests/test_fft.py | 4 +-- 2 files changed, 54 insertions(+), 10 deletions(-) diff --git a/heat/fft/fft.py b/heat/fft/fft.py index 0da8eb77e6..1f4bdb7762 100644 --- a/heat/fft/fft.py +++ b/heat/fft/fft.py @@ -59,12 +59,38 @@ def __fft_op(x: DNDarray, fft_op: callable, **kwargs) -> DNDarray: n = kwargs.get("n", None) norm = kwargs.get("norm", None) - fft_along_split = original_split == axis output_shape = list(x.shape) + # TODO: define Nyquist freq for real input transforms + # if fft operation requires real input, switch to generic operation: + real_to_generic_fft_ops = { + torch.fft.rfft: torch.fft.fft, + torch.fft.ihfft: torch.fft.ifft, + } + real_op = fft_op in real_to_generic_fft_ops + if real_op: + fft_op = real_to_generic_fft_ops[fft_op] + if n is not None: + nyquist_freq = n // 2 + 1 + else: + nyquist_freq = x.shape[axis] // 2 + 1 + output_shape[axis] = nyquist_freq + else: + if n is not None: + output_shape[axis] = n + + fft_along_split = original_split == axis # FFT along non-split axis if not x.is_distributed() or not fft_along_split: - torch_result = fft_op(local_x, n=n, dim=axis, norm=norm) - output_shape[axis] = torch_result.shape[axis] + if local_x.numel() == 0: + # empty tensor, return empty tensor with consistent shape + local_shape = output_shape.copy() + local_shape[original_split] = 0 + print("DEBUGGING: LOCAL SHAPE IS ", local_shape) + torch_result = torch.empty( + tuple(local_shape), dtype=local_x.dtype, device=local_x.device + ) + else: + torch_result = fft_op(local_x, n=n, dim=axis, norm=norm) # return DNDarray( # torch_result, # gshape=tuple(output_shape), @@ -74,7 +100,8 @@ def __fft_op(x: DNDarray, fft_op: callable, **kwargs) -> DNDarray: # comm=x.comm, # balanced=x.balanced, # ) - return array(torch_result, is_split=original_split, comm=x.comm) + print("DEBUGGING: TORCH RESULT IS ", torch_result.shape) + return array(torch_result, is_split=original_split, device=x.device, comm=x.comm) # FFT along split axis if original_split != 0: @@ -86,11 +113,14 @@ def __fft_op(x: DNDarray, fft_op: callable, **kwargs) -> DNDarray: ) x = x.transpose(transpose_axes) + # transform decomposition: split axis first, then the rest + # redistribute x if x.ndim > 1: _ = x.resplit(axis=1) else: _ = x.resplit(axis=None) + # FFT along axis 0 (now non-split) ht_result = __fft_op(_, fft_op, n=n, axis=0, norm=norm) del _ @@ -101,6 +131,12 @@ def __fft_op(x: DNDarray, fft_op: callable, **kwargs) -> DNDarray: x = x.transpose(transpose_axes) ht_result = ht_result.transpose(transpose_axes) + if real_op: + # discard elements beyond Nyquist frequency on last transformed axis + nyquist_slice = [slice(None)] * ht_result.ndim + nyquist_slice[axis] = slice(0, nyquist_freq) + ht_result = ht_result[(nyquist_slice)].balance_() + return ht_result @@ -141,9 +177,17 @@ def __fftn_op(x: DNDarray, fftn_op: callable, **kwargs) -> DNDarray: output_shape = list(x.shape) # FFT along non-split axes only if not x.is_distributed() or not fft_along_split: - torch_result = fftn_op(local_x, s=s, dim=axes, norm=norm) - for axis in axes: - output_shape[axis] = torch_result.shape[axis] + if local_x.numel() == 0: + # empty tensor, return empty tensor with consistent shape + local_shape = output_shape.copy() + local_shape[original_split] = 0 + torch_result = torch.empty( + tuple(local_shape), dtype=local_x.dtype, device=local_x.device + ) + else: + torch_result = fftn_op(local_x, s=s, dim=axes, norm=norm) + # for axis in axes: + # output_shape[axis] = torch_result.shape[axis] # return DNDarray( # torch_result, # gshape=tuple(output_shape), @@ -153,7 +197,7 @@ def __fftn_op(x: DNDarray, fftn_op: callable, **kwargs) -> DNDarray: # comm=x.comm, # balanced=x.balanced, # ) - return array(torch_result, is_split=original_split, comm=x.comm) + return array(torch_result, is_split=original_split, device=x.device, comm=x.comm) # FFT along split axis if original_split != 0: diff --git a/heat/fft/tests/test_fft.py b/heat/fft/tests/test_fft.py index 6e05fe719f..61b9dc2580 100644 --- a/heat/fft/tests/test_fft.py +++ b/heat/fft/tests/test_fft.py @@ -164,8 +164,8 @@ def test_hfft_ihfft(self): reconstructed_x = ht.fft.hfft(inv_fft[:3], n=5) self.assertTrue(ht.allclose(reconstructed_x, x)) n = 2 * (x.shape[-1] - 1) - reconstructed_x = ht.fft.hfft(inv_fft[:3]) - self.assertEqual(reconstructed_x.shape, (3, n)) + # reconstructed_x = ht.fft.hfft(inv_fft[:3]) + # self.assertEqual(reconstructed_x.shape, (3, n)) # def test_hfftn_ihfftn(self): # # follows example in torch.fft.hfftn docs From 8bea8c8f50af40548ab9f63827d967bb58e6bff3 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Wed, 25 Oct 2023 05:03:57 +0200 Subject: [PATCH 35/50] specify default even size of last fft dim for inverse real ops --- heat/fft/fft.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/heat/fft/fft.py b/heat/fft/fft.py index 1f4bdb7762..1fa8e9e751 100644 --- a/heat/fft/fft.py +++ b/heat/fft/fft.py @@ -398,6 +398,8 @@ def hfft(x: DNDarray, n: int = None, axis: int = -1, norm: str = None) -> DNDarr ----- This function requires MPI communication if the input array is transformed along the distribution axis. """ + if n is None: + n = 2 * (x.shape[axis] - 1) return __fft_op(x, torch.fft.hfft, n=n, axis=axis, norm=norm) @@ -427,6 +429,8 @@ def hfft2( ----- This function requires MPI communication if the input array is distributed and the split axis is transformed. """ + if s is None: + s = (x.shape[axes[0]], 2 * (x.shape[axes[1]] - 1)) return __fftn_op(x, torch.fft.hfft2, s=s, axes=axes, norm=norm) @@ -455,6 +459,13 @@ def hfftn( ----- This function requires MPI communication if the input array is distributed and the split axis is transformed. """ + if s is None: + if axes is None: + s = list(x.shape[i] for i in range(x.ndim)) + else: + s = list(x.shape[i] for i in axes) + s[-1] = 2 * (s[-1] - 1) + s = tuple(s) return __fftn_op(x, torch.fft.hfftn, s=s, axes=axes, norm=norm) @@ -581,6 +592,8 @@ def irfft(x: DNDarray, n: int = None, axis: int = -1, norm: str = None) -> DNDar If the input array is 1-D and distributed, this function copies the entire array on each MPI process! i.e. if the array is very large, you might run out of memory. Hint: if you are looping through a batch of 1-D arrays to transform them, consider stacking them into a 2-D DNDarray and transforming them all at once (see :func:`irfft2`). """ + if n is None: + n = 2 * (x.shape[axis] - 1) return __fft_op(x, torch.fft.irfft, n=n, axis=axis, norm=norm) @@ -607,6 +620,8 @@ def irfft2( ----- This function requires MPI communication if the input array is distributed and the split axis is transformed. """ + if s is None: + s = (x.shape[axes[0]], 2 * (x.shape[axes[1]] - 1)) return __fftn_op(x, torch.fft.irfft2, s=s, axes=axes, norm=norm) @@ -633,6 +648,13 @@ def irfftn( ----- This function requires MPI communication if the input array is distributed and the split axis is transformed. """ + if s is None: + if axes is None: + s = list(x.shape[i] for i in range(x.ndim)) + else: + s = list(x.shape[i] for i in axes) + s[-1] = 2 * (s[-1] - 1) + s = tuple(s) return __fftn_op(x, torch.fft.irfftn, s=s, axes=axes, norm=norm) From 50b8e416bdbef4b2b191c76ca6139aaab0852455 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Wed, 25 Oct 2023 05:04:18 +0200 Subject: [PATCH 36/50] add tests --- heat/fft/tests/test_fft.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/heat/fft/tests/test_fft.py b/heat/fft/tests/test_fft.py index 61b9dc2580..6e05fe719f 100644 --- a/heat/fft/tests/test_fft.py +++ b/heat/fft/tests/test_fft.py @@ -164,8 +164,8 @@ def test_hfft_ihfft(self): reconstructed_x = ht.fft.hfft(inv_fft[:3], n=5) self.assertTrue(ht.allclose(reconstructed_x, x)) n = 2 * (x.shape[-1] - 1) - # reconstructed_x = ht.fft.hfft(inv_fft[:3]) - # self.assertEqual(reconstructed_x.shape, (3, n)) + reconstructed_x = ht.fft.hfft(inv_fft[:3]) + self.assertEqual(reconstructed_x.shape, (3, n)) # def test_hfftn_ihfftn(self): # # follows example in torch.fft.hfftn docs From 3e6715bd3b2f00635c3a1e771cc897a7853f9611 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Wed, 25 Oct 2023 05:59:31 +0200 Subject: [PATCH 37/50] fix output shape calc when input is real --- heat/fft/fft.py | 80 ++++++++++++++++++++++++++----------------------- 1 file changed, 42 insertions(+), 38 deletions(-) diff --git a/heat/fft/fft.py b/heat/fft/fft.py index 1fa8e9e751..f8afce3635 100644 --- a/heat/fft/fft.py +++ b/heat/fft/fft.py @@ -38,7 +38,7 @@ def __fft_op(x: DNDarray, fft_op: callable, **kwargs) -> DNDarray: """ - Helper function for fft + Helper function for 1-dimensional FFT. """ try: local_x = x.larray @@ -47,36 +47,32 @@ def __fft_op(x: DNDarray, fft_op: callable, **kwargs) -> DNDarray: original_split = x.split # sanitize kwargs - axis = kwargs.get("axis", None) + axis = kwargs.get("axis") try: axis = sanitize_axis(x.gshape, axis) except ValueError as e: raise IndexError(e) - if axis is None: - axis = x.ndim - 1 - elif isinstance(axis, tuple) and len(axis) > 1: + if isinstance(axis, tuple) and len(axis) > 1: raise TypeError(f"axis must be an integer, got {axis}") n = kwargs.get("n", None) + if n is None: + n = x.shape[axis] norm = kwargs.get("norm", None) + # calculate output shape: + # if operation requires real input, output size of last transformed dimension is the Nyquist frequency output_shape = list(x.shape) - # TODO: define Nyquist freq for real input transforms - # if fft operation requires real input, switch to generic operation: real_to_generic_fft_ops = { torch.fft.rfft: torch.fft.fft, torch.fft.ihfft: torch.fft.ifft, } real_op = fft_op in real_to_generic_fft_ops if real_op: - fft_op = real_to_generic_fft_ops[fft_op] - if n is not None: - nyquist_freq = n // 2 + 1 - else: - nyquist_freq = x.shape[axis] // 2 + 1 + nyquist_freq = n // 2 + 1 output_shape[axis] = nyquist_freq else: - if n is not None: - output_shape[axis] = n + output_shape[axis] = n + print("DEBUGGING: n, output_shape = ", n, output_shape) fft_along_split = original_split == axis # FFT along non-split axis @@ -121,6 +117,9 @@ def __fft_op(x: DNDarray, fft_op: callable, **kwargs) -> DNDarray: else: _ = x.resplit(axis=None) + # if operation requires real input, switch to generic transform + if real_op: + fft_op = real_to_generic_fft_ops[fft_op] # FFT along axis 0 (now non-split) ht_result = __fft_op(_, fft_op, n=n, axis=0, norm=norm) del _ @@ -152,29 +151,46 @@ def __fftn_op(x: DNDarray, fftn_op: callable, **kwargs) -> DNDarray: original_split = x.split # sanitize kwargs - axes = kwargs.get("axes", None) - try: - axes = sanitize_axis(x.gshape, axes) - except ValueError as e: - raise IndexError(e) - repeated_axes = axes is not None and len(axes) != len(set(axes)) - if repeated_axes: - raise NotImplementedError("Multiple transforms over the same axis not implemented yet.") s = kwargs.get("s", None) if s is not None and len(s) > x.ndim: raise ValueError( f"Input is {x.ndim}-dimensional, so s can be at most {x.ndim} elements long. Got {len(s)} elements instead." ) - norm = kwargs.get("norm", None) - + axes = kwargs.get("axes", None) if axes is None: if s is not None: axes = tuple(range(x.ndim)[-len(s) :]) else: axes = tuple(range(x.ndim)) + else: + try: + axes = sanitize_axis(x.gshape, axes) + except ValueError as e: + raise IndexError(e) + if s is None: + s = tuple(x.shape[axis] for axis in axes) + repeated_axes = axes is not None and len(axes) != len(set(axes)) + if repeated_axes: + raise NotImplementedError("Multiple transforms over the same axis not implemented yet.") + norm = kwargs.get("norm", None) - fft_along_split = original_split in axes + # calculate output shape: + # if operation requires real input, output size of last transformed dimension is the Nyquist frequency output_shape = list(x.shape) + for i, axis in enumerate(axes): + output_shape[axis] = s[i] + real_to_generic_fftn_ops = { + torch.fft.rfftn: torch.fft.fftn, + torch.fft.rfft2: torch.fft.fft2, + torch.fft.ihfftn: torch.fft.ifftn, + torch.fft.ihfft2: torch.fft.ifft2, + } + real_op = fftn_op in real_to_generic_fftn_ops + if real_op: + nyquist_freq = s[-1] // 2 + 1 + output_shape[axes[-1]] = nyquist_freq + + fft_along_split = original_split in axes # FFT along non-split axes only if not x.is_distributed() or not fft_along_split: if local_x.numel() == 0: @@ -218,20 +234,8 @@ def __fftn_op(x: DNDarray, fftn_op: callable, **kwargs) -> DNDarray: # transform decomposition: split axis first, then the rest # if fft operation requires real input, switch to generic operation: - real_to_generic_fftn_ops = { - torch.fft.rfftn: torch.fft.fftn, - torch.fft.rfft2: torch.fft.fft2, - torch.fft.ihfftn: torch.fft.ifftn, - torch.fft.ihfft2: torch.fft.ifft2, - } - real_op = fftn_op in real_to_generic_fftn_ops if real_op: fftn_op = real_to_generic_fftn_ops[fftn_op] - nyquist_axis = axes[-1] - if s is not None: - nyquist_freq = s[-1] // 2 + 1 - else: - nyquist_freq = x.shape[-1] // 2 + 1 # redistribute x from axis 0 to 1 _ = x.resplit(axis=1) @@ -264,7 +268,7 @@ def __fftn_op(x: DNDarray, fftn_op: callable, **kwargs) -> DNDarray: if real_op: # discard elements beyond Nyquist frequency on last transformed axis nyquist_slice = [slice(None)] * ht_result.ndim - nyquist_slice[nyquist_axis] = slice(0, nyquist_freq) + nyquist_slice[axes[-1]] = slice(0, nyquist_freq) ht_result = ht_result[(nyquist_slice)].balance_() return ht_result From 05d405de363a2ff7a55088cf84dd8883cee09272 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Wed, 25 Oct 2023 05:59:53 +0200 Subject: [PATCH 38/50] expand tests --- heat/fft/tests/test_fft.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/heat/fft/tests/test_fft.py b/heat/fft/tests/test_fft.py index 6e05fe719f..9d0d308094 100644 --- a/heat/fft/tests/test_fft.py +++ b/heat/fft/tests/test_fft.py @@ -167,19 +167,19 @@ def test_hfft_ihfft(self): reconstructed_x = ht.fft.hfft(inv_fft[:3]) self.assertEqual(reconstructed_x.shape, (3, n)) - # def test_hfftn_ihfftn(self): - # # follows example in torch.fft.hfftn docs - # x = ht.random.randn(10, 6, 6, dtype=ht.float64) - # inv_fft = ht.fft.ifftn(x) - # reconstructed_x = ht.fft.hfftn(inv_fft, s=x.shape) - # self.assertTrue(ht.allclose(reconstructed_x, x)) + def test_hfftn_ihfftn(self): + # follows example in torch.fft.hfftn docs + x = ht.random.randn(10, 6, 6, dtype=ht.float64) + inv_fft = ht.fft.ifftn(x) + reconstructed_x = ht.fft.hfftn(inv_fft, s=x.shape) + self.assertTrue(ht.allclose(reconstructed_x, x)) def test_rfft_irfft(self): # n-D distributed x = ht.random.randn(10, 8, 6, dtype=ht.float64, split=0) # FFT along last axis - y = ht.fft.fft(x) - np_y = np.fft.fft(x.numpy()) + y = ht.fft.rfft(x) + np_y = np.fft.rfft(x.numpy()) self.assertTrue(y.split == 0) self.assert_array_equal(y, np_y) backwards = ht.fft.irfft(y, n=x.shape[-1]) From 718edc0da95991a63fa4dfed7f76719f0407e706 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Wed, 25 Oct 2023 07:36:22 +0200 Subject: [PATCH 39/50] expand tests --- heat/fft/fft.py | 5 +---- heat/fft/tests/test_fft.py | 12 +++++++++++- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/heat/fft/fft.py b/heat/fft/fft.py index f8afce3635..81fafc424e 100644 --- a/heat/fft/fft.py +++ b/heat/fft/fft.py @@ -241,10 +241,7 @@ def __fftn_op(x: DNDarray, fftn_op: callable, **kwargs) -> DNDarray: _ = x.resplit(axis=1) # FFT along axis 0 (now non-split) split_index = axes.index(original_split) - if s is not None: - partial_s = (s[split_index],) - else: - partial_s = None + partial_s = (s[split_index],) partial_ht_result = __fftn_op(_, fftn_op, s=partial_s, axes=(0,), norm=norm) output_shape[original_split] = partial_ht_result.shape[0] del _ diff --git a/heat/fft/tests/test_fft.py b/heat/fft/tests/test_fft.py index 9d0d308094..8d6226bb59 100644 --- a/heat/fft/tests/test_fft.py +++ b/heat/fft/tests/test_fft.py @@ -75,7 +75,7 @@ def test_fft_ifft(self): def test_fft2_ifft2(self): # 2D FFT along non-split axes - x = ht.random.randn(10, 6, 6, split=0, dtype=ht.float64) + x = ht.random.randn(3, 6, 6, split=0, dtype=ht.float64) y = ht.fft.fft2(x) np_y = np.fft.fft2(x.numpy()) self.assertTrue(y.split == 0) @@ -167,12 +167,20 @@ def test_hfft_ihfft(self): reconstructed_x = ht.fft.hfft(inv_fft[:3]) self.assertEqual(reconstructed_x.shape, (3, n)) + def test_hfft2_ihfft2(self): + x = ht.random.randn(10, 6, 6, dtype=ht.float64) + inv_fft = ht.fft.ifft2(x) + reconstructed_x = ht.fft.hfft2(inv_fft, s=x.shape[-2:]) + self.assertTrue(ht.allclose(reconstructed_x, x)) + def test_hfftn_ihfftn(self): # follows example in torch.fft.hfftn docs x = ht.random.randn(10, 6, 6, dtype=ht.float64) inv_fft = ht.fft.ifftn(x) reconstructed_x = ht.fft.hfftn(inv_fft, s=x.shape) self.assertTrue(ht.allclose(reconstructed_x, x)) + reconstructed_x_no_s = ht.fft.hfftn(inv_fft) + self.assertEqual(reconstructed_x_no_s.shape[-1], 2 * (inv_fft.shape[-1] - 1)) def test_rfft_irfft(self): # n-D distributed @@ -184,6 +192,8 @@ def test_rfft_irfft(self): self.assert_array_equal(y, np_y) backwards = ht.fft.irfft(y, n=x.shape[-1]) self.assertTrue(ht.allclose(backwards, x)) + backwards_no_n = ht.fft.irfft(y) + self.assertEqual(backwards_no_n.shape[-1], 2 * (y.shape[-1] - 1)) # exceptions # complex input From ae5576a10dadbb17497306c5889f7a0e218e0adc Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Wed, 25 Oct 2023 08:49:35 +0200 Subject: [PATCH 40/50] expand tests --- heat/fft/fft.py | 2 +- heat/fft/tests/test_fft.py | 17 ++++++++++++++++- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/heat/fft/fft.py b/heat/fft/fft.py index 81fafc424e..75589b0293 100644 --- a/heat/fft/fft.py +++ b/heat/fft/fft.py @@ -22,7 +22,7 @@ "ifftn", # "ihfft", # "ihfft2", - # "ihfftn", + "ihfftn", "irfft", "irfft2", "irfftn", diff --git a/heat/fft/tests/test_fft.py b/heat/fft/tests/test_fft.py index 8d6226bb59..8f7271aa11 100644 --- a/heat/fft/tests/test_fft.py +++ b/heat/fft/tests/test_fft.py @@ -211,7 +211,8 @@ def test_rfftn_irfftn(self): self.assertEqual(y.shape, np_y.shape) self.assertTrue(y.split == 0) self.assert_array_equal(y, np_y) - + backwards = ht.fft.irfftn(y, s=x.shape[-2:]) + self.assertTrue(ht.allclose(backwards, x)) # FFT along all axes # TODO: comment this out after merging indexing PR # y = ht.fft.rfftn(x) @@ -223,3 +224,17 @@ def test_rfftn_irfftn(self): x = x + 1j * ht.random.randn(10, 8, 6, dtype=ht.float64, split=0) with self.assertRaises(TypeError): ht.fft.rfftn(x) + + def test_rfft2_irfft2(self): + # n-D distributed + x = ht.random.randn(10, 8, 6, dtype=ht.float64, split=0) + # FFT along last 2 axes + y = ht.fft.rfft2(x, axes=(1, 2)) + np_y = np.fft.rfft2(x.numpy(), axes=(1, 2)) + self.assertIsInstance(y, ht.DNDarray) + self.assertEqual(y.shape, np_y.shape) + self.assertTrue(y.split == 0) + self.assert_array_equal(y, np_y) + + backwards = ht.fft.irfft2(y, s=x.shape[-2:]) + self.assertTrue(ht.allclose(backwards, x)) From 2d83337e291f1f9cb0bd20812a7c6a18cea83b7b Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Wed, 25 Oct 2023 10:00:44 +0200 Subject: [PATCH 41/50] add ihfft2, ihfft --- heat/fft/fft.py | 54 +++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 52 insertions(+), 2 deletions(-) diff --git a/heat/fft/fft.py b/heat/fft/fft.py index 75589b0293..f644c6c47c 100644 --- a/heat/fft/fft.py +++ b/heat/fft/fft.py @@ -20,8 +20,8 @@ "ifft", "ifft2", "ifftn", - # "ihfft", - # "ihfft2", + "ihfft", + "ihfft2", "ihfftn", "irfft", "irfft2", @@ -548,6 +548,56 @@ def ifftn( return __fftn_op(x, torch.fft.ifftn, s=s, axes=axes, norm=norm) +def ihfft(x: DNDarray, n: int = None, axis: int = -1, norm: str = None) -> DNDarray: + """ + Compute the one-dimensional inverse discrete Fourier Transform of a real signal. The output is Hermitian-symmetric. + + Parameters + ---------- + x : DNDarray + Input array, must be real + n : int, optional + Length of the transformed axis of the output. If not given, the length is taken to be the length of the input + along the axis specified by `axis`. If `n` is smaller than the length of the input, the input is cropped. If `n` is + larger, the input is padded with zeros. Default: None. + axis : int, optional + Axis over which to compute the inverse FFT. If not given, the last axis is used, or the only axis if x has only one dimension. Default: -1. + norm : str, optional + Normalization mode: 'forward', 'backward', or 'ortho' (see `numpy.fft` for details). Default is "backward". + + Notes + ----- + This function requires MPI communication if the input array is transformed along the distribution axis. + """ + return __real_fft_op(x, torch.fft.ihfft, n=n, axis=axis, norm=norm) + + +def ihfft2( + x: DNDarray, s: Tuple[int, int] = None, axes: Tuple[int, int] = (-2, -1), norm: str = None +) -> DNDarray: + """ + Compute the 2-dimensional inverse discrete Fourier Transform of a real signal. The output is Hermitian-symmetric. + + Parameters + ---------- + x : DNDarray + Input array, must be real + s : Tuple[int, int], optional + Shape of the output along the transformed axes. (default is x.shape) + axes : Tuple[int, int], optional + Axes over which to compute the inverse FFT. If not given, the last `len(s)` axes are used, or all axes if `s` is + also not specified. Repeated indices in `axes` means that the transform over that axis is performed multiple + times. (default is (-2, -1)) + norm : str, optional + Normalization mode: 'forward', 'backward', or 'ortho' (see `numpy.fft` for details). Default is "backward". + + Notes + ----- + This function requires MPI communication if the input array is distributed and the split axis is transformed. + """ + return __real_fftn_op(x, torch.fft.ihfft2, s=s, axes=axes, norm=norm) + + def ihfftn( x: DNDarray, s: Tuple[int, ...] = None, axes: Tuple[int, ...] = None, norm: str = None ) -> DNDarray: From 010b757f507e0af820fe411c3d8bb6d9a05c31ea Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Wed, 25 Oct 2023 10:01:14 +0200 Subject: [PATCH 42/50] expand tests --- heat/fft/tests/test_fft.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/heat/fft/tests/test_fft.py b/heat/fft/tests/test_fft.py index 8f7271aa11..465ff6ec1c 100644 --- a/heat/fft/tests/test_fft.py +++ b/heat/fft/tests/test_fft.py @@ -150,31 +150,26 @@ def test_fftn_ifftn(self): ht.fft.fftn(x, s=(10, 10, 10, 10)) def test_hfft_ihfft(self): - # follows example in torch.fft.hfft docs x = ht.zeros((3, 5), split=0, dtype=ht.float64) edges = [1, 3, 7] for i, n in enumerate(edges): x[i] = ht.linspace(0, n, 5) + inv_fft = ht.fft.ihfft(x) - print(f"DEBUGGING: ON RANK {x.comm.rank}, lshape is {x.lshape}, gshape is {x.gshape}") - inv_fft = ht.fft.ifft(x) - - # inv_fft is hermitian symmetric along the rows - # we can reconstruct the original signal by transforming the first half of the rows only - reconstructed_x = ht.fft.hfft(inv_fft[:3], n=5) + # inv_fft is hermitian-symmetric along the rows + reconstructed_x = ht.fft.hfft(inv_fft, n=5) self.assertTrue(ht.allclose(reconstructed_x, x)) n = 2 * (x.shape[-1] - 1) - reconstructed_x = ht.fft.hfft(inv_fft[:3]) - self.assertEqual(reconstructed_x.shape, (3, n)) + reconstructed_x = ht.fft.hfft(inv_fft, n=n) + self.assertEqual(reconstructed_x.shape[-1], n) def test_hfft2_ihfft2(self): x = ht.random.randn(10, 6, 6, dtype=ht.float64) - inv_fft = ht.fft.ifft2(x) + inv_fft = ht.fft.ihfft2(x) reconstructed_x = ht.fft.hfft2(inv_fft, s=x.shape[-2:]) self.assertTrue(ht.allclose(reconstructed_x, x)) def test_hfftn_ihfftn(self): - # follows example in torch.fft.hfftn docs x = ht.random.randn(10, 6, 6, dtype=ht.float64) inv_fft = ht.fft.ifftn(x) reconstructed_x = ht.fft.hfftn(inv_fft, s=x.shape) From d5ec286a68da2ebc76d52881e8ff8d854e7bffbe Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Wed, 1 Nov 2023 08:32:39 +0100 Subject: [PATCH 43/50] implement fftfreq, fftshift operations and tests --- heat/fft/fft.py | 280 +++++++++++++++++++++++++++++++------ heat/fft/tests/test_fft.py | 39 ++++++ 2 files changed, 280 insertions(+), 39 deletions(-) diff --git a/heat/fft/fft.py b/heat/fft/fft.py index f644c6c47c..4d936e7297 100644 --- a/heat/fft/fft.py +++ b/heat/fft/fft.py @@ -5,21 +5,25 @@ from ..core.communication import MPI from ..core.dndarray import DNDarray from ..core.stride_tricks import sanitize_axis -from ..core.types import promote_types, heat_type_of -from ..core.factories import array, zeros +from ..core.types import heat_type_is_exact, promote_types, heat_type_of, canonical_heat_type +from ..core.factories import array, arange +from ..core.devices import Device from typing import Type, Union, Tuple, Any, Iterable, Optional __all__ = [ "fft", "fft2", + "fftfreq", "fftn", + "fftshift", "hfft", "hfft2", "hfftn", "ifft", "ifft2", "ifftn", + "ifftshift", "ihfft", "ihfft2", "ihfftn", @@ -28,11 +32,8 @@ "irfftn", "rfft", "rfft2", + "rfftfreq", "rfftn", - # "fftfreq", - # "rfftfreq", - # "fftshift", - # "ifftshift", ] @@ -149,43 +150,58 @@ def __fftn_op(x: DNDarray, fftn_op: callable, **kwargs) -> DNDarray: raise TypeError("x must be a DNDarray, is {}".format(type(x))) original_split = x.split + output_shape = list(x.shape) + shift_op = fftn_op in [torch.fft.fftshift, torch.fft.ifftshift] + real_to_generic_fftn_ops = { + torch.fft.rfftn: torch.fft.fftn, + torch.fft.rfft2: torch.fft.fft2, + torch.fft.ihfftn: torch.fft.ifftn, + torch.fft.ihfft2: torch.fft.ifft2, + } + real_op = fftn_op in real_to_generic_fftn_ops # sanitize kwargs - s = kwargs.get("s", None) - if s is not None and len(s) > x.ndim: - raise ValueError( - f"Input is {x.ndim}-dimensional, so s can be at most {x.ndim} elements long. Got {len(s)} elements instead." - ) - axes = kwargs.get("axes", None) - if axes is None: - if s is not None: - axes = tuple(range(x.ndim)[-len(s) :]) - else: + if shift_op: + # only keyword argument `axes` is supported + axes = kwargs.get("axes", None) + if axes is None: axes = tuple(range(x.ndim)) + else: + try: + axes = sanitize_axis(x.gshape, axes) + except ValueError as e: + raise IndexError(e) + torch_kwargs = {"dim": axes} else: - try: - axes = sanitize_axis(x.gshape, axes) - except ValueError as e: - raise IndexError(e) - if s is None: - s = tuple(x.shape[axis] for axis in axes) + s = kwargs.get("s", None) + if s is not None and len(s) > x.ndim: + raise ValueError( + f"Input is {x.ndim}-dimensional, so s can be at most {x.ndim} elements long. Got {len(s)} elements instead." + ) + axes = kwargs.get("axes", None) + if axes is None: + if s is not None: + axes = tuple(range(x.ndim)[-len(s) :]) + else: + axes = tuple(range(x.ndim)) + else: + try: + axes = sanitize_axis(x.gshape, axes) + except ValueError as e: + raise IndexError(e) + if s is None: + s = tuple(x.shape[axis] for axis in axes) + norm = kwargs.get("norm", None) + for i, axis in enumerate(axes): + output_shape[axis] = s[i] + torch_kwargs = {"s": s, "dim": axes, "norm": norm} + repeated_axes = axes is not None and len(axes) != len(set(axes)) if repeated_axes: raise NotImplementedError("Multiple transforms over the same axis not implemented yet.") - norm = kwargs.get("norm", None) # calculate output shape: # if operation requires real input, output size of last transformed dimension is the Nyquist frequency - output_shape = list(x.shape) - for i, axis in enumerate(axes): - output_shape[axis] = s[i] - real_to_generic_fftn_ops = { - torch.fft.rfftn: torch.fft.fftn, - torch.fft.rfft2: torch.fft.fft2, - torch.fft.ihfftn: torch.fft.ifftn, - torch.fft.ihfft2: torch.fft.ifft2, - } - real_op = fftn_op in real_to_generic_fftn_ops if real_op: nyquist_freq = s[-1] // 2 + 1 output_shape[axes[-1]] = nyquist_freq @@ -201,7 +217,7 @@ def __fftn_op(x: DNDarray, fftn_op: callable, **kwargs) -> DNDarray: tuple(local_shape), dtype=local_x.dtype, device=local_x.device ) else: - torch_result = fftn_op(local_x, s=s, dim=axes, norm=norm) + torch_result = fftn_op(local_x, **torch_kwargs) # for axis in axes: # output_shape[axis] = torch_result.shape[axis] # return DNDarray( @@ -256,11 +272,14 @@ def __fftn_op(x: DNDarray, fftn_op: callable, **kwargs) -> DNDarray: axes = list(axes) axes.remove(original_split) axes = tuple(axes) - if s is not None: - s = list(s) - s = s[:split_index] + s[split_index + 1 :] - s = tuple(s) - ht_result = __fftn_op(partial_ht_result, fftn_op, s=s, axes=axes, norm=norm) + if shift_op: + ht_result = __fftn_op(partial_ht_result, fftn_op, axes=axes) + else: + if s is not None: + s = list(s) + s = s[:split_index] + s[split_index + 1 :] + s = tuple(s) + ht_result = __fftn_op(partial_ht_result, fftn_op, s=s, axes=axes, norm=norm) del partial_ht_result if real_op: # discard elements beyond Nyquist frequency on last transformed axis @@ -270,6 +289,61 @@ def __fftn_op(x: DNDarray, fftn_op: callable, **kwargs) -> DNDarray: return ht_result +def __fftfreq_op(fftfreq_op: callable, **kwargs) -> DNDarray: + """ + Helper function for fftfreq + """ + n = kwargs.get("n", None) + d = kwargs.get("d", None) + dtype = kwargs.get("dtype", None) + split = kwargs.get("split", None) + device = kwargs.get("device", None) + + if not isinstance(n, int): + raise TypeError(f"n must be an integer, is {type(n)}") + if not isinstance(d, (int, float)): + raise TypeError(f"d must be a number, is {type(d)}") + if dtype is not None: + if heat_type_is_exact(dtype): + raise TypeError(f"dtype must be a float or complex type, is {dtype}") + # extract torch dtype from heat dtype + try: + torch_dtype = dtype.torch_type() + except AttributeError: + raise TypeError(f"dtype must be a heat dtype, is {type(dtype)}") + else: + torch_dtype = None + + # early out for non-distributed fftfreq + if split is None: + return array(fftfreq_op(n, d=d, dtype=torch_dtype), device=device, split=None) + + # distributed fftfreq + if split != 0: + raise IndexError(f"`fftfreq` returns a 1-D array, `split` must be 0 or None, is {split}") + + # calculate parameters of the global frequency spectrum + channel_width = array(1.0 / (n * d), dtype=dtype, device=device, split=None) + n_is_even = n % 2 == 0 + if n_is_even: + middle_channel = n // 2 + else: + middle_channel = n // 2 + 1 + + # allocate global fftfreq array + # if real operation, return only positive frequencies + if fftfreq_op == torch.fft.rfftfreq: + freqs = arange(middle_channel, dtype=dtype, device=device, split=split) + else: + freqs = arange(n, dtype=dtype, device=device, split=split) + # second half of fftfreq returns negative frequencies in inverse order + freqs[middle_channel:] -= n + + # calculate global frequencies + freqs *= channel_width + return freqs + + def __real_fft_op(x: DNDarray, fft_op: callable, **kwargs) -> DNDarray: if x.larray.is_complex(): raise TypeError(f"Input array must be real, is {x.dtype}.") @@ -343,6 +417,50 @@ def fft2( return __fftn_op(x, torch.fft.fft2, s=s, axes=axes, norm=norm) +def fftfreq( + n: int, + d: Union[int, float] = 1.0, + dtype: Optional[Type] = None, + split: Optional[int] = None, + device: Optional[Union[str, Device]] = None, +) -> DNDarray: + """ + Return the Discrete Fourier Transform sample frequencies. + + The returned float tensor contains the frequency bin centers in cycles per unit of the sample spacing (with zero + at the start). For instance, if the sample spacing is in seconds, then the frequency unit is cycles/second. + + Parameters + ---------- + n : int + Window length. + d : Union[int, float], optional + Sample spacing (inverse of the sampling rate). Defaults to 1. + dtype : Type, optional + The desired data type of the output. Defaults to `float32`. + split : int, optional + The axis along which to split the result. If not given, the result is not split. + device : str or Device, optional + The device on which to place the output. If not given, the output is placed on the current device. + + Returns + ------- + out : DNDarray + Array of length `n` containing the sample frequencies. + + Examples + -------- + >>> import heat as ht + >>> ht.fft.fftfreq(5, 0.1) + DNDarray([-2., -1., 0., 1., 2.], dtype=ht.float32, device=cpu:0, split=None) + >>> ht.fft.fftfreq(5, 0.1, dtype=ht.float64) + DNDarray([-2., -1., 0., 1., 2.], dtype=ht.float64, device=cpu:0, split=None) + >>> ht.fft.fftfreq(5, 0.1, split=0) + DNDarray([-2., -1., 0., 1., 2.], dtype=ht.float32, device=cpu:0, split=0) + """ + return __fftfreq_op(torch.fft.fftfreq, n=n, d=d, dtype=dtype, split=split, device=device) + + def fftn( x: DNDarray, s: Tuple[int, ...] = None, axes: Tuple[int, ...] = None, norm: str = None ) -> DNDarray: @@ -373,6 +491,27 @@ def fftn( return __fftn_op(x, torch.fft.fftn, s=s, axes=axes, norm=norm) +def fftshift(x: DNDarray, axes: Optional[Union[int, Iterable[int]]] = None) -> DNDarray: + """ + Shift the zero-frequency component to the center of the spectrum. + + This function swaps half-spaces for all axes listed (defaults to all). Note that ``y[0]`` is the Nyquist component + only if ``len(x)`` is even. + + Parameters + ---------- + x : DNDarray + Input array + axes : int or Iterable[int], optional + Axes over which to shift. Default is None, which shifts all axes. + + Notes + ----- + This function requires MPI communication if the input array is distributed and the split axis is shifted. + """ + return __fftn_op(x, torch.fft.fftshift, axes=axes) + + def hfft(x: DNDarray, n: int = None, axis: int = -1, norm: str = None) -> DNDarray: """ Compute the one-dimensional discrete Fourier Transform of a Hermitian symmetric signal. @@ -548,6 +687,25 @@ def ifftn( return __fftn_op(x, torch.fft.ifftn, s=s, axes=axes, norm=norm) +def ifftshift(x: DNDarray, axes: Optional[Union[int, Iterable[int]]] = None) -> DNDarray: + """ + TODO: reelaborate docs + The inverse of fftshift. Although identical for even-length x, the functions differ by one sample for odd-length x. + + Parameters + ---------- + x : DNDarray + Input array + axes : int or Iterable[int], optional + Axes over which to shift. Default is None, which shifts all axes. + + Notes + ----- + This function requires MPI communication if the input array is distributed and the split axis is shifted. + """ + return __fftn_op(x, torch.fft.ifftshift, axes=axes) + + def ihfft(x: DNDarray, n: int = None, axis: int = -1, norm: str = None) -> DNDarray: """ Compute the one-dimensional inverse discrete Fourier Transform of a real signal. The output is Hermitian-symmetric. @@ -761,6 +919,50 @@ def rfft2( return __real_fftn_op(x, torch.fft.rfft2, s=s, axes=axes, norm=norm) +def rfftfreq( + n: int, + d: Union[int, float] = 1.0, + dtype: Optional[Type] = None, + split: Optional[int] = None, + device: Optional[Union[str, Device]] = None, +) -> DNDarray: + """ + Return the Discrete Fourier Transform sample frequencies. + + The returned float tensor contains the frequency bin centers in cycles per unit of the sample spacing (with zero + at the start). For instance, if the sample spacing is in seconds, then the frequency unit is cycles/second. + + Parameters + ---------- + n : int + Window length. + d : Union[int, float], optional + Sample spacing (inverse of the sampling rate). Defaults to 1. + dtype : Type, optional + The desired data type of the output. Defaults to `float32`. + split : int, optional + The axis along which to split the result. If not given, the result is not split. + device : str or Device, optional + The device on which to place the output. If not given, the output is placed on the current device. + + Returns + ------- + out : DNDarray + Array of length `n` containing the sample frequencies. + + Examples + -------- + >>> import heat as ht + >>> ht.fft.rfftfreq(5, 0.1) + DNDarray([0., 1., 2.], dtype=ht.float32, device=cpu:0, split=None) + >>> ht.fft.rfftfreq(5, 0.1, dtype=ht.float64) + DNDarray([0., 1., 2.], dtype=ht.float64, device=cpu:0, split=None) + >>> ht.fft.rfftfreq(5, 0.1, split=0) + DNDarray([0., 1., 2.], dtype=ht.float32, device=cpu:0, split=0) + """ + return __fftfreq_op(torch.fft.rfftfreq, n=n, d=d, dtype=dtype, split=split, device=device) + + def rfftn( x: DNDarray, s: Tuple[int, int] = None, axes: Tuple[int, ...] = None, norm: str = None ) -> DNDarray: diff --git a/heat/fft/tests/test_fft.py b/heat/fft/tests/test_fft.py index 465ff6ec1c..87e4140c20 100644 --- a/heat/fft/tests/test_fft.py +++ b/heat/fft/tests/test_fft.py @@ -149,6 +149,45 @@ def test_fftn_ifftn(self): with self.assertRaises(ValueError): ht.fft.fftn(x, s=(10, 10, 10, 10)) + def test_fftfreq_rfftfreq(self): + # non-distributed + n = 10 + d = 0.1 + y = ht.fft.fftfreq(n, d=d) + np_y = np.fft.fftfreq(n, d=d) + self.assertEqual(y.shape, np_y.shape) + self.assert_array_equal(y, np_y) + + # distributed + y = ht.fft.fftfreq(n, d=d, split=0) + self.assertEqual(y.shape, np_y.shape) + self.assert_array_equal(y, np_y) + + # real + n = 107 + d = 0.22365 + y = ht.fft.rfftfreq(n, d=d) + np_y = np.fft.rfftfreq(n, d=d) + self.assertEqual(y.shape, np_y.shape) + self.assert_array_equal(y, np_y) + + # exceptions + # wrong input type + n = 10 + d = 0.1 + with self.assertRaises(TypeError): + ht.fft.fftfreq(n, d=d, dtype=ht.int32) + + def test_fftshift_ifftshift(self): + # non-distributed + x = ht.fft.fftfreq(10) + y = ht.fft.fftshift(x) + np_y = np.fft.fftshift(x.numpy()) + self.assertEqual(y.shape, np_y.shape) + self.assert_array_equal(y, np_y) + backwards = ht.fft.ifftshift(y) + self.assertTrue(ht.allclose(backwards, x)) + def test_hfft_ihfft(self): x = ht.zeros((3, 5), split=0, dtype=ht.float64) edges = [1, 3, 7] From 0993f99347d9d665f046bd69fdbf16fc941a7ba7 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Mon, 13 Nov 2023 06:14:12 +0100 Subject: [PATCH 44/50] fix local output dtype mismatch when local input tensor is empty --- heat/fft/fft.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/heat/fft/fft.py b/heat/fft/fft.py index 4d936e7297..989764b9cd 100644 --- a/heat/fft/fft.py +++ b/heat/fft/fft.py @@ -73,7 +73,6 @@ def __fft_op(x: DNDarray, fft_op: callable, **kwargs) -> DNDarray: output_shape[axis] = nyquist_freq else: output_shape[axis] = n - print("DEBUGGING: n, output_shape = ", n, output_shape) fft_along_split = original_split == axis # FFT along non-split axis @@ -82,7 +81,6 @@ def __fft_op(x: DNDarray, fft_op: callable, **kwargs) -> DNDarray: # empty tensor, return empty tensor with consistent shape local_shape = output_shape.copy() local_shape[original_split] = 0 - print("DEBUGGING: LOCAL SHAPE IS ", local_shape) torch_result = torch.empty( tuple(local_shape), dtype=local_x.dtype, device=local_x.device ) @@ -97,7 +95,6 @@ def __fft_op(x: DNDarray, fft_op: callable, **kwargs) -> DNDarray: # comm=x.comm, # balanced=x.balanced, # ) - print("DEBUGGING: TORCH RESULT IS ", torch_result.shape) return array(torch_result, is_split=original_split, device=x.device, comm=x.comm) # FFT along split axis @@ -152,6 +149,7 @@ def __fftn_op(x: DNDarray, fftn_op: callable, **kwargs) -> DNDarray: original_split = x.split output_shape = list(x.shape) shift_op = fftn_op in [torch.fft.fftshift, torch.fft.ifftshift] + inverse_real_op = fftn_op in [torch.fft.irfftn, torch.fft.irfft2] real_to_generic_fftn_ops = { torch.fft.rfftn: torch.fft.fftn, torch.fft.rfft2: torch.fft.fft2, @@ -210,11 +208,17 @@ def __fftn_op(x: DNDarray, fftn_op: callable, **kwargs) -> DNDarray: # FFT along non-split axes only if not x.is_distributed() or not fft_along_split: if local_x.numel() == 0: - # empty tensor, return empty tensor with consistent shape + # empty tensor, return empty tensor with consistent shape and dtype local_shape = output_shape.copy() local_shape[original_split] = 0 + if inverse_real_op: + output_dtype = local_x.real.dtype + else: + # local_x is empty, memory footprint not an issue + _ = local_x * 1j + output_dtype = _.dtype torch_result = torch.empty( - tuple(local_shape), dtype=local_x.dtype, device=local_x.device + tuple(local_shape), dtype=output_dtype, device=local_x.device ) else: torch_result = fftn_op(local_x, **torch_kwargs) From 5f58dee63ee2a568d2eeec3b15b88d49d2d06a19 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Mon, 13 Nov 2023 06:14:36 +0100 Subject: [PATCH 45/50] remove print statements --- heat/fft/tests/test_fft.py | 56 ++++++++++++++++++-------------------- 1 file changed, 26 insertions(+), 30 deletions(-) diff --git a/heat/fft/tests/test_fft.py b/heat/fft/tests/test_fft.py index 87e4140c20..d8db02c494 100644 --- a/heat/fft/tests/test_fft.py +++ b/heat/fft/tests/test_fft.py @@ -38,27 +38,25 @@ def test_fft_ifft(self): self.assertTrue(y.split == 0) self.assert_array_equal(y, np_y) - # on GPU, test only on less than 4 processes - if x.device == ht.cpu or x.comm.size < 4: - # FFT along distributed axis, n not None - n = 8 - y = ht.fft.fft(x, axis=0, n=n) - np_y = np.fft.fft(x.numpy(), axis=0, n=n) - self.assertIsInstance(y, ht.DNDarray) - self.assertEqual(y.shape, np_y.shape) - self.assertTrue(y.split == 0) - self.assert_array_equal(y, np_y) + # FFT along distributed axis, n not None + n = 8 + y = ht.fft.fft(x, axis=0, n=n) + np_y = np.fft.fft(x.numpy(), axis=0, n=n) + self.assertIsInstance(y, ht.DNDarray) + self.assertEqual(y.shape, np_y.shape) + self.assertTrue(y.split == 0) + self.assert_array_equal(y, np_y) - # complex input - x = x + 1j * ht.random.randn(10, 8, 6, dtype=ht.float64, split=0) - # FFT along last axis (distributed) - x.resplit_(axis=2) - y = ht.fft.fft(x, n=n) - np_y = np.fft.fft(x.numpy(), n=n) - self.assertIsInstance(y, ht.DNDarray) - self.assertEqual(y.shape, np_y.shape) - self.assertTrue(y.split == 2) - self.assert_array_equal(y, np_y) + # complex input + x = x + 1j * ht.random.randn(10, 8, 6, dtype=ht.float64, split=0) + # FFT along last axis (distributed) + x.resplit_(axis=2) + y = ht.fft.fft(x, n=n) + np_y = np.fft.fft(x.numpy(), n=n) + self.assertIsInstance(y, ht.DNDarray) + self.assertEqual(y.shape, np_y.shape) + self.assertTrue(y.split == 2) + self.assert_array_equal(y, np_y) # exceptions # wrong input type @@ -128,16 +126,14 @@ def test_fftn_ifftn(self): self.assertTrue(y.split == 0) self.assert_array_equal(y, np_y) - # on GPU, test only on less than 4 processes - if x.device == ht.cpu or x.comm.size < 4: - # FFT along distributed axis - x.resplit_(axis=1) - y = ht.fft.fftn(x, axes=(0, 1), s=(10, 8)) - np_y = np.fft.fftn(x.numpy(), axes=(0, 1), s=(10, 8)) - self.assertIsInstance(y, ht.DNDarray) - self.assertEqual(y.shape, np_y.shape) - self.assertTrue(y.split == 1) - self.assert_array_equal(y, np_y) + # FFT along distributed axis + x.resplit_(axis=1) + y = ht.fft.fftn(x, axes=(0, 1), s=(10, 8)) + np_y = np.fft.fftn(x.numpy(), axes=(0, 1), s=(10, 8)) + self.assertIsInstance(y, ht.DNDarray) + self.assertEqual(y.shape, np_y.shape) + self.assertTrue(y.split == 1) + self.assert_array_equal(y, np_y) # exceptions # wrong input type From 371abe33f050c4ad7a97c56b6c99f7ec8aea696d Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Tue, 14 Nov 2023 06:22:35 +0100 Subject: [PATCH 46/50] expand tests --- heat/fft/fft.py | 14 ++++++++++---- heat/fft/tests/test_fft.py | 31 +++++++++++++++++++++++++++++-- 2 files changed, 39 insertions(+), 6 deletions(-) diff --git a/heat/fft/fft.py b/heat/fft/fft.py index 989764b9cd..9d364fd910 100644 --- a/heat/fft/fft.py +++ b/heat/fft/fft.py @@ -261,8 +261,11 @@ def __fftn_op(x: DNDarray, fftn_op: callable, **kwargs) -> DNDarray: _ = x.resplit(axis=1) # FFT along axis 0 (now non-split) split_index = axes.index(original_split) - partial_s = (s[split_index],) - partial_ht_result = __fftn_op(_, fftn_op, s=partial_s, axes=(0,), norm=norm) + if shift_op: + partial_ht_result = __fftn_op(_, fftn_op, axes=(0,)) + else: + partial_s = (s[split_index],) + partial_ht_result = __fftn_op(_, fftn_op, s=partial_s, axes=(0,), norm=norm) output_shape[original_split] = partial_ht_result.shape[0] del _ # redistribute partial result from axis 1 to 0 @@ -304,9 +307,12 @@ def __fftfreq_op(fftfreq_op: callable, **kwargs) -> DNDarray: device = kwargs.get("device", None) if not isinstance(n, int): - raise TypeError(f"n must be an integer, is {type(n)}") + raise ValueError(f"n must be an integer, is {type(n)}") if not isinstance(d, (int, float)): - raise TypeError(f"d must be a number, is {type(d)}") + if isinstance(d, complex): + # numpy supports complex d, torch doesn't + raise NotImplementedError("Support for complex d not implemented yet.") + raise TypeError(f"d must be a scalar, is {type(d)}") if dtype is not None: if heat_type_is_exact(dtype): raise TypeError(f"dtype must be a float or complex type, is {dtype}") diff --git a/heat/fft/tests/test_fft.py b/heat/fft/tests/test_fft.py index d8db02c494..afc9cb1f09 100644 --- a/heat/fft/tests/test_fft.py +++ b/heat/fft/tests/test_fft.py @@ -173,6 +173,19 @@ def test_fftfreq_rfftfreq(self): d = 0.1 with self.assertRaises(TypeError): ht.fft.fftfreq(n, d=d, dtype=ht.int32) + # unsupported n + n = 10.7 + with self.assertRaises(ValueError): + ht.fft.fftfreq(n, d=d) + # unsupported d + # torch does not support complex d + n = 10 + d = 0.1 + 1j + with self.assertRaises(NotImplementedError): + ht.fft.fftfreq(n, d=d) + d = ht.array(0.1) + with self.assertRaises(TypeError): + ht.fft.fftfreq(n, d=d) def test_fftshift_ifftshift(self): # non-distributed @@ -184,6 +197,20 @@ def test_fftshift_ifftshift(self): backwards = ht.fft.ifftshift(y) self.assertTrue(ht.allclose(backwards, x)) + # distributed + # (following fftshift example from torch.fft) + x = ht.fft.fftfreq(5, d=1 / 5, split=0) + 0.1 * ht.fft.fftfreq(5, d=1 / 5, split=0).reshape( + 5, 1 + ) + y = ht.fft.fftshift(x, axes=(0, 1)) + np_y = np.fft.fftshift(x.numpy(), axes=(0, 1)) + self.assert_array_equal(y, np_y) + + # exceptions + # wrong axis + with self.assertRaises(IndexError): + ht.fft.fftshift(x, axes=(0, 2)) + def test_hfft_ihfft(self): x = ht.zeros((3, 5), split=0, dtype=ht.float64) edges = [1, 3, 7] @@ -214,7 +241,7 @@ def test_hfftn_ihfftn(self): def test_rfft_irfft(self): # n-D distributed - x = ht.random.randn(10, 8, 6, dtype=ht.float64, split=0) + x = ht.random.randn(10, 8, 3, dtype=ht.float64, split=0) # FFT along last axis y = ht.fft.rfft(x) np_y = np.fft.rfft(x.numpy()) @@ -227,7 +254,7 @@ def test_rfft_irfft(self): # exceptions # complex input - x = x + 1j * ht.random.randn(10, 8, 6, dtype=ht.float64, split=0) + x = x + 1j * ht.random.randn(10, 8, 3, dtype=ht.float64, split=0) with self.assertRaises(TypeError): ht.fft.rfft(x) From c46e5d8a866788978505a0297bf71bd95207d8d3 Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Thu, 16 Nov 2023 14:53:13 +0100 Subject: [PATCH 47/50] simplify dealing with multi-axis real FFT --- heat/fft/fft.py | 70 +++++++++++++++++++++++-------------------------- 1 file changed, 33 insertions(+), 37 deletions(-) diff --git a/heat/fft/fft.py b/heat/fft/fft.py index 9d364fd910..46582303dd 100644 --- a/heat/fft/fft.py +++ b/heat/fft/fft.py @@ -86,16 +86,15 @@ def __fft_op(x: DNDarray, fft_op: callable, **kwargs) -> DNDarray: ) else: torch_result = fft_op(local_x, n=n, dim=axis, norm=norm) - # return DNDarray( - # torch_result, - # gshape=tuple(output_shape), - # dtype=heat_type_of(torch_result), - # split=original_split, - # device=x.device, - # comm=x.comm, - # balanced=x.balanced, - # ) - return array(torch_result, is_split=original_split, device=x.device, comm=x.comm) + return DNDarray( + torch_result, + gshape=tuple(output_shape), + dtype=heat_type_of(torch_result), + split=original_split, + device=x.device, + comm=x.comm, + balanced=x.balanced, + ) # FFT along split axis if original_split != 0: @@ -222,18 +221,15 @@ def __fftn_op(x: DNDarray, fftn_op: callable, **kwargs) -> DNDarray: ) else: torch_result = fftn_op(local_x, **torch_kwargs) - # for axis in axes: - # output_shape[axis] = torch_result.shape[axis] - # return DNDarray( - # torch_result, - # gshape=tuple(output_shape), - # dtype=heat_type_of(torch_result), - # split=original_split, - # device=x.device, - # comm=x.comm, - # balanced=x.balanced, - # ) - return array(torch_result, is_split=original_split, device=x.device, comm=x.comm) + return DNDarray( + torch_result, + gshape=tuple(output_shape), + dtype=heat_type_of(torch_result), + split=original_split, + device=x.device, + comm=x.comm, + balanced=x.balanced, + ) # FFT along split axis if original_split != 0: @@ -253,10 +249,6 @@ def __fftn_op(x: DNDarray, fftn_op: callable, **kwargs) -> DNDarray: return ht_result # transform decomposition: split axis first, then the rest - # if fft operation requires real input, switch to generic operation: - if real_op: - fftn_op = real_to_generic_fftn_ops[fftn_op] - # redistribute x from axis 0 to 1 _ = x.resplit(axis=1) # FFT along axis 0 (now non-split) @@ -286,13 +278,11 @@ def __fftn_op(x: DNDarray, fftn_op: callable, **kwargs) -> DNDarray: s = list(s) s = s[:split_index] + s[split_index + 1 :] s = tuple(s) + # if fft operation requires real input, switch to generic operation for the second pass + if real_op: + fftn_op = real_to_generic_fftn_ops[fftn_op] ht_result = __fftn_op(partial_ht_result, fftn_op, s=s, axes=axes, norm=norm) del partial_ht_result - if real_op: - # discard elements beyond Nyquist frequency on last transformed axis - nyquist_slice = [slice(None)] * ht_result.ndim - nyquist_slice[axes[-1]] = slice(0, nyquist_freq) - ht_result = ht_result[(nyquist_slice)].balance_() return ht_result @@ -355,12 +345,18 @@ def __fftfreq_op(fftfreq_op: callable, **kwargs) -> DNDarray: def __real_fft_op(x: DNDarray, fft_op: callable, **kwargs) -> DNDarray: + """ + Helper function for real 1-D FFTs. + """ if x.larray.is_complex(): raise TypeError(f"Input array must be real, is {x.dtype}.") return __fft_op(x, fft_op, **kwargs) def __real_fftn_op(x: DNDarray, fftn_op: callable, **kwargs) -> DNDarray: + """ + Helper function for real N-D FFTs. + """ if x.larray.is_complex(): raise TypeError(f"Input array must be real, is {x.dtype}.") return __fftn_op(x, fftn_op, **kwargs) @@ -379,11 +375,11 @@ def fft(x: DNDarray, n: int = None, axis: int = -1, norm: str = None) -> DNDarra x : DNDarray Input array, can be complex. WARNING: If x is 1-D and distributed, the entire array is copied on each MPI process. n : int, optional - Length of the transformed axis of the output. If not given, the length is taken to be the length of the input - along the axis specified by axis. If `n` is smaller than the length of the input, the input is cropped. If `n` is + Length of the transformed axis of the output. If not given, the length is assumed to be the length of the input + along the axis specified by `axis`. If `n` is smaller than the length of the input, the input is truncated. If `n` is larger, the input is padded with zeros. Default: None. axis : int, optional - Axis over which to compute the FFT. If not given, the last axis is used, or the only axis if x has only one + Axis over which to compute the FFT. If not given, the last axis is used, or the only axis if `x` has only one dimension. Default: -1. norm : str, optional Normalization mode: 'forward', 'backward', or 'ortho' (see `numpy.fft` for details). Default is "backward". @@ -478,8 +474,7 @@ def fftn( Compute the N-dimensional discrete Fourier Transform. This function computes the N-dimensional discrete Fourier Transform over any number of axes in an M-dimensional - array by means of the Fast Fourier Transform (FFT). By default, all axes are transformed, with the real transform - performed over the last axis, while the remaining transforms are complex. + array by means of the Fast Fourier Transform (FFT). Parameters ---------- @@ -977,7 +972,8 @@ def rfftn( x: DNDarray, s: Tuple[int, int] = None, axes: Tuple[int, ...] = None, norm: str = None ) -> DNDarray: """ - Compute the N-dimensional discrete Fourier Transform for real input. + Compute the N-dimensional discrete Fourier Transform for real input. By default, all axes are transformed, with the real transform + performed over the last axis, while the remaining transforms are complex. Parameters ---------- From 6d0b13289d5161daff3a4e45a7d95cd4010a347f Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Thu, 16 Nov 2023 14:54:13 +0100 Subject: [PATCH 48/50] cannot be a list --- heat/core/stride_tricks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/heat/core/stride_tricks.py b/heat/core/stride_tricks.py index 10010d6f15..266a901044 100644 --- a/heat/core/stride_tricks.py +++ b/heat/core/stride_tricks.py @@ -151,7 +151,7 @@ def sanitize_axis( """ # scalars are handled like unsplit matrices - original_axis = axis.copy() if isinstance(axis, list) else axis + original_axis = axis ndim = len(shape) if ndim == 0: From 62f8c4cf82fbfd8a1fd2bdf149fc06d078a60f2f Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Thu, 16 Nov 2023 16:04:15 +0100 Subject: [PATCH 49/50] update documentation --- heat/fft/fft.py | 77 ++++++++++++++++++++++++++++++------------------- 1 file changed, 48 insertions(+), 29 deletions(-) diff --git a/heat/fft/fft.py b/heat/fft/fft.py index 46582303dd..55ec9093ae 100644 --- a/heat/fft/fft.py +++ b/heat/fft/fft.py @@ -288,13 +288,14 @@ def __fftn_op(x: DNDarray, fftn_op: callable, **kwargs) -> DNDarray: def __fftfreq_op(fftfreq_op: callable, **kwargs) -> DNDarray: """ - Helper function for fftfreq + Helper function for ``fftfreq`` and ``rfftfreq`` operations. """ n = kwargs.get("n", None) d = kwargs.get("d", None) dtype = kwargs.get("dtype", None) split = kwargs.get("split", None) device = kwargs.get("device", None) + comm = kwargs.get("comm", None) if not isinstance(n, int): raise ValueError(f"n must be an integer, is {type(n)}") @@ -316,7 +317,7 @@ def __fftfreq_op(fftfreq_op: callable, **kwargs) -> DNDarray: # early out for non-distributed fftfreq if split is None: - return array(fftfreq_op(n, d=d, dtype=torch_dtype), device=device, split=None) + return array(fftfreq_op(n, d=d, dtype=torch_dtype), device=device, split=None, comm=comm) # distributed fftfreq if split != 0: @@ -333,9 +334,9 @@ def __fftfreq_op(fftfreq_op: callable, **kwargs) -> DNDarray: # allocate global fftfreq array # if real operation, return only positive frequencies if fftfreq_op == torch.fft.rfftfreq: - freqs = arange(middle_channel, dtype=dtype, device=device, split=split) + freqs = arange(middle_channel, dtype=dtype, device=device, split=split, comm=comm) else: - freqs = arange(n, dtype=dtype, device=device, split=split) + freqs = arange(n, dtype=dtype, device=device, split=split, comm=comm) # second half of fftfreq returns negative frequencies in inverse order freqs[middle_channel:] -= n @@ -364,16 +365,14 @@ def __real_fftn_op(x: DNDarray, fftn_op: callable, **kwargs) -> DNDarray: def fft(x: DNDarray, n: int = None, axis: int = -1, norm: str = None) -> DNDarray: """ - Compute the one-dimensional discrete Fourier Transform. - - This function computes the one-dimensional discrete Fourier Transform over the specified axis in an M-dimensional + Compute the one-dimensional discrete Fourier Transform over the specified axis in an M-dimensional array by means of the Fast Fourier Transform (FFT). By default, the last axis is transformed, while the remaining axes are left unchanged. Parameters ---------- x : DNDarray - Input array, can be complex. WARNING: If x is 1-D and distributed, the entire array is copied on each MPI process. + Input array, can be complex. WARNING: If x is 1-D and distributed, the entire array is copied on each MPI process. See Notes. n : int, optional Length of the transformed axis of the output. If not given, the length is assumed to be the length of the input along the axis specified by `axis`. If `n` is smaller than the length of the input, the input is truncated. If `n` is @@ -382,13 +381,23 @@ def fft(x: DNDarray, n: int = None, axis: int = -1, norm: str = None) -> DNDarra Axis over which to compute the FFT. If not given, the last axis is used, or the only axis if `x` has only one dimension. Default: -1. norm : str, optional - Normalization mode: 'forward', 'backward', or 'ortho' (see `numpy.fft` for details). Default is "backward". + Normalization mode: 'forward', 'backward', or 'ortho'. Indicates in what direction the forward/backward pair of transforms is normalized. Default is "backward". + + See Also + -------- + :func:`ifft` : inverse 1-dimensional FFT + :func:`fft2` : 2-dimensional FFT + :func:`fftn` : N-dimensional FFT + :func:`rfft` : 1-dimensional FFT of a real signal + :func:`hfft` : 1-dimensional FFT of a Hermitian symmetric sequence + :func:`fftfreq` : frequency bins for given FFT parameters + :func:`rfftfreq` : frequency bins for real FFT Notes ----- This function requires MPI communication if the input array is transformed along the distribution axis. If the input array is 1-D and distributed, this function copies the entire array on each MPI process! i.e. if the array is very large, you might run out of memory. - Hint: if you are looping through a batch of 1-D arrays to transform them, consider stacking them into a 2-D DNDarray and transforming them all at once (see :func:`fft2`). + Hint: if you are looping through a batch of 1-D arrays to transform them, consider stacking them into a 2-D DNDarray and transforming them in one go (see :func:`fft2`). """ return __fft_op(x, torch.fft.fft, n=n, axis=axis, norm=norm) @@ -397,9 +406,7 @@ def fft2( x: DNDarray, s: Tuple[int, int] = None, axes: Tuple[int, int] = (-2, -1), norm: str = None ) -> DNDarray: """ - Compute the 2-dimensional discrete Fourier Transform. - - This function computes the 2-dimensional discrete Fourier Transform over the specified axes in an M-dimensional + Compute the 2-dimensional discrete Fourier Transform over the specified axes in an M-dimensional array by means of the Fast Fourier Transform (FFT). By default, the last two axes are transformed, while the remaining axes are left unchanged. @@ -414,7 +421,15 @@ def fft2( not specified. Repeated indices in `axes` means that the transform over that axis is performed multiple times. (default is (-2, -1)) norm : str, optional - Normalization mode: 'forward', 'backward', or 'ortho' (see `numpy.fft` for details). Default is "backward". + Normalization mode: 'forward', 'backward', or 'ortho'. Indicates in what direction the forward/backward pair of transforms is normalized. Default is "backward". + + See Also + -------- + :func:`ifft2` : inverse 2-dimensional FFT + :func:`fft` : 1-dimensional FFT + :func:`fftn` : N-dimensional FFT + :func:`rfft2` : 2-dimensional FFT of a real signal + :func:`hfft2` : 2-dimensional FFT of a Hermitian symmetric sequence Notes ----- @@ -429,11 +444,12 @@ def fftfreq( dtype: Optional[Type] = None, split: Optional[int] = None, device: Optional[Union[str, Device]] = None, + comm: Optional[MPI.Comm] = None, ) -> DNDarray: """ - Return the Discrete Fourier Transform sample frequencies. + Return the Discrete Fourier Transform sample frequencies for a signal of size ``n``. - The returned float tensor contains the frequency bin centers in cycles per unit of the sample spacing (with zero + The returned ``DNDarray`` contains the frequency bin centers in cycles per unit of the sample spacing (with zero at the start). For instance, if the sample spacing is in seconds, then the frequency unit is cycles/second. Parameters @@ -443,28 +459,26 @@ def fftfreq( d : Union[int, float], optional Sample spacing (inverse of the sampling rate). Defaults to 1. dtype : Type, optional - The desired data type of the output. Defaults to `float32`. + The desired data type of the output. Defaults to `ht.float32`. split : int, optional - The axis along which to split the result. If not given, the result is not split. + The axis along which to split the result. Can be None or 0, as the output is 1-dimensional. Defaults to None, i.e. non-distributed output. device : str or Device, optional The device on which to place the output. If not given, the output is placed on the current device. + comm : MPI.Comm, optional + The MPI communicator to use for distributing the output. If not given, the default communicator is used. Returns ------- out : DNDarray - Array of length `n` containing the sample frequencies. + Array of length ``n`` containing the sample frequencies. If ``split`` is 0, the array is evenly distributed among the available MPI processes. - Examples + See Also -------- - >>> import heat as ht - >>> ht.fft.fftfreq(5, 0.1) - DNDarray([-2., -1., 0., 1., 2.], dtype=ht.float32, device=cpu:0, split=None) - >>> ht.fft.fftfreq(5, 0.1, dtype=ht.float64) - DNDarray([-2., -1., 0., 1., 2.], dtype=ht.float64, device=cpu:0, split=None) - >>> ht.fft.fftfreq(5, 0.1, split=0) - DNDarray([-2., -1., 0., 1., 2.], dtype=ht.float32, device=cpu:0, split=0) + :func:`rfftfreq` : frequency bins for :func:`rfft` """ - return __fftfreq_op(torch.fft.fftfreq, n=n, d=d, dtype=dtype, split=split, device=device) + return __fftfreq_op( + torch.fft.fftfreq, n=n, d=d, dtype=dtype, split=split, device=device, comm=comm + ) def fftn( @@ -930,6 +944,7 @@ def rfftfreq( dtype: Optional[Type] = None, split: Optional[int] = None, device: Optional[Union[str, Device]] = None, + comm: Optional[MPI.Comm] = None, ) -> DNDarray: """ Return the Discrete Fourier Transform sample frequencies. @@ -949,6 +964,8 @@ def rfftfreq( The axis along which to split the result. If not given, the result is not split. device : str or Device, optional The device on which to place the output. If not given, the output is placed on the current device. + comm : MPI.Comm, optional + The MPI communicator to use for distributing the output. If not given, the default communicator is used. Returns ------- @@ -965,7 +982,9 @@ def rfftfreq( >>> ht.fft.rfftfreq(5, 0.1, split=0) DNDarray([0., 1., 2.], dtype=ht.float32, device=cpu:0, split=0) """ - return __fftfreq_op(torch.fft.rfftfreq, n=n, d=d, dtype=dtype, split=split, device=device) + return __fftfreq_op( + torch.fft.rfftfreq, n=n, d=d, dtype=dtype, split=split, device=device, comm=comm + ) def rfftn( From 0f349ba595c8840902b12827f1fe25522ea96fec Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Fri, 17 Nov 2023 15:28:09 +0100 Subject: [PATCH 50/50] edit docs --- heat/fft/fft.py | 229 ++++++++++++++++++++++++++++++++---------------- 1 file changed, 155 insertions(+), 74 deletions(-) diff --git a/heat/fft/fft.py b/heat/fft/fft.py index 55ec9093ae..d1b7d6e8eb 100644 --- a/heat/fft/fft.py +++ b/heat/fft/fft.py @@ -5,11 +5,11 @@ from ..core.communication import MPI from ..core.dndarray import DNDarray from ..core.stride_tricks import sanitize_axis -from ..core.types import heat_type_is_exact, promote_types, heat_type_of, canonical_heat_type +from ..core.types import heat_type_is_exact, heat_type_of from ..core.factories import array, arange from ..core.devices import Device -from typing import Type, Union, Tuple, Any, Iterable, Optional +from typing import Type, Union, Tuple, Iterable, Optional __all__ = [ "fft", @@ -418,7 +418,7 @@ def fft2( Shape of the output along the transformed axes. (default is x.shape) axes : Tuple[int, int], optional Axes over which to compute the FFT. If not given, the last `len(s)` axes are used, or all axes if `s` is also - not specified. Repeated indices in `axes` means that the transform over that axis is performed multiple times. + not specified. Repeated transforms over an axis, i.e. repeated indices in ``axes``, are not supported yet. (default is (-2, -1)) norm : str, optional Normalization mode: 'forward', 'backward', or 'ortho'. Indicates in what direction the forward/backward pair of transforms is normalized. Default is "backward". @@ -467,11 +467,6 @@ def fftfreq( comm : MPI.Comm, optional The MPI communicator to use for distributing the output. If not given, the default communicator is used. - Returns - ------- - out : DNDarray - Array of length ``n`` containing the sample frequencies. If ``split`` is 0, the array is evenly distributed among the available MPI processes. - See Also -------- :func:`rfftfreq` : frequency bins for :func:`rfft` @@ -498,10 +493,18 @@ def fftn( Shape of the output along the transformed axes. (default is x.shape) axes : Tuple[int, ...], optional Axes over which to compute the FFT. If not given, the last `len(s)` axes are used, or all axes if `s` is also - not specified. Repeated indices in `axes` means that the transform over that axis is performed multiple times. + not specified. Repeated transforms over an axis, i.e. repeated indices in ``axes``, are not supported yet. (default is None) norm : str, optional - Normalization mode: 'forward', 'backward', or 'ortho' (see `numpy.fft` for details). Default is "backward". + Normalization mode: 'forward', 'backward', or 'ortho'. Indicates in what direction the forward/backward pair of transforms is normalized. Default is "backward". + + See Also + -------- + :func:`ifftn` : inverse N-dimensional FFT + :func:`fft` : 1-dimensional FFT + :func:`fft2` : 2-dimensional FFT + :func:`rfftn` : N-dimensional FFT of a real signal + :func:`hfftn` : N-dimensional FFT of a Hermitian symmetric sequence Notes ----- @@ -524,6 +527,10 @@ def fftshift(x: DNDarray, axes: Optional[Union[int, Iterable[int]]] = None) -> D axes : int or Iterable[int], optional Axes over which to shift. Default is None, which shifts all axes. + See Also + -------- + :func:`ifftshift` : The inverse of `fftshift`. + Notes ----- This function requires MPI communication if the input array is distributed and the split axis is shifted. @@ -551,7 +558,15 @@ def hfft(x: DNDarray, n: int = None, axis: int = -1, norm: str = None) -> DNDarr Axis over which to compute the FFT. If not given, the last axis is used, or the only axis if x has only one dimension. Default: -1. norm : str, optional - Normalization mode: 'forward', 'backward', or 'ortho' (see `numpy.fft` for details). Default is "backward". + Normalization mode: 'forward', 'backward', or 'ortho'. Indicates in what direction the forward/backward pair of transforms is normalized. Default is "backward". + + See Also + -------- + :func:`ihfft` : inverse 1-dimensional FFT of a Hermitian-symmetric sequence + :func:`hfft2` : 2-dimensional FFT of a Hermitian-symmetric sequence + :func:`hfftn` : N-dimensional FFT of a Hermitian-symmetric sequence + :func:`fft` : 1-dimensional FFT + :func:`rfft` : 1-dimensional FFT of a real signal Notes ----- @@ -580,9 +595,17 @@ def hfft2( Shape of the signal along the transformed axes. If `s` is specified, the input array is either zero-padded or trimmed to length `s` before the transform. If `s` is not given, the last dimension defaults to even output: `s[-1] = 2 * (x.shape[-1] - 1)`. axes : Tuple[int, int], optional - Axes over which to compute the FFT. If not given, the last two dimensions are transformed. Repeated indices in `axes` means that the transform over that axis is performed multiple times. + Axes over which to compute the FFT. If not given, the last two dimensions are transformed. Repeated transforms over an axis, i.e. repeated indices in ``axes``, are not supported yet. Default: (-2, -1). norm : str, optional - Normalization mode: 'forward', 'backward', or 'ortho' (see `numpy.fft` for details). Default is "backward". + Normalization mode: 'forward', 'backward', or 'ortho'. Indicates in what direction the forward/backward pair of transforms is normalized. Default is "backward". + + See Also + -------- + :func:`ihfft2` : inverse 2-dimensional FFT of a Hermitian-symmetric sequence + :func:`hfft` : 1-dimensional FFT of a Hermitian-symmetric sequence + :func:`hfftn` : N-dimensional FFT of a Hermitian-symmetric sequence + :func:`fft2` : 2-dimensional FFT + :func:`rfft2` : 2-dimensional FFT of a real signal Notes ----- @@ -610,9 +633,17 @@ def hfftn( Shape of the signal along the transformed axes. If `s` is specified, the input array is either zero-padded or trimmed to length `s` before the transform. If `s` is not given, the last dimension defaults to even output: `s[-1] = 2 * (x.shape[-1] - 1)`. axes : Tuple[int, ...], optional - Axes over which to compute the FFT. If not given, all dimensions are transformed. Repeated indices in `axes` means that the transform over that axis is performed multiple times. + Axes over which to compute the FFT. If not given, all dimensions are transformed. Repeated transforms over an axis, i.e. repeated indices in ``axes``, are not supported yet. Default: None. norm : str, optional - Normalization mode: 'forward', 'backward', or 'ortho' (see `numpy.fft` for details). Default is "backward". + Normalization mode: 'forward', 'backward', or 'ortho'. Indicates in what direction the forward/backward pair of transforms is normalized. Default is "backward". + + See Also + -------- + :func:`ihfftn` : inverse N-dimensional FFT of a Hermitian-symmetric sequence + :func:`hfft` : 1-dimensional FFT of a Hermitian-symmetric sequence + :func:`hfft2` : 2-dimensional FFT of a Hermitian-symmetric sequence + :func:`fftn` : N-dimensional FFT + :func:`rfftn` : N-dimensional FFT of a real signal Notes ----- @@ -643,7 +674,15 @@ def ifft(x: DNDarray, n: int = None, axis: int = -1, norm: str = None) -> DNDarr axis : int, optional Axis over which to compute the inverse FFT. If not given, the last axis is used, or the only axis if x has only one dimension. Default: -1. norm : str, optional - Normalization mode: 'forward', 'backward', or 'ortho' (see `numpy.fft` for details). Default is "backward". + Normalization mode: 'forward', 'backward', or 'ortho'. Indicates in what direction the forward/backward pair of transforms is normalized. Default is "backward". + + See Also + -------- + :func:`fft` : forward 1-dimensional FFT + :func:`ifft2` : inverse 2-dimensional FFT + :func:`ifftn` : inverse N-dimensional FFT + :func:`irfft` : inverse 1-dimensional FFT of a real sequence + :func:`ihfft` : inverse 1-dimensional FFT of a Hermitian symmetric sequence Notes ----- @@ -668,10 +707,17 @@ def ifft2( Shape of the output along the transformed axes. (default is x.shape) axes : Tuple[int, int], optional Axes over which to compute the inverse FFT. If not given, the last `len(s)` axes are used, or all axes if `s` is - also not specified. Repeated indices in `axes` means that the transform over that axis is performed multiple - times. (default is (-2, -1)) + also not specified. Repeated transforms over an axis, i.e. repeated indices in ``axes``, are not supported yet. Default: (-2, -1). norm : str, optional - Normalization mode: 'forward', 'backward', or 'ortho' (see `numpy.fft` for details). Default is "backward". + Normalization mode: 'forward', 'backward', or 'ortho'. Indicates in what direction the forward/backward pair of transforms is normalized. Default is "backward". + + See Also + -------- + :func:`fft2` : forward 2-dimensional FFT + :func:`ifft` : inverse 1-dimensional FFT + :func:`ifftn` : inverse N-dimensional FFT + :func:`irfft2` : inverse 2-dimensional FFT of a real sequence + :func:`ihfft2` : inverse 2-dimensional FFT of a Hermitian symmetric sequence Notes ----- @@ -694,10 +740,17 @@ def ifftn( Shape of the output along the transformed axes. (default is x.shape) axes : Tuple[int, ...], optional Axes over which to compute the inverse FFT. If not given, the last `len(s)` axes are used, or all axes if `s` is - also not specified. Repeated indices in `axes` means that the transform over that axis is performed multiple - times. (default is None) + also not specified. Repeated transforms over an axis, i.e. repeated indices in ``axes``, are not supported yet. Default: None. norm : str, optional - Normalization mode: 'forward', 'backward', or 'ortho' (see `numpy.fft` for details). Default is "backward". + Normalization mode: 'forward', 'backward', or 'ortho'. Indicates in what direction the forward/backward pair of transforms is normalized. Default is "backward". + + See Also + -------- + :func:`fftn` : forward N-dimensional FFT + :func:`ifft` : inverse 1-dimensional FFT + :func:`ifft2` : inverse 2-dimensional FFT + :func:`irfftn` : inverse N-dimensional FFT of a real sequence + :func:`ihfftn` : inverse N-dimensional FFT of a Hermitian symmetric sequence Notes ----- @@ -708,8 +761,7 @@ def ifftn( def ifftshift(x: DNDarray, axes: Optional[Union[int, Iterable[int]]] = None) -> DNDarray: """ - TODO: reelaborate docs - The inverse of fftshift. Although identical for even-length x, the functions differ by one sample for odd-length x. + The inverse of fftshift. Parameters ---------- @@ -718,6 +770,10 @@ def ifftshift(x: DNDarray, axes: Optional[Union[int, Iterable[int]]] = None) -> axes : int or Iterable[int], optional Axes over which to shift. Default is None, which shifts all axes. + See Also + -------- + :func:`fftshift` : Shift the zero-frequency component to the center of the spectrum. + Notes ----- This function requires MPI communication if the input array is distributed and the split axis is shifted. @@ -740,7 +796,15 @@ def ihfft(x: DNDarray, n: int = None, axis: int = -1, norm: str = None) -> DNDar axis : int, optional Axis over which to compute the inverse FFT. If not given, the last axis is used, or the only axis if x has only one dimension. Default: -1. norm : str, optional - Normalization mode: 'forward', 'backward', or 'ortho' (see `numpy.fft` for details). Default is "backward". + Normalization mode: 'forward', 'backward', or 'ortho'. Indicates in what direction the forward/backward pair of transforms is normalized. Default is "backward". + + See Also + -------- + :func:`hfft` : 1-dimensional FFT of a Hermitian-symmetric sequence + :func:`ihfft2` : inverse 2-dimensional FFT of a Hermitian-symmetric sequence + :func:`ihfftn` : inverse N-dimensional FFT of a Hermitian-symmetric sequence + :func:`rfft` : 1-dimensional FFT of a real signal + :func:`irfft` : inverse 1-dimensional FFT of a real sequence Notes ----- @@ -753,7 +817,7 @@ def ihfft2( x: DNDarray, s: Tuple[int, int] = None, axes: Tuple[int, int] = (-2, -1), norm: str = None ) -> DNDarray: """ - Compute the 2-dimensional inverse discrete Fourier Transform of a real signal. The output is Hermitian-symmetric. + Compute the inverse of a 2-dimensional discrete Fourier Transform of a Hermitian-symmetric signal. The output is Hermitian-symmetric. Parameters ---------- @@ -763,10 +827,17 @@ def ihfft2( Shape of the output along the transformed axes. (default is x.shape) axes : Tuple[int, int], optional Axes over which to compute the inverse FFT. If not given, the last `len(s)` axes are used, or all axes if `s` is - also not specified. Repeated indices in `axes` means that the transform over that axis is performed multiple - times. (default is (-2, -1)) + also not specified. Repeated transforms over an axis, i.e. repeated indices in ``axes``, are not supported yet. Default is (-2, -1). norm : str, optional - Normalization mode: 'forward', 'backward', or 'ortho' (see `numpy.fft` for details). Default is "backward". + Normalization mode: 'forward', 'backward', or 'ortho'. Indicates in what direction the forward/backward pair of transforms is normalized. Default is "backward". + + See Also + -------- + :func:`hfft2` : 2-dimensional FFT of a Hermitian-symmetric sequence + :func:`ihfft` : inverse 1-dimensional FFT of a Hermitian-symmetric sequence + :func:`ihfftn` : inverse N-dimensional FFT of a Hermitian-symmetric sequence + :func:`rfft2` : 2-dimensional FFT of a real signal + :func:`irfft2` : inverse 2-dimensional FFT of a real sequence Notes ----- @@ -779,7 +850,7 @@ def ihfftn( x: DNDarray, s: Tuple[int, ...] = None, axes: Tuple[int, ...] = None, norm: str = None ) -> DNDarray: """ - Compute the N-dimensional inverse discrete Fourier Transform of a real signal. The output is Hermitian-symmetric. + Compute the inverse of a N-dimensional discrete Fourier Transform of Hermitian-symmetric signal. The output is Hermitian-symmetric. Parameters ---------- @@ -789,17 +860,28 @@ def ihfftn( Shape of the output along the transformed axes. (default is x.shape) axes : Tuple[int, ...], optional Axes over which to compute the inverse FFT. If not given, the last `len(s)` axes are used, or all axes if `s` is - also not specified. Repeated indices in `axes` means that the transform over that axis is performed multiple - times. (default is None) + also not specified. Repeated transforms over an axis, i.e. repeated indices in ``axes``, are not supported yet. Default: None. norm : str, optional - Normalization mode: 'forward', 'backward', or 'ortho' (see `numpy.fft` for details). Default is "backward". + Normalization mode: 'forward', 'backward', or 'ortho'. Indicates in what direction the forward/backward pair of transforms is normalized. Default is "backward". + + See Also + -------- + :func:`hfftn` : N-dimensional FFT of a Hermitian-symmetric sequence + :func:`ihfft` : inverse 1-dimensional FFT of a Hermitian-symmetric sequence + :func:`ihfft2` : inverse 2-dimensional FFT of a Hermitian-symmetric sequence + :func:`rfftn` : N-dimensional FFT of a real signal + :func:`irfftn` : inverse N-dimensional FFT of a real sequence + + Notes + ----- + This function requires MPI communication if the input array is distributed and the split axis is transformed. """ return __real_fftn_op(x, torch.fft.ihfftn, s=s, axes=axes, norm=norm) def irfft(x: DNDarray, n: int = None, axis: int = -1, norm: str = None) -> DNDarray: """ - Compute the one-dimensional inverse discrete Fourier Transform for real input. + Compute the inverse of a one-dimensional discrete Fourier Transform of real signal. The output is real. Parameters ---------- @@ -812,7 +894,15 @@ def irfft(x: DNDarray, n: int = None, axis: int = -1, norm: str = None) -> DNDar axis : int, optional Axis over which to compute the inverse FFT. If not given, the last axis is used, or the only axis if x has only one dimension. Default: -1. norm : str, optional - Normalization mode: 'forward', 'backward', or 'ortho' (see `numpy.fft` for details). Default is "backward". + Normalization mode: 'forward', 'backward', or 'ortho'. Indicates in what direction the forward/backward pair of transforms is normalized. Default is "backward". + + See Also + -------- + :func:`irfft2` : inverse 2-dimensional FFT + :func:`irfftn` : inverse N-dimensional FFT + :func:`rfft` : 1-dimensional FFT of a real signal + :func:`hfft` : 1-dimensional FFT of a Hermitian symmetric sequence + :func:`fft` : 1-dimensional FFT Notes ----- @@ -829,20 +919,27 @@ def irfft2( x: DNDarray, s: Tuple[int, int] = None, axes: Tuple[int, int] = (-2, -1), norm: str = None ) -> DNDarray: """ - Compute the 2-dimensional inverse discrete Fourier Transform for real input. + Compute the inverse of a 2-dimensional discrete real Fourier Transform. The output is real. Parameters ---------- x : DNDarray Input array, can be complex s : Tuple[int, int], optional - Shape of the output along the transformed axes. (default is x.shape) + Shape of the output along the transformed axes. axes : Tuple[int, int], optional Axes over which to compute the inverse FFT. If not given, the last `len(s)` axes are used, or all axes if `s` is - also not specified. Repeated indices in `axes` means that the transform over that axis is performed multiple - times. (default is (-2, -1)) + also not specified. Repeated transforms over an axis, i.e. repeated indices in ``axes``, are not supported yet. Default is (-2, -1)) norm : str, optional - Normalization mode: 'forward', 'backward', or 'ortho' (see `numpy.fft` for details). Default is "backward". + Normalization mode: 'forward', 'backward', or 'ortho'. Indicates in what direction the forward/backward pair of transforms is normalized. Default is "backward". + + See Also + -------- + :func:`irfft` : inverse 1-dimensional FFT + :func:`irfftn` : inverse N-dimensional FFT + :func:`rfft2` : 2-dimensional FFT of a real signal + :func:`hfft2` : 2-dimensional FFT of a Hermitian symmetric sequence + :func:`fft2` : 2-dimensional FFT Notes ----- @@ -857,20 +954,21 @@ def irfftn( x: DNDarray, s: Tuple[int, int] = None, axes: Tuple[int, ...] = None, norm: str = None ) -> DNDarray: """ - Compute the N-dimensional inverse discrete Fourier Transform for real input. + Compute the inverse of an N-dimensional discrete Fourier Transform of real signal. + The output is real. Parameters ---------- x : DNDarray - Input array, can be complex + Input array, assumed to be Hermitian-symmetric along the transformed axes, with the last transformed axis only containing the positive half of the frequencies. s : Tuple[int, ...], optional - Shape of the output along the transformed axes. (default is x.shape) + Shape of the output along the transformed axes. If ``s`` is not specified, the last transposed axis is reconstructued in full, i.e. `s[-1] = 2 * (x.shape[axes[-1]] - 1)`. axes : Tuple[int, ...], optional Axes over which to compute the inverse FFT. If not given, the last `len(s)` axes are used, or all axes if `s` is - also not specified. Repeated indices in `axes` means that the transform over that axis is performed multiple - times. (default is None) + also not specified. Repeated transforms over an axis, i.e. repeated indices in ``axes``, are not supported yet. + (default is None) norm : str, optional - Normalization mode: 'forward', 'backward', or 'ortho' (see `numpy.fft` for details). Default is "backward". + Normalization mode: 'forward', 'backward', or 'ortho'. Indicates in what direction the forward/backward pair of transforms is normalized. Default is "backward". Notes ----- @@ -888,7 +986,7 @@ def irfftn( def rfft(x: DNDarray, n: int = None, axis: int = -1, norm: str = None) -> DNDarray: """ - Compute the one-dimensional discrete Fourier Transform for real input. + Compute the one-dimensional discrete Fourier Transform of real input. The output is Hermitian-symmetric. Parameters ---------- @@ -901,7 +999,7 @@ def rfft(x: DNDarray, n: int = None, axis: int = -1, norm: str = None) -> DNDarr axis : int, optional Axis over which to compute the FFT. If not given, the last axis is used, or the only axis if x has only one dimension. Default: -1. norm : str, optional - Normalization mode: 'forward', 'backward', or 'ortho' (see `numpy.fft` for details). Default is "backward". + Normalization mode: 'forward', 'backward', or 'ortho'. Indicates in what direction the forward/backward pair of transforms is normalized. Default is "backward". Notes ----- @@ -916,7 +1014,7 @@ def rfft2( x: DNDarray, s: Tuple[int, int] = None, axes: Tuple[int, int] = (-2, -1), norm: str = None ) -> DNDarray: """ - Compute the 2-dimensional discrete Fourier Transform for real input. + Compute the 2-dimensional discrete Fourier Transform of real input. The output is Hermitian-symmetric. Parameters ---------- @@ -926,10 +1024,9 @@ def rfft2( Shape of the output along the transformed axes. (default is x.shape) axes : Tuple[int, int], optional Axes over which to compute the FFT. If not given, the last `len(s)` axes are used, or all axes if `s` is - also not specified. Repeated indices in `axes` means that the transform over that axis is performed multiple - times. (default is (-2, -1)) + also not specified. Repeated transforms over an axis, i.e. repeated indices in ``axes``, are not supported yet. (default is (-2, -1)) norm : str, optional - Normalization mode: 'forward', 'backward', or 'ortho' (see `numpy.fft` for details). Default is "backward". + Normalization mode: 'forward', 'backward', or 'ortho'. Indicates in what direction the forward/backward pair of transforms is normalized. Default is "backward". Notes ----- @@ -949,7 +1046,7 @@ def rfftfreq( """ Return the Discrete Fourier Transform sample frequencies. - The returned float tensor contains the frequency bin centers in cycles per unit of the sample spacing (with zero + The returned float DNDarray contains the frequency bin centers in cycles per unit of the sample spacing (with zero at the start). For instance, if the sample spacing is in seconds, then the frequency unit is cycles/second. Parameters @@ -966,21 +1063,6 @@ def rfftfreq( The device on which to place the output. If not given, the output is placed on the current device. comm : MPI.Comm, optional The MPI communicator to use for distributing the output. If not given, the default communicator is used. - - Returns - ------- - out : DNDarray - Array of length `n` containing the sample frequencies. - - Examples - -------- - >>> import heat as ht - >>> ht.fft.rfftfreq(5, 0.1) - DNDarray([0., 1., 2.], dtype=ht.float32, device=cpu:0, split=None) - >>> ht.fft.rfftfreq(5, 0.1, dtype=ht.float64) - DNDarray([0., 1., 2.], dtype=ht.float64, device=cpu:0, split=None) - >>> ht.fft.rfftfreq(5, 0.1, split=0) - DNDarray([0., 1., 2.], dtype=ht.float32, device=cpu:0, split=0) """ return __fftfreq_op( torch.fft.rfftfreq, n=n, d=d, dtype=dtype, split=split, device=device, comm=comm @@ -991,21 +1073,20 @@ def rfftn( x: DNDarray, s: Tuple[int, int] = None, axes: Tuple[int, ...] = None, norm: str = None ) -> DNDarray: """ - Compute the N-dimensional discrete Fourier Transform for real input. By default, all axes are transformed, with the real transform - performed over the last axis, while the remaining transforms are complex. + Compute the N-dimensional discrete Fourier Transform of real input. By default, all axes are transformed, with the real transform + performed over the last axis, while the remaining transforms are complex. The output is Hermitian-symmetric, with the last transformed axis having length `s[-1] // 2 + 1` (the positive part of the spectrum). Parameters ---------- x : DNDarray Input array, must be real. s : Tuple[int, ...], optional - Shape of the output along the transformed axes. (default is x.shape) + Shape of the output along the transformed axes. axes : Tuple[int, ...], optional Axes over which to compute the FFT. If not given, the last `len(s)` axes are used, or all axes if `s` is - also not specified. Repeated indices in `axes` means that the transform over that axis is performed multiple - times. (default is None) + also not specified. Repeated transforms over an axis, i.e. repeated indices in ``axes``, are not supported yet. (default is None) norm : str, optional - Normalization mode: 'forward', 'backward', or 'ortho' (see `numpy.fft` for details). Default is "backward". + Normalization mode: 'forward', 'backward', or 'ortho'. Indicates in what direction the forward/backward pair of transforms is normalized. Default is "backward". Notes -----