# CS295/395: Secure Distributed Computation
## In-Class Exercise, 9/18/2020

In [None]:
# Imports and definitions
import numpy as np
from collections import defaultdict
_PRIME = 2 ** 7 - 1

def plusFE(p, a, b):
    """Add field elements a and b in GF(p)"""
    return (a + b) % p
    
def multFE(p, a, b):
    """Multiply field elements a and b in GF(p)"""
    return (a * b) % p

In [None]:
# This code is adapted from:
# https://en.wikipedia.org/wiki/Shamir%27s_Secret_Sharing

import random
import functools

# 12th Mersenne Prime
# (for this application we want a known prime number as close as
# possible to our security level; e.g.  desired security level of 128
# bits -- too large and all the ciphertext is large; too small and
# security is compromised)
_PRIME = 2 ** 13 - 1
# 13th Mersenne Prime is 2**521 - 1

_RINT = functools.partial(random.SystemRandom().randint, 0)

def _eval_at(poly, x, prime):
    """Evaluates polynomial (coefficient tuple) at x, used to generate a
    shamir pool in make_random_shares below.
    """
    accum = 0
    for coeff in reversed(poly):
        accum *= x
        accum += coeff
        accum %= prime
    return accum

def make_random_shares(minimum, shares, prime=_PRIME):
    """
    Generates a random shamir pool, returns the secret and the share
    points.
    """
    if minimum > shares:
        raise ValueError("Pool secret would be irrecoverable.")
    poly = [_RINT(prime - 1) for i in range(minimum)]
    points = [(i, _eval_at(poly, i, prime))
              for i in range(1, shares + 1)]
    return poly[0], points

def _extended_gcd(a, b):
    """
    Division in integers modulus p means finding the inverse of the
    denominator modulo p and then multiplying the numerator by this
    inverse (Note: inverse of A is B such that A*B % p == 1) this can
    be computed via extended Euclidean algorithm
    http://en.wikipedia.org/wiki/Modular_multiplicative_inverse#Computation
    """
    x = 0
    last_x = 1
    y = 1
    last_y = 0
    while b != 0:
        quot = a // b
        a, b = b, a % b
        x, last_x = last_x - quot * x, x
        y, last_y = last_y - quot * y, y
    return last_x, last_y

def _divmod(num, den, p):
    """Compute num / den modulo prime p

    To explain what this means, the return value will be such that
    the following is true: den * _divmod(num, den, p) % p == num
    """
    inv, _ = _extended_gcd(den, p)
    return num * inv

def _lagrange_interpolate(x, x_s, y_s, p):
    """
    Find the y-value for the given x, given n (x, y) points;
    k points will define a polynomial of up to kth order.
    """
    k = len(x_s)
    assert k == len(set(x_s)), "points must be distinct"
    def PI(vals):  # upper-case PI -- product of inputs
        accum = 1
        for v in vals:
            accum *= v
        return accum
    nums = []  # avoid inexact division
    dens = []
    for i in range(k):
        others = list(x_s)
        cur = others.pop(i)
        nums.append(PI(x - o for o in others))
        dens.append(PI(cur - o for o in others))
    den = PI(dens)
    num = sum([_divmod(nums[i] * den * y_s[i] % p, dens[i], p)
               for i in range(k)])
    return (_divmod(num, den, p) + p) % p

def recover_secret(shares, prime=_PRIME):
    """
    Recover the secret from share points
    (x, y points on the polynomial).
    """
    if len(shares) < 2:
        raise ValueError("need at least two shares")
    x_s, y_s = zip(*shares)
    return _lagrange_interpolate(0, x_s, y_s, prime)


secret, shares = make_random_shares(minimum=3, shares=6)

print('Secret:                                                     ',
      secret)
print('Shares:')
if shares:
    for share in shares:
        print('  ', share)

print('Secret recovered from minimum subset of shares:             ',
      recover_secret(shares[:3]))
print('Secret recovered from a different minimum subset of shares: ',
      recover_secret(shares[-3:]))

## Question 1

Write a function `share_input` to share out a specific secret input value. Base your code on the definition of `make_random_shares` above.

In [None]:
def share_input(inp, minimum, shares, prime=_PRIME):
    # YOUR CODE HERE
    raise NotImplementedError()

shares = share_input(5, minimum = 3, shares = 5)
print('Shares:', shares)
print('Recovered secret with all shares:', recover_secret(shares))
print('Recovered secret with 3 shares:', recover_secret(shares[:3]))
print('Recovered secret with 2 shares:', recover_secret(shares[:2]))

In [None]:
# TEST CASE for question 1
assert recover_secret(share_input(5, minimum=3, shares=6)) == 5

## Question 2

Given the two sets of shares `shares1` and `shares2` below, write a function whose output is their sum (as a set of shares).

In [None]:
shares1 = share_input(minimum=3, shares=6, inp=20)
shares2 = share_input(minimum=3, shares=6, inp=10)

def add_shares(shares1, shares2):
    # YOUR CODE HERE
    raise NotImplementedError()

added_shares = add_shares(shares1, shares2)
print(added_shares)
recover_secret(added_shares[:3])

In [None]:
# TEST CASE for question 2
assert recover_secret(add_shares(shares1, shares2)[:3]) == 30

## Question 3

Given the two sets of shares `shares1` and `shares2` below, write a function whose output is their **product** (as a set of shares).

In [None]:
shares1 = share_input(minimum=3, shares=6, inp=20)
shares2 = share_input(minimum=3, shares=6, inp=10)

def mult_shares(shares1, shares2):
    # YOUR CODE HERE
    raise NotImplementedError()

prod_shares = mult_shares(shares1, shares2)
print('True answer:', 20 * 10)
print('Shares:', prod_shares)

print('Recovered result, 3 shares:', recover_secret(prod_shares[:3]))
print('Recovered result, 4 shares:', recover_secret(prod_shares[:4]))
print('Recovered result, 5 shares:', recover_secret(prod_shares[:5]))
print('Recovered result, 6 shares:', recover_secret(prod_shares))

In [None]:
# TEST CASE for question 3
assert recover_secret(mult_shares(shares1, shares2)) == 200