In [1]:
n_parties = 3 # Set the number of answering parties.
index = 6 # Set the index of the mnist test set to use as the query (index of a sample).

#### Imports and helper functions

In [2]:
"""
This script assumes that a subdir with name {n_parties} exists in /models with the model 
files stored here.
The number of model files should equal the value of {n_parties} + 1.
It kicks off a server for each answering party and a single client who will be 
requesting queries.
client.py holds the clients training protocol, and server.py the response algorithms.
train_inits.py should be run first to train each model on a separate partition and save 
them as per the required scheme.
USAGE: call this file with: 
OMP_NUM_THREADS=24 NGRAPH_HE_VERBOSE_OPS=all NGRAPH_HE_LOG_LEVEL=3 python run_experiment.py
SETUP: create a tmux session with 3 panes, each in /home/dockuser/code/capc
"""

import warnings
from utils import client_data
from utils.client_data import get_data
from utils.time_utils import get_timestamp, log_timing

warnings.filterwarnings('ignore')
import tensorflow as tf

tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
import argparse
import os
import time
import numpy as np
import atexit
import libtmux
from utils.remove_files import remove_files_by_name
import consts
from consts import out_client_name, out_server_name, out_final_name
import getpass
import get_r_star
import subprocess
import os
import client

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [3]:
def get_FLAGS():
    """Initial setup of parameters to be used."""
    parser = argparse.ArgumentParser('')
    parser.add_argument('--session', type=str, help='session name',
                        default='capc')
    parser.add_argument('--log_timing_file', type=str,
                        help='name of the global log timing file',
                        default=f'logs/log-timing-{get_timestamp()}.log')
    parser.add_argument('--n_parties', type=int, default=n_parties,
                        help='number of servers')
    parser.add_argument('--seed', type=int, default=2,
                        help='seed for top level script')
    parser.add_argument('--batch_size', type=int, default=1,
                        help='batch size')
    parser.add_argument('--num_classes', type=int, default=10,
                        help='Number of classes in the dataset.')
    parser.add_argument(
        "--rstar_exp",
        type=int,
        default=10,
        help='The exponent for 2 to generate the random r* from.',
    )
    parser.add_argument(
        "--max_logit",
        type=float,
        default=36.0,
        help='The maximum value of a logit.',
    )
    parser.add_argument(
        "--user",
        type=str,
        default=getpass.getuser(),
        help="The name of the OS USER.",
    )
    parser.add_argument(
        "--log_level",
        type=int,
        default=0,
        help='log level for he-transformer',
    )
    parser.add_argument(
        '--round_exp',
        type=int,
        default=3,
        help='Multiply r* and logits by 2^round_exp.'
    )
    parser.add_argument(
        '--num_threads',
        type=int,
        default=20,
        help='Number of threads.',
    )
    parser.add_argument(
        '--qp_id', type=int, default=0, help='which model is the QP?')
    parser.add_argument(
        "--start_batch",
        type=int,
        default=0,
        help="Test data start index")
    parser.add_argument(
        "--model_type",
        type=str,
        default='cryptonets-relu',
        help="The type of models used.",
    )
    parser.add_argument(
        "--input_node",
        type=str,
        default="import/input:0",
        help="Tensor name of data input",
    )
    parser.add_argument(
        "--output_node",
        type=str,
        default="import/output/BiasAdd:0",
        help="Tensor name of model output",
    )
    parser.add_argument(
        '--dataset_path', type=str,
        default='/home/dockuser/queries',
        help='where the queries are.')
    parser.add_argument(
        '--dataset_name', type=str,
        default='mnist',
        help='name of dataset where queries came from')
    parser.add_argument('--debug', default=False, action='store_true')
    parser.add_argument('--start_port', type=int, default=37000,
                    help='the number of the starting port')
    parser.add_argument('--n_queries',
                        type=int,
                        default=1,
                        help='total len(queries)')
    parser.add_argument('--checkpoint_dir', type=str,
                        default='/home/dockuser/checkpoints',
                        help='dir with all checkpoints')
    parser.add_argument('--cpu', default=False, action='store_true',
                        help='set to use cpu and no encryption.')
    parser.add_argument('--ignore_parties', default=True, action='store_true', #False
                        help='set when using crypto models.')
    # parser.add_argument('--',
    #                     default='$HE_TRANSFORMER/configs/he_seal_ckks_config_N13_L5_gc.json')
    parser.add_argument('--encryption_params',
                        default='config/10.json')
    FLAGS, unparsed = parser.parse_known_args()
    if unparsed:
        print("Unparsed flags:", unparsed)
        exit(1)
    return FLAGS

