# NTT Reference Implementation

In [1]:
import numpy as np
import copy
import random

### Dilithium Parameter

In [2]:
DILITHIUM_Q = 8380417 # 2**23 - 2**13 + 1
DILITHIUM_N = 256
DILITHIUM_LOGN = 8
DILITHIUM_ROOT_OF_UNITY = 1753

### Butterfly

In [3]:
def ct_bf(u, v, z):
    global DILITHIUM_Q
    t = (v * z) % DILITHIUM_Q
    v = (u - t) % DILITHIUM_Q
    u = (u + t) % DILITHIUM_Q
    return u, v

def gs_bf(u, v, z):
    global DILITHIUM_Q
    t = (u - v) % DILITHIUM_Q
    u = (u + v) % DILITHIUM_Q
    v = (t * z) % DILITHIUM_Q
    return u, v

def gs_bf_div2(u, v, z):
    global DILITHIUM_Q
    t = div2((u - v) % DILITHIUM_Q)
    u = div2((u + v) % DILITHIUM_Q)
    v = (t * z) % DILITHIUM_Q
    return u, v

def div2(x):
    global DILITHIUM_Q
    if x & 1:
        return (x >> 1) + ((DILITHIUM_Q + 1) // 2)
    else:
        return x >> 1

### Twiddle factor

In [4]:
def bit_reversal(x):  
    binary_string = format(x, '08b')    
    reversed_binary_string = binary_string[::-1]    
    reversed_decimal = int(reversed_binary_string, 2)
    return reversed_decimal
    
def zeta_generator():
    global DILITHIUM_Q
    global DILITHIUM_N
    global DILITHIUM_ROOT_OF_UNITY
    
    tree=np.zeros(DILITHIUM_N)
    for i in range (DILITHIUM_N):
        tree[i] = bit_reversal(i)

    tmp={}
    tmp[0] = 1

    zetas={}
    zetas_inv={}

    for i in range (1, DILITHIUM_N, 1):
        tmp[i] = (tmp[i-1] * DILITHIUM_ROOT_OF_UNITY)  % DILITHIUM_Q

    for i in range (0, DILITHIUM_N, 1):
        zetas[i] = tmp[tree[i]]
        zetas[i] = zetas[i]
        zetas_inv[i] = -zetas[i] % DILITHIUM_Q

    return zetas, zetas_inv


zetas, zetas_inv = zeta_generator()

In [5]:
## print zeta values for NTT
print("zeta=")
for i in range (DILITHIUM_N):
    print(hex(zetas[i])[2:].upper().zfill(6))

zeta=
000001
495E02
397567
396569
4F062B
53DF73
4FE033
4F066B
76B1AE
360DD5
28EDB0
207FE4
397283
70894A
088192
6D3DC8
4C7294
41E0B4
28A3D2
66528A
4A18A7
794034
0A52EE
6B7D81
4E9F1D
1A2877
2571DF
1649EE
7611BD
492BB7
2AF697
22D8D5
36F72A
30911E
29D13F
492673
50685F
2010A2
3887F7
11B2C3
0603A4
0E2BED
10B72C
4A5F35
1F9D15
428CD4
3177F4
20E612
341C1D
1AD873
736681
49553F
3952F6
62564A
65AD05
439A1C
53AA5F
30B622
087F38
3B0E6D
2C83DA
1C496E
330E2B
1C5B70
2EE3F1
137EB9
57A930
3AC6EF
3FD54C
4EB2EA
503EE1
7BB175
2648B4
1EF256
1D90A2
45A6D4
2AE59B
52589C
6EF1F5
3F7288
175102
075D59
1187BA
52ACA9
773E9E
0296D8
2592EC
4CFF12
404CE8
4AA582
1E54E6
4F16C1
1A7E79
03978F
4E4817
31B859
5884CC
1B4827
5B63D0
5D787A
35225E
400C7E
6C09D1
5BD532
6BC4D3
258ECB
2E534C
097A6C
3B8820
6D285C
2CA4F8
337CAA
14B2A0
558536
28F186
55795D
4AF670
234A86
75E826
78DE66
05528C
7ADF59
0F6E17
5BF3DA
459B7E
628B34
5DBECB
1A9E7B
0006D9
6257C5
574B3C
69A8EF
289838
64B5FE
7EF8F5
2A4E78
120A23
0154A8
09B7FF
435E87
437FF8
5CD5B4


In [6]:
## print zeta_inv values for INTT
print("zetas_inv=")
for i in range (DILITHIUM_N):
    print(hex(zetas_inv[i])[2:].upper().zfill(6))

zetas_inv=
7FE000
3681FF
466A9A
467A98
30D9D6
2C008E
2FFFCE
30D996
092E53
49D22C
56F251
5F601D
466D7E
0F56B7
775E6F
12A239
336D6D
3DFF4D
573C2F
198D77
35C75A
069FCD
758D13
146280
3140E4
65B78A
5A6E22
699613
09CE44
36B44A
54E96A
5D072C
48E8D7
4F4EE3
560EC2
36B98E
2F77A2
5FCF5F
47580A
6E2D3E
79DC5D
71B414
6F28D5
3580CC
6042EC
3D532D
4E680D
5EF9EF
4BC3E4
65078E
0C7980
368AC2
468D0B
1D89B7
1A32FC
3C45E5
2C35A2
4F29DF
7760C9
44D194
535C27
639693
4CD1D6
638491
50FC10
6C6148
2836D1
451912
400AB5
312D17
2FA120
042E8C
59974D
60EDAB
624F5F
3A392D
54FA66
2D8765
10EE0C
406D79
688EFF
7882A8
6E5847
2D3358
08A163
7D4929
5A4D15
32E0EF
3F9319
353A7F
618B1B
30C940
656188
7C4872
3197EA
4E27A8
275B35
6497DA
247C31
226787
4ABDA3
3FD383
13D630
240ACF
141B2E
5A5136
518CB5
766595
4457E1
12B7A5
533B09
4C6357
6B2D61
2A5ACB
56EE7B
2A66A4
34E991
5C957B
09F7DB
07019B
7A8D75
0500A8
7071EA
23EC27
3A4483
1D54CD
222136
654186
7FD928
1D883C
2894C5
163712
5747C9
1B2A03
00E70C
559189
6DD5DE
7E8B59
762802
3C817A
3C6009
23

### ref NTTT/INTT model

In [7]:
def fwd_NTT(poly_r):
    global DILITHIUM_Q
    global DILITHIUM_N
    global zetas

    r = copy.deepcopy(poly_r)
    
    k = 0
    m = 128
    while (m > 0):
        start = 0
        while (start < DILITHIUM_N):
            k += 1
            zeta = zetas[k]
            for j in range(start, start+m):
                r[j], r[j + m] = ct_bf(r[j], r[j + m], zeta)
            start = start + 2 * m
        m >>= 1

    return r

def inv_NTT(poly_r):
    global DILITHIUM_Q
    global DILITHIUM_N
    global zetas_inv

    r = copy.deepcopy(poly_r)
    
    k = DILITHIUM_N
    m = 1
    while (m < DILITHIUM_N):
        start = 0
        while (start < DILITHIUM_N):
            k -= 1
            zeta = zetas_inv[k]
            for j in range(start, start+m):
                r[j], r[j + m] = gs_bf(r[j], r[j + m], zeta)
            start = start + 2 * m
        m <<= 1
    
    f = 8347681  # 256^-1 mod DILITHIUM_Q
    for j in range(DILITHIUM_N):
        r[j] = f*r[j] % DILITHIUM_Q

    return r

### 2x2 NTT/INTT model

In [8]:
def fwd_NTT2x2(poly_r):
    global DILITHIUM_Q
    global DILITHIUM_N
    global DILITHIUM_LOGN
    global zetas

    r = copy.deepcopy(poly_r)

    k2={}
    zeta2={}

    for l in range(DILITHIUM_LOGN, 0, -2):
        m = 1 << (l - 2)
        for i in range(0, DILITHIUM_N, 1 << l):
            k1 = (DILITHIUM_N + i) >> l
            k2[0] = (DILITHIUM_N + i) >> (l - 1)
            k2[1] = k2[0] + 1
            zeta1 = zetas[k1]
            zeta2[0] = zetas[k2[0]]
            zeta2[1] = zetas[k2[1]]

            for j in range(i, i + m):
                u00 = r[j]
                u01 = r[j + m]
                v00 = r[j + 2 * m]
                v01 = r[j + 3 * m]

                u10, u11 = ct_bf(u00, v00, zeta1)
                v10, v11 = ct_bf(u01, v01, zeta1)

                u20, v20 = ct_bf(u10, v10, zeta2[0])
                u21, v21 = ct_bf(u11, v11, zeta2[1])

                r[j] = u20
                r[j + m] = v20
                r[j + 2 * m] = u21
                r[j + 3 * m] = v21

    return r

def inv_NTT2x2(poly_r):
    global DILITHIUM_Q
    global DILITHIUM_N
    global DILITHIUM_LOGN
    global zetas_inv

    r = copy.deepcopy(poly_r)
    
    k1={}
    zeta1={}

    for l in range(0, DILITHIUM_LOGN - (DILITHIUM_LOGN & 1), 2):
        m = 1 << l
        for i in range(0, DILITHIUM_N, 1 << (l + 2)):
            k1[0] = ((DILITHIUM_N - (i >> 1)) >> l) - 1
            k1[1] = k1[0] - 1
            k2 = ((DILITHIUM_N - (i >> 1)) >> (l + 1)) - 1
            zeta1[0] = zetas_inv[k1[0]]
            zeta1[1] = zetas_inv[k1[1]]
            zeta2 = zetas_inv[k2]

            for j in range(i, i + m):
                u00 = r[j]
                v00 = r[j + m]
                u01 = r[j + 2 * m]
                v01 = r[j + 3 * m]

                u10, u11 = gs_bf(u00, v00, zeta1[0])
                v10, v11 = gs_bf(u01, v01, zeta1[1])

                u20, u21 = gs_bf(u10, v10, zeta2)
                v20, v21 = gs_bf(u11, v11, zeta2)

                r[j] = u20
                r[j + m] = v20
                r[j + 2 * m] = u21
                r[j + 3 * m] = v21
            
    f = 8347681  # 256^-1 mod DILITHIUM_Q
    for j in range(DILITHIUM_N):
        r[j] = f*r[j] % DILITHIUM_Q

    return r


def inv_NTT2x2_div2(poly_r):
    global DILITHIUM_Q
    global DILITHIUM_N
    global DILITHIUM_LOGN
    global zetas_inv

    r = copy.deepcopy(poly_r)
    
    k1={}
    zeta1={}

    for l in range(0, DILITHIUM_LOGN - (DILITHIUM_LOGN & 1), 2):
        m = 1 << l
        for i in range(0, DILITHIUM_N, 1 << (l + 2)):
            k1[0] = ((DILITHIUM_N - (i >> 1)) >> l) - 1
            k1[1] = k1[0] - 1
            k2 = ((DILITHIUM_N - (i >> 1)) >> (l + 1)) - 1
            zeta1[0] = zetas_inv[k1[0]]
            zeta1[1] = zetas_inv[k1[1]]
            zeta2 = zetas_inv[k2]

            for j in range(i, i + m):
                u00 = r[j]
                v00 = r[j + m]
                u01 = r[j + 2 * m]
                v01 = r[j + 3 * m]

                u10, u11 = gs_bf_div2(u00, v00, zeta1[0])
                v10, v11 = gs_bf_div2(u01, v01, zeta1[1])

                u20, u21 = gs_bf_div2(u10, v10, zeta2)
                v20, v21 = gs_bf_div2(u11, v11, zeta2)                

                r[j] = u20
                r[j + m] = v20
                r[j + 2 * m] = u21
                r[j + 3 * m] = v21

    return r


### Test

In [9]:
test_no = 1
for test_i in range(test_no):
    r_init = {}
    for i in range (DILITHIUM_N):
        r_init[i] = i % DILITHIUM_Q #random.randrange(DILITHIUM_Q)  #

    #using ref model
    r_in_ntt = fwd_NTT(r_init)
    r_from_ntt = inv_NTT(r_in_ntt)
    #check ref model
    if (r_init != r_from_ntt):
        print("Error in ref model")

    #using 2x2 architecture
    r_in_ntt2x2 = fwd_NTT2x2(r_init)
    r_from_ntt2x2 = inv_NTT2x2(r_in_ntt2x2)
    #check 2x2 architecture
    if (r_in_ntt != r_in_ntt2x2):
        print("Error in ntt2x2 model")
    if (r_from_ntt2x2 != r_init):
        print("Error in inv_ntt2x2 model")

    #using 2x2 div2 architecture
    r_from_ntt2x2_div2 = inv_NTT2x2_div2(r_in_ntt2x2)
    #check 2x2 div2 architecture
    if (r_from_ntt2x2_div2 != r_init):
        print("Error in inv_ntt2x2 div2 model")


def print_table(label, data, rows, cols):
    print(label)
    values = list(data.values())
    for i in range(0, len(values), cols):
        row_values = values[i:i + cols]
        print(" ".join(f"{value:06X}" for value in row_values))
    
print_table("r_init=", r_init, rows=16, cols=16)
print_table("r_in_ntt=", r_in_ntt, rows=16, cols=16)
print_table("r_in_ntt2x2=", r_in_ntt2x2, rows=16, cols=16)
print_table("r_from_ntt=", r_from_ntt, rows=16, cols=16)
print_table("r_from_ntt2x2=", r_from_ntt2x2, rows=16, cols=16)
print_table("r_from_ntt2x2_div2=", r_from_ntt2x2_div2, rows=16, cols=16)

r_init=
000000 000001 000002 000003 000004 000005 000006 000007 000008 000009 00000A 00000B 00000C 00000D 00000E 00000F
000010 000011 000012 000013 000014 000015 000016 000017 000018 000019 00001A 00001B 00001C 00001D 00001E 00001F
000020 000021 000022 000023 000024 000025 000026 000027 000028 000029 00002A 00002B 00002C 00002D 00002E 00002F
000030 000031 000032 000033 000034 000035 000036 000037 000038 000039 00003A 00003B 00003C 00003D 00003E 00003F
000040 000041 000042 000043 000044 000045 000046 000047 000048 000049 00004A 00004B 00004C 00004D 00004E 00004F
000050 000051 000052 000053 000054 000055 000056 000057 000058 000059 00005A 00005B 00005C 00005D 00005E 00005F
000060 000061 000062 000063 000064 000065 000066 000067 000068 000069 00006A 00006B 00006C 00006D 00006E 00006F
000070 000071 000072 000073 000074 000075 000076 000077 000078 000079 00007A 00007B 00007C 00007D 00007E 00007F
000080 000081 000082 000083 000084 000085 000086 000087 000088 000089 00008A 00008B 00008C 00008