In [1]:
import random
import time
import math

### Parameters

In [2]:
n = 256
q = 2 ** 23 - 2 ** 13 + 1
l = 64 # 64
t = 8 # 8

In [3]:
kron_mod = 2**(l * n // t) + 1
root_t = 2**(2 * l * n // (t**2)) % kron_mod
inv_root_t = pow(root_t, t - 1, kron_mod)

# modulo_polynomial = Y ^ (n/t) + 1

modulo_polynomial = [0] * (n // t + 1)
modulo_polynomial[0] = 1
modulo_polynomial[n // t] = 1

### Naive Polynomial Multiplication

In [4]:
def naive_multiplication(polynomial1, polynomial2):
    res = [0] * n

    for i in range(n):
        for j in range(0, n-i):
            res[i+j] += (polynomial1[i] * polynomial2[j])

    for i in range(1, n):
        for j in range(n-i, n):
            res[i+j-n] -= (polynomial1[i] * polynomial2[j])

    for i in range(n):
        res[i] = res[i] % q
    
    return res

### Kroneker+

Step1: reorder coefficients and snort

In [5]:
def reorder_coefficients(polynomial):
    nussmatrix = [[0 for _ in range(n // t)] for _ in range(t)]
    
    for i in range(t):
        for j in range(n // t):
            nussmatrix[i][j] = polynomial[j * t + i]

    return nussmatrix

In [6]:
def snort(polynomial):
    nusskron = 0

    for i in range(n // t):
        nusskron += polynomial[i] * 2**(l * i)

    return nusskron

In [7]:
def _step1(polynomial):
    nussmatrix = reorder_coefficients(polynomial)
    nusskron   = [snort(polynomial) for polynomial in nussmatrix]

    return nusskron

In [8]:
def step1(polynomial1, polynomial2):
    return _step1(polynomial1), _step1(polynomial2)

Step 2: multiply each polynomial i by X^i

In [9]:
def step2(nusskron1, nusskron2):
    # normally kron_modulo reductions are necessary but the implementation leaves those for later since there is enough space to keep a few extra bits

    for i in range(t):
        nusskron1[i] = (nusskron1[i] << (i * l // t))
        nusskron2[i] = (nusskron2[i] << (i * l // t))

    return nusskron1, nusskron2

Step3: forward butterfly

In [10]:
def bit_reverse(v):
    w = int(math.log2(len(v)))
    
    for i in range(len(v)):
         b = '{:0{width}b}'.format(i, width=w)
         j = int(b[::-1], 2)

         if i < j:
               v[i], v[j] = v[j], v[i]

    return v

In [11]:
def forward_butterfly(poly):
    f = [poly[i] for i in range(t)]
    
    m = 2
    while m <= t:
        d = t // m
        root_m = (root_t ** d) % kron_mod 
        w = bit_reverse([root_m ** i % kron_mod for i in range(m // 2)])

        for i in range(m // 2):
            for k in range(0, d):
                p = i * (2 * d) + k
                u = f[p]
                v = w[i] * f[p + d] % kron_mod

                f[p] = (u + v) % kron_mod
                f[p + d] = (u - v) % kron_mod
        m *= 2

    return f

In [12]:
def step3(nusskron1, nusskron2):
    return forward_butterfly(nusskron1), forward_butterfly(nusskron2)

Step4: t multiplications

In [13]:
def step4(kron_ntt1, kron_ntt2):
    return [kron_ntt1[i] * kron_ntt2[i] % kron_mod for i in range(t)]

Step 5: backward butterfly without multiplication by t^-1

In [14]:
def _egcd(a, b):
    if a == 0:
        return (b, 0, 1)
    else:
        g, y, x = _egcd(b % a, a)
        return (g, x - (b // a) * y, y)

In [15]:
def _modinv(a, m):
    g, x, y = _egcd(a, m)
    if g != 1:
        raise Exception('modular inverse does not exist')
    else:
        return x % m

In [130]:
def backward_butterfly(poly):
    p = [poly[i] for i in range(t)]

    m = 2
    while m <= t:
        inv_root_m = (inv_root_t ** (t // m)) % kron_mod 
        w = 1

        for i in range(0, m//2):
            for k in range(0, t, m):
                u = p[k + i]
                v = (p[k + i + m // 2] * w) % kron_mod
                p[k + i] = (u + v) % kron_mod
                p[k + i + (m // 2)] = (u - v) % kron_mod

            w = (w * inv_root_m) % kron_mod
        m *= 2
    
    return p

In [17]:
def step5(kron_ntt):
    return backward_butterfly(kron_ntt)

Step6

In [18]:
def step6(nusskron):

    inv_poly_degree = _modinv(t, kron_mod)
    for i in range(t):
        nusskron[i] = (nusskron[i] * inv_poly_degree) % kron_mod

    X = (2**(l//t)) % kron_mod
    inv_X = _modinv(X, kron_mod)

    for i in range(t):
        nusskron[i] = (nusskron[i] * pow(int(inv_X), int(i), int(kron_mod))) % kron_mod

    return nusskron

Step 7

In [19]:
def sneeze(G):

    r = []

    for _ in range(n // t):
        r.append(G % (2**l))
        G = G >> l

        if r[-1] > (2**(l-1)):
            r[-1] -= (2**l)
            G += 1
    
    assert(G <= 1 and G >= 0)

    r[0] = r[0] - G
    
    return r

In [20]:
def order_coefficients_and_modulo_q_reduction(coefficients):
    ordered_coefficients = []

    for j in range(n // t):
        for i in range(t):
            ordered_coefficients.append(coefficients[i][j] % q)

    return ordered_coefficients

In [21]:
def step7(nusskron):
    # The inputed numbers are supposed to be modulo 2 ** 2304 + 1. This implies that the left G in the sneeze operation can only be 0 or 1. The resulting coefficients will be stored as their minimal positive representation modulo q.

    sneezed_numbers = []
    for number in nusskron:
        sneezed_numbers.append(sneeze(number))

    return order_coefficients_and_modulo_q_reduction(sneezed_numbers)

All steps

In [22]:
def run(input_polynomia1, input_polynomia2):
    nusskron1, nusskron2 = step1(input_polynomia1, input_polynomia2)
    nusskron1, nusskron2 = step2(nusskron1, nusskron2)
    kron_ntt1, kron_ntt2 = step3(nusskron1, nusskron2)
    kron_ntt             = step4(kron_ntt1, kron_ntt2)
    nusskron             = step5(kron_ntt)
    nusskron             = step6(nusskron)
    output_polynomial    = step7(nusskron)
    
    return output_polynomial

### Input

In [23]:
def random_polynomial():
    return [random.randint(0, q - 1) for i in range(n)]

In [24]:
random.seed(0)

input_polynomia1 = random_polynomial()
input_polynomia2 = random_polynomial()

# for i in input_polynomia1:
#     print(f"  .word {i:#0{10}x}")

# for i in input_polynomia2:
#     print(f"  .word {i:#0{10}x}")

### Playground

In [25]:
def print_numbers_in_hex(numbers, number_of_digits=None):
    for number in numbers:
        if number_of_digits != None:
            print(f"{number:#0{number_of_digits}x}")
        else:
            print(hex(number % kron_mod))

In [26]:
start = time.time()
output_polynomial = run(input_polynomia1, input_polynomia2)
end = time.time()
print('time: ', (end - start) * 1000, 'miliseconds')

print(output_polynomial[:6])
print(naive_multiplication(input_polynomia1, input_polynomia2)[:6])

time:  2.0389556884765625 miliseconds
[3330502, 7090661, 5194480, 439890, 2745074, 2460165]
[3330502, 7090661, 5194480, 439890, 2745074, 2460165]


In [27]:
def turn_to_number(s: str):
    v =  s.replace('=',' ').replace('\n', ' ').split(' ')
    v.reverse()
    _n = '0x'

    for _s in v:
        if _s[0:2] == '0x':
            _n += _s[2:]

    return _n

In [133]:
nusskron1, nusskron2 = step1(input_polynomia1, input_polynomia2)
nusskron1, nusskron2 = step2(nusskron1, nusskron2)
kron_ntt1, kron_ntt2 = step3(nusskron1, nusskron2)
kron_ntt             = step4(kron_ntt1, kron_ntt2)
nusskron             = step5(kron_ntt)
nusskron             = step6(nusskron)
output_polynomial    = step7(nusskron)

print(output_polynomial)
print(naive_multiplication(input_polynomia1, input_polynomia2))

print_numbers_in_hex(output_polynomial, 10)

[3330502, 7090661, 5194480, 439890, 2745074, 2460165, 7132666, 2160673, 7528362, 1512433, 2670703, 5399576, 2684345, 7787054, 7519338, 5213484, 5394636, 2321342, 6560029, 8321219, 7931233, 3399259, 2706820, 176874, 3575315, 7465892, 1005451, 6912753, 4482928, 4837004, 7980387, 5063904, 351307, 4412748, 7170238, 4554998, 4504726, 2127503, 7746314, 7593731, 6161752, 6155119, 7884856, 1769861, 6142015, 2412120, 3168197, 3549821, 4958879, 6124278, 5507489, 7011110, 6087456, 6400730, 2755712, 1605217, 2260191, 7277862, 8121048, 7692018, 1658435, 208969, 5086144, 2005628, 5849495, 6050581, 6301159, 770716, 4250054, 233233, 3807896, 951736, 2882996, 7202356, 6473107, 1599990, 5421908, 1776588, 2099049, 4635202, 3195461, 6513083, 6262186, 7778339, 5573226, 1907927, 214948, 5678972, 6050282, 2454510, 1325886, 7257824, 5104932, 2035901, 1410074, 4244686, 6610457, 1302195, 4770397, 1045523, 4437073, 4993033, 4942866, 5053046, 7794462, 1320525, 7883328, 1061306, 3111300, 7604150, 500969, 6332165, 

In [135]:
with open('i.txt', 'r') as f:
    print(turn_to_number(f.read()))

0x006b1ef10040e417003ae59b0008a2fa0067cb56001002aa00255b4c0048e52b005d9aa70073d2980032b0ea003816b00070fda60034f9ff0057554c006dc8490039f33c002ac47f0038bc060074c7a100240f57002fbcf50043f06d0074737c002c93dc00714b760006d27b0045c9750018fa540075925d00430091001eef45002e44d100576f6b000959d7003c2dd9000256c1002dab9f00100a47003f47000013bcba006809150030be0d004a4f810064fd8f004bfd61006e5ae3007887cc005d5dd7003e2a6500694dd30022e8d800436edf00663fab001b37500053e7970045aaac000cb8de000036ab007cbef8003f129d0009442b0014798d007af751007ed0f800561813003f28e3002c95cc0076c2e2005cfe7500798b1800601ab7005678e400365190002c1eb2000c5cdd00308dd50004194100031acf0015728d006df4dd001efd5a00743eda00459e3a003ea5ed00173afa002501b7002f2063001064530064ffd50023b67e001d46330033fb6f00343f4600352c63000ba3fb006b745d001d9fd600614e0d005c81a000292f510057eb9f002df95c0067f093001501e3000c5780006111fe004e566c00163041002c178f0019a09b002977170028a931006a47df00652866006733bf00410a12007d9fc20064d6b700656899003cda860027827f0062eabc0054597d00753a