# Confidential ML Training Demo (Analyst)

This notebook is the Analyst part of the *Confidential ML Training Demo* showing how a simple logistic regression classifier can be trained while keeping the training data provably confidential. The demo requires the [Training Client API](https://github.com/decentriq/avato-python-client-training) and its dependencies to be installed.  

Note that while we demo the training of a logistic regression enclave, it can be used to train a variety of other classifiers. 

## 1 - Import dependencies and set parameters

In [12]:
import os
import pandas as pd
import json
from avato import Client
from avato import Secret
from avato_training import Training_Instance, Configuration
import numpy as np
import example

analyst_username, analyst_***REMOVED*** = example.analyst_credentials

# The analyst needs these to control who can upload data
dataowner1_username = os.getenv('DATAOWNER1_ID')
dataowner2_username = os.getenv('DATAOWNER2_ID')

# This is the hash of the code
expected_measurement = "4ff505f350698c78e8b3b49b8e479146ce3896a06cd9e5109dfec8f393f14025"

# How the analyst expects the data to be formatted
feature_columns = ['fixed acidity', 'volatile acidity', 'citric acid', 'residual sugar', 'chlorides', 'free sulfur dioxide', 'total sulfur dioxide', 'density', 'pH', 'sulphates', 'alcohol']
label_column = "quality"

# This points to the confidential computing system
backend_host = "localhost" 
backend_port = 3000

## 2 - Set up instance
### Create new instance

In [13]:
# Create client.
analyst_client = Client(
    username=analyst_username,
    ***REMOVED***=analyst_***REMOVED***,
    instance_types=[Training_Instance],
    backend_host=backend_host,
    backend_port=backend_port
)

# Spin up an instance. Set who can participate in the instance.
analyst_instance = analyst_client.create_instance(
    "Training", 
    Training_Instance.type, 
    [dataowner1_username, dataowner2_username]
)
print("Instance ID: {}".format(analyst_instance.id))

Instance ID: e74d2ed9-a719-45f3-a237-20951eb54c85


#### Verify security
Validating the so-called fatquote. This is crucial for all security guarantees.
This step gets and validates the cryptographic proof from the enclave:

* i)   It proves it is a valid SGX enclave (by checking a certificate).
* ii)  It compares the hash of the enclave code provided by the user to
     an expected value (to verify what code is running in the enclave).
* iii) As part of the proof also a public key is transmitted that allows
     establishing a secure connection into the enclave (as the private
     key is only known to the enclave).
     
As we are using a non-production environment, we whitelist the debug and out_of_data flags.

In [14]:
analyst_instance.validate_fatquote(
    expected_measurement=expected_measurement,
    accept_debug=True,
    accept_group_out_of_date=True
)

#### Configure instance

In [15]:
# Set the configuration
configuration = Configuration(
    feature_columns=feature_columns,
    label_column=label_column,
    ***REMOVED***=analyst_***REMOVED*** # the ***REMOVED*** to execute
)

# Create and set public-private keypair for secure communication.
analyst_secret = Secret()
analyst_instance.set_secret(analyst_secret)

# Upload
analyst_instance.upload_configuration(configuration)

## 3 - Train model (after data has been uploaded by the Dataowners)
### Train with first set of hyperparameters
This returns the trained classifier and metadata

In [16]:
hyperparameters = {
    "learning_rate": 0.1,
    "num_splits": 10,
    "num_epochs": 5,
    "l2_penalty": 0.0,
    "l1_penalty": 0.0,
}

analyst_instance.start_execution(analyst_***REMOVED***, hyperparameters)

classifier, metadata = analyst_instance.get_results(analyst_***REMOVED***)
print("metadata: {}".format(json.dumps(metadata, indent=2)))

BadRequestError: b'EnclaveCommunicationError call=PostMessage message=2 UNKNOWN: Stream removed'

### Train with second set of hyperparameters

In [None]:
hyperparameters = {
    "learning_rate": 1.0,
    "num_splits": 10,
    "num_epochs": 5,
    "l2_penalty": 0.0,
    "l1_penalty": 0.0,
}

analyst_instance.start_execution(analyst_***REMOVED***, hyperparameters)

classifier, metadata = analyst_instance.get_results(analyst_***REMOVED***)
print("metadata: {}".format(json.dumps(metadata, indent=2)))

## 4 - Use the classifier
### Use classifier, compute accuracy on full dataset, compare with metadata results

In [None]:
X, y = example.load_data()
accuracy = example.compute_accuracy(classifier, X, y)

print("Some predictions of the classifier: {}".format(classifier.predict(X[0:2, :])))
print("Accuracy of the enclave classifier on the full dataset (as returned through the metadata object): {}".format(metadata["Fullset Accuracy"]))
print("Accuracy of the local classifier on the full dataset: {}".format(accuracy))

## 5 - Clean Up

In [None]:
analyst_instance.shutdown()
analyst_instance.delete()
assert analyst_instance.id not in analyst_client.get_instances()