Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distributed Fast Fourier Transforms #1218

Merged
merged 65 commits into from
Nov 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
65 commits
Select commit Hold shift + click to select a range
3bc7037
implement fftn, first draft
ClaudiaComito Sep 13, 2023
5190e00
implement general , add
ClaudiaComito Sep 14, 2023
e0e48de
split fft_op and fftn_op, implement inverse and real fft
ClaudiaComito Sep 14, 2023
9a1cb98
add TODO hermitian fft
ClaudiaComito Sep 14, 2023
05af244
add fft tests first draft
ClaudiaComito Sep 14, 2023
2ce5a09
implement tests first draft
ClaudiaComito Sep 18, 2023
a3cdcc3
update fft/__init__.py
ClaudiaComito Sep 18, 2023
ef98d05
Merge branch 'main' into features/1097-Provide_Fast_Fourier_Transform…
ClaudiaComito Sep 18, 2023
8afe87a
Merge branch 'main' into features/1097-Provide_Fast_Fourier_Transform…
mrfh92 Oct 5, 2023
2e6c411
Merge branch 'main' into features/1097-Provide_Fast_Fourier_Transform…
mrfh92 Oct 5, 2023
e427baa
Merge branch 'features/1097-Provide_Fast_Fourier_Transform_FFT' of gi…
ClaudiaComito Oct 11, 2023
4c5d057
Merge branch 'main' into features/1097-Provide_Fast_Fourier_Transform…
ClaudiaComito Oct 11, 2023
e5e015c
expand tests fft
ClaudiaComito Oct 11, 2023
c76d61c
expand tests ffftn
ClaudiaComito Oct 11, 2023
2be8a24
expand tests and fix errors
ClaudiaComito Oct 13, 2023
f8dd67e
Merge branch 'main' into features/1097-Provide_Fast_Fourier_Transform…
ClaudiaComito Oct 13, 2023
74ae80e
add Hermitian FFTs
ClaudiaComito Oct 13, 2023
35b0d43
heat/fft/tests/test_fft.py
ClaudiaComito Oct 13, 2023
4e7ae87
Merge branch 'main' into features/1097-Provide_Fast_Fourier_Transform…
ClaudiaComito Oct 16, 2023
f578522
raise IndexError, not ValueError, when axes don't match dimensions
ClaudiaComito Oct 16, 2023
23fea9e
expand tests
ClaudiaComito Oct 16, 2023
75e289f
edit error message for better understanding
ClaudiaComito Oct 16, 2023
02b0a29
replace == with allclose for 2D FFTs
ClaudiaComito Oct 18, 2023
c464f92
Merge branch 'main' into features/1097-Provide_Fast_Fourier_Transform…
ClaudiaComito Oct 18, 2023
47a2fe8
fix error
ClaudiaComito Oct 18, 2023
e2c10dd
remove redundant communication
ClaudiaComito Oct 19, 2023
3f0eac6
remove redundant tests
ClaudiaComito Oct 19, 2023
585bac3
fix bug in axes handling
ClaudiaComito Oct 19, 2023
432d6ff
test hermitian FFT
ClaudiaComito Oct 19, 2023
51e7b33
cast numpy fft2 to complex64
ClaudiaComito Oct 19, 2023
f241144
expand tests
ClaudiaComito Oct 19, 2023
eee0001
edit error messages
ClaudiaComito Oct 19, 2023
acfbea3
remove unnecessary axis check
ClaudiaComito Oct 19, 2023
6d6b0fd
test inverse ffts as well
ClaudiaComito Oct 19, 2023
1cd62f7
skip comm-intensive tests on gpu
ClaudiaComito Oct 19, 2023
ba73d62
introduce helper functions for real fft operations
ClaudiaComito Oct 19, 2023
d526700
fix output shape wrt Nyquist frequency
ClaudiaComito Oct 24, 2023
e27243e
double precision tests
ClaudiaComito Oct 24, 2023
60eb72e
debugging: introduce synchronization
ClaudiaComito Oct 24, 2023
dedc752
Merge branch 'main' into features/1097-Provide_Fast_Fourier_Transform…
ClaudiaComito Oct 24, 2023
c56169a
debugging
ClaudiaComito Oct 24, 2023
5b7d134
debugging
ClaudiaComito Oct 24, 2023
4005afa
debugging
ClaudiaComito Oct 24, 2023
8bea8c8
specify default even size of last fft dim for inverse real ops
ClaudiaComito Oct 25, 2023
50b8e41
add tests
ClaudiaComito Oct 25, 2023
3e6715b
fix output shape calc when input is real
ClaudiaComito Oct 25, 2023
05d405d
expand tests
ClaudiaComito Oct 25, 2023
718edc0
expand tests
ClaudiaComito Oct 25, 2023
ae5576a
expand tests
ClaudiaComito Oct 25, 2023
2d83337
add ihfft2, ihfft
ClaudiaComito Oct 25, 2023
010b757
expand tests
ClaudiaComito Oct 25, 2023
0fb9d2f
Merge branch 'main' into features/1097-Provide_Fast_Fourier_Transform…
ClaudiaComito Oct 30, 2023
d5ec286
implement fftfreq, fftshift operations and tests
ClaudiaComito Nov 1, 2023
96c1462
Merge branch 'main' into features/1097-Provide_Fast_Fourier_Transform…
ClaudiaComito Nov 7, 2023
0993f99
fix local output dtype mismatch when local input tensor is empty
ClaudiaComito Nov 13, 2023
5f58dee
remove print statements
ClaudiaComito Nov 13, 2023
371abe3
expand tests
ClaudiaComito Nov 14, 2023
c46e5d8
simplify dealing with multi-axis real FFT
ClaudiaComito Nov 16, 2023
6d0b132
cannot be a list
ClaudiaComito Nov 16, 2023
62f8c4c
update documentation
ClaudiaComito Nov 16, 2023
0f349ba
edit docs
ClaudiaComito Nov 17, 2023
23116c7
Merge branch 'main' into features/1097-Provide_Fast_Fourier_Transform…
ClaudiaComito Nov 21, 2023
52aefdc
Merge branch 'main' into features/1097-Provide_Fast_Fourier_Transform…
ClaudiaComito Nov 22, 2023
83c7b1a
Merge branch 'main' into features/1097-Provide_Fast_Fourier_Transform…
ClaudiaComito Nov 22, 2023
e4409cb
Merge branch 'main' into features/1097-Provide_Fast_Fourier_Transform…
ClaudiaComito Nov 24, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions heat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from . import core
from . import classification
from . import cluster
from . import fft
from . import graph
from . import naive_bayes
from . import nn
Expand Down
11 changes: 8 additions & 3 deletions heat/core/stride_tricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,10 @@ def sanitize_axis(

"""
# scalars are handled like unsplit matrices
if len(shape) == 0:
original_axis = axis
ndim = len(shape)

if ndim == 0:
axis = None

if axis is not None and not isinstance(axis, int) and not isinstance(axis, tuple):
Expand All @@ -160,7 +163,9 @@ def sanitize_axis(
axis = tuple(dim + len(shape) if dim < 0 else dim for dim in axis)
for dim in axis:
if dim < 0 or dim >= len(shape):
raise ValueError(f"axis {axis} is out of bounds for shape {shape}")
raise ValueError(
f"axis {original_axis} is out of bounds for {ndim}-dimensional array"
)
return axis

if axis is None or 0 <= axis < len(shape):
Expand All @@ -169,7 +174,7 @@ def sanitize_axis(
axis += len(shape)

if axis < 0 or axis >= len(shape):
raise ValueError(f"axis {axis} is out of bounds for shape {shape}")
raise ValueError(f"axis {original_axis} is out of bounds for {ndim}-dimensional array")

return axis

Expand Down
8 changes: 4 additions & 4 deletions heat/core/tests/test_suites/basic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,11 @@ def assert_array_equal(self, heat_array, expected_array):
f"Local shapes do not match. Got {heat_array.lshape} expected {expected_array[slices].shape}",
)
# compare local tensors to corresponding slice of expected_array
is_allclose = np.allclose(heat_array.larray.cpu(), expected_array[slices])
ht_is_allclose = ht.array(
[is_allclose], dtype=ht.bool, is_split=0, device=heat_array.device
is_allclose = torch.tensor(
np.allclose(heat_array.larray.cpu(), expected_array[slices]), dtype=torch.int32
)
self.assertTrue(ht.all(ht_is_allclose))
heat_array.comm.Allreduce(MPI.IN_PLACE, is_allclose, MPI.SUM)
self.assertTrue(is_allclose == heat_array.comm.size)

def assert_func_equal(
self,
Expand Down
5 changes: 5 additions & 0 deletions heat/fft/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""
import the fft functions into the fft namespace
"""

from .fft import *
Loading
Loading