Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add InputError to CUDA fft() function #503

Closed
xqft opened this issue Jul 10, 2023 · 4 comments
Closed

Add InputError to CUDA fft() function #503

xqft opened this issue Jul 10, 2023 · 4 comments

Comments

@xqft
Copy link
Member

xqft commented Jul 10, 2023

The CPU and Metal fft() operation fns handle the case where an input with a non-power-of-two length is given and returns an InputError in that case. CUDA's doesn't. This code should be removed and InputError should be added to CUDA's fft() op.

@startup-dreamer
Copy link
Contributor

startup-dreamer commented Jul 16, 2023

when going through Metal's fft() operation code I found the // TODO: make a twiddle factor abstraction for handling invalid twiddles I think this can be achieved through

if input.len() != twiddles.len() {
    return Err(MetalError::InvalidTwiddles(input.len(), twiddles.len()));
}

@xqft do you think I am in the right direction here.

@startup-dreamer
Copy link
Contributor

And to validate twiddle factors

    let twiddle_length = 1 << (input.len().trailing_zeros() as usize);
    if twiddles.len() != twiddle_length {
        return Err(MetalError::InvalidTwiddleFactors(twiddles.len()));
    }

we can do something like this please let me know if i missed something.

@xqft
Copy link
Member Author

xqft commented Jul 19, 2023

Hello @startup-dreamer, sorry I just saw your comments!

The logic of your first snippet is incorrect, FFT actually uses input.len() / 2 amount of twiddle factors. For calculating the discrete transform of a set all factors (input.len() amount) are needed but because of the way FFT works it only needs half, the rest are implicitly computed in the sum and product operations.

Your second snippet is doing the same as the first (you are defining twiddle_length := 2 ^ (log2(input.len())), which is equal to input.len())

Twiddle factors actually have some mathematical meaning, they're powers of a Nth primitive root of unity of the finite field which we are operating on, where N = input.len(). In other words, if we call w the Nth primitive root of unity (you don't need to understand what a primitive root of unity is for now), then twiddle factors are w^0, w^1, w^2, w^3, ... ,w^((N-1)/2). So validating twiddle factors means that we need to check that the numbers we take as twiddle factor fulfill this definition, not only its length as you are suggesting. We could make this by defining something like a TwiddleFactors struct which we know will always contain twiddle factors when it's defined (private fields and a new() fn that makes sure of this), though there could be another approach.

You can checkout the IsFFTField, get_primitive_roots_of_unity() and get_twiddles() implementations to know more about twiddle factors. If you still want to work on this let me know and I'll make a detailed issue where we can discuss it so we don't keep offtopic-ing this one. Thanks for your interest!

@MauroToscano
Copy link
Collaborator

We are removing this cuda implementation. See #831

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants