# Base field arithmetic

In [1]:
PRIME = 433

In [2]:
def base_egcd(a, b):
    r0, r1 = a, b
    s0, s1 = 1, 0
    t0, t1 = 0, 1
    
    while r1 != 0:
        q, r2 = divmod(r0, r1)
        r0, s0, t0, r1, s1, t1 = \
            r1, s1, t1, \
            r2, s0 - s1*q, t0 - t1*q

    d = r0
    s = s0
    t = t0
    return d, s, t

In [3]:
def base_inverse(a):
    _, b, _ = base_egcd(a, PRIME)
    return b if b >= 0 else b+PRIME

In [4]:
def base_add(a, b):
    return (a + b) % PRIME

In [5]:
def base_sub(a, b):
    return (a - b) % PRIME

In [6]:
def base_mul(a, b):
    return (a * b) % PRIME

In [7]:
def base_div(a, b):
    return base_mul(a, base_inverse(b))

# Polynomial arithmetic

In [8]:
from copy import copy

In [9]:
def expand_to_match(A, B):
    diff = len(A) - len(B)
    if diff > 0:
        return A, B + [0] * diff
    elif diff < 0:
        diff = abs(diff)
        return A + [0] * diff, B
    else:
        return A, B

assert( expand_to_match([1,1], [])  == ([1,1], [0,0]) )
assert( expand_to_match([1,1], [1]) == ([1,1], [1,0]) )

In [10]:
def reduce(A):
    for i in reversed(range(len(A))):
        if A[i] != 0:
            return A[:i+1]
    return []

assert( reduce([ ]) == [] )
assert( reduce([0]) == [] )
assert( reduce([0,0]) == [] )
assert( reduce([0,1,2]) == [0,1,2] )
assert( reduce([0,1,2,0,0]) == [0,1,2] )

In [11]:
def lc(A):
    B = reduce(A)
    return B[-1]

assert( lc([0,1,2,0]) == 2 )

In [12]:
def deg(A):
    return len(reduce(A)) - 1

assert( deg([ ]) == -1 )
assert( deg([0]) == -1 )
assert( deg([1,0]) == 0 )
assert( deg([0,0,1]) == 2 )

In [13]:
def poly_add(A, B):
    F, G = expand_to_match(A, B)
    return reduce([ base_add(f, g) for f, g in zip(F, G) ])

assert( poly_add([1,2,3], [2,1]) == [3,3,3] )

In [14]:
def poly_sub(A, B):
    F, G = expand_to_match(A, B)
    return reduce([ base_sub(f, g) for f, g in zip(F, G) ])

assert( poly_sub([1,2,3], [1,2]) == [0,0,3] )

In [15]:
def poly_mul(A, B):
    C = [0] * (len(A) + len(B) - 1)
    for i in range(len(A)):
        for j in range(len(B)):
            C[i+j] = base_add(C[i+j], base_mul(A[i], B[j]))
    return reduce(C)

In [16]:
def poly_divmod(A, B):
    t = base_inverse(lc(B))
    Q = [0] * len(A)
    R = copy(A)
    for i in range(len(A) - len(B), -1, -1):
        Q[i] = base_mul(t, R[i + len(B) - 1])
        for j in range(len(B)):
            R[i+j] = base_sub(R[i+j], base_mul(Q[i], B[j]))
    return reduce(Q), reduce(R)

A = [7,4,5,4]
B = [1,0,1]
Q, R = poly_divmod(A, B)
assert( poly_add(poly_mul(Q, B), R) == A )

In [17]:
def poly_div(A, B):
    Q, _ = poly_divmod(A, B)
    return Q

def poly_mod(A, B):
    _, R = poly_divmod(A, B)
    return R

In [18]:
def poly_scalarmul(A, b):
    return reduce([ base_mul(a, b) for a in A ])

def poly_scalardiv(A, b):
    return reduce([ base_div(a, b) for a in A ])

In [19]:
def poly_gcd(A, B):
    R0, R1 = A, B
    while R1 != []:
        R2 = poly_mod(R0, R1)
        R0, R1 = R1, R2
    D = poly_scalardiv(R0, lc(R0))
    return D

