In [1]:
using LinearAlgebra
using PlotlyJS
using FFTW
using BenchmarkTools
using Kronecker
using Primes

In [2]:
vecsize = 9
x = randn(vecsize) + 1im * randn(vecsize);

In [3]:
# Functions to build a DFT matrix for prime sizes

function dft_matrix(n)
    """ 
    Computes the Discrete Fourier Transform matrix of size n.
    Based on algorithm 1.16 from
    Van Loan, C. (1992). Computational frameworks for the fast Fourier transform.
    
    Input: 
    n: (integer)
    
    Returns:
    F: (n x n complex matrix)
    """
    
    # Base cases
    if n == 1
        return 1
    end
    
    F = ones(ComplexF64, n, n)
    F[1, 2] = 1
    
    for p = 1:n-1
        F[p + 1, 2] = exp(-2 * pi * 1im * p/n)
    end
    
    for q = 2:n-1
        F[:, q + 1] = F[:, q] .* F[:, 2]
    end

    return F
end

function naive_dft(x)
    """ 
    Computes the Discrete Fourier Transform of an input x.
    
    Input: 
    x: (vector)
    
    Returns:
    y: (vector)
    """
    
    return dft_matrix(length(x)) * x
end

function naive_idft_unscaled(y)
    """ 
    Computes the inverse Discrete Fourier Transform of an input y (unscaled).
    
    Input: 
    y: (vector)
    
    Returns:
    x: (vector)
    """
    
    return (dft_matrix(length(y))' * y) 
end

naive_idft_unscaled (generic function with 1 method)

In [4]:
# Function to perform Mixed Radix recursive FFT  
function genfft_base(x, n, inverse = false)
    """ 
    Computes the Radix recursive FFT/IFFT for a vector with general size
    based on Algorithm present 
    in Section 2.1.4 
    from Van Loan, C. (1992).
    """

    pwr_sign = inverse ? 1 : -1
    
    omega = exp(pwr_sign * 2im * π / n)

    # Base case: Use DFT matrix multiplication if n is prime
    if isprime(n)
        return inverse ? naive_idft_unscaled(x) : naive_dft(x)
    end

    # Find smallest factor p of n
    p = first(filter(d -> n % d == 0, 2:n))
    m = n ÷ p  # Compute m such that n = p * m

    # Twiddle factor vector 
    Omega_vec = [omega^j for j in 0:m-1]  

    # Step 1: Compute FFTs along p smaller groups
    z = similar(x, ComplexF64)  
    for j in 0:p-1
        z[j*m+1:(j+1)*m] .= (Omega_vec .^ j) .* genfft_base(x[j+1:p:end], m, inverse)  # Element-wise multiplication
    end

    # Step 2: Compute FFTs along m groups
    y = similar(x, ComplexF64)  
    for j in 0:m-1
        y[j+1:m:end] .= genfft_base(z[j+1:m:end], p, inverse)  # Vectorized slicing
    end

    return y
end

function genfft(x)
    n = length(x)
    return genfft_base(x, n)
end

function genifft(x)
    n = length(x)
    return genfft_base(x, n, true) / n
end

genifft (generic function with 1 method)

In [5]:
norm(fft(x) - genfft(x), Inf)/norm(fft(x), Inf)

4.665379181428376e-16

In [6]:
norm(x - genifft(genfft(x)), Inf)/norm(x, Inf)

3.9944749813616255e-16