## Base field arithmetic

In [1]:
PRIME = 433

In [2]:
def base_egcd(a, b):
    u, v, s, t, r = 1, 0, 0, 1, 0
    while (a % 2 == 0) and (b % 2 == 0):
        a, b, r = a//2, b//2, r+1
    alpha, beta = a, b
    while (a % 2 == 0):
        a = a//2
        if (u % 2 == 0) and (v % 2 == 0):
            u, v = u//2, v//2
        else:
            u, v = (u + beta)//2, (v - alpha)//2
    while a != b:
        if (b % 2 == 0):
            b = b//2
            if (s % 2 == 0) and (t % 2 == 0):
                s, t = s//2, t//2
            else:
                s, t = (s + beta)//2, (t - alpha)//2
        elif b < a:
            a, b, u, v, s, t = b, a, s, t, u, v
        else:
            b, s, t = b - a, s - u, t - v
    return (2 ** r) * a, s, t

In [3]:
def base_inverse(a):
    _, b, _ = base_egcd(a, PRIME)
    return b

In [4]:
def base_add(a, b):
    return (a + b) % PRIME

In [5]:
def base_sub(a, b):
    return (a - b) % PRIME

In [6]:
def base_mul(a, b):
    return (a * b) % PRIME

In [7]:
def base_div(a, b):
    return base_mul(a, base_inverse(b))

## Polynomial arithmetic

In [8]:
def expand_to_match(A, B):
    diff = len(A) - len(B)
    if diff > 0:
        return A, B + [0] * diff
    elif diff < 0:
        diff = abs(diff)
        return A + [0] * diff, B
    else:
        return A, B

assert( expand_to_match([1,1], [])  == ([1,1], [0,0]) )
assert( expand_to_match([1,1], [1]) == ([1,1], [1,0]) )

In [9]:
def reduce(A):
    for i in range(len(A)-1, -1, -1):
        if A[i] != 0:
            break
    else:
        return []
    return A[:i+1]

assert( reduce([ ]) == [] )
assert( reduce([0]) == [] )
assert( reduce([0,0]) == [] )
assert( reduce([0,1,2]) == [0,1,2] )
assert( reduce([0,1,2,0,0]) == [0,1,2] )

In [10]:
def lc(A):
    B = reduce(A)
    return B[-1]

In [11]:
def deg(A):
    return len(reduce(A)) - 1

assert( deg([ ]) == -1 )
assert( deg([0]) == -1 )
assert( deg([1,0]) == 0 )
assert( deg([0,0,1]) == 2 )

In [12]:
def poly_add(A, B):
    F, G = expand_to_match(A, B)
    return reduce([ base_add(f, g) for f, g in zip(F, G) ])

assert( poly_add([1,2,3], [2,1]) == [3,3,3] )

In [13]:
def poly_sub(A, B):
    F, G = expand_to_match(A, B)
    return reduce([ base_sub(f, g) for f, g in zip(F, G) ])

assert( poly_sub([1,2,3], [1,2]) == [0,0,3] )

In [14]:
def poly_mul(A, B):
    C = [ 0 for _ in range(len(A) + len(B) - 1) ]
    for i in range(len(A)):
        for j in range(len(B)):
            C[i+j] = base_add(C[i+j], base_mul(A[i], B[j]))
    return reduce(C)

In [15]:
def poly_divmod(A, B):
    t = base_inverse(lc(B))
    Q = [0] * len(A)
    R = [a for a in A]
    for i in range(len(A) - len(B), -1, -1):
        Q[i] = base_mul(t, R[i + len(B) - 1])
        for j in range(len(B)):
            R[i+j] = base_sub(R[i+j], base_mul(Q[i], B[j]))
    return reduce(Q), reduce(R)

A = [7,4,5,4]
B = [1,0,1]
Q, R = poly_divmod(A, B)
assert( poly_add(poly_mul(Q, B), R) == A )

In [16]:
def poly_div(A, B):
    Q, _ = poly_divmod(A, B)
    return Q

def poly_mod(A, B):
    _, R = poly_divmod(A, B)
    return R

In [17]:
def scalar_div(A, b):
    return reduce([ base_div(a, b) for a in A ])

In [18]:
def poly_gcd(A, B):
    R0, R1 = A, B
    while R1 != []:
        R2 = poly_mod(R0, R1)
        R0, R1 = R1, R2
    D = scalar_div(R0, lc(R0))
    return D