D = [1,0,0,1]
A = poly_mul(D, [2,0,2])
B = poly_mul(D, [3,1])
assert( poly_gcd(A, B) == D )

In [20]:
def poly_egcd(A, B):
    R0, R1 = A, B
    S0, S1 = [1], []
    T0, T1 = [], [1]
    
    while R1 != []:
        Q, R2 = poly_divmod(R0, R1)
        
        R0, S0, T0, R1, S1, T1 = \
            R1, S1, T1, \
            R2, poly_sub(S0, poly_mul(S1, Q)), poly_sub(T0, poly_mul(T1, Q))
            
    c = lc(R0)
    D = poly_scalardiv(R0, c)
    S = poly_scalardiv(S0, c)
    T = poly_scalardiv(T0, c)
    return D, S, T

F = [1,0,0,1]
G = poly_mul(F, [2,0,2])
H = poly_mul(F, [3,1])

D, S, T = poly_egcd(G, H)
assert( D == poly_gcd(G, H) )
assert( D == poly_add(poly_mul(G, S), poly_mul(H, T)) )

In [21]:
def poly_eval(A, x):
    result = 0
    for coef in reversed(A):
        result = base_add(coef, base_mul(x, result))
    return result

# Polynomial interpolation

In [22]:
def lagrange_polynomials(xs):
    polys = []
    for i, xi in enumerate(xs):
        poly = [1]
        for j, xj in enumerate(xs):
            if i == j: continue
            factor = [base_sub(0, xj), 1]
            poly = poly_mul(poly, factor)
        polys.append(poly)
    return polys

In [23]:
def lagrange_divisors(xs):
    divisors = []
    for i, xi in enumerate(xs):
        divisor = 1
        for j, xj in enumerate(xs):
            if i == j: continue
            factor = base_sub(xi, xj)
            divisor = base_mul(divisor, factor)
        divisors.append(base_inverse(divisor))
    return divisors

In [24]:
def lagrange_interpolation(xs, ys):
    ls = lagrange_polynomials(xs)
    ds = lagrange_divisors(xs)
    poly = []
    for i in range(len(ys)):
        term = poly_scalarmul(ls[i], base_mul(ys[i], ds[i]))
        poly = poly_add(poly, term)
    return poly

In [25]:
f = [1,2,3]

xs = [10,20,30,40]
ys = [ poly_eval(f, x) for x in xs ]

g = lagrange_interpolation(xs, ys)
assert( g == f )

# Faulty reconstruction

In [26]:
def gao_decoding(h, points, max_degree, max_error_count):
    assert(len(points) >= 2*max_error_count + max_degree)
    
    # compute f
    f = [1]
    for xi in points:
        fi = [base_sub(0, xi), 1]
        f = poly_mul(f, fi)
    
    # run EGCD-like algorithm on (f,h) to find EEA triples
    R0, R1 = f, h
    S0, S1 = [1], []
    T0, T1 = [], [1]
    while True:
        Q, R2 = poly_divmod(R0, R1)
        
        if deg(R0) < max_degree + max_error_count:
            if poly_mod(R0, T0) == []:
                return poly_div(R0, T0)
            else:
                return None
        
        R0, S0, T0, R1, S1, T1 = \
            R1, S1, T1, \
            R2, poly_sub(S0, poly_mul(S1, Q)), poly_sub(T0, poly_mul(T1, Q))

In [27]:
def reconstruct_faulty(points, values, max_degree, max_error_count):
    assert(len(values) == len(points))
    
    h = lagrange_interpolation(points, values)
    g = gao_decoding(h, points, max_degree, max_error_count)
    
    return g

In [28]:
OMEGA2, ORDER2 = 354, 8

ERROR_COUNT = 2
MAX_DEGREE = 4

g = [1,2,3,4]
assert(deg(g) < MAX_DEGREE)

points = [ pow(OMEGA2, e, PRIME) for e in range(8) ]
values = [ poly_eval(g, point) for point in points ]

for i in range(ERROR_COUNT):
    values[i] = 0

recovered_g = reconstruct_faulty(points, values, MAX_DEGREE, ERROR_COUNT)
print(recovered_g)

assert( recovered_g == g )

[1, 2, 3, 4]
