# CS295/395: Secure Distributed Computation
## In-Class Exercise, week of 9/26/2022

In [None]:
# For later this week:
from nacl.public import PrivateKey, Box, SealedBox

# PyNaCl is a library for (traditional) encryption
# It is easiest to install using: `conda install pynacl`
# It can also be installed using: `pip install pynacl`
# but the conda version is more likely to work cleanly.
# See documentation here: https://pynacl.readthedocs.io/en/latest/

In [None]:
# Imports and definitions
import numpy as np
from collections import defaultdict
import numpy as np
import galois
GF = galois.GF(2**13 - 1)

# Library for circuits
from dataclasses import dataclass

@dataclass
class Gate:
    type: str
    in1: int
    in2: int
    out: int

@dataclass
class Circuit:
    inputs: any
    outputs: any
    gates: any
        
def print_circuit(c):
    print('inputs:', c.inputs)
    print('outputs:', c.outputs)
    print('gates:')
    for g in c.gates:
        print('  ', g)

## Party Class and Shamir sharing

In [None]:
class Party:
    """A participant in a multiparty computation protocol."""
    def __init__(self):
        """Initialize the field size and dictionary to hold received messages."""
        self.input = None
        self.output = None
        self.received = defaultdict(list)
    
    def send(self, other, round, msg):
        """Simulate sending a message `msg` to another party `other` during round `round`"""
        other.received[round].append(msg)

    def get_view(self):
        """Returns the view of this party: its input, output, and received messages."""
        return (self.input, self.output, dict(self.received))

# Generate Shamir shares for secret v with threshold t and number of shares n
def shamir_share(v, t, n):
    coefficients = GF([GF.Random() for _ in range(t-1)] + [v])
    poly = galois.Poly(coefficients)
    shares = [(GF(x), poly(GF(x))) for x in range(1, n+1)]
    return shares

# Reconstruct the secret from at least t Shamir shares
def reconstruct(shares):
    xs = GF([s[0] for s in shares])
    ys = GF([s[1] for s in shares])
    poly = galois.lagrange_poly(xs, ys)
    #print(poly)
    secret = poly(0)
    
    return secret

## Question 1

Implement a function `sum_sq_circuit` that returns the sum and the squared sum of a list of numbers.

In [None]:
def sum_sq_circuit(n):
    inputs = [[i] for i in range(n)]

    # YOUR CODE HERE
    raise NotImplementedError()

print_circuit(sum_sq_circuit(6))

## Question 2

Implement a function `eval_circuit` for evaluating circuits.

In [None]:
def eval_circuit(inputs, circuit):
    # YOUR CODE HERE
    raise NotImplementedError()

In [None]:
# TEST CASE
# Example: sum of 0 to 6 = 15
circuit = sum_sq_circuit(6)
inputs = [[i] for i in range(6)]
outputs = eval_circuit(inputs, circuit)
assert outputs == [GF(15), GF(225)]

## Question 3

Sketch the BGW protocol for evaluating an arithmetic or boolean circuit with $n$ parties.

YOUR ANSWER HERE

## Question 4

Implement the BGW protocol.

In [None]:
class BGWParty(Party):
    def round1(self, parties, circuit, my_inputs):
        self.parties = parties
        self.is_done = False
        self.circuit = circuit
        n = len(parties)
        t = int(n/2)
        
        # YOUR CODE HERE
        raise NotImplementedError()

    def round2(self, my_id):
        self.wire_vals = {}

        # YOUR CODE HERE
        raise NotImplementedError()

    def roundn(self, round_num):
        n = len(self.parties)
        t = int(n/2)
        
        # YOUR CODE HERE
        raise NotImplementedError()

In [None]:
def run_bgw_protocol():
    NUM_PARTIES = 6
    n = NUM_PARTIES
    
    circuit = sum_sq_circuit(6)
    
    inputs = [[i] for i in range(6)]
    print('Inputs:', inputs)
    parties = [BGWParty() for _ in range(NUM_PARTIES)]
    
    for p, i in zip(parties, inputs):
        p.round1(parties, circuit, i)
    for p in parties:
        p.round2(parties)
    round_num = 3

    while not parties[0].is_done:
        for p in parties:
            p.roundn(round_num)       
        round_num += 1
        
    for p in parties:
        print('Output:', p.output)

    outputs = [p.output for p in parties]
    return outputs

In [None]:
# TEST CASE
outputs = run_bgw_protocol()
for o in outputs:
    assert o == [GF(15), GF(225)]

## Question 5

Describe the 1-out-of-2 *oblivious transfer* (OT) protocol. Reference Section 3.7 in Pragmatic MPC.

YOUR ANSWER HERE

## Question 6

Why is the oblivious transfer protocol secure against semi-honest adversaries? Why is it not secure against malicious adversaries?

YOUR ANSWER HERE

## Question 7

Implement 1-out-of-2 OT.

In [None]:
class OT_Sender(Party):
    # x1 and x2 are the secrets
    def round1(self, x1, x2, receiver):
        self.x1 = x1
        self.x2 = x2
        self.receiver = receiver

    def round2(self):
        # YOUR CODE HERE
        raise NotImplementedError()
    
    def round3(self):
        pass

class OT_Receiver(Party):
    def round1(self, b, sender):
        self.sender = sender
        self.b = b
        # YOUR CODE HERE
        raise NotImplementedError()
    
    def round2(self):
        pass
    
    def round3(self):
        # YOUR CODE HERE
        raise NotImplementedError()

In [None]:
# TEST CASE
GF_2 = galois.GF(2)

sender = OT_Sender()
receiver = OT_Receiver()

# Round 1
sender.round1(GF_2(0), GF_2(1), receiver)
receiver.round1(GF_2(1), sender)

# Round 2
sender.round2()
receiver.round2()

# Round 3
sender.round3()
output = receiver.round3()

print("Receiver's output:", output)
assert output == 1

## Question 8

Describe 1-out-of-4 OT.

YOUR ANSWER HERE

## Question 9

Describe the GMW protocol for evaluating a binary circuit.

YOUR ANSWER HERE