In [86]:
import numpy as np
from sympy import ntt as sntt
from sympy import intt as sintt
from sympy.ntheory import isprime, primitive_root

In [14]:
f = np.array([1,4,0,0])
# f_hat = np.array([5, 249, 766, 522]) # para m = 769
# f_hat = np.array([0, 4, 2, 3]) # para m=5

N_t = 4

## Fermat numbers
# alpha_t = 2
# m_t = 5
# alpha = 4
# m = 17
alpha_t = 2
m_t = 257

In [15]:
np.mod(8**2, 257)

64

In [16]:
def verify_params_ntt(N, alpha, p):
    n = np.arange(1, N)
    return (np.gcd(np.power(alpha,N), p) == 1) and np.all((np.gcd(np.power(alpha,n)-1, p) == 1))

In [4]:
verify_params_ntt(4, 2, 5)

True

In [43]:
verify_params_ntt(4, 2, 257)

True

In [5]:
verify_params_ntt(4, 4, 17)

True

In [6]:
verify_params_ntt(8, 2, 257)

True

In [82]:
verify_params_ntt(256, 3, 257)

True

In [92]:
verify_params_ntt(512, 4607,  262657)

True

In [95]:
verify_params_ntt(512, 3,  262657)

True

In [96]:
verify_params_ntt(512, 5,  262657)

True

# About 2-D NTT

In [9]:
# a_b2 = np.array([[0,0,1,0, 0,0,0,0,0,0,0,0,0,0,0,0],
#                  [1,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0],
#                  [0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0],
#                  [0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0]])

# b_b2 = np.array([[0,1,0,0, 0,0,0,0,0,0,0,0,0,0,0,0],
#                  [0,0,1,1, 0,0,0,0,0,0,0,0,0,0,0,0],
#                  [0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0],
#                  [0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0]])

In [10]:
def _NTT2D(x, alpha1, alpha2, M, N, F_t, k, l):
    gsum = 0
    for m in range(M):
        for n in range(N):
            gsum +=np.mod(x[m,n] * np.power(alpha1, np.mod(m*k, M)) * np.power(alpha2, np.mod(n*l, N)), F_t)
    return np.mod(gsum, F_t)

def _INTT2D(x, alpha1, alpha2, M, N, F_t, m, n):
    gsum = 0
    inv_M = get_inv(M, F_t)
    inv_N = get_inv(N, F_t)
    for k in range(M):
        for l in range(N):
            gsum +=np.mod(inv_M * inv_N * x[k,l] * np.power(alpha1, np.mod(-m*k, M)) * np.power(alpha2, np.mod(-n*l, N)), F_t)
    return np.mod(gsum, F_t)

In [11]:
def NTT2D(x):
    L_L = []
    for k in range(x.shape[0]):
        L = []
        for l in range(x.shape[1]):
            L.append(_NTT2D(x, 2, 2, x.shape[0], x.shape[1], 257, k, l))
        L_L.append(L)
    
    return np.array(L_L)

def INTT2D(hat_x):
    L_L = []
    for m in range(hat_x.shape[0]):
        L = []
        for n in range(hat_x.shape[1]):
            L.append(_INTT2D(hat_x, 2, 2, hat_x.shape[0], hat_x.shape[1], 257, m, n))
        L_L.append(L)
    
    return np.array(L_L)

In [12]:
#INTT2D(np.multiply(NTT2D(a_b2), NTT2D(a_b2)))

In [13]:
#np.multiply(NTT2D(a_b2), NTT2D(a_b2)).shape

# Proof #NTT

In [9]:
def _NTT(f, n, N, alpha, m):
    gsum=0
    for k in range(0, N):
        gsum += f[k]*np.power(alpha, n*k)
    
    return np.mod(gsum, m)

def NTT_base(f, alpha, m):
    N = len(f)
    l_ret = []
    for n in range(0, N):
        l_ret.append(_NTT(f, n, N, alpha, m))
    
    return np.array(l_ret)

