Copyright (c) 2020 Apple Inc.  
SPDX-License-Identifier: MPL-2.0

In [None]:
import os, sys, boto3, json
from cu_vec import *
from client_io import *
from crypto_io import *
from server_io import *
from client import *
from ingestor import *
from server import *
from credentials import *

Load the bindings for Rust Prio implementation

In [None]:
prefix = {'win32': ''}.get(sys.platform, 'lib')
extension = {'darwin': '.dylib', 'win32': '.dll'}.get(sys.platform, '.so')
lib_name = "../target/debug/" + prefix + "libprio_rs" + extension
lib = ctypes.cdll.LoadLibrary(lib_name)

In [None]:
setup_cu_vectors(lib)
setup_client_io(lib)
setup_crypto_io(lib)
setup_server_io(lib)

Create some values for the simulation

In [None]:
dimension = 10
n_clients = 10
aggregate_name = "test_counts"
s3_prefix = "test_counts/"

server1_s3_session = boto3.Session(
    region_name=server1_region,
    aws_access_key_id=server1_access_key_id,
    aws_secret_access_key=server1_secret_access_key
)

server2_s3_session = boto3.Session(
    region_name=server2_region,
    aws_access_key_id=server2_access_key_id,
    aws_secret_access_key=server2_secret_access_key
)

keys = get_random_key_pairs()

private_key1 = keys["private_key1"]
public_key1 = keys["public_key1"]
private_key2 = keys["private_key2"]
public_key2 = keys["public_key2"]

Simulate clients sending messages and keep track of the true aggregate

In [None]:
global_truth = []
client_simulation = []

with ClientSimulator(lib = lib,
                     dimension = dimension,
                     aggregate_name = aggregate_name,
                     public_key1 = public_key1,         
                     public_key2 = public_key2,
                     seed=9
                    ) as client_simulator:
    client_simulator.simulate_n(n_clients)
    (client_simulation, global_truth) = client_simulator.get_simulation()
print(global_truth)

Split the message shares across two s3 buckets

In [None]:
with Ingestor(lib=lib,
              dimension = dimension,
              s3_session1 = server1_s3_session,
              s3_session2 = server2_s3_session
             ) as ingestor:
    ingestor.read_shares(client_simulation)
    ingestor.s3_put_shares(server1_bucket, server2_bucket, s3_prefix)

Server 1 creates its verification messages

In [None]:
with Server(lib=lib,
            dimension=dimension, 
            is_first_server=True, 
            private_key=private_key1, 
            public_key1=public_key1, 
            public_key2=public_key2,
            s3_session1 = server1_s3_session,
            s3_session2 = server2_s3_session) as server1:
    server1_input_list = server1.s3_get_shares(server1_bucket,s3_prefix)
    server1.generate_verification_messages(server1_input_list)
    server1.s3_put_verification_messages(server1_bucket,server2_bucket,s3_prefix)

Server 2 creates its verification messages

In [None]:
with Server(lib=lib,
            dimension=dimension, 
            is_first_server=False, 
            private_key=private_key2, 
            public_key1=public_key1, 
            public_key2=public_key2,
            s3_session1 = server1_s3_session,
            s3_session2 = server2_s3_session) as server2:
    server2_input_list = server2.s3_get_shares(server2_bucket,s3_prefix)
    server2.generate_verification_messages(server2_input_list)
    server2.s3_put_verification_messages(server1_bucket,server2_bucket,s3_prefix)

Server 1 verifies that shares are valid and aggregates them

In [None]:
with Server(lib=lib,
            dimension=dimension, 
            is_first_server=True, 
            private_key=private_key1, 
            public_key1=public_key1, 
            public_key2=public_key2,
            s3_session1 = server1_s3_session,
            s3_session2 = server2_s3_session) as server1:
    server1_input_list = server1.s3_get_shares(server1_bucket,s3_prefix)
    (server1_verification_messages_for_server1,server2_verification_messages_for_server1) = server1.s3_get_verification_messages(server1_bucket,s3_prefix)
    v1 = server1.read_verification_messages(server1_verification_messages_for_server1)
    v2 = server1.read_verification_messages(server2_verification_messages_for_server1)
    server1.aggregate(server1_input_list,v1,v2)
    server1.s3_put_valid_shares(server1_bucket,server2_bucket,s3_prefix)
    server1.s3_put_total_shares(server1_bucket,server2_bucket,s3_prefix)

Server 2 verifies that shares are valid and aggregates them

In [None]:
with Server(lib=lib,
            dimension=dimension, 
            is_first_server=False, 
            private_key=private_key2, 
            public_key1=public_key1, 
            public_key2=public_key2,
            s3_session1 = server1_s3_session,
            s3_session2 = server2_s3_session) as server2:
    server2_input_list = server2.s3_get_shares(server2_bucket,s3_prefix)
    (server1_verification_messages_for_server2,server2_verification_messages_for_server2) = server2.s3_get_verification_messages(server2_bucket,s3_prefix)
    v1 = server2.read_verification_messages(server1_verification_messages_for_server2)
    v2 = server2.read_verification_messages(server2_verification_messages_for_server2)
    server2.aggregate(server2_input_list,v1,v2)
    server2.s3_put_valid_shares(server1_bucket,server2_bucket,s3_prefix)
    server2.s3_put_total_shares(server1_bucket,server2_bucket,s3_prefix)
    server2.s3_cleanup(server2_bucket,s3_prefix)

Server 1 combines both aggregates to find the total

In [None]:
total_counts = []
with Server(lib=lib,
            dimension=dimension, 
            is_first_server=True, 
            private_key=private_key1, 
            public_key1=public_key1, 
            public_key2=public_key2,
            s3_session1 = server1_s3_session,
            s3_session2 = server2_s3_session) as server1:
    (server1_total_shares_for_server1,server2_total_shares_for_server1) = server1.s3_get_total_shares(server1_bucket,s3_prefix)
    total_counts = server1.reconstruct_shares(server1_total_shares_for_server1,server2_total_shares_for_server1)
    server1.s3_cleanup(server1_bucket,s3_prefix)

Confirm that the reconstructed count matches the true count

In [None]:
list(total_counts)

In [None]:
for (prio_count,true_count) in zip(total_counts,global_truth):
    assert(prio_count == true_count)

Plot the reconstructed counts

In [None]:
%matplotlib notebook
import matplotlib.pyplot as plt

labels = ['bin '+str(i+1) for i in range(len(total_counts))]

fig, ax = plt.subplots()

ax.bar(labels, total_counts, 0.8)

ax.set_ylabel('Counts')
ax.set_title('Counts Reconstructed with Prio')

plt.show()
