Skip to content
Permalink
Branch: master
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
198 lines (177 sloc) 7.78 KB
# Creates an object that includes convenience operations for numbers
# and polynomials in some prime field
class PrimeField():
def __init__(self, modulus):
assert pow(2, modulus, modulus) == 2
self.modulus = modulus
def add(self, x, y):
return (x+y) % self.modulus
def sub(self, x, y):
return (x-y) % self.modulus
def mul(self, x, y):
return (x*y) % self.modulus
def exp(self, x, p):
return pow(x, p, self.modulus)
# Modular inverse using the extended Euclidean algorithm
def inv(self, a):
if a == 0:
return 0
lm, hm = 1, 0
low, high = a % self.modulus, self.modulus
while low > 1:
r = high//low
nm, new = hm-lm*r, high-low*r
lm, low, hm, high = nm, new, lm, low
return lm % self.modulus
def multi_inv(self, values):
partials = [1]
for i in range(len(values)):
partials.append(self.mul(partials[-1], values[i] or 1))
inv = self.inv(partials[-1])
outputs = [0] * len(values)
for i in range(len(values), 0, -1):
outputs[i-1] = self.mul(partials[i-1], inv) if values[i-1] else 0
inv = self.mul(inv, values[i-1] or 1)
return outputs
def div(self, x, y):
return self.mul(x, self.inv(y))
# Evaluate a polynomial at a point
def eval_poly_at(self, p, x):
y = 0
power_of_x = 1
for i, p_coeff in enumerate(p):
y += power_of_x * p_coeff
power_of_x = (power_of_x * x) % self.modulus
return y % self.modulus
# Arithmetic for polynomials
def add_polys(self, a, b):
return [((a[i] if i < len(a) else 0) + (b[i] if i < len(b) else 0))
% self.modulus for i in range(max(len(a), len(b)))]
def sub_polys(self, a, b):
return [((a[i] if i < len(a) else 0) - (b[i] if i < len(b) else 0))
% self.modulus for i in range(max(len(a), len(b)))]
def mul_by_const(self, a, c):
return [(x*c) % self.modulus for x in a]
def mul_polys(self, a, b):
o = [0] * (len(a) + len(b) - 1)
for i, aval in enumerate(a):
for j, bval in enumerate(b):
o[i+j] += a[i] * b[j]
return [x % self.modulus for x in o]
def div_polys(self, a, b):
assert len(a) >= len(b)
a = [x for x in a]
o = []
apos = len(a) - 1
bpos = len(b) - 1
diff = apos - bpos
while diff >= 0:
quot = self.div(a[apos], b[bpos])
o.insert(0, quot)
for i in range(bpos, -1, -1):
a[diff+i] -= b[i] * quot
apos -= 1
diff -= 1
return [x % self.modulus for x in o]
# Build a polynomial that returns 0 at all specified xs
def zpoly(self, xs):
root = [1]
for x in xs:
root.insert(0, 0)
for j in range(len(root)-1):
root[j] -= root[j+1] * x
return [x % self.modulus for x in root]
# Given p+1 y values and x values with no errors, recovers the original
# p+1 degree polynomial.
# Lagrange interpolation works roughly in the following way.
# 1. Suppose you have a set of points, eg. x = [1, 2, 3], y = [2, 5, 10]
# 2. For each x, generate a polynomial which equals its corresponding
# y coordinate at that point and 0 at all other points provided.
# 3. Add these polynomials together.
def lagrange_interp(self, xs, ys):
# Generate master numerator polynomial, eg. (x - x1) * (x - x2) * ... * (x - xn)
root = self.zpoly(xs)
assert len(root) == len(ys) + 1
# print(root)
# Generate per-value numerator polynomials, eg. for x=x2,
# (x - x1) * (x - x3) * ... * (x - xn), by dividing the master
# polynomial back by each x coordinate
nums = [self.div_polys(root, [-x, 1]) for x in xs]
# Generate denominators by evaluating numerator polys at each x
denoms = [self.eval_poly_at(nums[i], xs[i]) for i in range(len(xs))]
invdenoms = self.multi_inv(denoms)
# Generate output polynomial, which is the sum of the per-value numerator
# polynomials rescaled to have the right y values
b = [0 for y in ys]
for i in range(len(xs)):
yslice = self.mul(ys[i], invdenoms[i])
for j in range(len(ys)):
if nums[i][j] and ys[i]:
b[j] += nums[i][j] * yslice
return [x % self.modulus for x in b]
# Optimized poly evaluation for degree 4
def eval_quartic(self, p, x):
xsq = x * x % self.modulus
xcb = xsq * x
return (p[0] + p[1] * x + p[2] * xsq + p[3] * xcb) % self.modulus
# Optimized version of the above restricted to deg-4 polynomials
def lagrange_interp_4(self, xs, ys):
x01, x02, x03, x12, x13, x23 = \
xs[0] * xs[1], xs[0] * xs[2], xs[0] * xs[3], xs[1] * xs[2], xs[1] * xs[3], xs[2] * xs[3]
m = self.modulus
eq0 = [-x12 * xs[3] % m, (x12 + x13 + x23), -xs[1]-xs[2]-xs[3], 1]
eq1 = [-x02 * xs[3] % m, (x02 + x03 + x23), -xs[0]-xs[2]-xs[3], 1]
eq2 = [-x01 * xs[3] % m, (x01 + x03 + x13), -xs[0]-xs[1]-xs[3], 1]
eq3 = [-x01 * xs[2] % m, (x01 + x02 + x12), -xs[0]-xs[1]-xs[2], 1]
e0 = self.eval_poly_at(eq0, xs[0])
e1 = self.eval_poly_at(eq1, xs[1])
e2 = self.eval_poly_at(eq2, xs[2])
e3 = self.eval_poly_at(eq3, xs[3])
e01 = e0 * e1
e23 = e2 * e3
invall = self.inv(e01 * e23)
inv_y0 = ys[0] * invall * e1 * e23 % m
inv_y1 = ys[1] * invall * e0 * e23 % m
inv_y2 = ys[2] * invall * e01 * e3 % m
inv_y3 = ys[3] * invall * e01 * e2 % m
return [(eq0[i] * inv_y0 + eq1[i] * inv_y1 + eq2[i] * inv_y2 + eq3[i] * inv_y3) % m for i in range(4)]
# Optimized version of the above restricted to deg-2 polynomials
def lagrange_interp_2(self, xs, ys):
m = self.modulus
eq0 = [-xs[1] % m, 1]
eq1 = [-xs[0] % m, 1]
e0 = self.eval_poly_at(eq0, xs[0])
e1 = self.eval_poly_at(eq1, xs[1])
invall = self.inv(e0 * e1)
inv_y0 = ys[0] * invall * e1
inv_y1 = ys[1] * invall * e0
return [(eq0[i] * inv_y0 + eq1[i] * inv_y1) % m for i in range(2)]
# Optimized version of the above restricted to deg-4 polynomials
def multi_interp_4(self, xsets, ysets):
data = []
invtargets = []
for xs, ys in zip(xsets, ysets):
x01, x02, x03, x12, x13, x23 = \
xs[0] * xs[1], xs[0] * xs[2], xs[0] * xs[3], xs[1] * xs[2], xs[1] * xs[3], xs[2] * xs[3]
m = self.modulus
eq0 = [-x12 * xs[3] % m, (x12 + x13 + x23), -xs[1]-xs[2]-xs[3], 1]
eq1 = [-x02 * xs[3] % m, (x02 + x03 + x23), -xs[0]-xs[2]-xs[3], 1]
eq2 = [-x01 * xs[3] % m, (x01 + x03 + x13), -xs[0]-xs[1]-xs[3], 1]
eq3 = [-x01 * xs[2] % m, (x01 + x02 + x12), -xs[0]-xs[1]-xs[2], 1]
e0 = self.eval_quartic(eq0, xs[0])
e1 = self.eval_quartic(eq1, xs[1])
e2 = self.eval_quartic(eq2, xs[2])
e3 = self.eval_quartic(eq3, xs[3])
data.append([ys, eq0, eq1, eq2, eq3])
invtargets.extend([e0, e1, e2, e3])
invalls = self.multi_inv(invtargets)
o = []
for (i, (ys, eq0, eq1, eq2, eq3)) in enumerate(data):
invallz = invalls[i*4:i*4+4]
inv_y0 = ys[0] * invallz[0] % m
inv_y1 = ys[1] * invallz[1] % m
inv_y2 = ys[2] * invallz[2] % m
inv_y3 = ys[3] * invallz[3] % m
o.append([(eq0[i] * inv_y0 + eq1[i] * inv_y1 + eq2[i] * inv_y2 + eq3[i] * inv_y3) % m for i in range(4)])
# assert o == [self.lagrange_interp_4(xs, ys) for xs, ys in zip(xsets, ysets)]
return o
You can’t perform that action at this time.
You signed in with another tab or window. Reload to refresh your session. You signed out in another tab or window. Reload to refresh your session.