# Training Demo

This notebook demonstrates the use of the Confidential Training tool. It requires the [avato Training API](https://github.com/decentriq/avato-python-client-training) and its dependencies to be installed.  

Note that in a realistic, non-demo use of the Confidential Training tool, one analyst user and multiple dataowner users would upload data from different computers. In this workbook, for simplicity, the workflows of the analyst user and the two dataowner users are all shown together.   

### Import dependencies

In [95]:
import os
import pandas as pd
import json
from avato import Client
from avato import Secret
from avato_training import Training_Instance, Configuration
import numpy as np

analyst_username = os.getenv('ANALYST_ID')
analyst_***REMOVED*** = os.getenv('ANALYST_PASSWORD')

dataowner1_username = os.getenv('DATAOWNER1_ID')
dataowner1_***REMOVED*** = os.getenv('DATAOWNER1_PASSWORD')

dataowner2_username = os.getenv('DATAOWNER2_ID')
dataowner2_***REMOVED*** = os.getenv('DATAOWNER2_PASSWORD')

# This is the hash of the code
expected_measurement = "4ff505f350698c78e8b3b49b8e479146ce3896a06cd9e5109dfec8f393f14025"

# How the analyst expects the data to be formatted
feature_columns = ['fixed acidity', 'volatile acidity', 'citric acid', 'residual sugar', 'chlorides', 'free sulfur dioxide', 'total sulfur dioxide', 'density', 'pH', 'sulphates', 'alcohol']
label_column = "quality"

# The datafiles uploaded by the 
dataowner1_file = "test-data/wine-dataowner1.csv"
dataowner2_file = "test-data/wine-dataowner2.csv"

backend_host = "localhost" 
backend_port = 3000 

### ANALYST USER
#### Create new instance

In [96]:
# Create client.
analyst_client = Client(
    username=analyst_username,
    ***REMOVED***=analyst_***REMOVED***,
    instance_types=[Training_Instance],
    backend_host=backend_host,
    backend_port=backend_port
)

# Spin up an instance. Set who can participate in the instance.
analyst_instance = analyst_client.create_instance(
    "Training", 
    Training_Instance.type, 
    [dataowner1_username, dataowner2_username]
)
print("Instance ID: {}".format(analyst_instance.id))

Instance ID: 6cc8d512-0288-4fb8-a1d0-34e506f99ed6


#### Check security guarantees
Validating the so-called fatquote. This step is crucial for all security guarantees.
This step gets and validates the cryptographic proof from the enclave:

* i)   It proves it is a valid SGX enclave (by checking a certificate).
* ii)  It compares the hash of the enclave code provided by the user to
     an expected value (to verify what code is running in the enclave).
* iii) As part of the proof also a public key is transmitted that allows
     establishing a secure connection into the enclave (as the private
     key is only known to the enclave).
     
As we are using a non-production environment, we whitelist the debug and out_of_data flags.

In [97]:
analyst_instance.validate_fatquote(
    expected_measurement=expected_measurement,
    accept_debug=True,
    accept_group_out_of_date=True
)

The quote is part of the fatquote and provides a detailed fingerprint of the program and state of the remote machine. For example:
* using `flags` we can detect if the CPU is running in un-trusted debug mode
* using `*_snv` we can verify if all security patches have been deployed to the infrastructure
* using `mrenclave` we can attest to the exact program being executed on the remote machine

In [98]:
# Uncomment to inspect 
# print(analyst_instance.quote)

#### Configure instance

In [99]:
# Set the configuration
configuration = Configuration(
    feature_columns=feature_columns,
    label_column=label_column,
    ***REMOVED***=analyst_***REMOVED*** # the ***REMOVED*** to execute
)

# Create and set public-private keypair for secure communication.
analyst_secret = Secret()
analyst_instance.set_secret(analyst_secret)

# Upload
analyst_instance.upload_configuration(configuration)

### DATAOWNERS

In [100]:
# This function submits for a given dataowner a data file to the instance.
def dataowner_submit_data(dataowner_username, dataowner_***REMOVED***, instance_id, data_file):

    # Create client
    dataowner_client = Client(
        username=dataowner_username,
        ***REMOVED***=dataowner_***REMOVED***,
        instance_types=[Training_Instance],
        backend_host=backend_host,
        backend_port=backend_port
    )

    # Connect to instance (using ID from the analyst user)
    dataowner_instance = dataowner_client.get_instance(instance_id)

    # Check security guarantees.
    dataowner_instance.validate_fatquote(
        expected_measurement=expected_measurement,
        accept_debug=True,
        accept_group_out_of_date=True
    )

    # Create and set public-private keypair for secure communication.
    dataowner_secret = Secret()
    dataowner_instance.set_secret(dataowner_secret)

    # Get data format from the enclave
    data_format = dataowner_instance.get_data_format()
    print("Data format:\n{}".format(data_format))

    # Load data
    df = pd.read_csv(data_file)
    
    print("Loaded data:\n")
    print(df)

    # Submit data
    (ingested_rows, failed_rows) = dataowner_instance.submit_data(df)
    print("Number of successfully ingested rows: {}, number of failed rows: {}".format(ingested_rows, failed_rows))
    
    return dataowner_instance

