# Four-step FFT
Four-step FFT(also known as Bailey's FFT)  is a high-performance algorithm for computing the fast Fourier transform (FFT). This variation of the Cooley–Tukey FFT algorithm was originally designed for systems with hierarchical memory common in modern computers. The algorithm treats the samples as a two dimensional matrix (thus yet another name, a matrix FFT algorithm) and executes short FFT operations on the columns and rows of the matrix

## 1. How Four-step FFT work
Here is a brief overview of how the "4-step" version of the Bailey FFT algorithm works:<br>
First: The data (in natural order) is first arranged into a matrix.<br>
Second: Each column of a matrix is then independently processed using a standard FFT algorithm<br>
Third: Each element of a matrix is multiplied by a correction coefficient(called twiddle factors)<br>
Fourth: Each row of a matrix is then independently processed using a standard FFT algorithm<br>
The following pictures show the process of this algorithm:<br>
<img src="workflow of 4-step_FFT.png" alt="Example Image" style="background-color: #f0f0f0; width:400px; display: block; margin: auto;">

## 2. How the Four-Step FFT Enhances Performance and Efficiency 
Take a data of size $2^{10}$ for example, using the Cooley-Tukey FFT algorithm requires $2^{10}*10$ operations. If use Four-step algorithm, we can transfer it to a matrix($2^{5}$ column,$2^{5}$ row), each row and each column run fft algorithm, then should take $2^{5}*2^{5}*5*2=2^{10}*10$ operations. However, we must add to this a further more operations to perform the scaling operation. It seems that the Four-Step FFT requires a bit more computational steps than the conventional FFT. So, what are the advantages of the Four-Step FFT?

Despite having the same computational complexity, Four-Step FFT offers several practical advantages due to its memory access patterns and hardware compatibility:
* The Four-Step FFT transforms a large-scale FFT into two smaller FFTs (row FFT and column FFT), each operating on only a portion of the matrix at a time. This reduces random memory access compared to conventional FFT, where butterfly operations involve data spread across the entire input
* Row FFTs and column FFTs are independent and can be computed in parallel. Modern hardware like FPGAs and GPUs can leverage this parallelism to compute multiple FFTs simultaneously
* Instead of storing the entire FFT input/output, Four-Step FFT processes smaller chunks (e.g., rows or columns of the matrix) that fit into the limited on-chip RAM<br>

Based on these advantages, the Four-Step FFT offers significant performance improvements on hardware, particularly FPGA or GPU, due to its compatibility with the hardware's architecture and memory hierarchy

## 3. Code implementation of Four-step FFT
we will use python code to show the process of this algorithm.

### 3.1 Finte Field
We perform FFT calculations in a finite field (rather than the complex number field), using a primitive root in the finite field (analogous to the roots of unity in the complex field). Therefore, this FFT implementation in the finite field is also known as the Number Theoretic Transform (NTT)

In [1]:
import numpy as np

class Field:
    # basic operations of finite domains
    def __init__(self, modulus):
        self.modulus = modulus

    def add(self, x, y):
        return (x + y) % self.modulus

    def sub(self, x, y):
        return (x - y) % self.modulus

    def mul(self, x, y):
        return (x * y) % self.modulus

    def pow(self, x, exp):
        return pow(x, exp, self.modulus)

    def inv(self, x):
        return pow(x, self.modulus - 2, self.modulus)

    def roots_of_unity(self, n):
        root = self.pow(11, (self.modulus - 1) // n)  # suppose 11 is the primitive root
        return [self.pow(root, i) for i in range(n)]

## 3.2 Forward and inverse NTT
The following describes the content of the NTT algorithm, including the forward and inverse computations of DIT and DIF, as well as some methods required for matrix operations.
These operation modules will be used in four-step FFT.

In [2]:
class NTT:
    # generate forward and inverse roots, bit-reverse, dit and dif FFT, forward and inverse dit or dif.
    def __init__(self, modulus, n):
        self.gf = Field(modulus)
        self.n = n
    
    def get_forward_roots(self,n):
        return self.gf.roots_of_unity(n)
    
    def get_inverse_roots(self,n):
        forward_roots=self.gf.roots_of_unity(n)        
        return [self.gf.inv(r) for r in forward_roots]

    def bit_reversed_indices(self, n):
        logn = n.bit_length() - 1
        return [int(f"{i:0{logn}b}"[::-1], 2) for i in range(n)]

    def bit_reverse(self, a):
        n = len(a)
        indices = self.bit_reversed_indices(n)
        return [a[i] for i in indices]

    def dit(self, a, roots):
        n = len(a)
        a = self.bit_reverse(a)
        logn = n.bit_length() - 1
        for s in range(1, logn + 1):
            m = 1 << s
            wm = roots[n//m]
            for k in range(0, n, m):
                w = 1
                for j in range(m // 2):
                    u = a[k + j]
                    v = self.gf.mul(w, a[k + j + m // 2])
                    a[k + j] = self.gf.add(u, v)
                    a[k + j + m // 2] = self.gf.sub(u, v)
                    w = self.gf.mul(w, wm)
        return a

    def dif(self, a, roots):
        n = len(a)
        logn = n.bit_length() - 1
        for s in range(logn, 0, -1):
            m = 1 << s
            wm = roots[n//m]
            for k in range(0, n, m):
                w = 1
                for j in range(m // 2):
                    u = a[k + j]
                    v = a[k + j + m // 2]
                    a[k + j] = self.gf.add(u, v)
                    a[k + j + m // 2] = self.gf.mul(w, self.gf.sub(u, v))
                    w = self.gf.mul(w, wm)
        return self.bit_reverse(a)

    def forward_dit(self, a):
        roots=self.get_forward_roots(len(a))
        return self.dit(a,roots)

    def inverse_dit(self, a):
        inverse_roots=self.get_inverse_roots(len(a))
        a = self.dit(a, inverse_roots)
        n_inv = self.gf.inv(len(a))
        return [self.gf.mul(x, n_inv) for x in a]

    def forward_dif(self, a):
        roots=self.get_forward_roots(len(a))
        return self.dif(a, roots)

    def inverse_dif(self, a):
        inverse_roots=self.get_inverse_roots(len(a))
        a = self.dif(a, inverse_roots)
        n_inv = self.gf.inv(len(a))
        return [self.gf.mul(x, n_inv) for x in a]
    
    def matrix(self, a, log_rows, log_cols): 
        # transfer array into matrix
        rows = 1 << log_rows
        cols = 1 << log_cols
        return np.array(a).reshape((rows, cols))

    def transpose_and_flatten(self, matrix):
        # Transpose the matrix, and flatten it
        return [element for row in matrix.T for element in row]

    def apply_twiddles(self, wm, matrix):
        # each matrix[i,j] mul wm^(i*j), wm is the root of n-domain
        n, m = matrix.shape
        for i in range(n):
            for j in range(m):
                factor = self.gf.pow(wm, i * j)
                matrix[i, j] = self.gf.mul(matrix[i, j], factor)

    def apply_column_fft(self, matrix):
        # do fft for each column in matrix
        n_rows, n_cols = matrix.shape  
        for j in range(n_cols):  
            column = matrix[:, j].tolist()  
            fft_result = self.forward_dit(column)  
            matrix[:, j] = fft_result  

    def apply_row_fft(self,matrix):
        # do fft for each row in matrix
        for i in range(matrix.shape[0]):
            matrix[i] = self.forward_dit(matrix[i].tolist())

Here is the test code for ntt algorithm:

In [3]:
def test_ntt():
    # test for ffts: forward_dit, inverse_dit, forward_dif, inverse_dif.
    modulus = 17 
    input_array = [1,2,3,4,5,6,7,8] 
    n = len(input_array)         
    ntt = NTT(modulus, n)
    forward_roots=ntt.get_forward_roots(n)
    print("forward_roots:", forward_roots)
    inverse_roots=ntt.get_inverse_roots(n)
    print("inverse_roots:", inverse_roots)

    # test for forward fft
    forward_result_dit = ntt.forward_dit(input_array[:]) 
    forward_result_dif = ntt.forward_dif(input_array[:]) 
    assert forward_result_dit==forward_result_dif,"forward_result_dit is not equal to forward_result_dif!"
    print("forward_result is:", forward_result_dit)

    # test for inverse fft
    inverse_result_dit = ntt.inverse_dit(input_array[:]) 
    inverse_result_dif = ntt.inverse_dif(input_array[:]) 
    assert inverse_result_dit==inverse_result_dif,"inverse_result_dit is not equal to inverse_result_dif!"
    print("inverse_result is:", inverse_result_dit)  

    # test if it can be restored to the original input
    result_back=ntt.inverse_dit(forward_result_dit)
    print("inverse back result is:", result_back)  

    assert result_back == input_array,"NTT test Failed!"
    print("NTT tests passed!")

if __name__ == "__main__":
    test_ntt()

forward_roots: [1, 2, 4, 8, 16, 15, 13, 9]
inverse_roots: [1, 9, 13, 15, 16, 8, 4, 2]
forward_result is: [2, 8, 14, 6, 13, 3, 12, 1]
inverse_result is: [13, 15, 10, 11, 8, 5, 6, 1]
inverse back result is: [1, 2, 3, 4, 5, 6, 7, 8]
NTT tests passed!


## 3.3 Four-step FFT


In [4]:
def four_step(array, log_rows,modulus):
    n = len(array)
    logn = n.bit_length() - 1
    log_cols = logn - log_rows
    assert log_rows > 0
    assert log_cols > 0
    assert modulus > n

    gf = Field(modulus)
    ntt = NTT(modulus, n)

    # first step: transfer the array into matrix
    matrix = ntt.matrix(array, log_cols, log_rows)
    print("origin matrix is:",matrix)

    # second step: do FFT for each column
    ntt.apply_column_fft(matrix)
    print("after column fft, matrix is:",matrix)

    # third step: apply twiddles wm^(i*j)
    wm = ntt.get_forward_roots(n)[1]
    ntt.apply_twiddles(wm, matrix)

    # fourth step: do FFT for each row
    ntt.apply_row_fft(matrix)
    print("after row fft, matrix is:",matrix)    

    # Transpose the matrix, and flatten it into array
    out_array = ntt.transpose_and_flatten(matrix)
    print("after transpose and flatten, array is:",out_array)   
    return out_array

Here is the test code for four-step FFT algorithm:

In [5]:
def test_four_step_fft():
    # a small example
    modulus = 17  # test example take 17 as modulus
    n = 16        
    log_rows = 2  # matrix shape is 4*4
    gf = Field(modulus)
    ntt = NTT(modulus, n)
    input_array = list(range(1, n + 1))  # [1, 2, ..., 16]

    # test four step fft
    four_step_result = four_step(input_array, log_rows,modulus)
    print("Four-Step FFT result:", four_step_result)

    # Test the result is consistent with directly performing fft
    direct_result = ntt.forward_dit(input_array)
    print("Direct FFT result is:", direct_result)
    assert four_step_result == direct_result, "Four-Step FFT failed !"
    print("Four-Step FFT tests passed!")

if __name__ == "__main__":
    test_four_step_fft()

origin matrix is: [[ 1  2  3  4]
 [ 5  6  7  8]
 [ 9 10 11 12]
 [13 14 15 16]]
after column fft, matrix is: [[11 15  2  6]
 [11 11 11 11]
 [ 9  9  9  9]
 [ 7  7  7  7]]
after row fft, matrix is: [[ 0 11  9  7]
 [ 5 15 10 14]
 [16 12  6  2]
 [ 4  8  3 13]]
after transpose and flatten, array is: [0, 5, 16, 4, 11, 15, 12, 8, 9, 10, 6, 3, 7, 14, 2, 13]
Four-Step FFT result: [0, 5, 16, 4, 11, 15, 12, 8, 9, 10, 6, 3, 7, 14, 2, 13]
Direct FFT result is: [0, 5, 16, 4, 11, 15, 12, 8, 9, 10, 6, 3, 7, 14, 2, 13]
Four-Step FFT tests passed!