D = [1,0,0,1]
A = poly_mul(D, [2,0,2])
B = poly_mul(D, [3,1])
assert( poly_gcd(A, B) == D )

In [19]:
def poly_egcd(A, B):
    R0, R1 = A, B
    S0, S1 = [1], []
    T0, T1 = [], [1]
    
    while R1 != []:
        Q, R2 = poly_divmod(R0, R1)
        
        R0, S0, T0, R1, S1, T1 = R1, S1, T1, R2, \
            poly_sub(S0, poly_mul(S1, Q)), \
            poly_sub(T0, poly_mul(T1, Q))
            
    c = lc(R0)
    D = scalar_div(R0, c)
    S = scalar_div(S0, c)
    T = scalar_div(T0, c)
    return D, S, T

F = [1,0,0,1]
G = poly_mul(F, [2,0,2])
H = poly_mul(F, [3,1])

D, S, T = poly_egcd(G, H)
assert( D == poly_gcd(G, H) )
assert( D == poly_add(poly_mul(G, S), poly_mul(H, T)) )

In [20]:
def poly_eea(A, B):
    EEA = []
    R0, R1 = [a for a in A], [b for b in B]
    S0, S1 = [1], []
    T0, T1 = [], [1]
    while R1 != []:
        Q, R2 = poly_divmod(R0, R1)
        
        EEA.append( [R0, S0, T0] )
        
        R0, S0, T0, R1, S1, T1 = R1, S1, T1, R2, \
            poly_sub(S0, poly_mul(S1, Q)), \
            poly_sub(T0, poly_mul(T1, Q))
            
    return EEA

F = [1,0,0,1]
G = poly_mul(F, [2,0,2])
H = poly_mul(F, [3,1])

EEA = poly_eea(G, H)
print(EEA)

[[[2, 0, 2, 2, 0, 2], [1], []], [[3, 1, 0, 3, 1], [], [1]]]


# Faulty reconstruction

In [21]:
ORDER2 = 8
ORDER3 = 9

OMEGA2 = 354
OMEGA3 = 150

In [22]:
def fft2_forward(A_coeffs, omega):
    if len(A_coeffs) == 1:
        return A_coeffs

    # split A into B and C such that A(x) = B(x^2) + x C(x^2)
    B_coeffs = A_coeffs[0::2]
    C_coeffs = A_coeffs[1::2]
    
    # apply recursively
    omega_squared = pow(omega, 2, PRIME)
    B_values = fft2_forward(B_coeffs, omega_squared)
    C_values = fft2_forward(C_coeffs, omega_squared)
        
    # combine subresults
    A_values = [0] * len(A_coeffs)
    L_half = len(A_coeffs) // 2
    for i in range(0, L_half):
        
        j = i
        x = pow(omega, j, PRIME)
        A_values[j] = base_add(B_values[i], base_mul(x, C_values[i]))
        
        j = i + L_half
        x = pow(omega, j, PRIME)
        A_values[j] = base_add(B_values[i], base_mul(x, C_values[i]))
        
    return A_values

def fft2_backward(A_values, omega):
    L_inv = base_inverse(len(A_values))
    A_coeffs = [ base_mul(a, L_inv) for a in fft2_forward(A_values, base_inverse(omega)) ]
    return A_coeffs

In [23]:
def reconstruct_faulty(values, points, h, max_degree, max_error_count):
    assert(len(values) == len(points))
    assert(len(values) >= 2*max_error_count + max_degree)

    # compute f
    f = [1]
    for xi in points:
        fi = [base_sub(0, xi), 1]
        f = poly_mul(f, fi)
    
    # run EGCD to find EEA triples
    eea = poly_eea(f, h)
    for (r,s,t) in eea:
        if deg(r) < max_degree + max_error_count:
            if poly_mod(r, t) == []:
                return poly_div(r,t)

In [24]:
ERROR_COUNT = 2
MAX_DEGREE = 4

g = [1,2,3,4,0,0,0,0]
assert(deg(g) < MAX_DEGREE)

values = fft2_forward(g, OMEGA2)
points = [ pow(OMEGA2, e, PRIME) for e in range(8) ]

for i in range(ERROR_COUNT):
    values[i] = 0

h = fft2_backward(values, OMEGA2)
print("h: %s" % h)

recovered_g = reconstruct_faulty(values, points, h, MAX_DEGREE, ERROR_COUNT)
print(recovered_g)

assert( reduce(recovered_g) == reduce(g) )

h: [203, 338, 425, 255, 12, 311, 225, 396]
[1, 2, 3, 4]
