In [46]:
from scipy.interpolate import lagrange
import random
import matplotlib.pyplot as plt
import numpy as np

In [5]:
class RandPoly:
    """
    Random and zero free coefficient polynomial
    """

    def __init__(self, n, name='', R = None, fzc=False, p = 1337):
        self.name = name
        self.n = n
        self.p = p
        self.fzc = fzc # free zero coefficient
        if R:
            self.R = R
            # assert(len(R) == n)
        else: 
            self.R = [0] * (n+1) 
            for t in range(self.n+1):
                if t == 0 and fzc is True:
                    self.R[t] = (t, 0)
                else:
                    r = random.randint(1,self.p)
                    self.R[t] = (t, r)

    def poly(self, x):
        s = 0
        for (n, r) in self.R:
            s += r * x ** n
        return s

    
    def poly_str(self):
        """
        outputs the underlying polynomial
        """
        s = ""
        first_zero = self.R[0][1] == 0
        if first_zero:
            for (i, r) in self.R:
                if i == 0:
                    continue
                elif i == 1:
                    s +=  f"{r}x" 
                else:
                    s += f"+{r}x^{i}"
        else:
            for (i, r) in self.R:
                if i == 0:
                    s += f"{r}"
                elif i == 1:
                    s +=  f"+{r}x" 
                else:
                    s += f"+{r}x^{i}"
        if self.name:
            return f"{self.name}(x)={s}"
        else:
            return s

    def __str__(self):
        return self.poly_str()

    def __repr__(self):
        return self.poly_str()

    def __add__(self, other):
        n = min(len(self.R), len(other.R))
        m = max(len(self.R), len(other.R))
        R = [0] * m
        for i in range(m):
            if i < n:
                R[i] = (i, self.R[i][1] + other.R[i][1])
            elif i < len(self.R):
                R[i]= (i, self.R[i][1])
            elif i < other.R:
                 R[i][1] = (i, other.R[i][1])
        
        return RandPoly(n=self.n, R=R)
        

In [None]:
def make_graph(data_point, real_poly, predicted_poly, shares: list):
    x = np.linspace(-10, 10, 100)
    plt.plot(x, predicted_poly, color='red')
    plt.plot(x, real_poly, color='blue')
    plt.show()

In [66]:
def get_random_points(n, rand_seed, max_x=100000):
    points = random.Random(rand_seed).sample(range(1, max_x), n)
    return points

In [67]:
def get_shares(x_points, real_poly):
    shares = []
    for x in x_points:
        shares.append(real_poly(x))
    return shares

In [74]:
def compute_private_val(x_points, shares):
    f = lagrange(x_points, shares)
    return f(0)

In [93]:
data_point = 7
# rand_poly = RandPoly(2, R=[(0, data_point), (1, 4), (2, 13)])
rand_poly = RandPoly(2)
print(rand_poly.poly_str())
real_poly = rand_poly.poly

points3_t1 = get_random_points(3, 67)
shares3_t1 = get_shares(points3_t1, real_poly)

points3_t2 = get_random_points(3, 68)
shares3_t2 = get_shares(points3_t2, real_poly)

points3_t3 = get_random_points(3, 69)
shares3_t3 = get_shares(points3_t3, real_poly)

print(points3_t1, shares3_t1, points3_t2, shares3_t2, points3_t3, shares3_t3)

print(compute_private_val(points3_t1, shares3_t1))
print(compute_private_val(points3_t2, shares3_t2))
print(compute_private_val(points3_t3, shares3_t3))

481+3x+277x^2
[9804, 15217, 53528] [26624831125, 64141359785, 793673520233] [97213, 61128, 90387] [2617752053333, 1035047354233, 2263037577655] [89685, 4890, 12562] [2228021854861, 6623666851, 43711702955]
481.00001525878906
480.9921875
480.99999618530273


In [100]:
features_arr = [4,8,2]

all_functions = []
for feature in range(len(features_arr)):
    func = RandPoly(
        name=f"f{feature}",
        n=2,
        R=[
            (i, x)
            for i, x in enumerate(
                list([features_arr[feature], random.randint(2, 250), random.randint(2, 250)])
            )
        ],
    )
    all_functions.append(func)

for f in all_functions:
    print(f.poly_str())


f0(x)=4+108x+94x^2
f1(x)=8+161x+23x^2
f2(x)=2+12x+81x^2
