## Schönhage-Strassen multiplication

Implemenation of a classical version of a quantum algorithm for multiplying 2 integers, described in [this paper](https://ieeexplore.ieee.org/abstract/document/10138719).

*Dmytro Fedoriaka, December 2024.*


In [1]:
import math 
import random

# Input parameters. 
# Must be selected so that n=2^l*M bits are enough to represent the output AB.
M,l = 16, 4

# Compute all the constants.
D = 2**l
n=D*M
assert D <= M
M1 = int(math.ceil((2*M+l+2)/D))
n1 = D*M1
assert 2*M+l+2 <= n1 < 4*M
N=2**n1+1
assert N>=2**(2*M+l+2)+1
#sqrt_g = 2**M1
g = 2**(2*M1)
g_inv = 2**((-2*M1)%(2*n1)) % N
D_inv = 2**(2*n1-l) % N
W = [(2**(k*M1))%N for k in range(D)]
W_inv = [(2**((-k*M1)%(2*n1)))%N for k in range(D)]

# Check that inverses are computed correctly.
assert((g*g_inv)%N==1)
assert((D*D_inv)%N==1)
for i in range(D):
    assert (W[i]*W_inv[i])%N == 1

# Check that g is indeed the Dth principal root of unity of N (Def 1).
for a in range(1,D):
    assert (g**a)%N != 1
assert (g**D)%N == 1
for a in range(1,D):
    assert sum(g**(a*t) for t in range(D))%N == 0
    
# Faster, recursive implementation of G.
def FFT(X, cur_g):
    if len(X)==1:
        return X
    cur_g_squared = (cur_g**2)%N
    FFT_even = FFT(X[0::2], cur_g_squared)
    FFT_odd = FFT(X[1::2], cur_g_squared)
    ans = [0]*len(X)
    for i in range(len(X)//2):
        e = FFT_even[i]
        o = (FFT_odd[i]* (cur_g**i) ) %N
        ans[i] = (e+o)%N
        ans[i+len(X)//2]=(e-o)%N
    return ans
    
# Operator G_{N,g}(X), by defintion.
def G(X):
    assert len(X)==D
    ans = [sum(X[t]*g**(t*m) for t in range(D))%N for m in range(D)]
    assert ans == FFT(X, g)
    return ans

# Operator G^{-1}_{N,g}(X), by defintion.
def G_inv(X):
    assert len(X)==D
    return [D_inv * sum(X[t]*g_inv**(t*m) for t in range(D))%N for m in range(D)]    

# Negative cyclic convolution, by defintion.
def NCC(A, B):
    assert len(A) == len(B) == D
    ans = [0] * D
    for k in range(D):
        for i in range(D):
            j = k-i
            if j>=0:
                ans[k] += A[i]*B[j]
            else:
                ans[k] -= A[i]*B[j+D]
    return [x % N for x in ans]

# Pointwise multiplication.
def mul_pw(A, B):
    assert len(A) == len(B) == D
    return [(A[i]*B[i])%N for i in range(D)]

# Converts integer to a sequence.
def num_to_seq(A):
    return [(A>>(M*i))%(2**M) for i in range(D)]

# Converts sequence to the integer it represents.
def seq_to_num(A):
    assert len(A) == D
    ans = 0
    for i in range(D):
        ans += A[i]<<(M*i)
    return ans
    
def validate(A, B):
    product = A*B
    assert product <= 2**n   # "so we will pick a large enough n such that AB < 2^n + 1."

    A_seq = num_to_seq(A)
    B_seq = num_to_seq(B)
    conv = NCC(A_seq, B_seq)

    # Verify equation (2).
    assert conv == mul_pw(W_inv, G_inv(mul_pw(G(mul_pw(W, A_seq)), G(mul_pw(W, B_seq)))))

    # Verify equation (3).
    assert(A*B == seq_to_num(conv))

for _ in range(1000):
    A = random.randint(0, 2**(n//2)-1)
    B = random.randint(0, 2**(n//2)-1)
    validate(A,B)
print("OK")

OK


In [3]:
n1,M1

(48, 3)