In [16]:
import random
import math
import sys

def miller_rabin(p,s=11):
    #computes p-1 decomposition in 2**u*r
    r = p-1
    u = 0
    while r&1 == 0:#true while the last bit of r is zero
        u += 1
        r = int(r/2)

    # apply miller_rabin primality test
    for i in range(s):
        a = random.randrange(2,p-1) # choose random a in {2,3,...,p-2}
        z = pow(a,r,p)

        if z != 1 and z != p-1:
            for j in range(u-1):
                if z != p-1:
                    z = pow(z,2,p)
                    if z == 1:
                        return False
                else:
                    break
            if z != p-1:
                return False
    return True


def is_prime(n,s=11):
     #lowPrimes is all primes (sans 2, which is covered by the bitwise and operator)
     #under 1000. taking n modulo each lowPrime allows us to remove a huge chunk
     #of composite numbers from our potential pool without resorting to Rabin-Miller
     lowPrimes =   [3,5,7,11,13,17,19,23,29,31,37,41,43,47,53,59,61,67,71,73,79,83,89,97
                   ,101,103,107,109,113,127,131,137,139,149,151,157,163,167,173,179
                   ,181,191,193,197,199,211,223,227,229,233,239,241,251,257,263,269
                   ,271,277,281,283,293,307,311,313,317,331,337,347,349,353,359,367
                   ,373,379,383,389,397,401,409,419,421,431,433,439,443,449,457,461
                   ,463,467,479,487,491,499,503,509,521,523,541,547,557,563,569,571
                   ,577,587,593,599,601,607,613,617,619,631,641,643,647,653,659,661
                   ,673,677,683,691,701,709,719,727,733,739,743,751,757,761,769,773
                   ,787,797,809,811,821,823,827,829,839,853,857,859,863,877,881,883
                   ,887,907,911,919,929,937,941,947,953,967,971,977,983,991,997]
     if (n >= 3):
         if (n&1 != 0):
             for p in lowPrimes:
                 if (n == p):
                    return True
                 if (n % p == 0):
                     return False
             return miller_rabin(n,s)
     return False

def generate_large_prime(k,s=11):
    #print "Generating prime of %d bits" % k
    #k is the desired bit length

    # using security parameter s=11, we have a error probability of less than
    # 2**-80

    r=int(100*(math.log(k,2)+1)) #number of max attempts
    while r>0:
        #randrange is mersenne twister and is completely deterministic
        #unusable for serious crypto purposes
        n = random.randrange(2**(k-1),2**(k))
        r-=1
        if is_prime(n,s) == True:
            return n
    raise Exception("Failure after %d tries." % r)

In [17]:
import math

DEBUG_MODE_NTT  = 0
DEBUG_MODE_INTT = 0

