# CS295/395: Secure Distributed Computation
## Homework 2

## Definitions

In [None]:
# IMPORTS & DEFINITIONS

from scipy import stats
import numpy as np
from collections import defaultdict
import galois

GF = galois.GF(97)

def make_additive_shares(n, x):
    """Create additive secret shares of the input, so that the shares look random but add up to the input.
    n: number of shares to create
    x: input value"""
    first_shares = GF.Random(n-1)
    last_share = x - first_shares.sum()
    return GF(list(first_shares) + [last_share])

class Party:
    """A participant in a multiparty computation protocol."""
    def __init__(self, field_size):
        """Initialize the field size and dictionary to hold received messages."""
        self.field_size = field_size
        self.received = defaultdict(list)
    
    def send(self, other, round, msg):
        """Simulate sending a message `msg` to another party `other` during round `round`"""
        other.received[round].append(msg)

## Question 1 (30 points)

Implement the `SecureAggregationParty` class, which implements an aggregation protocol secure against semi-honest adversaries.

In [None]:
class SecureAggregationParty(Party):
    """A protocol for secure aggregation using additive secret sharing."""
    def round1(self, parties, input):
        """In round 1, each party shares out `input` and sends one share 
        to each party (including itself)."""
        self.input = input
        
        # YOUR CODE HERE
        raise NotImplementedError()

    def round2(self, parties):
        """In round 2, each party sums up the shares it has received and
        sends the result to all parties (including itself)."""
        
        # YOUR CODE HERE
        raise NotImplementedError()
            
    def round3(self):
        """In round 3, each party outputs the sum of the sums it has received
        (by setting `self.output` to the sum)."""
        
        # YOUR CODE HERE
        raise NotImplementedError()
    
    def get_view(self):
        """Returns the view of this party: its input, output, and received messages."""
        return (self.input, self.output, self.received[1], self.received[2])

In [None]:
# TEST CASE for question 1

# field size 100, 5 parties
p = 100
parties = [SecureAggregationParty(p) for i in range(5)]
inputs = [np.random.randint(1, 10) for _ in parties]

# run round 1
for party, x_p in zip(parties, inputs):
    party.round1(parties, x_p)

# run round 2
for party in parties:
    party.round2(parties)

# run round 3
for party in parties:
    assert party.round3() == sum(inputs)

In [None]:
# Example: how to print the view of a party
# Must run this cell *after* running the test case above
for party in parties:
    print(party.get_view())

## Question 2 (15 points)

Implement the `CorruptAggregationParty` class, a malicious participant in the secure aggregation protocol from Question 1. The `CorruptAggregationParty` should output the correct answer, but cause other parties to output the wrong answer.

In [None]:
class CorruptAggregationParty(Party):
    """A corrupted party for secure aggregation."""
    def round1(self, parties, input):
        # YOUR CODE HERE
        raise NotImplementedError()

    def round2(self, parties):
        # YOUR CODE HERE
        raise NotImplementedError()
            
    def round3(self):
        # YOUR CODE HERE
        raise NotImplementedError()

In [None]:
# TEST CASE for question 2

# field size 100, 5 honest parties and 1 corrupt party
p = 100
parties = [SecureAggregationParty(p) for i in range(5)]
parties.append(CorruptAggregationParty(p))

# run round 1
for party in parties:
    party.round1(parties, 10)

# run round 2
for party in parties:
    party.round2(parties)

# run round 3
for party in parties:
    output = party.round3()
    if isinstance(party, SecureAggregationParty):
        assert output != 60   # honest parties output the *wrong* answer
    else:
        assert output == 60   # corrupt party outputs the *right* answer

## Question 3 (15 points)

Write a function that implements a simulator for `SecureAggregationParty`. Your function should simulate one party. It should take the size of the finite field, the number of parties participating, the simulated party's input, and the output of the functionality (the information available in the ideal world) as arguments. Your function should print a view that is indistinguishable from one printed by the `print_view` method of the `SecureAggregationParty` class.

In [None]:
def simulator(n, input, output):
    """Simulate the view of a single `SecureAggregationParty` using only the functionality's
    input and output, field size p, and number of parties n. Outputs a 4-tuple."""
    
    # YOUR CODE HERE
    raise NotImplementedError()
    return (input, output, round1_view, round2_view)

In [None]:
# TEST CASE 1 for question 3

# run the simulator 1000 times
simulator_runs = [simulator(5, 5, 50) for _ in range(1000)]

# check that the views have the correct input and output
for input, output, round1_view, round2_view in simulator_runs:
    assert input == 5
    assert output == 50

# generate 5000 uniformly random field elements
unif = np.array([GF.Random() for _ in range(1000*5)])

# check that the views have the correct randomness
r1s = np.array([s[2] for s in simulator_runs]).flat
r2s = np.array([s[3] for s in simulator_runs]).flat
assert stats.wasserstein_distance(r1s, unif) <= 2
assert stats.wasserstein_distance(r2s, unif) <= 2

In [None]:
# TEST CASE 2 for question 3

real_world_views = []

# run the simulator 1000 times
simulator_runs = [simulator(5, 5, 50) for _ in range(1000)]

# check that the views have the correct input and output
for input, output, round1_view, round2_view in simulator_runs:
    assert input == 5
    assert output == 50

# run the real protocol 1000 times
for _ in range(1000):
    # field size 100, 5 parties
    p = 100
    parties = [SecureAggregationParty(p) for i in range(5)]

    # run round 1
    for party in parties:
        party.round1(parties, 10)

    # run round 2
    for party in parties:
        party.round2(parties)

    # run round 3
    for party in parties:
        assert party.round3() == 50
        real_world_views.append(party.get_view())

# check that the simulated views in the ideal world have the same 
# randomness as the real world views
r1s_ideal = np.array([s[2] for s in simulator_runs]).flat
r2s_ideal = np.array([s[3] for s in simulator_runs]).flat
r1s_real = np.array([s[2] for s in real_world_views]).flat
r2s_real = np.array([s[3] for s in real_world_views]).flat

assert stats.wasserstein_distance(r1s_ideal, r1s_real) <= 2
assert stats.wasserstein_distance(r2s_ideal, r2s_real) <= 2