#### dataowner 1 - Submit data

In [101]:
dataowner1_instance = dataowner_submit_data(
    dataowner1_username, 
    dataowner1_***REMOVED***, 
    analyst_instance.id, 
    data_file=dataowner1_file
)

Data format:
categoriesColumns: "fixed acidity"
categoriesColumns: "volatile acidity"
categoriesColumns: "citric acid"
categoriesColumns: "residual sugar"
categoriesColumns: "chlorides"
categoriesColumns: "free sulfur dioxide"
categoriesColumns: "total sulfur dioxide"
categoriesColumns: "density"
categoriesColumns: "pH"
categoriesColumns: "sulphates"
categoriesColumns: "alcohol"
valueColumn: "quality"

Loaded data:

      fixed acidity  volatile acidity  citric acid  residual sugar  chlorides  \
0               7.2              0.23         0.32            8.50      0.058   
1               6.2              0.32         0.16            7.00      0.045   
2               8.1              0.22         0.43            1.50      0.044   
3               8.3              0.42         0.62           19.25      0.040   
4               6.6              0.17         0.38            1.50      0.032   
...             ...               ...          ...             ...        ...   
2478         

#### dataowner 2 - Submit Data

In [102]:
dataowner2_instance = dataowner_submit_data(
    dataowner2_username, 
    dataowner2_***REMOVED***, 
    analyst_instance.id, 
    dataowner2_file
)

Data format:
categoriesColumns: "fixed acidity"
categoriesColumns: "volatile acidity"
categoriesColumns: "citric acid"
categoriesColumns: "residual sugar"
categoriesColumns: "chlorides"
categoriesColumns: "free sulfur dioxide"
categoriesColumns: "total sulfur dioxide"
categoriesColumns: "density"
categoriesColumns: "pH"
categoriesColumns: "sulphates"
categoriesColumns: "alcohol"
valueColumn: "quality"

Loaded data:

      fixed acidity  volatile acidity  citric acid  residual sugar  chlorides  \
0               7.0             0.270         0.36           20.70      0.045   
1               6.3             0.300         0.34            1.60      0.049   
2               8.1             0.280         0.40            6.90      0.050   
3               7.2             0.230         0.32            8.50      0.058   
4               8.1             0.280         0.40            6.90      0.050   
...             ...               ...          ...             ...        ...   
2410         

### ANALYST
#### Train with first set of hyperparameters

In [103]:
hyperparameters = {
    "learning_rate": 0.1,
    "num_splits": 10,
    "num_epochs": 5,
    "l2_penalty": 0.0,
    "l1_penalty": 0.0,
}

analyst_instance.start_execution(analyst_***REMOVED***, hyperparameters)

classifier, metadata = analyst_instance.get_results(analyst_***REMOVED***)
print("metadata: {}".format(json.dumps(metadata, indent=2)))

metadata: {
  "CV Test Accuracy": "0.4860941",
  "CV Train Accuracy": "0.48822862",
  "Fullset Accuracy": "0.48591262"
}


#### Train with second set of hyperparameters

In [104]:
hyperparameters = {
    "learning_rate": 1.0,
    "num_splits": 10,
    "num_epochs": 5,
    "l2_penalty": 0.0,
    "l1_penalty": 0.0,
}

analyst_instance.start_execution(analyst_***REMOVED***, hyperparameters)

classifier, metadata = analyst_instance.get_results(analyst_***REMOVED***)
print("metadata: {}".format(json.dumps(metadata, indent=2)))

metadata: {
  "CV Test Accuracy": "0.47443762",
  "CV Train Accuracy": "0.47695622",
  "Fullset Accuracy": "0.48040017"
}


#### Use classifier, compute accuracy on full dataset, compare with metadata results

In [48]:
def load_data():
    Xy = np.array(
        pd.concat([
            pd.read_csv(dataowner1_file),
            pd.read_csv(dataowner2_file)
        ])
    );
    X = Xy[:,0:-1]
    y = Xy[:,-1]
    return X, y

def compute_accuracy(classifier, X, y):

    y_hat = classifier.predict(X)
    assert len(y) == len(y_hat)
    n = len(y)
    n_eq = 0
    for yi, yi_hat in zip(y, y_hat):
        if float(yi) == float(yi_hat):
            n_eq = n_eq + 1
    accuracy = float(n_eq)/n
    return accuracy


X, y = load_data()
accuracy = compute_accuracy(classifier, X, y)

print("Some predictions of the classifier: {}".format(classifier.predict(X[0:2, :])))
print("Accuracy of the enclave classifier on the full dataset (as returned through the metadata object): {}".format(metadata["Fullset Accuracy"]))
print("Accuracy of the local classifier on the full dataset: {}".format(accuracy))




Some predictions of the classifier: ['6' '5']
Accuracy of the enclave classifier on the full dataset (as returned through the metadata object): 0.5151082
Accuracy of the local classifier on the full dataset: 0.5151082074316048


### ANALYST USER - Clean Up

In [49]:
analyst_instance.shutdown()
analyst_instance.delete()
assert analyst_instance.id not in analyst_client.get_instances()