# Modular inverse (https://stackoverflow.com/questions/4798654/modular-multiplicative-inverse-function-in-python)
def egcd(a, b):
    if a == 0:
        return (b, 0, 1)
    else:
        g, y, x = egcd(b % a, a)
        return (g, x - (b // a) * y, y)

def modinv(a, m):
    g, x, y = egcd(a, m)
    if g != 1:
        raise Exception('Modular inverse does not exist')
    else:
        return x % m

# Bit-Reverse integer
def intReverse(a,n):
    b = ('{:0'+str(n)+'b}').format(a)
    return int(b[::-1],2)

# Bit-Reversed index
def indexReverse(a,r):
    n = len(a)
    b = [0]*n
    for i in range(n):
        rev_idx = intReverse(i,r)
        b[rev_idx] = a[i]
    return b

# forward ntt (takes input in normal order, produces output in bit-reversed order)
def IterativeForwardNTT(arrayIn, P, W, R):
    #########################################################
    if DEBUG_MODE_NTT:
        A_ntt_interm_1 = open("NTT_DIN_DEBUG_1.txt","w") # Just result
        A_ntt_interm_2 = open("NTT_DIN_DEBUG_2.txt","w") # BTF inputs
    #########################################################

    arrayOut = [0] * len(arrayIn)
    N = len(arrayIn)

    for idx in range(N):
        arrayOut[idx] = arrayIn[idx]

    #########################################################
    if DEBUG_MODE_NTT:
        A_ntt_interm_1.write("------------------------------ input: \n")
        A_ntt_interm_2.write("------------------------------ input: \n")
        for idx in range(N):
            A_ntt_interm_1.write(str(arrayOut[idx])+"\n")
            A_ntt_interm_2.write(str(arrayOut[idx])+"\n")
    #########################################################

    v = int(math.log(N, 2))

    for i in range(0, v):
        #########################################################
        if DEBUG_MODE_NTT:
            A_ntt_interm_1.write("------------------------------ stage: "+str(i)+"\n")
            A_ntt_interm_2.write("------------------------------ stage: "+str(i)+"\n")
        #########################################################
        for j in range(0, (2 ** i)):
            for k in range(0, (2 ** (v - i - 1))):
                s = j * (2 ** (v - i)) + k
                t = s + (2 ** (v - i - 1))

                w = (W ** ((2 ** i) * k)) % P

                as_temp = arrayOut[s]
                at_temp = arrayOut[t]

                arrayOut[s] = (as_temp + at_temp) % P
                arrayOut[t] = ((as_temp - at_temp) * w) % P

                #########################################################
                if DEBUG_MODE_NTT:
                    A_ntt_interm_2.write((str(s)+" "+str(t)+" "+str(((2 ** i) * k))).ljust(16)+"("+str(as_temp).ljust(12)+" "+str(at_temp).ljust(12)+" "+str((w*R) % P).ljust(12)+") -> ("+str(arrayOut[s]).ljust(12)+" "+str(arrayOut[t]).ljust(12)+")"+"\n")
                #########################################################

        #########################################################
        if DEBUG_MODE_NTT:
            for idx in range(N):
                A_ntt_interm_1.write(str(arrayOut[idx])+"\n")
        #########################################################

    #########################################################
    if DEBUG_MODE_NTT:
        A_ntt_interm_1.write("------------------------------ result: \n")
        A_ntt_interm_2.write("------------------------------ result: \n")
        for idx in range(N):
            A_ntt_interm_1.write(str(arrayOut[idx])+"\n")
            A_ntt_interm_2.write(str(arrayOut[idx])+"\n")
    #########################################################

    #########################################################
    if DEBUG_MODE_NTT:
        A_ntt_interm_1.close()
        A_ntt_interm_2.close()
    #########################################################

    return arrayOut

# inverse ntt (takes input in normal order, produces output in bit-reversed order)
def IterativeInverseNTT(arrayIn, P, W, R):
    #########################################################
    if DEBUG_MODE_INTT:
        A_ntt_interm_1 = open("test/INTT_DIN_DEBUG_1.txt","w") # Just result
        A_ntt_interm_2 = open("test/INTT_DIN_DEBUG_2.txt","w") # BTF inputs
    #########################################################

    arrayOut = [0] * len(arrayIn)
    N = len(arrayIn)

    for idx in range(N):
        arrayOut[idx] = arrayIn[idx]

    #########################################################
    if DEBUG_MODE_INTT:
        A_ntt_interm_1.write("------------------------------ input: \n")
        A_ntt_interm_2.write("------------------------------ input: \n")
        for idx in range(N):
            A_ntt_interm_1.write(str(arrayOut[idx])+"\n")
            A_ntt_interm_2.write(str(arrayOut[idx])+"\n")
    #########################################################

    v = int(math.log(N, 2))

    for i in range(0, v):
        #########################################################
        if DEBUG_MODE_INTT:
            A_ntt_interm_1.write("------------------------------ stage: "+str(i)+"\n")
            A_ntt_interm_2.write("------------------------------ stage: "+str(i)+"\n")
        #########################################################
        for j in range(0, (2 ** i)):
            for k in range(0, (2 ** (v - i - 1))):
                s = j * (2 ** (v - i)) + k
                t = s + (2 ** (v - i - 1))

                w = (W ** ((2 ** i) * k)) % P

                as_temp = arrayOut[s]
                at_temp = arrayOut[t]

                arrayOut[s] = (as_temp + at_temp) % P
                arrayOut[t] = ((as_temp - at_temp) * w) % P

                #########################################################
                if DEBUG_MODE_INTT:
                    A_ntt_interm_2.write((str(s)+" "+str(t)+" "+str(((2 ** i) * k))).ljust(16)+"("+str(as_temp).ljust(12)+" "+str(at_temp).ljust(12)+" "+str((w*R) % P).ljust(12)+") -> ("+str(arrayOut[s]).ljust(12)+" "+str(arrayOut[t]).ljust(12)+")"+"\n")
                #########################################################

        #########################################################
        if DEBUG_MODE_INTT:
            for idx in range(N):
                A_ntt_interm_1.write(str(arrayOut[idx])+"\n")
        #########################################################

    #########################################################
    if DEBUG_MODE_INTT:
        A_ntt_interm_1.write("------------------------------ result: \n")
        A_ntt_interm_2.write("------------------------------ result: \n")
        for idx in range(N):
            A_ntt_interm_1.write(str(arrayOut[idx])+"\n")
            A_ntt_interm_2.write(str(arrayOut[idx])+"\n")
    #########################################################

    N_inv = modinv(N, P)
    for i in range(N):
        arrayOut[i] = (arrayOut[i] * N_inv) % P

    #########################################################
    if DEBUG_MODE_INTT:
        A_ntt_interm_1.write("------------------------------ result (with N_inv): \n")
        A_ntt_interm_2.write("------------------------------ result (with N_inv): \n")
        for idx in range(N):
            A_ntt_interm_1.write(str(arrayOut[idx])+"\n")
            A_ntt_interm_2.write(str(arrayOut[idx])+"\n")
    #########################################################

    #########################################################
    if DEBUG_MODE_INTT:
        A_ntt_interm_1.close()
        A_ntt_interm_2.close()
    #########################################################

    return arrayOut

In [18]:

# Copyright 2020
# Ahmet Can Mert <ahmetcanmert@sabanciuniv.edu>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from math import log,ceil
from random import randint



# Test Generator for N-pt NTT/INTT with P Processing Element

# -------------------------------------------------------------------------- TXT
PRM_TXT       = open("PARAM.txt","w")
NTT_DIN_TXT   = open("NTT_DIN.txt","w")
NTT_DOUT_TXT  = open("NTT_DOUT.txt","w")
INTT_DIN_TXT  = open("INTT_DIN.txt","w")
INTT_DOUT_TXT = open("INTT_DOUT.txt","w")
W_TXT         = open("W.txt","w")
WINV_TXT      = open("WINV.txt","w")
# -------------------------------------------------------------------------- TXT

# Pre-defined parameter set
PC = 1 # 0: generate parameters / 1: use pre-defined parameter set

# Number of Processing Elements
P = 1

# Generate parameters
q       = 0
psi     = 0
psi_inv = 0
w       = 0
w_inv   = 0
n_inv   = 0

if PC:
    #k=log q base2
    #N, K, q, psi = 1024, 64, 2**64-2**32+1, 816101479115663336

    N, K, q, psi = 1024, 64, 2**64-2**32+1, 9630642590298920985

    #N, K, q, psi = 1024, 27, 132120577, 73993
    #N, K, q, psi = 1024, 29, 463128577, 61961
    #N, K, q, psi = 2048, 30, 618835969, 327404
    #N, K, q, psi = 2048, 37, 137438691329, 22157790
    #N, K, q, psi = 4096, 25, 33349633, 8131
    #N, K, q, psi = 4096, 36, 68719230977, 29008497
    #N, K, q, psi = 4096, 55, 36028797009985537, 5947090524825
    #N, K, q, psi = 8192, 43, 8796092858369, 1734247217
    #N, K, q, psi = 16384, 49, 562949951881217, 45092463253
    #N, K, q, psi = 16384, 50, 1125899903500289, 68423600398
    #N, K, q, psi = 32768, 55, 36028797009985537, 5947090524825

    psi_inv = modinv(psi,q)
    w       = pow(psi,2,q)
    w_inv   = modinv(w,q)

    R       = 2**((int(log(N,2))+1) * int(ceil((1.0*K)/(1.0*((int(log(N,2))+1))))))
    n_inv   = modinv(N,q)
    PE      = P*2
else:
    # Input parameters
    #N, K = 256, 13
    #N, K = 256, 23
    #N, K = 512, 14
    N, K = 1024, 64
    #N, K = 1024, 29
    #N, K = 2048, 30
    #N, K = 4096, 60

    while(1):
        q = 2**64-2**32+1
        # check q = 1 (mod 2n or n)
        while (not ((q % (2*N)) == 1)):
            q = 2**64-2**32+1

        # generate NTT parameters
        for i in range(2,q-1):
            if pow(i,2*N,q) == 1:
                if pow(i,N,q) == (q-1):
                    pru = [i**x % q for x in range(1,2*N)]
                    if not(1 in pru):
                        psi     = i
                        psi_inv = modinv(i,q)
                        w       = pow(psi,2,q)
                        w_inv   = modinv(w,q)
                        break
                else:
                    continue
                break
            else:
                continue
            break
        else:
            continue
        break

    R     = 2**((int(log(N,2))+1) * int(ceil((1.0*K)/(1.0*((int(log(N,2))+1))))))
    n_inv = modinv(N,q)
    PE    = P*2

# Print parameters
print("-----------------------")
print("N      : {}".format(N))
print("K      : {}".format(K))
print("PE     : {}".format(P))
print("q      : {}".format(q))
print("psi    : {}".format(psi))
print("psi_inv: {}".format(psi_inv))
print("w      : {}".format(w))
print("w_inv  : {}".format(w_inv))
print("n_inv  : {}".format(n_inv))
print("log(R) : {}".format(int(log(R,2))))
print("-----------------------")

# --------------------------------------------------------------------------

PRM_TXT.write(hex(N          ).replace("L","")[2:].ljust(20)+"\n")
PRM_TXT.write(hex(q          ).replace("L","")[2:].ljust(20)+"\n")
PRM_TXT.write(hex(w          ).replace("L","")[2:].ljust(20)+"\n")
PRM_TXT.write(hex(w_inv      ).replace("L","")[2:].ljust(20)+"\n")
PRM_TXT.write(hex(psi        ).replace("L","")[2:].ljust(20)+"\n")
PRM_TXT.write(hex(psi_inv    ).replace("L","")[2:].ljust(20)+"\n")
PRM_TXT.write(hex((n_inv*R)%q).replace("L","")[2:].ljust(20)+"\n")
PRM_TXT.write(hex(R          ).replace("L","")[2:].ljust(20)+"\n")

PRM_TXT.write("// Input order:\n")

PRM_TXT.write("// N\n")
PRM_TXT.write("// q\n")
PRM_TXT.write("// w\n")
PRM_TXT.write("// w_inv\n")
PRM_TXT.write("// psi\n")
PRM_TXT.write("// psi_inv\n")
PRM_TXT.write("// n_inv\n")
PRM_TXT.write("// R\n")
PRM_TXT.write("// \n")
PRM_TXT.write("// K :"+str(K)+"\n")
PRM_TXT.write("// PE:"+str(P)+"\n")

# --------------------------------------------------------------------------

# NTT/INTT operation
A = [randint(0,q-1) for _ in range(N)]

A_ntt = IterativeForwardNTT(A,q,w,R)
A_rev = indexReverse(A_ntt,int(log(N,2)))
A_rec = IterativeInverseNTT(A_rev,q,w_inv,R)
A_res = indexReverse(A_rec,int(log(N,2)))

# Sanity Check
if sum([abs(x-y) for x,y in zip(A,A_res)]) == 0:
    print("Sanity Check: NTT operation is correct.")
else:
    print("Sanity Check: Check your math with NTT/INTT operation.")

# Print input/output to txt (normal input - bit-reversed output)
for i in range(N):
    if((i==966)):
      print(hex(A[i]).replace("L","")[2:]+"\n")
      print(hex(A_ntt[i]).replace("L","")[2:]+"\n")
    NTT_DIN_TXT.write(hex(A[i]).replace("L","")[2:]+"\n")
    NTT_DOUT_TXT.write(hex(A_ntt[i]).replace("L","")[2:]+"\n")

for i in range(N):
    if((i==966)):
      print(hex(A_rev[i]).replace("L","")[2:]+"\n")
      print(hex(A_rec[i]).replace("L","")[2:]+"\n")
    INTT_DIN_TXT.write(hex(A_rev[i]).replace("L","")[2:]+"\n")
    INTT_DOUT_TXT.write(hex(A_rec[i]).replace("L","")[2:]+"\n")

# Print TWs to txt
for j in range(int(log(N, 2))):
    for k in range(1 if (((N//PE)>>j) < 1) else ((N//PE)>>j)):
        for i in range(P):
            w_pow = (((P<<j)*k + (i<<j)) % (N//2))
            W_TXT.write(hex(((w**w_pow % q) * R) % q).replace("L","")[2:]+"\n")
            WINV_TXT.write(hex(((w_inv**w_pow % q) * R) % q).replace("L","")[2:]+"\n")

# --------------------------------------------------------------------------

-----------------------
N      : 1024
K      : 64
PE     : 1
q      : 18446744069414584321
psi    : 9630642590298920985
psi_inv: 8576353051786437773
w      : 4440654710286119610
w_inv  : 12337821426711963180
n_inv  : 18428729670909296641
log(R) : 66
-----------------------
Sanity Check: NTT operation is correct.
631e2c5363d85e1a

73e1dde5c2b99940

2105e8c9429a01b2

36642a76222f5286

