In [1]:
def find_n(lbits,hbits):
    nums = []
    for i in range(1, hbits + 1):
        for j in range(0, hbits + 1):
            n = 2^i * 3^j
            if 2^lbits <= n <= 2^hbits:
                nums.append(n)
    nums.sort()
    return nums

def find_q(n,d,lbits,hbits):
    fac = factor(n)
    
    if len(fac) > 2 or len(fac) == 0:
        return 0
    if len(fac) == 1 and fac[0][0] != 2:
        return 0
    if len(fac) == 2 and (fac[0][0] != 2 or fac[1][0] != 3):
        return 0
        
    if (n % d != 0) or (n/2 < d):
        return 0
    
    w_order = 3 * n / d
    qs = []
    k = 1
    
    while True:
        q = w_order * k + 1
        k += 1
        
        if q < 2^lbits:
            continue
        if q > 2^hbits:
            break
        if q in Primes():
            qs.append(q)
            
    return qs

def run_parameter_search():
    n_set = find_n(9, 11)
    d_set = [1, 2, 3, 4]

    for n in n_set:
        f = factor(n)
        a_exp = f[0][1] if f[0][0] == 2 else 0
        b_exp = f[1][1] if len(f) > 1 and f[1][0] == 3 else (1 if f[0][0] == 3 else 0)
        
        print(f"--- [ n = {n} (2^{a_exp} * 3^{b_exp}) ] ---")
        
        for b in d_set:
            qs = find_q(n, b, 0, 15)
            if qs != 0 and len(qs) > 0:
                print(f"  d: {b} | qs: {qs}")
        print()

In [2]:
run_parameter_search()

--- [ n = 512 (2^9 * 3^0) ] ---
  d: 1 | qs: [7681, 10753, 12289, 15361, 18433, 23041, 26113, 32257]
  d: 2 | qs: [769, 7681, 10753, 12289, 14593, 15361, 18433, 22273, 23041, 26113, 26881, 31489, 32257]
  d: 4 | qs: [769, 1153, 2689, 3457, 4993, 6529, 7297, 7681, 9601, 10369, 10753, 12289, 13441, 14593, 15361, 18049, 18433, 20353, 21121, 22273, 23041, 26113, 26497, 26881, 29569, 31489, 31873, 32257]

--- [ n = 576 (2^6 * 3^2) ] ---
  d: 1 | qs: [3457, 8641, 10369, 12097, 19009]
  d: 2 | qs: [2593, 3457, 8641, 10369, 12097, 16417, 19009, 21601, 25057, 28513, 30241]
  d: 3 | qs: [577, 1153, 3457, 6337, 7489, 8641, 10369, 12097, 13249, 14401, 18433, 19009, 20161, 21313, 23041, 26497, 27073, 30529, 32257]
  d: 4 | qs: [433, 1297, 2161, 2593, 3457, 3889, 6481, 8209, 8641, 10369, 12097, 15121, 16417, 17713, 19009, 19441, 21169, 21601, 23761, 25057, 28081, 28513, 30241, 32401]

