# CS 3990/5990: Secure Distributed Computation
## In-Class Exercise, week of 9/18/2023

In [None]:
# Imports and definitions
import numpy as np
from collections import defaultdict
import numpy as np
import galois
GF = galois.GF(2 ** 13 - 1)

In [None]:
class Party:
    """A participant in a multiparty computation protocol."""
    def __init__(self):
        """Initialize the field size and dictionary to hold received messages."""
        self.input = None
        self.output = None
        self.received = defaultdict(list)
    
    def send(self, other, round, msg):
        """Simulate sending a message `msg` to another party `other` during round `round`"""
        other.received[round].append(msg)

    def get_view(self):
        """Returns the view of this party: its input, output, and received messages."""
        return (self.input, self.output, dict(self.received))

# Generate Shamir shares for secret v with threshold t and number of shares n
def shamir_share(v, t, n):
    coefficients = GF([GF.Random() for _ in range(t-1)] + [v])
    poly = galois.Poly(coefficients)
    shares = [(GF(x), poly(GF(x))) for x in range(1, n+1)]
    return shares

# Reconstruct the secret from at least t Shamir shares
def reconstruct(shares):
    xs = GF([s[0] for s in shares])
    ys = GF([s[1] for s in shares])
    poly = galois.lagrange_poly(xs, ys)
    #print(poly)
    secret = poly(0)
    
    return secret

## Question 1

Write code to generate shares of a secret $x$ in a $(t, n)$-secret sharing scheme using Shamir's technique, where $n = 5$ and $t = 2$.

In [None]:
def shr_2_5(v):
    # YOUR CODE HERE
    raise NotImplementedError()

In [None]:
# Example for question 1

shr_2_5(GF(5))

In [None]:
# TEST CASE for question 1

assert len(shr_2_5(GF(5))) == 5

## Question 2

Write a function to reconstruct the secret, using only two shares.

In [None]:
def reconstruct(s1, s2):
    # YOUR CODE HERE
    raise NotImplementedError()

shares = shr_2_5(GF(5))
reconstruct(shares[0], shares[1])

In [None]:
# TEST CASE
shares = shr_2_5(GF(5))
assert reconstruct(shares[0], shares[1]) == GF(5)

## Question 3

Why is a threshold secret sharing scheme more useful than the simpler additive secret sharing scheme we saw earlier?

YOUR ANSWER HERE

## Question 4

Write code to generate shares of a secret $x$ in a $(t, n)$-secret sharing scheme using Shamir's technique, for any $t$ and $n$.

In [None]:
def shamir_share(v, t, n):
    # YOUR CODE HERE
    raise NotImplementedError()

shamir_share(GF(5), 3, 6)

In [None]:
# Example for question 1

shr_2_5(GF(5))

In [None]:
# TEST CASE
assert len(shamir_share(GF(5), 3, 6)) == 6
shares = shamir_share(GF(5), 2, 6)
assert reconstruct(shares[0], shares[1]) == GF(5)

## Question 5

Given the two sets of shares `shares1` and `shares2` below, write a function whose output is their sum (as a set of shares).

In [None]:
shares1 = shamir_share(GF(20), 2, 6)
shares2 = shamir_share(GF(5), 2, 6)

def add_shares(shares1, shares2):
    # YOUR CODE HERE
    raise NotImplementedError()

added_shares = add_shares(shares1, shares2)
print(added_shares)
reconstruct(added_shares[0], added_shares[1])

In [None]:
# TEST CASE
added_shares = add_shares(shares1, shares2)
assert reconstruct(added_shares[0], added_shares[2]) == GF(25)

## Question 6