In [10]:
def NTT_slow(f, alpha, p):
    f = np.asarray(f, dtype=int)
    N = f.shape[0]
    
    n = np.arange(N)
    k = n.reshape((N, 1))
    
    if not verify_params_ntt(N, alpha, p): raise Exception('The parameters alpha, p and N are not coprimes !!')
    
    M = np.power(alpha, n*k)
    
    return np.mod(np.dot(M, f), p)

In [67]:
def FNTT(f, alpha, p):
    f = np.asarray(f, dtype=int)
    N = f.shape[0]
    
    if N % 2 > 0:
        raise ValueError("size of x must be a power of 2")
    elif N <= 32:
        return NTT_slow(f, alpha, p)
    else:
        F_even = FNTT(f[::2], alpha, p)
        F_odd = FNTT(f[1::2], alpha, p)
        
        M = np.power(alpha, np.arange(N))
        
#         print(M)
        
        return np.concatenate([F_even + M[:N//2] * F_odd,
                               F_even + M[N//2:] * F_odd])

In [70]:
# FNTT(x_test, 5, 262657)

In [71]:
x_test = np.random.randint(low=0, high=100000, size=512)

In [72]:
%timeit sntt(x_test, 262657)
%timeit NTT_base(x_test, 5, 262657)
%timeit NTT_slow(x_test, 5, 262657)
%timeit FNTT(x_test, 5, 262657)

1.68 ms ± 10.3 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


  gsum += f[k]*np.power(alpha, n*k)


381 ms ± 204 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
9.36 ms ± 8.45 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
427 µs ± 912 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [73]:
print(np.allclose(NTT_base(x_test, 5, 262657), sntt(x_test, 262657)))
print(np.allclose(NTT_base(x_test, 5, 262657), NTT_slow(x_test, 5, 262657)))
print(np.allclose(NTT_base(x_test, 5, 262657), FNTT(x_test, 5, 262657)))

print(np.allclose(sntt(x_test, 262657), FNTT(x_test, 5, 262657)))

  gsum += f[k]*np.power(alpha, n*k)


False
True
False
False


In [74]:
x_test_hat = NTT_slow(x_test, 5, 262657)

In [20]:
%timeit sntt(f, 5)
%timeit NTT_base(f, 2, 5)
%timeit NTT_slow(f, 2, 5)

15.6 µs ± 110 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
17.3 µs ± 72.5 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
13.6 µs ± 45.1 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
True
True


In [None]:
print(np.allclose(NTT_base(f, 2, 5), sntt(f, 5)))
print(np.allclose(NTT_base(f, 2, 5), NTT_slow(f, 2, 5)))

In [18]:
f_hat = NTT_slow(f, 2, 5)

In [19]:
f_hat

array([0, 4, 2, 3])

# Proof # INTT

In [76]:
def _INTT(f_h, k, N, alpha, m):
    inv_N = get_inv(N, m)
    gsum=0
    for n in range(0, N):
        gsum += inv_N * f_h[n]*np.power(alpha, np.mod(-n*k, N))
    
    return np.mod(gsum, m)

def INTT_base(f_h, alpha, m):
    N = len(f_h)
    l_ret = []
    for k in range(0, N):
        l_ret.append(_INTT(f_h, k, N, alpha, m))
    
    return np.array(l_ret)

In [77]:
def INTT_slow(f, alpha, p):
    f = np.asarray(f, dtype=int)
    N = f.shape[0]
    inv_N = get_inv(N, p)
    
    n = np.arange(N)
    k = n.reshape((N, 1))
    
    if not verify_params_ntt(N, alpha, p): raise Exception('The parameters alpha, p and N are not coprimes !!')
    
    M = np.power(alpha, np.mod(-n*k, N))
    
    return np.mod(inv_N * np.dot(M, f), p)

In [24]:
%timeit sintt(f_hat, 5)
%timeit INTT_base(f_hat, 2, 5)
%timeit INTT_slow(f_hat, 2, 5)

17.5 µs ± 54.7 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
37.9 µs ± 109 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
18.3 µs ± 34.6 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
True
True
True


In [39]:
print(np.allclose(sintt(f_hat, 5), f))
print(np.allclose(INTT_base(f_hat, 2, 5), f))
print(np.allclose(INTT_slow(f_hat, 2, 5), f))

True
True
True


In [78]:
%timeit sintt(x_test_hat, 262657)
%timeit INTT_base(x_test_hat, 5, 262657)
%timeit INTT_slow(x_test_hat, 5, 262657)

1.79 ms ± 9.86 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


  gsum += inv_N * f_h[n]*np.power(alpha, np.mod(-n*k, N))


560 ms ± 2.06 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
4.59 ms ± 6.91 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [38]:
print(np.allclose(sintt(x_test_hat, 262657), x_test))
print(np.allclose(INTT_base(x_test_hat, 5, 262657), x_test))
print(np.allclose(INTT_slow(x_test_hat, 5, 262657), x_test))

print(np.allclose(INTT_base(x_test_hat, 5, 262657), sintt(x_test_hat, 262657)))
print(np.allclose(INTT_base(x_test_hat, 5, 262657), INTT_slow(x_test_hat, 5, 262657)))

False


  gsum += inv_N * f_h[n]*np.power(alpha, np.mod(-n*k, N))


False
False
False
True


In [26]:
f_n = INTT_slow(f_hat, 2, 5)

In [27]:
f_n

array([1, 4, 0, 0])

## Test 2D matrix

In [28]:
a_b2 = np.array([[0,1,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0],
                 [0,0,0,1, 0,0,0,0, 0,0,0,0, 0,0,0,0],
                 [0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0],
                 [0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0],
                 [0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0],
                 [0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0],
                 [0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0],
                 [0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0],
                 [0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0],
                 [0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0],
                 [0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0],
                 [0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0],
                 [0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0],
                 [0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0],
                 [0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0],
                 [0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0]])

b_b2 = np.array([[0,0,1,0, 0,0,0,0, 0,0,0,0, 0,0,0,0],
                 [1,1,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0],
                 [0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0],
                 [0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0],
                 [0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0],
                 [0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0],
                 [0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0],
                 [0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0],
                 [0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0],
                 [0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0],
                 [0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0],
                 [0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0],
                 [0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0],
                 [0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0],
                 [0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0],
                 [0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0]])

In [29]:
alpha_t = 2
m_t = 257
#verify_params_ntt(8, 2, 257)

In [47]:
65536%257

1

In [30]:
a_b2_x = np.apply_along_axis(NTT_slow, 1, a_b2, alpha_t, m_t)

In [31]:
a_b2_xy = np.apply_along_axis(NTT_slow, 0, a_b2_x, alpha_t, m_t)

In [32]:
b_b2_x = np.apply_along_axis(NTT_slow, 1, b_b2, alpha_t, m_t)

In [33]:
b_b2_xy = np.apply_along_axis(NTT_slow, 0, b_b2_x, alpha_t, m_t)

In [34]:
c_h = np.multiply(a_b2_xy, b_b2_xy)

In [35]:
c_x = np.apply_along_axis(INTT_slow, 1, c_h, alpha_t, m_t)

In [36]:
c_xy = np.apply_along_axis(INTT_slow, 0, c_x, alpha_t, m_t)

In [37]:
c_xy

array([[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])

# Proofs

## Proof #1
$$\hat{f}^{\prime}_{n} = \sum^{N-1}_{k=1}{(kf_{k}) \alpha^{n(k-1)}} \mod{m}, \quad n=0, \ldots, N-1$$

In [None]:
for n in range(N_t):
    sum_o = 0
    for k in range(1,N_t):
        sum_o += np.mod(k * f[k] * np.power(alpha_t, n*(k-1)), m_t)
    print(f'n={n} f_hat_p={sum_o}')

## Proof #2

$$\hat{f}^{\prime}_{n} = \sum^{N-1}_{k=1}{k \left(N^{-1} \sum_{l=0}^{N-1} \hat{f}_{l} \alpha^{-l k}\right) \alpha^{n(k-1)}} \mod{m}, \quad n=0, \ldots, N-1$$

In [None]:
N_inv_t = get_inv(N_t, m_t)

for n in range(N_t):    
    sum_k = 0
    for k in range(1,N_t):
        
        sum_l = 0
        for l in range(N_t):
            sum_l += f_hat[l]*np.power(alpha_t, -l*k%N_t ) ## change 1
        sum_l *= N_inv_t ## change 2
        
        sum_k += k*sum_l*np.power(alpha_t, n*(k-1))
    sum_k = np.mod(sum_k, m_t)

    print(f'n={n} f_hat_p={sum_k}')

## Proof #3

$$\hat{f}^{\prime}_{n} = N^{-1}\sum_{l=0}^{N-1} \hat{f}_{l}\sum^{N-1}_{k=1} k \alpha^{k(n-l)}\alpha^{-n} \mod{m}, \quad n=0, \ldots, N-1$$

In [None]:
for n in range(N_t):    
    
    sum_l = 0
    for l in range(N_t):
        
        sum_k = 0
        for k in range(1,N_t):
            sum_k += k*np.power(alpha_t, (k*(n-l))%N_t ) ## change 1

        sum_l += f_hat[l]*sum_k
    
    sum_l = N_inv_t*sum_l*np.power(alpha_t, -n%N_t) ## change 2 and 3
    sum_l = np.mod(sum_l, m_t)

    print(f'n={n} f_hat_p={sum_l}')

## Proof #4

$$\hat{f}^{\prime}_{n} =N^{-1}\sum_{l=0}^{N-1} \hat{f}_{l}\alpha^{-n}  \sum^{N-1}_{k=1} k \alpha^{k(n-l)}  \mod{m}, \quad n=0, \ldots, N-1$$

In [None]:
for n in range(N_t):    
    
    sum_l = 0
    for l in range(N_t):
        
        sum_k = 0
        for k in range(1,N_t):
            sum_k += k*np.power(alpha_t, (k*(n-l))%N_t ) ## change 1
        
        # print(np.mod(sum_k,m))

        sum_l += f_hat[l]*sum_k
    
    sum_l = N_inv_t*np.power(alpha_t, -n%N_t)*sum_l ##change 2 and 3
    
    sum_l = np.mod(sum_l, m_t)

    print(f'n={n} f_hat_p={sum_l}')

## Proof #5

\begin{equation*}
\sum^{N-1}_{k=1} k \alpha^{k(n-l)} = \begin{cases}
  2^{-1}\left(N-1\right)N  &\text{ if } l=n  \\
  \frac{N}{\left(\alpha^{n-l} -1\right)} &\text{ if } l\neq n,
\end{cases}
\end{equation*}

In [None]:
inv_two = get_inv(2, m_t)

In [None]:
## Left 
for n in range(N_t):
    for l in range(N_t):
        sum_k=0
        for k in range(1,N_t):
            sum_k += k* np.power(alpha_t, (k*(n-l))%N_t )

        print(f'l={l},k={n}, sum= {np.mod(sum_k,m_t)}')
    print('')

In [None]:
## Right
for n in range(N_t):
    for l in range(N_t):
        if l==n:
            sum_k= inv_two*(N_t-1)*N_t
        else:
            pow_a = np.mod((n-l), N_t) ## only positive exponents
            a = np.power(alpha_t, pow_a)
            a_inv = get_inv(a-1, m_t)
            
            sum_k = N_t * a_inv


        print(f'l={l},l={n}, sum= {np.mod(sum_k,m_t)}')
    print('')

# Building the T_prime matrix

In [None]:
l_l = []
for n in range(N_t):
    n_l = []
    for l in range(N_t):
        print(n,l)
        if l==n:
            t_l_n= inv_two*(N_t-1)*np.power(alpha_t,-n%N_t)
        else:
            pow_a = np.mod((n-l), N_t) ## only positive exponents
            a = np.power(alpha_t, pow_a)
            a_inv = get_inv(a-1, m_t)
            
            t_l_n = np.power(alpha_t, -n%N_t)*a_inv

        n_l.append(t_l_n)
        # print(f'l={l},l={n}, sum= {t_l_n}')
    l_l.append(n_l)

In [None]:
T_prime = np.mod(np.array(l_l),m_t)

# f_h = np.array(f_h)

np.mod(np.dot(T_prime, f_hat),m_t)

In [None]:
T_prime

# Prime Matrix

In [None]:
N_t = 4

n = np.arange(N_t)
#l = np.arange(N_t)
l = n.reshape((N_t, 1))

alpha_t = 2
m_t = 5

vget_inv = np.vectorize(get_inv)

In [None]:
l_l = []
for n in range(N_t):
    n_l = []
    for l in range(N_t):
        print(n,l)
        if l==n:
            t_l_n= inv_two*(N_t-1)*np.power(alpha_t,-n%N_t)
        else:
            pow_a = np.mod((n-l), N_t) ## only positive exponents
            a = np.power(alpha_t, pow_a)
            a_inv = get_inv(a-1, m_t)
            
            t_l_n = np.power(alpha_t, -n%N_t)*a_inv

        n_l.append(t_l_n)
        # print(f'l={l},l={n}, sum= {t_l_n}')
    l_l.append(n_l)

In [None]:
vget_inv(np.power(alpha_t, np.mod((n-l), N_t))-1, m_t)

In [None]:
n

In [None]:
def f(r,x):
    return np.where(x<0.5,2*r*x, 2*r*(1-x))

In [None]:
# alpha_t = 2
# m_t = 5
def ff(n,l):
    return np.piecewise([n,n],

                 [l==n,
                 l!=n],

                 [vget_inv(2, m_t)*(N_t-1)*np.power(alpha_t,np.mod(-n, N_t)),
                  1*np.power(alpha_t,np.mod(-n, N_t))]

                )

In [None]:
ff(n,n)

In [None]:
def Tprime(n,l,N,inv_two,p,alpha):
    return np.where(l==n,
                    inv_two*(N-1)*np.power(alpha,np.mod(-n, N)),
                    get_inv(np.power(alpha, np.mod((n-l), N))-1, p)*np.power(alpha,np.mod(-n, N)))

In [None]:
def PrimeNTT(f_hat, alpha, p):
    f = np.asarray(f_hat, dtype=int)
    N = f.shape[0]
    
    inv_two = get_inv(2, m_t)
    
    n = np.arange(N)
    k = n.reshape((N, 1))
    
    T_prime = Tprime(n,k,N,inv_two,p,alpha)
    
    return np.mod(np.dot(T_prime, f_hat), p)

In [None]:
#PrimeNTT(f_hat, 2, 5)

In [2]:
from sympy.ntheory import isprime, primitive_root


def generate_twiddle_factors(n, q):
    # Produces `n` omegas (or twiddle factors) 
    # given the generator: x^b (mod q) of the 
    # prime field of q.
    assert isprime(q)
    
    x = primitive_root(q)
    
    # Applying Dirichlet's theorem, 
    # we have: q = bn + 1.
    b = (q - 1) // n
    
    omega = (x ** b) % q

    omegas = [1]
    for i in range(n):
        # Multiply (mod q) by the previous value.
        omegas.append((omegas[i] * omega) % q)

    return omegas[:n]  # Drop the last, needless value.


In [6]:
generate_twiddle_factors(16, 257)

[1, 249, 64, 2, 241, 128, 4, 225, 256, 8, 193, 255, 16, 129, 253, 32]

In [7]:
def naive_ntt(a, q, omegas):
    n = len(a)
    out = [0] * n
    
    for i in range(n):
        for j in range(n):
            out[i] = (out[i] + a[j] * omegas[(i * j) % n]) % q
    return out

In [22]:
oms = generate_twiddle_factors(4, 5)

In [23]:
naive_ntt(f, 5, oms)

[0, 4, 2, 3]

In [24]:
f_hat

array([0, 4, 2, 3])

In [27]:
import math


def cooley_tukey_ntt_opt(a, n, q, phis):
    """Cooley-Tukey DIT algorithm with an extra optimization.
    We can avoid computing bit reversed order with each call by
    pre-computing the phis in bit-reversed order.
    Requires:
     `phis` are provided in bit-reversed order.
     `n` is a power of two.
     `q` is equivalent to `1 mod 2n`.
    Reference:
       https://www.microsoft.com/en-us/research/wp-content/
       uploads/2016/05/RLWE-1.pdf
    """

    assert q % (2 * n) == 1, f'{q} is not equivalent to 1 mod {2 * n}'
    assert (n & (n - 1) == 0) and n > 0, f'n: {n} is not a power of 2.'

    t = n
    m = 1
    while m < n:
        t >>= 1
        for i in range(0, m):
            j1 = i * (t << 1)
            j2 = j1 + t - 1
            S = phis[m + i]
            for j in range(j1, j2 + 1):
                U = a[j]
                V = a[j + t] * S
                a[j] = (U + V) % q
                a[j + t] = (U - V) % q
        m <<= 1
    return a


def gentleman_sande_intt_opt(a, n, q, inv_phis):
    """Gentleman-Sande INTT butterfly algorithm.
    Assumes that inverse phis are stored in bit-reversed order.
    Reference:
       https://www.microsoft.com/en-us/research/wp-content/
       uploads/2016/05/RLWE-1.pdf
    """
    t = 1
    m = n
    while (m > 1):
        j1 = 0
        h = m >> 1
        for i in range(h):
            j2 = j1 + t - 1
            S = inv_phis[h + i]
            for j in range(j1, j2 + 1):
                U = a[j]
                V = a[j + t]
                a[j] = (U + V) % q
                a[j + t] = ((U - V) * S) % q
            j1 += (t << 1)
        t <<= 1
        m >>= 1

    shift_n = int(math.log2(n))
    return [(i >> shift_n) % q for i in a]

def get_bit_reversed(c, n, q):
    cc = c.copy()
    for i in range(n):
        rev_i = reverse_bits(i, n.bit_length() - 1)
        if rev_i > i:
            cc[i], cc[rev_i] = cc[rev_i], cc[i]

    return cc


def gen_phis(omegas, q):
    def legendre(x, q):
        return pow(x, (q - 1) // 2, q)

    def tonelli_shanks(x, q):
        # Finds the `sqrt(x) mod q`.
        # Source: https://rosettacode.org/wiki/Tonelli-Shanks_algorithm
        Q = q - 1
        s = 0
        while Q % 2 == 0:
            Q //= 2
            s += 1
        if s == 1:
            return pow(x, (q + 1) // 4, q)
        for z in range(2, q):
            if q - 1 == legendre(z, q):
                break
        c = pow(z, Q, q)
        r = pow(x, (Q + 1) // 2, q)
        t = pow(x, Q, q)
        m = s
        t2 = 0
        while (t - 1) % q != 0:
            t2 = (t * t) % q
            for i in range(1, m):
                if (t2 - 1) % q == 0:
                    break
                t2 = (t2 * t2) % q
            b = pow(c, 1 << (m - i - 1), q)
            r = (r * b) % q
            c = (b * b) % q
            t = (t * c) % q
            m = i
        return r

    return [tonelli_shanks(x, q) for x in omegas]

In [26]:
cooley_tukey_ntt(f, 5, oms)

array([0, 4, 2, 3])

In [32]:
oms = generate_twiddle_factors(512, 262657)

In [33]:
gen_phis(oms, 262657)

ValueError: negative shift count

In [None]:
%timeit sntt(x_test, 262657)
%timeit NTT_base(x_test, 5, 262657)
%timeit NTT_slow(x_test, 5, 262657)

In [None]:
print(np.allclose(NTT_base(x_test, 5, 262657), sntt(x_test, 262657)))
print(np.allclose(NTT_base(x_test, 5, 262657), NTT_slow(x_test, 5, 262657)))