-
Notifications
You must be signed in to change notification settings - Fork 71
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Question concerning FFT operation. #4
Comments
Great question! There may be two interpretations to your question - how we
make it causal, and how we compute a correct FFT. I’ll try to answer both :)
For causal language modeling, we want a causal convolution, so we pad the
input and kernel with zeros (this is what that fft_size = seqlen * 2 is
doing). Then the line at the end that truncates the output to seqlen size
([…, :seqlen]) cuts off the extra to get to a causal convolution.
You may be asking about bit shifting - in which case, torch.fft computes
does the correct shifting in both directions. You can check that the output
of a torch.fft returns a correctly shifted answer!
…On Sun, Mar 12, 2023 at 7:57 AM veritas9872 ***@***.***> wrote:
https://github.com/HazyResearch/safari/blob/9ecfaf0e49630b5913fce19adec231b41c2e0e39/standalone_hyena.py#L17C2-L21
Hello. Thank you for the great work in this paper. I only have a minor
question concerning the code.
When performing the FFT, it is my understanding that the inputs should be
shifted before and after the operation to be equivalent to the DFT.
Therefore, fftshift(fft(ifftshift(x))) and fftshift(ifft(ifftshift(X)))
are the correct methods.
Because the rfft function removes half of the frequency space, I believe
that the correct transformation should be rfft(ifftshift(x)) and
fftshift(irfft(X)) for the conversions to and from the frequency domain.
This may not impact the model performance, and there may be no great
difference in the outputs, but I believe that it may be worth noting.
I have included the following links for reference.
https://groups.google.com/g/comp.soft-sys.matlab/c/rUcc0bRRZf4?pli=1
https://dsp.stackexchange.com/questions/66716/why-do-we-have-to-rearrange-a-vector-and-shift-the-zero-point-to-the-first-index
—
Reply to this email directly, view it on GitHub
<#4>, or unsubscribe
<https://github.com/notifications/unsubscribe-auth/ABDDIIRITKBGY3EMMK6OCCDW3W227ANCNFSM6AAAAAAVYCHUWU>
.
You are receiving this because you are subscribed to this thread.Message
ID: ***@***.***>
|
Thank you for the quick response! Due to an implementation issue, the FFT and IFFT require center frequency shifting to accurately calculate the DFT. This discussion may also be helpful. pytorch/pytorch#51022 |
From looking at your first link, I believe this is a difference between the
MATLAB and PyTorch implementations. You can check the result of FFT conv
compared to manually computing the convolution, and it should be within
numerical error.
…On Sun, Mar 12, 2023 at 9:34 AM veritas9872 ***@***.***> wrote:
Thank you for the quick response!
I think that my question is slightly different.
The FFTShift and IFFTShift operations move the low-frequency regions to
the center of the sequence.
[image: image]
<https://user-images.githubusercontent.com/33523965/224548039-61a4a934-7419-459d-9a2b-c999e8922412.png>
Due to an implementation issue, the FFT and IFFT require center frequency
shifting to accurately calculate the DFT.
While this may be canceled out, I was curious if this might affect the
result.
—
Reply to this email directly, view it on GitHub
<#4 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/ABDDIIVBWSVVK7Y6EUPOSPLW3XGFHANCNFSM6AAAAAAVYCHUWU>
.
You are receiving this because you commented.Message ID:
***@***.***>
|
afaik, this is due to the fact that MATLAB arrays are 1-indexed, which forced many communities working with MATLAB to adopt the |
I have tested the function and I believe that this is indeed the issue. The following code does indeed show that shifting is unnecessary for FFT in PyTorch. Thank you for your help! from scipy import signal
import torch
import numpy as np
@torch.inference_mode()
def test1():
seq_len = 13
a = np.random.rand(seq_len)
b = np.random.rand(seq_len)
c = signal.convolve(a, b, mode='full', method='direct')
d = torch.fft.rfft(torch.from_numpy(a), n=2 * seq_len) / (2 * seq_len)
e = torch.fft.rfft(torch.from_numpy(b), n=2 * seq_len)
f = torch.fft.irfft(d * e, n=2 * seq_len, norm='forward').numpy()[:-1]
print(np.allclose(c, f)) # True
@torch.inference_mode()
def test2():
seq_len = 13
a = np.random.rand(seq_len)
b = np.random.rand(seq_len)
c = signal.convolve(a, b, mode='full', method='direct')
d = torch.fft.rfft(torch.fft.ifftshift(torch.from_numpy(a)), n=2 * seq_len) / (2 * seq_len)
e = torch.fft.rfft(torch.fft.ifftshift(torch.from_numpy(b)), n=2 * seq_len)
f = torch.fft.fftshift(torch.fft.irfft(d * e, n=2 * seq_len, norm='forward')).numpy()[:-1]
print(np.allclose(c, f)) # False |
The PyTorch and NumPy functions produce identical results. The MATLAB implementation does seem to have been the issue. |
Another question though. Is taking the front of the resultant convolved sequence the desired behavior? I believe that the middle part, corresponding to The resulting code would be as follows. seqlen = u.shape[-1]
fft_size = 2 * seqlen
k_f = torch.fft.rfft(k, n=fft_size, norm='forward')
u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size, norm='backward') # Explicit norm mode for better readability.
if len(u.shape) > 3: k_f = k_f.unsqueeze(1)
y = torch.fft.irfft(u_f * k_f, n=fft_size, norm='forward')[seqlen//2:seqlen//2+seqlen] |
Thanks for verifying! Could you elaborate as to why that would be more desirable? If you don't take the first |
I see that the desired result is to take only the first part of the output sequence, instead of the region with the maximum overlap. Thank you for the explanation! |
Hello, can you kindly explain why let L = L * 2 make it casual .. "For causal language modeling, we want a causal convolution, so we pad the input and kernel with zeros (this is what that fft_size = seqlen * 2 is doing)" where the padding operation is performed? thank you |
This line sets the FFT size:
https://github.com/HazyResearch/safari/blob/9ecfaf0e49630b5913fce19adec231b41c2e0e39/standalone_hyena.py#L15
The torch fft does the padding implicitly. (As does flash fft conv).
…On Thu, Sep 5, 2024 at 3:59 PM 0205090923 ***@***.***> wrote:
Hello, can you kindly explain why let L = L * 2 make it casual .. "For
causal language modeling, we want a causal convolution, so we pad the input
and kernel with zeros (this is what that fft_size = seqlen * 2 is doing)"
where the padding operation is performed? thank you
—
Reply to this email directly, view it on GitHub
<#4 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/ABDDIISUTC7432A44KGTJ4LZVBITXAVCNFSM6AAAAABNWOEF3GVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDGMZRGYYTSMRXGU>
.
You are receiving this because you commented.Message ID:
***@***.***>
|
thank you for your explanation. Setting n=2 * L completes the input and k to a length of 2L. However, it seems to be done at the end of the sequence, which is inconsistent with the normal operation of adding 0 at the beginning and then truncating the previous seqlen in nn.conv1d. I cant understand how it ensures casual. Can you kindly explain it? Thank you |
See this blog post, it has to do with the mathematics of the FFT:
https://hazyresearch.stanford.edu/blog/2023-12-11-conv-tutorial
nn.conv1d padding operates differently.
…On Thu, Sep 5, 2024 at 5:39 PM 0205090923 ***@***.***> wrote:
thank you for your explanation. Setting n=2 * L completes the input and k
to a length of 2L. However, it seems to be done at the end of the sequence,
which is inconsistent with the normal operation of adding 0 at the
beginning and then truncating the previous seqlen in nn.conv1d. I cant
understand how it ensures casual. Can you kindly explain it? Thank you
—
Reply to this email directly, view it on GitHub
<#4 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/ABDDIITA46YMSESALLCT563ZVBUL3AVCNFSM6AAAAABNWOEF3GVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDGMZRHA3DOMRWGU>
.
You are receiving this because you commented.Message ID:
***@***.***>
|
safari/standalone_hyena.py
Lines 17 to 21 in 9ecfaf0
Hello. Thank you for the great work in this paper. I only have a minor question concerning the code.
When performing the FFT, it is my understanding that the inputs should be shifted before and after the operation to be equivalent to the DFT.
Therefore,
fftshift(fft(ifftshift(x)))
andfftshift(ifft(ifftshift(X)))
are the correct methods.Because the
rfft
function removes half of the frequency space, I believe that the correct transformation should berfft(ifftshift(x))
andfftshift(irfft(X))
for the conversions to and from the frequency domain. This may not impact the model performance, and there may be no great difference in the outputs, but I believe that it may be worth noting.I have included the following links for reference.
https://groups.google.com/g/comp.soft-sys.matlab/c/rUcc0bRRZf4?pli=1
https://dsp.stackexchange.com/questions/66716/why-do-we-have-to-rearrange-a-vector-and-shift-the-zero-point-to-the-first-index
The text was updated successfully, but these errors were encountered: