In [68]:
from sympy.ntheory import isprime, primitive_root, nthroot_mod
from sympy.core.compatibility import as_int, iterable
from sympy.utilities.iterables import ibin

In [10]:
def _number_theoretic_transform(seq, prime, inverse=False):
    """Utility function for the Number Theoretic Transform"""

    if not iterable(seq):
        raise TypeError("Expected a sequence of integer coefficients "
                        "for Number Theoretic Transform")

    p = as_int(prime)
    if not isprime(p):
        raise ValueError("Expected prime modulus for "
                        "Number Theoretic Transform")

    a = [as_int(x) % p for x in seq]

    n = len(a)
    if n < 1:
        return a

    b = n.bit_length() - 1
    if n&(n - 1):
        b += 1
        n = 2**b

    if (p - 1) % n:
        raise ValueError("Expected prime modulus of the form (m*2**k + 1)")

    a += [0]*(n - len(a))
    for i in range(1, n):
        j = int(ibin(i, b, str=True)[::-1], 2)
        if i < j:
            a[i], a[j] = a[j], a[i]

    pr = primitive_root(p)

    rt = pow(pr, (p - 1) // n, p)
    if inverse:
        rt = pow(rt, p - 2, p)

    w = [1]*(n // 2)
    for i in range(1, n // 2):
        w[i] = w[i - 1]*rt % p

    h = 2
    while h <= n:
        hf, ut = h // 2, n // h
        for i in range(0, n, h):
            for j in range(hf):
                u, v = a[i + j], a[i + j + hf]*w[ut * j]
                a[i + j], a[i + j + hf] = (u + v) % p, (u - v) % p
        h *= 2

    if inverse:
        rv = pow(n, p - 2, p)
        a = [x*rv % p for x in a]

    return a

In [17]:
def ntt(seq, prime):
    return _number_theoretic_transform(seq, prime=prime)


def intt(seq, prime):
    return _number_theoretic_transform(seq, prime=prime, inverse=True)

In [12]:
a = [0, 0, 0, 0, 0, 0, 2, 0, 1, 0, 0, 0, 2, 1, 2, 0, 2, 0, 0, 2, 0, 1, 2, 1, 0, 2, 0, 0, 0, 0, 2, 0, 0, 2, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 0, 0, 0, 0, 1, 1, 2, 0, 1, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0, 2, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 2, 1, 0, 0, 0, 0, 0, 1, 2, 0, 1, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 2, 0, 0, 2, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 2, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 1, 0, 0, 0, 2, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 2, 0, 0, 1, 0, 0, 2, 0, 0, 1, 0, 1, 0, 0, 0, 0, 2, 0, 2, 0, 0, 1, 0, 0, 2, 2, 0, 2, 0, 2, 0, 0, 0, 0, 2, 2, 0, 2, 2, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 2, 1, 0, 0, 0, 0, 0, 1, 0, 2, 0, 1, 0, 0, 1, 1, 0, 2, 2, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 2, 1, 0, 0, 2, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 2, 0, 0, 0, 2, 0, 2, 2, 0, 0, 0, 0, 0, 2, 0, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
b = [3, 4, 1, 2, 3, 1, 3, 3, 3, 3, 2, 1, 4, 4, 3, 3, 2, 2, 2, 3, 2, 1, 2, 3, 2, 2, 4, 0, 4, 2, 2, 2, 0, 2, 4, 1, 0, 0, 2, 1, 1, 3, 1, 3, 2, 3, 2, 1, 3, 1, 4, 3, 0, 2, 3, 0, 2, 2, 3, 0, 1, 3, 2, 4, 2, 3, 0, 1, 4, 1, 3, 1, 2, 2, 4, 1, 1, 0, 4, 0, 4, 3, 3, 0, 0, 2, 2, 3, 3, 2, 3, 0, 1, 4, 3, 2, 2, 1, 2, 0, 2, 1, 2, 3, 1, 3, 3, 2, 4, 2, 0, 2, 1, 2, 0, 3, 3, 3, 3, 1, 4, 2, 2, 2, 4, 4, 3, 2, 1, 1, 1, 3, 0, 3, 2, 1, 4, 0, 0, 2, 2, 2, 3, 1, 2, 3, 0, 2, 3, 4, 1, 1, 3, 3, 3, 2, 1, 1, 1, 2, 4, 4, 1, 3, 4, 2, 4, 1, 4, 2, 4, 3, 1, 3, 4, 2, 3, 1, 2, 4, 4, 4, 0, 4, 3, 2, 3, 3, 4, 1, 2, 1, 0, 4, 1, 4, 3, 3, 0, 1, 2, 3, 2, 1, 4, 3, 3, 3, 2, 2, 3, 0, 3, 2, 3, 3, 2, 1, 4, 3, 4, 0, 3, 1, 3, 3, 1, 3, 1, 1, 4, 1, 3, 3, 3, 0, 1, 2, 1, 3, 0, 1, 3, 4, 2, 2, 3, 3, 3, 2, 1, 3, 0, 2, 1, 2, 4, 1, 3, 1, 2, 3, 3, 1, 3, 1, 1, 3, 0, 4, 2, 0, 1, 2, 2, 3, 3, 1, 4, 1, 2, 0, 4, 3, 1, 4, 2, 2, 2, 2, 4, 2, 1, 3, 2, 1, 2, 4, 3, 3, 2, 2, 3, 0, 1, 1, 3, 2, 0, 1, 0, 3, 0, 1, 2, 2, 1, 1, 2, 4, 2, 0, 2, 3, 2, 1, 1, 4, 4, 2, 0, 1, 1, 3, 3, 3, 3, 1, 2, 3, 0, 4, 1, 1, 3, 2, 0, 3, 1, 1, 2, 4, 1, 4, 3, 1, 1, 0, 1, 2, 2, 0, 1, 0, 3, 2, 2, 1, 2, 3, 2, 3, 1, 0, 2, 1, 1, 4, 2, 2, 3, 3, 2, 1, 1, 3, 3, 3, 3, 3, 0, 2, 0, 0, 0, 2, 3, 3, 3, 3, 2, 1, 4, 4, 3, 3, 1, 0, 4, 3, 3, 3, 2, 1, 0, 3, 3, 4, 3, 1, 2, 1, 1, 4, 2, 2, 3, 2, 2, 3, 2, 1, 3, 4, 3, 3, 0, 2, 3, 4, 1, 2, 2, 0, 4, 3, 3, 3, 4, 2, 1, 3, 3, 2, 0, 2, 2, 0, 3, 0, 3, 2, 1, 1, 2, 1, 3, 4, 1, 3, 2, 3, 3, 0, 1, 2, 0, 4, 4, 3, 2, 3, 3, 1, 0, 3, 1, 2, 3, 3, 2, 3, 2, 1, 4, 1, 2, 2, 0, 3, 0, 0, 4, 3, 3, 0, 4, 3, 3, 2, 4, 0]

In [13]:
MOD = 7681

In [16]:
print(ntt(a, MOD))

[192, 3069, 3007, 3196, 2373, 5921, 6553, 6192, 4168, 3814, 1407, 7400, 2622, 6966, 958, 805, 5001, 4138, 3622, 2373, 6658, 3780, 495, 6312, 5062, 3051, 7259, 5585, 2566, 1678, 2467, 4419, 2087, 2879, 2480, 4771, 7248, 4537, 3229, 6211, 4274, 6204, 5493, 7110, 6666, 5682, 1294, 717, 4720, 3449, 1318, 7167, 606, 3561, 430, 7434, 4128, 3565, 4849, 7460, 360, 7001, 6077, 2113, 6703, 3890, 5875, 533, 7021, 919, 4744, 3596, 4486, 1744, 5156, 324, 4237, 2653, 4096, 4307, 3888, 3429, 4876, 1962, 7367, 5641, 2108, 349, 6451, 512, 3553, 3206, 7591, 5890, 6776, 7093, 1845, 5676, 6205, 1572, 4467, 3733, 516, 1683, 2896, 6303, 2327, 7494, 6949, 2065, 7234, 1322, 2706, 5809, 4544, 2863, 600, 5644, 6425, 1393, 5010, 2342, 4886, 2116, 585, 6280, 4834, 3502, 4666, 4614, 2961, 5847, 6135, 7116, 3159, 1566, 4351, 2861, 1110, 3664, 2426, 280, 1518, 2933, 6719, 3089, 5278, 2103, 2618, 5502, 612, 1792, 5167, 2219, 3278, 5842, 126, 2360, 6557, 2767, 5276, 5374, 3095, 1884, 2795, 3730, 5096, 5363, 399, 2833,

In [26]:
primitive_root(7681)

17

In [29]:
len(a)

512

In [79]:
nthroot_mod(-1, 1024, 12289)

mpz(7)

In [76]:
nthroot_mod(-1, 256, 7681)

62

In [78]:
nthroot_mod(-1, 128, 3329)

17

In [53]:
import sympy
for k in range (100):
    if sympy.isprime(k*256+1):
        print(k, k*256+1)

1 257
3 769
13 3329
30 7681
31 7937
37 9473
42 10753
46 11777
48 12289
52 13313
55 14081
57 14593
60 15361
70 17921
72 18433
76 19457
87 22273
90 23041
91 23297


In [80]:
(2285**2)/128

40790.8203125