def clean_old_files():
    """Delete old data files. This function is called before running the protocol."""
    for name in [out_client_name,
                 out_server_name,
                 out_final_name,
                 consts.input_data,
                 consts.input_labels,
                 consts.predict_labels, consts.label_final_name]:
        remove_files_by_name(starts_with=name)


# Provide data.
def set_data_labels(FLAGS):
    """Gets MNIST data and labels, saving it in the local folder"""
    data, labels = get_data(start_batch=FLAGS.start_batch,
                            batch_size=FLAGS.batch_size)
    np.save(consts.input_data, data)
    np.save(consts.input_labels, labels)


def get_models(model_dir, n_parties, ignore_parties):
    """Gets model files from model_dir."""
    model_files = [f for f in os.listdir(model_dir) if
                   os.path.isfile(os.path.join(model_dir, f))]
    if len(model_files) != n_parties and not ignore_parties:
        raise ValueError(
            f'{len(model_files)} models found when {n_parties + 1} parties requested. Not equal.')
    return model_dir, model_files


#### Initial setup for CaPC protocol

In [4]:
FLAGS = get_FLAGS()
np.random.seed(FLAGS.seed)
clean_old_files()
set_data_labels(FLAGS=FLAGS)


log_timing_file = FLAGS.log_timing_file
log_timing('main: start capc', log_file=log_timing_file)
n_parties = FLAGS.n_parties
batch_size = FLAGS.batch_size
num_classes = FLAGS.num_classes
rstar_exp = FLAGS.rstar_exp
log_level = FLAGS.log_level
round_exp = FLAGS.round_exp
num_threads = FLAGS.num_threads
input_node = FLAGS.input_node
output_node = FLAGS.output_node
start_port = FLAGS.start_port
index = FLAGS.start_batch
backend = 'HE_SEAL' if not FLAGS.cpu else 'CPU'
models_loc, model_files = get_models(
    FLAGS.checkpoint_dir, n_parties=n_parties,
    ignore_parties=FLAGS.ignore_parties)
for port in range(37000, 37000 + n_parties): 
    files_to_delete = [consts.out_client_name + str(port) + 'privacy.txt']
    files_to_delete += [consts.out_final_name + str(port) + '.txt' ]#+ 'privacy.txt']
    files_to_delete += [consts.out_server_name + str(port) + '.txt']#+ 'privacy.txt']
    files_to_delete += [f"{out_final_name}.txt",
                        f"{out_server_name}.txt"]  # aggregates across all parties
    files_to_delete += [consts.inference_times_name,
                        consts.argmax_times_name,
                        consts.client_csp_times_name,
                        consts.inference_no_network_times_name]
    for f in files_to_delete:
        if os.path.exists(f):
            print(f'delete file: {f}')
            os.remove(f)

Unparsed flags: ['-f', '/home/dockuser/.local/share/jupyter/runtime/kernel-33b6b4db-5ebc-41ed-895c-fa235191cc1b.json']
delete file: files/logits37000privacy.txt
delete file: files/output37000.txt
delete file: files/noise37000.txt
delete file: files/output.txt
delete file: files/noise.txt
delete file: files/inference_times
delete file: files/argmax_times
delete file: files/client_csp_times
delete file: files/inference_no_network_times


#### Step 1 of Protocol. The files server.py and client.py together complete Step 1

