Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question concerning FFT operation. #4

Closed
veritas9872 opened this issue Mar 12, 2023 · 13 comments
Closed

Question concerning FFT operation. #4

veritas9872 opened this issue Mar 12, 2023 · 13 comments

Comments

@veritas9872
Copy link

veritas9872 commented Mar 12, 2023

k_f = torch.fft.rfft(k, n=fft_size) / fft_size
u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size)
if len(u.shape) > 3: k_f = k_f.unsqueeze(1)
y = torch.fft.irfft(u_f * k_f, n=fft_size, norm='forward')[..., :seqlen]

Hello. Thank you for the great work in this paper. I only have a minor question concerning the code.

When performing the FFT, it is my understanding that the inputs should be shifted before and after the operation to be equivalent to the DFT.

Therefore, fftshift(fft(ifftshift(x))) and fftshift(ifft(ifftshift(X))) are the correct methods.

Because the rfft function removes half of the frequency space, I believe that the correct transformation should be rfft(ifftshift(x)) and fftshift(irfft(X)) for the conversions to and from the frequency domain. This may not impact the model performance, and there may be no great difference in the outputs, but I believe that it may be worth noting.

I have included the following links for reference.

https://groups.google.com/g/comp.soft-sys.matlab/c/rUcc0bRRZf4?pli=1

https://dsp.stackexchange.com/questions/66716/why-do-we-have-to-rearrange-a-vector-and-shift-the-zero-point-to-the-first-index

@DanFu09
Copy link
Contributor

DanFu09 commented Mar 12, 2023 via email

@veritas9872
Copy link
Author

veritas9872 commented Mar 12, 2023

Thank you for the quick response!
I think that my question is slightly different.
The FFTShift and IFFTShift operations move the low-frequency regions to the center of the sequence.
image

Due to an implementation issue, the FFT and IFFT require center frequency shifting to accurately calculate the DFT.
While this may be canceled out, I was curious if this might affect the result.

This discussion may also be helpful. pytorch/pytorch#51022

@DanFu09
Copy link
Contributor

DanFu09 commented Mar 12, 2023 via email

@Zymrael
Copy link
Contributor

Zymrael commented Mar 12, 2023

afaik, this is due to the fact that MATLAB arrays are 1-indexed, which forced many communities working with MATLAB to adopt the fftshift + centered DFT convention. You don't need fftshift in PyTorch code for the DFT result to be right.

@veritas9872
Copy link
Author

veritas9872 commented Mar 13, 2023

I have tested the function and I believe that this is indeed the issue.

The following code does indeed show that shifting is unnecessary for FFT in PyTorch.

Thank you for your help!

from scipy import signal
import torch
import numpy as np


@torch.inference_mode()
def test1():
    seq_len = 13
    a = np.random.rand(seq_len)
    b = np.random.rand(seq_len)
    c = signal.convolve(a, b, mode='full', method='direct')
    d = torch.fft.rfft(torch.from_numpy(a), n=2 * seq_len) / (2 * seq_len)
    e = torch.fft.rfft(torch.from_numpy(b), n=2 * seq_len)
    f = torch.fft.irfft(d * e, n=2 * seq_len, norm='forward').numpy()[:-1]
    print(np.allclose(c, f))  # True


@torch.inference_mode()
def test2():
    seq_len = 13
    a = np.random.rand(seq_len)
    b = np.random.rand(seq_len)
    c = signal.convolve(a, b, mode='full', method='direct')
    d = torch.fft.rfft(torch.fft.ifftshift(torch.from_numpy(a)), n=2 * seq_len) / (2 * seq_len)
    e = torch.fft.rfft(torch.fft.ifftshift(torch.from_numpy(b)), n=2 * seq_len)
    f = torch.fft.fftshift(torch.fft.irfft(d * e, n=2 * seq_len, norm='forward')).numpy()[:-1]
    print(np.allclose(c, f))  # False

@veritas9872
Copy link
Author

The PyTorch and NumPy functions produce identical results. The MATLAB implementation does seem to have been the issue.

@veritas9872
Copy link
Author

veritas9872 commented Mar 13, 2023

Another question though. Is taking the front of the resultant convolved sequence the desired behavior? I believe that the middle part, corresponding to scipy.signal.convolve(...,mode='same') may be more desirable.

The resulting code would be as follows.

seqlen = u.shape[-1]
fft_size = 2 * seqlen

 k_f = torch.fft.rfft(k, n=fft_size, norm='forward')
 u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size, norm='backward')  # Explicit norm mode for better readability.
  
 if len(u.shape) > 3: k_f = k_f.unsqueeze(1) 
 y = torch.fft.irfft(u_f * k_f, n=fft_size, norm='forward')[seqlen//2:seqlen//2+seqlen]

@Zymrael
Copy link
Contributor

Zymrael commented Mar 13, 2023

Thanks for verifying! Could you elaborate as to why that would be more desirable? If you don't take the first seqlen elements, your convolution is no longer causal. Padding is just an artifact to turn a circular convolution (for which the FFTConv method holds) into a linear convolution (which is what we want to compute) - at the output, you need to select the first elements for the result to be correct.

@veritas9872
Copy link
Author

I see that the desired result is to take only the first part of the output sequence, instead of the region with the maximum overlap. Thank you for the explanation!

@0205090923
Copy link

Hello, can you kindly explain why let L = L * 2 make it casual .. "For causal language modeling, we want a causal convolution, so we pad the input and kernel with zeros (this is what that fft_size = seqlen * 2 is doing)" where the padding operation is performed? thank you

@DanFu09
Copy link
Contributor

DanFu09 commented Sep 5, 2024 via email

@0205090923
Copy link

thank you for your explanation. Setting n=2 * L completes the input and k to a length of 2L. However, it seems to be done at the end of the sequence, which is inconsistent with the normal operation of adding 0 at the beginning and then truncating the previous seqlen in nn.conv1d. I cant understand how it ensures casual. Can you kindly explain it? Thank you

@DanFu09
Copy link
Contributor

DanFu09 commented Sep 5, 2024 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants