In [1]:
"All constants used in the project."

from pathlib import Path
import pandas
import numpy
import os
import shutil


# The directory of this project
REPO_DIR = Path.cwd()

# Main necessary directories
DEPLOYMENT_PATH = REPO_DIR / "deployment_files"
FHE_KEYS = REPO_DIR / ".fhe_keys"
CLIENT_FILES = REPO_DIR / "client_files"

# ALl deployment directories
DEPLOYMENT_PATH = DEPLOYMENT_PATH / "model"

# Create the necessary directories
FHE_KEYS.mkdir(exist_ok=True)
CLIENT_FILES.mkdir(exist_ok=True)

# Development settings
PROCESSED_INPUT_SHAPE = (1, 39)

CLIENT_TYPES = ["applicant", "bank", "credit_bureau"]

In [2]:
import copy
from itertools import chain

from typing import Tuple

from concrete.fhe import Value, EvaluationKeys
from concrete.ml.deployment.fhe_client_server import FHEModelClient, FHEModelDev, FHEModelServer
from concrete.ml.sklearn import DecisionTreeClassifier 


class MultiInputsFHEModelClient(FHEModelClient):

    def __init__(self, *args, nb_inputs=1, **kwargs):
        self.nb_inputs = nb_inputs

        super().__init__(*args, **kwargs)
    
    def quantize_encrypt_serialize_multi_inputs(
        self, 
        x: numpy.ndarray, 
        input_index: int, 
        processed_input_shape: Tuple[int], 
        input_slice: slice,
    ) -> bytes:
        """Quantize, encrypt and serialize inputs for a multi-party model.

        In the following, the 'quantize_input' method called is the one defined in Concrete ML's 
        built-in models. Since they don't natively handle inputs for multi-party models, we need
        to use padding along indexing and slicing so that inputs from a specific party are correctly 
        associated with input quantizers.
        
        Args:
            x (numpy.ndarray): The input to consider. Here, the input should only represent a
                single party.
            input_index (int): The index representing the type of model (0: "applicant", 1: "bank", 
                2: "credit_bureau")
            processed_input_shape (Tuple[int]): The total input shape (all parties combined) after
                pre-processing.
            input_slice (slice): The slices to consider for the given party.

        """

        x_padded = numpy.zeros(processed_input_shape)

        x_padded[:, input_slice] = x

        q_x_padded = self.model.quantize_input(x_padded)

        q_x = q_x_padded[:, input_slice]
        
        q_x_inputs = [None for _ in range(self.nb_inputs)]
        q_x_inputs[input_index] = q_x

        # Encrypt the values
        q_x_enc = self.client.encrypt(*q_x_inputs)

        # Serialize the encrypted values to be sent to the server
        q_x_enc_ser = q_x_enc[input_index].serialize()
        return q_x_enc_ser

  from .autonotebook import tqdm as notebook_tqdm


In [3]:


def clean_temporary_files(n_keys=20):
    """Clean older keys and encrypted files.

    A maximum of n_keys keys and associated temporary files are allowed to be stored. Once this 
    limit is reached, the oldest files are deleted.

    Args:
        n_keys (int): The maximum number of keys and associated files to be stored. Default to 20.

    """
    # Get the oldest key files in the key directory
    key_dirs = sorted(FHE_KEYS.iterdir(), key=os.path.getmtime)

    # If more than n_keys keys are found, remove the oldest
    client_ids = []
    if len(key_dirs) > n_keys:
        n_keys_to_delete = len(key_dirs) - n_keys
        for key_dir in key_dirs[:n_keys_to_delete]:
            client_ids.append(key_dir.name)
            shutil.rmtree(key_dir)
    
    # Delete all files related to the IDs whose keys were deleted
    for directory in chain(CLIENT_FILES.iterdir()):
        for client_id in client_ids:
            if client_id in directory.name:
                shutil.rmtree(directory)
                
def _get_client(client_id):
    """Get the client instance.

    Args:
        client_id (int): The client ID to consider.

    Returns:
        FHEModelClient: The client instance.
    """
    key_dir = FHE_KEYS / f"{client_id}"

    return MultiInputsFHEModelClient(DEPLOYMENT_PATH, key_dir=key_dir, nb_inputs=len(CLIENT_TYPES))

def _get_client_file_path(name, client_id, client_type=None):
    """Get the file path for the client.

    Args:
        name (str): The desired file name (either 'evaluation_key', 'encrypted_inputs' or 
            'encrypted_outputs').
        client_id (int): The client ID to consider.
        client_type (Optional[str]): The type of client to consider (either 'applicant', 'bank', 
            'credit_bureau' or None). Default to None, which is used for evaluation key and output.

    Returns:
        pathlib.Path: The file path.
    """
    client_type_suffix = "" 
    if client_type is not None:
        client_type_suffix = f"_{client_type}"

    dir_path = CLIENT_FILES / f"{client_id}"
    dir_path.mkdir(exist_ok=True)

    return dir_path / f"{name}{client_type_suffix}"

def shorten_bytes_object(bytes_object, limit=500):
    """Shorten the input bytes object to a given length.

    Encrypted data is too large for displaying it in the browser using Gradio. This function
    provides a shorten representation of it.

    Args:
        bytes_object (bytes): The input to shorten
        limit (int): The length to consider. Default to 500.

    Returns:
        str: Hexadecimal string shorten representation of the input byte object. 

    """
    # Define a shift for better display
    shift = 100
    return bytes_object[shift : limit + shift].hex()

def keygen_send():
    """Generate the private and evaluation key, and send the evaluation key to the server.
    
    Returns:
        client_id (str): The current client ID to consider.
    """
    # Clean temporary files
    clean_temporary_files(3)

    # Create an ID for the current client to consider
    client_id = numpy.random.randint(0, 2**32)

    # Retrieve the client instance
    client = _get_client(client_id)

    # Generate the private and evaluation keys
    client.generate_private_and_evaluation_keys(force=True)

    # Retrieve the serialized evaluation key
    evaluation_key = client.get_serialized_evaluation_keys()

    file_name = "evaluation_key"

    # Save evaluation key as bytes in a file as it is too large to pass through regular Gradio
    # buttons (see https://github.com/gradio-app/gradio/issues/1877)
    evaluation_key_path = _get_client_file_path(file_name, client_id)

    with evaluation_key_path.open("wb") as evaluation_key_file:
        evaluation_key_file.write(evaluation_key)

    # Send the evaluation key to the server
    # _send_to_server(client_id, None, file_name)
    print(f"Client ID: {client_id}, Evaluation key sent to server: {evaluation_key_path}")

    # Create a truncated version of the evaluation key for display
    evaluation_key_short = shorten_bytes_object(evaluation_key)
    
    return client_id, evaluation_key_short

keygen_send()
print("Keys are generated and evaluation key is sent ✅")

Client ID: 3624725669, Evaluation key sent to server: /Users/elvin/Development/llm/dissertation-dev/fed-analysis-keys/client_files/3624725669/evaluation_key
Keys are generated and evaluation key is sent ✅


KeySetCache: miss, regenerating /Users/elvin/Development/llm/dissertation-dev/fed-analysis-keys/.fhe_keys/3624725669/2477099136904845538
