## Introduction

This notebook provides an overview and implementation of [Confidential and Private Collaborative Learning](https://openreview.net/forum?id=h2EbJ4_wMVq). CaPC integrates cryptography and differential privacy to provide Confidential and Private Collaborative Learning. It is an extension of [PATE](https://arxiv.org/abs/1802.08908) and this relationship is shown more clearly in the diagram below. 

<p align="center">
<img width="750" alt="dpsgd" src="http://cleverhans.io/assets/capc/capc1.PNG">
</p>

Related files in this folder are referenced in this code and they can be opened for more details about the implementation. The MNIST dataset is used for this implementation however the code can be extended to other datasets as well. We divide the notebook into several sections based on the steps in the CaPC protocol. A brief description of the steps is first provided, which is then followed by the implementation. The numbering of the steps is the same as in the figure below. 

<p align="center">
<img width="750" alt="dpsgd" src="http://cleverhans.io/assets/capc/capc2.PNG">
</p>

#### Settings for the number of parties and the index to be used

In [1]:
n_parties = 2 # Set the number of answering parties.
index = 11 # Set the index of the data point in the MNIST test set to use as the query (index of a sample).

#### Imports and helper functions

In [2]:
import warnings

from utils import client_data
from utils.time_utils import get_timestamp
from utils.time_utils import log_timing

warnings.filterwarnings('ignore')

import argparse
import os
import numpy as np
import atexit
from utils.remove_files import remove_files_by_name
import consts
from consts import out_client_name
from consts import out_server_name
from consts import out_final_name
import getpass

import subprocess
import client

ModuleNotFoundError: No module named 'torch'

#### Arguments to be used in the code

In [3]:
def get_args():    
    user = getpass.getuser()
    """Initial setup of parameters to be used."""
    parser = argparse.ArgumentParser('')
    parser.add_argument('--session', type=str, help='session name',
                        default='capc')
    parser.add_argument('--log_timing_file', type=str,
                        help='name of the global log timing file',
                        default=f'logs/log-timing-{get_timestamp()}.log')
    parser.add_argument('--n_parties', type=int, default=n_parties,
                        help='number of servers')
    parser.add_argument('--start_port', type=int, default=37000,
                        help='the number of the starting port')
    parser.add_argument('--seed', type=int, default=2,
                        help='seed for top level script')
    parser.add_argument('--batch_size', type=int, default=1,
                        help='batch size')
    parser.add_argument('--num_classes', type=int, default=10,
                        help='Number of classes in the dataset.')
    parser.add_argument(
        "--rstar_exp",
        type=int,
        default=10,
        help='The exponent for 2 to generate the random r* from.',
    )
    parser.add_argument(
        "--max_logit",
        type=float,
        default=36.0,
        help='The maximum value of a logit.',
    )
    parser.add_argument('--dp_noise_scale', type=float, default=0.05,
                        help='The scale of the Gaussian noise for DP privacy.')
    parser.add_argument(
        "--user",
        type=str,
        default=user,
        help="The name of the OS USER.",
    )
    parser.add_argument(
        "--log_level",
        type=int,
        default=0,
        help='log level for he-transformer',
    )
    parser.add_argument(
        '--round_exp',
        type=int,
        default=3,
        help='Multiply r* and logits by 2^round_exp.'
    )
    parser.add_argument(
        '--num_threads',
        type=int,
        default=20,
        help='Number of threads.',
    )
    parser.add_argument(
        '--qp_id', type=int, default=0, help='which model is the QP?')
    parser.add_argument(
        "--start_batch",
        type=int,
        default=index,  # 0
        help="Test data start index")
    parser.add_argument(
        "--model_type",
        type=str,
        default='cryptonets-relu',
        help="The type of models used.",
    )
    parser.add_argument(
        "--input_node",
        type=str,
        default="import/input:0",
        help="Tensor name of data input",
    )
    parser.add_argument(
        "--output_node",
        type=str,
        default="import/output/BiasAdd:0",
        help="Tensor name of model output",
    )
    parser.add_argument(
        '--dataset_path', type=str,
        default='/home/dockuser/queries',
        help='where the queries are.')
    parser.add_argument(
        '--dataset_name', type=str,
        default='mnist',
        help='name of dataset where queries came from')
    parser.add_argument('--debug', default=False, action='store_true')
    parser.add_argument('--n_queries',
                        type=int,
                        default=1,
                        help='total len(queries)')
    parser.add_argument('--checkpoint_dir', type=str,
                        default=f'./models',
                        help='dir with all checkpoints')
    parser.add_argument('--cpu', default=False, action='store_true',
                        help='set to use cpu and no encryption.')
    parser.add_argument('--ignore_parties', default=True, action='store_true',
                        # False
                        help='set when using crypto models.')
    # parser.add_argument('--',
    #                     default='$HE_TRANSFORMER/configs/he_seal_ckks_config_N13_L5_gc.json')
    parser.add_argument('--encryption_params',
                        default='config/10.json')
    args, unparsed = parser.parse_known_args()
    if unparsed:
        print("Unparsed flags:", unparsed)
        exit(1)
    return args

def clean_old_files():
    """
    Delete old data files.
    This function is called before running the protocol.
    """
    cur_dir = os.getcwd()
    for name in [out_client_name,
                 out_server_name,
                 out_final_name,
                 consts.input_data,
                 consts.input_labels,
                 consts.predict_labels,
                 consts.label_final_name]:
        remove_files_by_name(starts_with=name, directory=cur_dir)


def delete_files(port):
    """
    Delete files related to this port.
    :param port: port number
    """
    files_to_delete = [consts.out_client_name + str(port) + 'privacy.txt']
    files_to_delete += [
        consts.out_final_name + str(port) + '.txt']  # + 'privacy.txt']
    files_to_delete += [
        consts.out_server_name + str(port) + '.txt']  # + 'privacy.txt']
    files_to_delete += [f"{out_final_name}.txt",
                        f"{out_server_name}.txt"]  # aggregates across all parties
    files_to_delete += [consts.inference_times_name,
                        consts.argmax_times_name,
                        consts.client_csp_times_name,
                        consts.inference_no_network_times_name]
    for f in files_to_delete:
        if os.path.exists(f):
            print(f'delete file: {f}')
            os.remove(f)


def set_data_labels(FLAGS):
    """Gets MNIST data and labels, saving it in the local folder"""
    data, labels = get_data(start_batch=FLAGS.start_batch,
                            batch_size=FLAGS.batch_size)
    np.save(consts.input_data, data)
    np.save(consts.input_labels, labels)


def get_models(model_dir, n_parties, ignore_parties):
    """Gets model files from model_dir."""
    model_files = [f for f in os.listdir(model_dir) if
                   os.path.isfile(os.path.join(model_dir, f))]
    if len(model_files) != n_parties and not ignore_parties:
        raise ValueError(
            f'{len(model_files)} models found when {n_parties + 1} parties '
            f'requested. Not equal.')
    return model_dir, model_files

#### Initial setup for the CaPC protocol:

In [4]:
args = get_args()
np.random.seed(args.seed)
clean_old_files()
set_data_labels(FLAGS=args)

if not os.path.exists('./logs'):
    os.mkdir('./logs')
log_timing_file = args.log_timing_file
log_timing('main: start capc', log_file=log_timing_file)

processes = []

def kill_processes():
    for p in processes:
        p.kill()

if not args.debug:
    atexit.register(kill_processes)

n_parties = args.n_parties
n_queries = args.n_queries
batch_size = args.batch_size
num_classes = args.num_classes
rstar_exp = args.rstar_exp
log_level = args.log_level
round_exp = args.round_exp
num_threads = args.num_threads
input_node = args.input_node
output_node = args.output_node
start_port = args.start_port
index = args.start_batch

# if FLAGS.cpu then use cpu without the encryption.
backend = 'HE_SEAL' if not args.cpu else 'CPU'

models_loc, model_files = get_models(
    args.checkpoint_dir, n_parties=n_parties,
    ignore_parties=args.ignore_parties)

for port in range(start_port, start_port + n_queries * n_parties):
    delete_files(port=port)


Unparsed flags: ['-f', '/home/dockuser/.local/share/jupyter/runtime/kernel-7a337de6-2cb0-4ea2-abd9-1eb2e811085f.json']
remove file:  input_data.npy
remove file:  input_labels.npy
delete file: files/logits37000privacy.txt
delete file: files/output37000.txt
delete file: files/noise37000.txt
delete file: files/output.txt
delete file: files/noise.txt
delete file: files/inference_times
delete file: files/argmax_times
delete file: files/client_csp_times
delete file: files/inference_no_network_times
delete file: files/logits37001privacy.txt
delete file: files/output37001.txt
delete file: files/noise37001.txt


### Step 1

#### The files server.py and client.py together complete Step 1 and can be referenced for more details. In this step, the querying party (in this case the client) first sends the query $q$ from the MNIST dataset to the answering party (the server) in Step 1a which on its own end generates a prediction for the query $r$. Each answering party then generates a random vector $r^{*}$ and  sends the vector $r-r^{*}$ to the querying party in Step 1b. Finally in Step 1c, the answering parties run  secure 2PC with the querying party to find the $s$ vector for the querying party and the $\hat{s_i}$ vectors for the answering party so that $s + \hat{s_i}$ is the one hot encoding of the argmax of the logits. 

In [5]:
# Querying Process
for query_num in range(n_queries):
    for port, model_file in zip(
                [start_port + int(i + query_num * n_parties) for i in
                 range(n_parties)],
                model_files):
        print(f"port: {port}")
        new_model_file = os.path.join(
            "/home/dockuser/models", str(port) + ".pb")

        print('Start the servers (answering parties: APs).')
        log_timing('start server (AP)', log_file=log_timing_file)
        # Command to start server with the relevant parameters.
        cmd_string = " ".join(
            [
                'python -W ignore', 'server.py',
                '--backend', backend,
                '--n_parties', f'{n_parties}',
                '--model_file', new_model_file,
                '--dataset_name', args.dataset_name,
                '--indext', str(index),
                '--encryption_parameters', args.encryption_params,
                '--enable_client', 'true',
                '--enable_gc', 'true',
                '--mask_gc_inputs', 'true',
                '--mask_gc_outputs', 'true',
                '--from_pytorch', '1',
                '--dataset_name', args.dataset_name,
                '--dataset_path', args.dataset_path,
                '--num_gc_threads', f'{num_threads}',
                '--input_node', f'{input_node}',
                '--output_node', f'{output_node}',
                '--minibatch_id', f'{query_num}',
                '--rstar_exp', f'{rstar_exp}',
                '--num_classes', f'{num_classes}',
                '--round_exp', f'{round_exp}',
                '--log_timing_file', log_timing_file,
                '--port', f'{port}',
                '--checkpoint_dir', args.checkpoint_dir,
            ])
        server_process = subprocess.Popen(cmd_string, shell=True)
        print("Start the client (the querying party: QP).")
        log_timing('start the client QP', log_file=log_timing_file)
        cmd_string = " ".join(
            [
                # Command to start client server with the relevant parameters.
                'python -W ignore client.py',
                '--batch_size', f'{batch_size}',
                '--encrypt_data_str', 'encrypt',
                '--indext', str(index),
                '--n_parties', f'{n_parties}',
                '--round_exp', f'{round_exp}',
                '--from_pytorch', '1',
                '--minibatch_id', f'{query_num}',
                '--dataset_path', f'{args.dataset_path}',
                '--port', f'{port}',
                '--dataset_name', args.dataset_name,
                '--data_partition', 'test',
                '--log_timing_file', log_timing_file,
            ])
        client_process = subprocess.Popen(cmd_string, shell=True)

        client_process.wait()
        server_process.wait()


port: 37000
Start the servers (answering parties: APs).
Start the client (the querying party: QP).
port: 37001
Start the servers (answering parties: APs).
Start the client (the querying party: QP).


### Steps 2 and 3

#### The file pg.py is used to run the privacy guardian (PG). The PG adds the $\hat{s}$ vectors from all answering parties and then adds Gaussian noise. This is followed by 2PC between the PG and the querying party (who has the sum of the $s$ vectors) to compute the final label. 

In [6]:
log_timing('start privacy guardian', log_file=log_timing_file)
# Command to run Privacy Guardian (Steps 2 & 3).
cmd_string = " ".join(
    ['python -W ignore', 'pg.py',
     '--start_port', f'{start_port + int(query_num * n_parties)}',
     '--end_port',
     f'{start_port + int(query_num * n_parties) + n_parties}',
     '--log_timing_file', log_timing_file,
     '--dp_noise_scale', str(args.dp_noise_scale),
     ])
print(f"start privacy guardian: {cmd_string}")
pg_process = subprocess.Popen(cmd_string, shell=True)
pg_process.wait()
log_timing('finish capc', log_file=log_timing_file)

start privacy guardian: python -W ignore pg.py --start_port 37000 --end_port 37002 --log_timing_file logs/log-timing-2021-11-04-22-43-39-583383.log --dp_noise_scale 0.05


#### Compare predicted label with actual label. The client (querying party) prints the outputted label. The actual label is manually found using the index of the query used. Note that the use of the client to run the function here is arbitary. The function can equivalently be made here below and used by calling print_label()

In [7]:
def print_label():
    """Function to print final label after Step 3 i.e. 2PC is complete"""
    with open(os.path.join('files', 'final_label.txt'), 'r') as file:
        label = file.read(1)
    print("Predicted label: ", label)

In [8]:
print_label() 
#client.print_label()
(x_train, y_train, x_test, y_test) = client_data.load_mnist_data(index, 1)
print("The correct label should be: ", np.argmax(y_test)) 
log_timing('finish capc', log_file=log_timing_file)

Predicted label:  6
The correct label should be:  6
