# CS 3990/5990: Secure Distributed Computation
## Homework 5

## Definitions

In [None]:
# Imports and definitions
import numpy as np
from collections import defaultdict
import urllib.request
import galois
from nacl.public import PrivateKey, Box, SealedBox

GF_2 = galois.GF(2) # we work in the binary field this week!

# Library for circuits
from dataclasses import dataclass

@dataclass
class Gate:
    type: str
    in1: int
    in2: int
    out: int

@dataclass
class Circuit:
    inputs: any
    outputs: any
    gates: any

class Party:
    """A participant in a multiparty computation protocol."""
    def __init__(self):
        """Initialize the field size and dictionary to hold received messages."""
        self.input = None
        self.output = None
        self.received = defaultdict(list)
    
    def send(self, other, round, msg):
        """Simulate sending a message `msg` to another party `other` during round `round`"""
        other.received[round].append(msg)

    def get_view(self):
        """Returns the view of this party: its input, output, and received messages."""
        return (self.input, self.output, dict(self.received))

# Parsing Circuits

In [None]:
import urllib.request
adder_url = "https://homes.esat.kuleuven.be/~nsmart/MPC/adder64.txt"
adder_txt = urllib.request.urlopen(adder_url).read().decode("utf-8")
sha256_url = "https://homes.esat.kuleuven.be/~nsmart/MPC/sha256.txt"
sha256_txt = urllib.request.urlopen(sha256_url).read().decode("utf-8")

In [None]:
# Parse a circuit from a Bristol-Fashion specification
def parse_circuit(bristol_fashion_text):
    lines = [l.strip() for l in bristol_fashion_text.split('\n') if l != '']
    total_wires = int(lines[0].split(' ')[1])
    inputs = lines[1]
    outputs = lines[2]
    gates_txt = lines[3:]
    gates = []
    
    # parse the gates
    for g_txt in gates_txt:
        sp = g_txt.split(' ')
        gate_type = sp[-1]
        if gate_type in ['XOR', 'AND']:
            _, _, in1, in2, out, typ = g_txt.split(' ')
        elif gate_type == 'INV':
            _, _, in1, out, typ = g_txt.split(' ')
            in2 = -1
        else:
            raise RuntimeError('unknown gate type:', gate_type)
        gates.append(Gate(typ, int(in1), int(in2), int(out)))
    
    ins = inputs.split(' ')
    num_inputs = int(ins[0])
    
    # generate the bundles of input wires
    w = 0
    input_bundle_sizes = [int(x) for x in inputs.split(' ')[1:]]
    inputs = []
    for bundle_size in ins[1:]:
        inputs.append(list(range(w, w+int(bundle_size))))
        w += int(bundle_size)

    # generate the bundles of output wires
    output_bundle_sizes = [int(x) for x in outputs.split(' ')[1:]]
    total_output_wires = sum(output_bundle_sizes)
    w = total_wires - total_output_wires
    outputs = []
    for bundle_size in output_bundle_sizes:
        outputs.append(list(range(w, w+int(bundle_size))))
        w += int(bundle_size)
    
    return Circuit(inputs, outputs, gates)

def int_to_bitstring(i, n):
    return [int(x) for x in list(reversed('{0:0b}'.format(i).zfill(n)))]

def bitstring_to_int(bs):
    return sum([int(x)*(2**i) for i, x in enumerate(bs)])

In [None]:
adder = parse_circuit(adder_txt)
sha256 = parse_circuit(sha256_txt)

In [None]:
# Compute the value of an AND gate, using all additive shares of its inputs
def S(s1_i, s1_j, s2_i, s2_j):
    return (s1_i + s2_i) * (s1_j + s2_j)

# Generate the truth table describing P2's share of an AND gate's output
def T_G(r, s1_i, s1_j):
    combinations = GF_2([(0,0), (0,1), (1,0), (1,1)])
    output_table = []
    for s2_i, s2_j in combinations:
        s2_k = r + S(s1_i, s1_j, s2_i, s2_j)
        output_table.append(s2_k)
    return output_table

## Question 1

Implement the GMW protocol.

Reference the following exercise questions:
- The definition of 1-out-of-4 Oblivious Transfer (OT) from the 10/02/2023 exercise
- The definition of the BGW protocol from the 9/25/2023 exercise
- The definition of circuit evaluation from the 9/25/2023 exercise

In [None]:
class GMW_P1(Party):
    def __init__(self):
        super().__init__()
        self.is_done = False
        self.phase = 1
        self.wire_vals = {-1: None}

    def roundn(self, round_num, circuit, inputs, p2):
        # YOUR CODE HERE
        raise NotImplementedError()

class GMW_P2(Party):
    def __init__(self):
        super().__init__()
        self.is_done = False
        self.phase = 1
        self.wire_vals = {-1: None}

    def roundn(self, round_num, circuit, inputs, p1):
        # YOUR CODE HERE
        raise NotImplementedError()

In [None]:
# Driver function for the protocol
def run_gmw(circuit, p1_input, p1_bitwidth, p2_input, p2_bitwidth):
    p1_inputs = int_to_bitstring(p1_input, p1_bitwidth)
    p2_inputs = int_to_bitstring(p2_input, p2_bitwidth)

    p1 = GMW_P1()
    p2 = GMW_P2()

    round_num = 1
    while not p1.is_done and not p2.is_done:
        p1.roundn(round_num, circuit, p1_inputs, p2)
        p2.roundn(round_num, circuit, p2_inputs, p1)
        round_num += 1
#     print('P1 output:', GF_2(p1.output))
#     print('P2 output:', GF_2(p2.output))

#     print('P1 output (int):', bitstring_to_int(p1.output))
#     print('P2 output (int):', bitstring_to_int(p2.output))
    
    return bitstring_to_int(p1.output), bitstring_to_int(p2.output)

In [None]:
## ADDER TEST CASE
for _ in range(10):
    n1 = np.random.randint(0, 1000)
    n2 = np.random.randint(0, 1000)
    
    o1, o2 = run_gmw(adder, n1, 64, n2, 64)
    assert o1 == o2 == n1 + n2, f'Mismatch! Inputs {n1}, {n2}, outputs {o1}, {o2}'

In [None]:
### SHA256 TEST CASE
### Warning: takes about a minute to run

o1, o2 = run_gmw(sha256, 1, 512, 2, 256)
assert o1 == o2 == 62635937818952219496566001010706647480343244544051980721954351996715678910351