# CS295/395: Secure Distributed Computation
## Homework 6

## Definitions

In [None]:
# Imports and definitions
import numpy as np
from collections import defaultdict
from collections import namedtuple
import urllib.request

_PRIME = 2 ** 13 - 1

shamir_lib_url = "https://raw.githubusercontent.com/jnear/cs295-secure-computation/master/utils/shamir.py"

### DANGER: this line is dangerous. Make sure the URL above is correct, and has correct code.
exec(urllib.request.urlopen(shamir_lib_url).read())

class Party:
    """A participant in a multiparty computation protocol."""
    def __init__(self, field_size=_PRIME):
        """Initialize the field size and dictionary to hold received messages."""
        self.field_size = field_size
        self.input = None
        self.output = None
        self.received = defaultdict(list)
    
    def send(self, other, round, msg):
        """Simulate sending a message `msg` to another party `other` during round `round`"""
        other.received[round].append(msg)

    def get_view(self):
        """Returns the view of this party: its input, output, and received messages."""
        return (self.input, self.output, dict(self.received))

In [None]:
AddGate = namedtuple('AddGate', ['in1', 'in2'])
MultGate = namedtuple('MultGate', ['in1', 'in2'])

largest_wire = 0
def new_wire():
    global largest_wire
    largest_wire += 1
    return f'w{largest_wire}'

## Question 1

Implement the BGW protocol for MPC. Your solution should handle both addition and multiplication gates.

In [None]:
class BGWParty(Party):
    def receive_inputs(self, input_wire_values, circuit, eval_order, t, n):
        self.wire_values = input_wire_values
        self.circuit = circuit
        self.eval_order = eval_order
        self.is_done = False
        self.t = t
        self.n = n
    
    def round_n(self, round_num, parties):
        """Perform one round of the BGW protocol. Reference Section 3.3 in 'Pragmatic MPC.'"""
        # YOUR CODE HERE
        raise NotImplementedError()

In [None]:
def run_bgw(inputs, circuit, eval_order, output_wires):
    """Runs the BGW Protocol. Feel free to change this driver function if it helps you to do so."""
    n = 6
    t = 3
    
    # calculate the input shares
    input_shares = {w: share_shamir(t, n, x) for w,x in inputs.items()}
    parties = [BGWParty(_PRIME) for _ in range(n)]
    
    # split the shares up for the parties
    keys = input_shares.keys()
    party_shares = [dict(zip(keys, vals)) for vals in zip(*(input_shares[k] for k in keys))]

    # kick off each party with its inputs and copies of the circuit and evaluation plan
    for p, s in zip(parties, party_shares):
        p.receive_inputs(s, circuit.copy(), eval_order.copy(), t, n)
    
    done = False
    round_num = 1
    
    # keep evaluating until one of the parties is finished
    while not done:
        for p in parties:
            p.round_n(round_num, parties)
            if p.is_done:
                done = True

        round_num = round_num + 1

    # for each output wire, get the shares from the parties for that wire
    output_shares = [[p.wire_values[w] for p in parties] for w in output_wires]

    outputs = [reconstruct_shamir(shares) for shares in output_shares]
    return outputs

In [None]:
# TEST CASE: a simple circuit
inputs = {'x': 5, 'y': 6}
circuit = {'w1': AddGate('x', 'y'),
           'w2': MultGate('w1', 'x'),
           'w3': AddGate('w2', 'w1')}

eval_order = list(circuit.keys())

result = run_bgw(inputs, circuit, eval_order, ['w3'])
print('Result:', result)
assert result == [66]

## Additional Test Case

In [None]:
def build_large_circuit(wire_names):
    first_wire, *rest_wires = wire_names
    circuit = {}
    
    last_output = first_wire
    
    for wire in rest_wires:
        w1 = new_wire()
        circuit[w1] = AddGate(last_output, wire)
        w2 = new_wire()
        circuit[w2] = MultGate(wire, w1)
        last_output = w2
    
    return circuit

In [None]:
# TEST CASE: a large circuit
circuit_size = 100
input_wires = [f'x{n}' for n in range(circuit_size)]
inputs = {w: n for w, n in zip(input_wires, range(circuit_size))}

circuit = build_large_circuit(input_wires)

eval_order = list(circuit.keys())

result = run_bgw(inputs, circuit, eval_order, [eval_order[-1]])
print('Result:', result)
assert result == [2653]