In [1]:
import time
import random
from math import ceil
from decimal import Decimal
from typing import List, Tuple

import random
from sympy import Symbol, Poly


# Use a large prime as the modulus for arithmetic operations
prime: int = 15485867
threshold: int = 5
secret1: int = 1234
secret2: int = 4321

In [2]:
# Client-side
# All parties use the same prime and same threshold
def sample_shamir_polynomial(zero_value: int) -> List[int]:
    coefs = [zero_value] + [random.randrange(prime) for _ in range(threshold - 1)]
    return coefs

In [3]:
# Client-side
def evaluate_at_point(coefs: List[int], point: int) -> int:
    result = 0
    for coef in reversed(coefs):
        result = (coef + point * result) % prime
    return result

In [4]:
# Server-side
def interpolate_at_point(points_values: List[Tuple[int, int]], query_x_axis: int) -> int:
    x_vals, y_vals = zip(*points_values)
    constants = lagrange_constants_for_point(x_vals, query_x_axis)
    return sum(ci * vi for ci, vi in zip(constants, y_vals)) % prime

In [5]:
# Server-side
def lagrange_constants_for_point(points: List[int], query_x_axis: int) -> List[int]:
    constants = [0] * len(points)
    for i in range(len(points)):
        xi = points[i]
        num = 1
        denum = 1
        for j in range(len(points)):
            if j != i:
                xj = points[j]
                num = (num * (xj - query_x_axis)) % prime
                denum = (denum * (xj - xi)) % prime
        constants[i] = (num * pow(denum, -1, prime)) % prime
    return constants

In [6]:
# Client-side
def shamir_share(secret: int, num_shares: int) -> List[Tuple[int, int]]:
    polynomial = sample_shamir_polynomial(secret)
    print(polynomial)
    shares = [(i, evaluate_at_point(polynomial, i)) for i in range(1, num_shares + 1)]
    return shares

In [7]:
shares1 = shamir_share(secret1, 10)
shares2 = shamir_share(secret2, 10)

[1234, 13931080, 841122, 7763442, 12711893]
[4321, 5440030, 13254229, 6191438, 7532626]


In [8]:
# Server-side
def shamir_add(x, y):
    return [ (i+1, (xi[1] + yi[1]) % prime) for i, (xi, yi) in enumerate(list(zip(x, y))) ]

In [9]:
added = shamir_add(shares1, shares2)

In [10]:
# Server-side
def shamir_reconstruct(shares: List[Tuple[int, int]], query_x_axis: int = 0) -> int:
    polynomial = [(p, v) for p, v in shares]
    secret = interpolate_at_point(polynomial, query_x_axis)
    return secret

In [11]:
print('points:')
added

points:


[(1, 5727947),
 (2, 4161045),
 (3, 2571729),
 (4, 14033458),
 (5, 10968669),
 (6, 2578112),
 (7, 3869116),
 (8, 4683855),
 (9, 671082),
 (10, 8771996)]

In [12]:
for i in range(len(added)+1):
    print(f'x: {i}, y: {shamir_reconstruct(added, i)}')

x: 0, y: 5555
x: 1, y: 5727947
x: 2, y: 4161045
x: 3, y: 2571729
x: 4, y: 14033458
x: 5, y: 10968669
x: 6, y: 2578112
x: 7, y: 3869116
x: 8, y: 4683855
x: 9, y: 671082
x: 10, y: 8771996


In [13]:
print(f'secret: {shamir_reconstruct(added, 0)}')

secret: 5555