Write a function to reconstruct a secret from a set of at least $t$ shares. Use the `galois.lagrange_poly` function, which implements [Lagrange interpolation](https://en.wikipedia.org/wiki/Lagrange_polynomial).

In [None]:
def reconstruct(shares):
    # YOUR CODE HERE
    raise NotImplementedError()

reconstruct(added_shares)

In [None]:
# TEST CASE
shares = shamir_share(GF(30), 5, 10)
assert reconstruct(shares) == GF(30)
assert reconstruct(shares[:5]) == GF(30)  # t shares are sufficient
assert reconstruct(shares[:4]) != GF(30)  # t - 1 shares are not sufficient

## Question 7

Given the two sets of shares `shares1` and `shares2` below, write a function whose output is their product (as a set of shares).

In [None]:
shares1 = shamir_share(GF(20), 3, 6)
shares2 = shamir_share(GF(3), 3, 6)

def mult_shares(shares1, shares2):
    # YOUR CODE HERE
    raise NotImplementedError()

product_shares = mult_shares(shares1, shares2)
print(product_shares)
reconstruct(shares1)
reconstruct(product_shares)

In [None]:
# TEST CASE
product_shares = mult_shares(shares1, shares2)

assert reconstruct(product_shares) == GF(60)
assert reconstruct(product_shares[:4]) != GF(60)  # t shares are no longer sufficient

## Question 8

Describe a protocol to multiply two input numbers. The input numbers will be secret-shared according to a $(t,n)$ Shamir secret sharing scheme before the protocol starts, and each party will receive one share of both numbers. Each party should output *one share of the product*, using a $(t, n)$ Shamir secret sharing scheme (i.e. the threshold for the output should be the same as the threshold for the input).

\begin{equation*}
\textbf{Functionality: Multiply Two Numbers}\\
\fbox{$\mathcal{F}(a, b) = a \cdot b$}
\end{equation*}




YOUR ANSWER HERE

## Question 9

Implement your protocol from the last question.

In [None]:
class MultTwoParty(Party):
    def round1(self, parties, a_shr, b_shr, t):
        self.input = (a_shr, b_shr)
        self.parties = parties
        n = len(parties)
        assert t <= n/2
        
        # YOUR CODE HERE
        raise NotImplementedError()

    def round2(self):
        n = len(self.parties)
        
        # YOUR CODE HERE
        raise NotImplementedError()

In [None]:
NUM_PARTIES = 6
# (t, n)-Shamir scheme
n = NUM_PARTIES
t = 3

shares1 = shamir_share(5, t, n)
shares2 = shamir_share(6, t, n)

parties = [MultTwoParty() for _ in range(NUM_PARTIES)]

for p,s1,s2 in zip(parties, shares1, shares2):
    p.round1(parties, s1, s2, t)
for p in parties:
    p.round2()
for p in parties:
    print(p.get_view())

output_shares = [p.output for p in parties]
print('Reconstruction, with all shares:', reconstruct(output_shares))
print('Reconstruction, with 3 shares:', reconstruct(output_shares[:3]))
print('Reconstruction, with 2 shares:', reconstruct(output_shares[:2]))

assert reconstruct(output_shares) == 30
assert reconstruct(output_shares[:3]) == 30
assert reconstruct(output_shares[:2]) != 30

# Circuits

- Overview of the BGW protocol: [Pragmatic MPC, Section 3.3](https://securecomputation.org/docs/pragmaticmpc.pdf)
- Vandermonde Matrices for polynomial evaluation: [Asharov & Lindell, 2011, Section 3.3, Definition 3.6](https://eprint.iacr.org/2011/136.pdf)
- Formal protocol description (GRR protocol): [Lindell & Nof, 2017, Appendix B.3 (Protocol B.3)](https://eprint.iacr.org/2017/816.pdf)

In [None]:
# Library for binary circuits
from collections import namedtuple
AddGate = namedtuple('AddGate', ['in1', 'in2'])
MultGate = namedtuple('MultGate', ['in1', 'in2'])
#Gate = namedtuple('Gate', ['type', 'in1', 'in2', 'out'])

from dataclasses import dataclass

@dataclass
class Gate:
    type: str
    in1: int
    in2: int
    out: int

@dataclass
class Circuit:
    inputs: any
    outputs: any
    gates: any
        
def print_circuit(c):
    print('inputs:', c.inputs)
    print('outputs:', c.outputs)
    print('gates:')
    for g in c.gates:
        print('  ', g)

## Question 10

Write a function `sum_circuit` that builds an arithmetic circuit for summing up a set of `n` inputs.

In [None]:
def sum_circuit(n):
    # YOUR CODE HERE
    raise NotImplementedError()

In [None]:
print_circuit(sum_circuit(4))

In [None]:
# TEST CASE

assert sum_circuit(2) == \
  Circuit(inputs=[0, 1], outputs=[3], gates=[Gate(type='ADD', in1=0, in2=1, out=3)])

In [None]:
import urllib.request
adder_url = "https://homes.esat.kuleuven.be/~nsmart/MPC/adder64.txt"
adder_txt = urllib.request.urlopen(adder_url).read().decode("utf-8")

In [None]:
# Parse a circuit from a Bristol-Fashion specification
def parse_circuit(bristol_fashion_text):
    lines = [l.strip() for l in bristol_fashion_text.split('\n') if l != '']
    total_wires = int(lines[0].split(' ')[1])
    inputs = lines[1]
    outputs = lines[2]
    gates_txt = lines[3:]
    gates = []
    
    # parse the gates
    for g_txt in gates_txt:
        sp = g_txt.split(' ')
        gate_type = sp[-1]
        if gate_type in ['XOR', 'AND']:
            _, _, in1, in2, out, typ = g_txt.split(' ')
        elif gate_type == 'INV':
            _, _, in1, out, typ = g_txt.split(' ')
            in2 = -1
        else:
            raise RuntimeError('unknown gate type:', gate_type)
        gates.append(Gate(typ, int(in1), int(in2), int(out)))
    
    ins = inputs.split(' ')
    num_inputs = int(ins[0])
    
    # generate the bundles of input wires
    w = 0
    input_bundle_sizes = [int(x) for x in inputs.split(' ')[1:]]
    inputs = []
    for bundle_size in ins[1:]:
        inputs.append(list(range(w, w+int(bundle_size))))
        w += int(bundle_size)

    # generate the bundles of output wires
    output_bundle_sizes = [int(x) for x in outputs.split(' ')[1:]]
    total_output_wires = sum(output_bundle_sizes)
    w = total_wires - total_output_wires
    outputs = []
    for bundle_size in output_bundle_sizes:
        outputs.append(list(range(w, w+int(bundle_size))))
        w += int(bundle_size)
    
    return Circuit(inputs, outputs, gates)

In [None]:
print_circuit(parse_circuit(adder_txt))

## Question 11

Implement a function `eval_circuit` for evaluating circuits.

In [None]:
def int_to_bitstring(i, n):
    return [int(x) for x in list(reversed('{0:0b}'.format(i).zfill(n)))]

def bitstring_to_int(bs):
    return sum([x*(2**i) for i, x in enumerate(bs)])

def eval_circuit(inputs, circuit):
    # YOUR CODE HERE
    raise NotImplementedError()

In [None]:
# TEST CASE
# Example: 5 + 6 = 11
circuit = parse_circuit(adder_txt)
inputs = [int_to_bitstring(5, 64), int_to_bitstring(6, 64)]
outputs = eval_circuit(inputs, parse_circuit(adder_txt))
assert [bitstring_to_int(b) for b in outputs] == [11]

In [None]:
sha256_url = "https://homes.esat.kuleuven.be/~nsmart/MPC/sha256.txt"
sha256_txt = urllib.request.urlopen(sha256_url).read().decode("utf-8")

sha256_circuit = parse_circuit(sha256_txt)

In [None]:
# Example: SHA256 hash of a bunch of 1s
test_inputs = [[1 for x in y] for y in sha256_circuit.inputs]
outputs = eval_circuit(test_inputs, sha256_circuit)
bitstring_to_int(outputs[0])

## Question 12

Sketch the BGW protocol for evaluating an arithmetic or boolean circuit with $n$ parties.

YOUR ANSWER HERE