--- [ n = 648 (2^3 * 3^4) ] ---
  d: 1 | qs: [3889, 9721, 17497, 19441]
  d: 2 | qs: [2917, 3889, 4861, 9721, 12637,

In [3]:
def find_gs(q):
    Zq = IntegerModRing(q)
    gs = range(1,q)

    for x in list(factor(q-1)):
        p = x[0]
        t = []
        for g in gs:
            if Zq(g)^((q-1)/p) != 1:
                t.append(Zq(g))
        gs = t
    return gs

def find_ws(n,q,d):
    w_order = 3*n / d
    k = Integer((q-1)/w_order)
    ws = [g^k for g in find_gs(q)]
    return sorted(list(set(ws)))

class NTRUPlusGenerator:
    def __init__(self, n, q, d, w, R_val=2**16, L_val=2**32):
        self.n = n
        self.q = Integer(q)
        self.d = d
        self.w = w
        self.R = R_val
        self.L = L_val
        self.Zq = IntegerModRing(self.q)
        self.w_order = (3 * self.n) // self.d

        self.gen_zetas()

    def center(self, x):
        qhalf = self.q // 2
        return ((Integer(x) + qhalf) % self.q) - qhalf

    def tomont(self, x):
        return self.center(Integer(x) * self.R)
    
    def toplantard(self, x):
        qprime = Integer(self.q^-1 % self.L)
        x = self.Zq(x) * (-self.L)
        x = Integer(x)
        x = (x * qprime) & 0xffffffff
        return x

    def gen_zetas(self):
        fac = factor(self.n // self.d)
        if len(fac) == 2:
            Radix2, Radix3 = fac[0][1], fac[1][1]
        elif len(fac) == 1:
            Radix2 = fac[0][1] if fac[0][0] == 2 else 0
            Radix3 = fac[0][1] if fac[0][0] == 3 else 0
        else:
            Radix2, Radix3 = 0, 0

        level = Radix2 + Radix3
        nn_d = self.n // self.d
        tree = zero_matrix(ZZ, level + 1, nn_d)
        tree[0, 0] = self.w_order

        zetas = [Integer(1)]
        tree[1, 0] = tree[0, 0] // 6
        tree[1, 1] = 5 * tree[0, 0] // 6
        zetas.append(self.Zq(self.w)**(tree[1, 0]))

        for l in range(1, Radix3 + 1):
            for i in range(2 * 3**(l - 1)):
                base = tree[l, i] // 3
                tree[l + 1, 3 * i] = base
                tree[l + 1, 3 * i + 1] = base + self.w_order // 3
                tree[l + 1, 3 * i + 2] = base + 2 * self.w_order // 3
                zetas.append(self.Zq(self.w)**(tree[l + 1, 3 * i]))
                zetas.append(self.Zq(self.w)**(tree[l + 1, 3 * i] * 2))

        for l in range(Radix3 + 1, level):
            for i in range(2 * 3**(Radix3) * 2**(l - (Radix3 + 1))):
                base = tree[l, i] // 2
                tree[l + 1, 2 * i] = base
                tree[l + 1, 2 * i + 1] = base + self.w_order // 2
                zetas.append(self.Zq(self.w)**(tree[l + 1, 2 * i]))

        self.zetas = zetas

    def print_zetas_mont(self):
        l = len(self.zetas)
        print("const int16_t zetas[%d] = {" % l)
        for i in range(0, l, 8):
            chunk = self.zetas[i : i + 8]
            formatted = ["%5d" % self.tomont(v) for v in chunk]
            line = "\t" + ", ".join(formatted)
            if i + 8 < l:
                line += ","
            print(line)
        print("};")

    def print_zetas_plant(self):
        l = len(self.zetas)
        print("const uint32_t zetas[%d] = {" % l)
        for i in range(0, l, 8):
            chunk = self.zetas[i : i + 8]
            formatted = ["0x%08Xu" % self.toplantard(v) for v in chunk]
            line = "\t" + ", ".join(formatted)
            if i + 8 < l:
                line += ","
            print(line)
        print("};")

    def print_params_mont(self):
        print("========ntt consts for montgomery========")

        # 1. Montgomery Basic Params
        print("#define NTRUPLUS_R           %5d" % self.center(self.R))
        print("#define NTRUPLUS_RINV        %5d" % self.center(self.Zq(self.R)**-1))
        print("#define NTRUPLUS_RSQ         %5d" % self.center(self.Zq(self.R)**2))
        print("#define NTRUPLUS_QINV        %5d" % (self.q**-1 % self.R))

        # 2. NTT Core Constants
        omega_val = self.Zq(self.w)**(self.w_order // 3)
        print("#define NTRUPLUS_OMEGA       %5d" % self.tomont(omega_val))
        
        z = self.Zq(self.w)**(self.w_order // 6)
        zminz5inv = (z - z**5)**-1
        print("#define NTRUPLUS_ZMINUSZ5INV %5d" % self.tomont(zminz5inv))
        
        n_d_inv = self.Zq(self.n // self.d)**-1
        print("#define NTRUPLUS_NINV        %5d" % self.tomont(n_d_inv))
        print("#define NTRUPLUS_2NINV       %5d\n" % self.tomont(2 * n_d_inv))
        print()
        
        # 3. Zetas Table
        self.print_zetas_mont()

    def print_params_plant(self):
        print("========ntt consts for plantard========")

        print("#define NTRUPLUS_R           0x%08Xu" % self.toplantard(1))
        print("#define NTRUPLUS_RINV        0x%08Xu" % self.toplantard((-self.L)^(-2)))
        print("#define NTRUPLUS_RSQ         0x%08Xu" % self.toplantard(-self.L))
        print("#define NTRUPLUS_QINV        0x%08Xu" % (self.q**-1 % self.L))

        # 2. NTT Core Constants
        omega_val = self.Zq(self.w)**(self.w_order // 3)
        print("#define NTRUPLUS_OMEGA       0x%08Xu" % self.toplantard(omega_val))
        
        z = self.Zq(self.w)**(self.w_order // 6)
        zminz5inv = (z - z**5)**-1
        print("NTRUPLUS_ZMINUSZ5INV 0x%08Xu" % self.toplantard(zminz5inv))
        
        n_d_inv = self.Zq(self.n // self.d)**-1
        print("#define NTRUPLUS_NINV        0x%08Xu" % self.toplantard(n_d_inv))
        print("#define NTRUPLUS_2NINV       0x%08Xu" % self.toplantard(2*n_d_inv))
        print()

        # 3. Zetas Table
        self.print_zetas_plant()   


In [4]:
n = 1152
q = 3457
d = 4

ws = find_ws(n,q,d)
w = min(ws)

print(w)

9


In [5]:
n, q, d, w = 768, 3457, 4, 22

gen = NTRUPlusGenerator(n, q, d, w)
gen.print_params_mont()
print()
gen.print_params_plant()

#define NTRUPLUS_R            -147
#define NTRUPLUS_RINV         -682
#define NTRUPLUS_RSQ           867
#define NTRUPLUS_QINV        12929
#define NTRUPLUS_OMEGA        -886
#define NTRUPLUS_ZMINUSZ5INV -1665
#define NTRUPLUS_NINV         -811
#define NTRUPLUS_2NINV       -1622


const int16_t zetas[192] = {
	 -147, -1033,  -682,  -248,  -708,   682,     1,  -722,
	 -723,  -257, -1124,  -867,  -256,  1484,  1262, -1590,
	 1611,   222,  1164, -1346,  1716, -1521,  -357,   395,
	 -455,   639,   502,   655,  -699,   541,    95, -1577,
	-1241,   550,   -44,    39,  -820,  -216,  -121,  -757,
	 -348,   937,   893,   387,  -603,  1713, -1105,  1058,
	 1449,   837,   901,  1637,  -569, -1617, -1530,  1199,
	   50,  -830,  -625,     4,   176,  -156,  1257, -1507,
	 -380,  -606,  1293,   661,  1428, -1580,  -565,  -992,
	  548,  -800,    64,  -371,   961,   641,    87,   630,
	  675,  -834,   205,    54, -1081,  1351,  1413, -1331,
	-1673, -1267, -1558,   281, -1464,  -588,  1015,   436,
	  22

In [6]:
n, q, d, w = 864, 3457, 3, 9

gen = NTRUPlusGenerator(n, q, d, w)
gen.print_params_mont()
print()
gen.print_params_plant()

#define NTRUPLUS_R            -147
#define NTRUPLUS_RINV         -682
#define NTRUPLUS_RSQ           867
#define NTRUPLUS_QINV        12929
#define NTRUPLUS_OMEGA        -886
#define NTRUPLUS_ZMINUSZ5INV -1665
#define NTRUPLUS_NINV        -1693
#define NTRUPLUS_2NINV          71


const int16_t zetas[288] = {
	 -147, -1033, -1265,   708,   460,  1265,  -467,   727,
	  556,  1307,  -773,  -161,  1200, -1612,   570,  1529,
	 1135,  -556,  1120,   298,  -822, -1556,   -93,  1463,
	  532,  -377,  -909,    58,  -392,  -450,  1722,  1236,
	 -486,  -491, -1569, -1078,    36,  1289, -1443,  1628,
	 1664,  -725,  -952,    99, -1020,   353,  -599,  1119,
	  592,   839,  1622,   652,  1244,  -783, -1085,  -726,
	  566,  -284, -1369, -1292,   268,  -391,   781,  -172,
	   96, -1172,   211,   737,   473,  -445,  -234,   264,
	-1536,  1467,  -676, -1542,  -170,   635,  -705, -1332,
	 -658,   831, -1712,  1311,  1488,  -881,  1087, -1315,
	 1245,   -75,   791,    -6,  -875,  -697,   -70, -1162,
	  28

In [7]:
n, q, d, w = 1152, 3457, 4, 9

gen = NTRUPlusGenerator(n, q, d, w)
gen.print_params_mont()
print()
gen.print_params_plant()

#define NTRUPLUS_R            -147
#define NTRUPLUS_RINV         -682
#define NTRUPLUS_RSQ           867
#define NTRUPLUS_QINV        12929
#define NTRUPLUS_OMEGA        -886
#define NTRUPLUS_ZMINUSZ5INV -1665
#define NTRUPLUS_NINV        -1693
#define NTRUPLUS_2NINV          71


const int16_t zetas[288] = {
	 -147, -1033, -1265,   708,   460,  1265,  -467,   727,
	  556,  1307,  -773,  -161,  1200, -1612,   570,  1529,
	 1135,  -556,  1120,   298,  -822, -1556,   -93,  1463,
	  532,  -377,  -909,    58,  -392,  -450,  1722,  1236,
	 -486,  -491, -1569, -1078,    36,  1289, -1443,  1628,
	 1664,  -725,  -952,    99, -1020,   353,  -599,  1119,
	  592,   839,  1622,   652,  1244,  -783, -1085,  -726,
	  566,  -284, -1369, -1292,   268,  -391,   781,  -172,
	   96, -1172,   211,   737,   473,  -445,  -234,   264,
	-1536,  1467,  -676, -1542,  -170,   635,  -705, -1332,
	 -658,   831, -1712,  1311,  1488,  -881,  1087, -1315,
	 1245,   -75,   791,    -6,  -875,  -697,   -70, -1162,
	  28