In [5]:
for query_num in range(FLAGS.n_queries):  #Querying process
    for port, model_file in zip(
            [37000 + int(i + query_num * n_parties) for i in
             range(n_parties)],
            model_files):
        print(f"port: {port}")
        full_model_file = fr'{models_loc}/{model_file}'
        full_model_file_new = ""
        for s in full_model_file:
            if s == '(' or s == ')':
                full_model_file_new += "\\"
            full_model_file_new += s
        full_model_file = full_model_file_new
        new_model_file = os.path.join("/home/dockuser/models",
                                      str(port) + ".pb")
        r_star = get_r_star.get_rstar_server(  # Generate random vector needed in Step 1a
            max_logit=FLAGS.max_logit,
            batch_size=batch_size,
            num_classes=num_classes,
            exp=FLAGS.rstar_exp,
        ).flatten()
        print(f"run_exp rstar: {r_star}")
        print(f"port: {port}")
        print('Start the servers (answering parties: APs)')
        log_timing('start server (AP)', log_file=log_timing_file)
        cmd_string = " ".join([  # Command to start server with the relevant parameters. 
                                  f'OMP_NUM_THREADS={num_threads}',
                                  f'NGRAPH_HE_LOG_LEVEL={log_level}',
                                  'python -W ignore', 'server.py',
                                  '--backend', backend,
                                  '--n_parties', f'{n_parties}',   
                                  '--model_file', new_model_file,
                                  '--dataset_name', FLAGS.dataset_name,
                                  '--indext', str(index),
                                  '--encryption_parameters',
                                  FLAGS.encryption_params,
                                  '--enable_client', 'true',
                                  '--enable_gc', 'true',
                                  '--mask_gc_inputs', 'true',
                                  '--mask_gc_outputs', 'true',
                                  '--from_pytorch', '1', '--dataset_name',
                                  FLAGS.dataset_name,
                                  '--dataset_path', FLAGS.dataset_path,
                                  '--num_gc_threads', f'{num_threads}',
                                  '--input_node', f'{input_node}',
                                  '--output_node', f'{output_node}',
                                  '--minibatch_id', f'{query_num}',
                                  '--rstar_exp', f'{rstar_exp}',
                                  '--num_classes', f'{num_classes}',
                                  '--round_exp', f'{round_exp}',
                                  '--log_timing_file', log_timing_file,
                                  "--r_star"] + [str(x) for x in r_star] + [
                                  '--port', f'{port}',
                              ])
        subprocess.Popen(cmd_string, shell = True)  # Run server.py with the given parameters. 
        if not FLAGS.cpu:
            time.sleep(1)
            print(f"port: {port}")
            log_timing('start cleint (the querying party: QP)', log_file=log_timing_file)
            cmd_string = " ".join([
                        # Command to start client server with the relevant parameters.
                        f'OMP_NUM_THREADS={num_threads}',
                        f'NGRAPH_HE_LOG_LEVEL={log_level}',
                        'python -W ignore client.py',
                        '--batch_size', f'{batch_size}',
                        '--encrypt_data_str', 'encrypt',
                        '--indext', str(index),
                        '--n_parties', f'{n_parties}',
                        '--round_exp', f'{round_exp}',
                        '--from_pytorch', '1',
                        '--minibatch_id', f'{query_num}',
                        '--dataset_path', f'{FLAGS.dataset_path}',
                        '--port', f'{port}',
                        '--dataset_name', FLAGS.dataset_name,
                        '--data_partition', 'test',
                        '--log_timing_file', log_timing_file,
                    ])
            subprocess.Popen(cmd_string, shell=True)  # Run client_server.py with the given parameters. 
            time.sleep(16) 
        else:
            time.sleep(1)


port: 37000
run_exp rstar: [ -95.08244041 -934.90307722  137.7087547   -96.45973992 -127.08674132
 -311.47428658 -568.87959749  280.26693909 -374.30722831 -441.53774059]
port: 37000
Start the servers (answering parties: APs)
port: 37000
port: 37001
run_exp rstar: [ 284.08208951   95.68300908 -712.38027193   63.80799235 -610.26715516
  620.36638273  760.94139933   24.19704296  745.75792201 -824.88606309]
port: 37001
Start the servers (answering parties: APs)
port: 37001
port: 37002
run_exp rstar: [  46.74399257 -854.29323902 -111.20547308 -790.30468473 -727.57637796
  234.13439279 -525.12742276 -768.97523853 -536.81288792 -271.55576831]
port: 37002
Start the servers (answering parties: APs)
port: 37002


#### Steps 2 and 3 of the Protocol. The file pg.py runs the privacy guardian. 

In [6]:
log_timing('start privacy guardian', log_file=log_timing_file)
# Command to run Privacy Guardian (Steps 2 & 3).
cmd_string = " ".join(
    ['python -W ignore', 'pg.py',
     f'{start_port + int(query_num * n_parties)}',
     f'{start_port + int(query_num * n_parties) + n_parties}'
     ])
print(f"start privacy guardian: {cmd_string}")
subprocess.Popen(cmd_string, shell=True)  # Run pg.py with the given parameters.
time.sleep(5)

start privacy guardian: python -W ignore pg.py 37000 37003


#### Compare predicted label with actual label. The client (querying party) print the outputted label. The actual label is manually found using the index of the query used

In [7]:
client.print_label()
(x_train, y_train, x_test, y_test) = client_data.load_mnist_data(index, 1)
print("Actual label: ", np.argmax(y_test)) 
log_timing('finish capc', log_file=log_timing_file)

Predicted label:  7
Actual label 7
