# Secure XGBoost Demo Notebook
This notebook provides an example of how one could use Secure XGBoost. In this example, we will use the client's encrypted data to train an XGBoost model on the server, within a secure enclave. 

For the purposes of the example, the client and server both run on the same machine. However, in an actual deployment, the client process would be executed on a separate, trusted machine. The server is assumed to be completely untrusted (except the secure enclave), so no sensitive data should be left in plaintext (i.e., unencrypted) outside the enclave.

The example consists of the following steps. 

1. **Key generation**: The client generates a secret symmetric key.
2. **Data encryption**: The client uses the key to encrypt its data.
3. **Enclave setup**: The client initializes its user object with the secret symmetric key, its private key, and its certificate. The server creates an enclave, and starts a process within it. The client [*attests*](https://software.intel.com/en-us/articles/code-sample-intel-software-guard-extensions-remote-attestation-end-to-end-example) the enclave process, and securely transfers its key to the enclave.
5. **Data loading**: The enclave loads the client's encrypted data.
6. **Training**: The enclave trains a model using the provided data.
67. **Prediction**: The enclave makes predictions with the model, and produces a set of encrypted results; the client decrypts the results.

Documentation for Secure XGBoost can be found [here](https://mc2-project.github.io/secure-xgboost/).

In [1]:
%load_ext autoreload
%autoreload 2

import securexgboost as xgb
import os

username = "user1"
HOME_DIR = os.path.abspath('') + "/../../../"
CURRENT_DIR = os.path.abspath('') + "/"
PUB_KEY = HOME_DIR + "config/user1.pem"
CERT_FILE = HOME_DIR + "config/{0}.crt".format(username)

ModuleNotFoundError: No module named 'securexgboost'

## 1. Key Generation
Generate a key to be used for encryption.

In [None]:
KEY_FILE = CURRENT_DIR + "key.txt"

# Generate a key you will be using for encryption
xgb.generate_client_key(KEY_FILE)

## 2. Data Encryption
Use the key generated above to encrypt our data.

In [None]:
training_data = HOME_DIR + "demo/data/agaricus.txt.train"
enc_training_data = CURRENT_DIR + "train.enc"

# Encrypt training data
xgb.encrypt_file(training_data, enc_training_data, KEY_FILE)

In [None]:
test_data = HOME_DIR + "demo/data/agaricus.txt.test"
enc_test_data = CURRENT_DIR + "test.enc"

# Encrypt test data
xgb.encrypt_file(test_data, enc_test_data, KEY_FILE)

## 3. Enclave setup

We'll need to create an enclave, authenticate the enclave, and lastly give the enclave the key we used to encrypt the data.

First, the client sets up its keys, creates an enclave, and runs the secure XGBoost binary inside the enclave. (This step may take several seconds to initialize the enclave.)

In [None]:
xgb.init_client(user_name=username, sym_key_file=KEY_FILE, priv_key_file=PUB_KEY, cert_file=CERT_FILE)

Simultaneously, the server launches the enclave.

In [None]:
xgb.init_server(enclave_image=HOME_DIR + "build/enclave/xgboost_enclave.signed", client_list=[username])

Next, the client verifies that the enclace has been correctly deployed, using remote attestation.

In [None]:
# Remote Attestation

# Pass in `verify=False` if running in simulation mode.
xgb.attest()

## 4. Data loading
The enclave is now ready to start the training process. First, load the encrypted data into a `DMatrix` within the enclave.

In [None]:
# Load training data
dtrain = xgb.DMatrix({username: enc_training_data})

In [None]:
# Load test data
dtest = xgb.DMatrix({username: enc_test_data})

## 5. Training
Set the training parameters, and start the training process within the enclave.

In [None]:
# Set parameters
params = {
        "tree_method": "hist",
        "n_gpus": "0",
        "objective": "binary:logistic",
        "min_child_weight": "1",
        "gamma": "0.1",
        "max_depth": "3",
        "verbosity": "1" 
}

In [None]:
# Train
num_rounds = 5
booster = xgb.train(params, dtrain, num_rounds, evals=[(dtrain, "train"), (dtest, "test")])

## 6. Prediction
Our `predict()` function yields predictions in an encrypted manner. The buffer that it returns will need to be decrypted by the client using the same key that the original data was encrypted with.

In [None]:
# Get Encrypted Predictions
enc_preds, num_preds = booster.predict(dtest, decrypt=False)

In [None]:
# Decrypt Predictions
preds = booster.decrypt_predictions(enc_preds, num_preds)
print(preds)