# Training Demo

This notebook demonstrates the use of the Confidential Training tool. It requires the [avato Training API](https://github.com/decentriq/avato-python-client-training) and its dependencies to be installed.  

Note that in a realistic, non-demo use of the Confidential Training tool, one admin user and multiple participant users would upload data from different computers. In this workbook, for simplicity, the workflows of the admin user and the two participant users are all shown together.   

### Import dependencies

In [13]:
import os
import pandas as pd
import json
from avato import Client
from avato import Secret
from avato_training import Training_Instance, Configuration

admin_username = "***REMOVED***"
admin_***REMOVED*** = "***REMOVED***"

participant1_username = "***REMOVED***"
participant1_***REMOVED*** = "***REMOVED***"

participant2_username = "***REMOVED***"
participant2_***REMOVED*** = "***REMOVED***"

# This is the hash of the code
expected_measurement = "4ff505f350698c78e8b3b49b8e479146ce3896a06cd9e5109dfec8f393f14025"

backend_host = "localhost" 
backend_port = 3000 

### ADMIN USER - Set Up Instance
#### Create new instance

In [2]:
# Create client.
admin_client = Client(
    username=admin_username,
    ***REMOVED***=admin_***REMOVED***,
    instance_types=[Training_Instance],
    backend_host=backend_host,
    backend_port=backend_port
)

# Spin up an instance. Set who can participate in the instance.
admin_instance = admin_client.create_instance(
    "Training", 
    Training_Instance.type, 
    [participant1_username, participant2_username]
)
print("Instance ID: {}".format(admin_instance.id))

Instance ID: 7cc0d5b9-6ca0-4643-8d96-673d89b2be56


#### Check security guarantees
Validating the so-called fatquote. This step is crucial for all security guarantees.
This step gets and validates the cryptographic proof from the enclave:

* i)   It proves it is a valid SGX enclave (by checking a certificate).
* ii)  It compares the hash of the enclave code provided by the user to
     an expected value (to verify what code is running in the enclave).
* iii) As part of the proof also a public key is transmitted that allows
     establishing a secure connection into the enclave (as the private
     key is only known to the enclave).
     
As we are using a non-production environment, we whitelist the debug and out_of_data flags.

In [3]:
admin_instance.validate_fatquote(
    expected_measurement=expected_measurement,
    accept_debug=True,
    accept_group_out_of_date=True
)

The quote is part of the fatquote and provides a detailed fingerprint of the program and state of the remote machine. For example:
* using `flags` we can detect if the CPU is running in un-trusted debug mode
* using `*_snv` we can verify if all security patches have been deployed to the infrastructure
* using `mrenclave` we can attest to the exact program being executed on the remote machine

In [4]:
# Uncomment to inspect 
# print(admin_instance.quote)

#### Configure instance

In [5]:
# Set the configuration
configuration = Configuration(
    categories_columns=["x_1", "x_2", "x_3", "x_4"],
    value_column="y",
    ***REMOVED***=admin_***REMOVED*** # the ***REMOVED*** to execute
)

# Create and set public-private keypair for secure communication.
admin_secret = Secret()
admin_instance.set_secret(admin_secret)

# Upload
admin_instance.upload_configuration(configuration)

### PARTICIPANT USERS - Submit Data

In [6]:
# This function submits for a given participant a data file to the instance.
def participant_submit_data(participant_username, participant_***REMOVED***, instance_id, data_file):

    # Create client
    participant_client = Client(
        username=participant_username,
        ***REMOVED***=participant_***REMOVED***,
        instance_types=[Training_Instance],
        backend_host=backend_host,
        backend_port=backend_port
    )

    # Connect to instance (using ID from the admin user)
    participant_instance = participant_client.get_instance(instance_id)

    # Check security guarantees.
    participant_instance.validate_fatquote(
        expected_measurement=expected_measurement,
        accept_debug=True,
        accept_group_out_of_date=True
    )

    # Create and set public-private keypair for secure communication.
    participant_secret = Secret()
    participant_instance.set_secret(participant_secret)

    # Get data format from the enclave
    data_format = participant_instance.get_data_format()
    print("Data format:\n{}".format(data_format))

    # Load data
    df = pd.read_csv(data_file)
    
    print("Loaded data:\n")
    print(df)

    # Submit data
    (ingested_rows, failed_rows) = participant_instance.submit_data(df)
    print("Number of successfully ingested rows: {}, number of failed rows: {}".format(ingested_rows, failed_rows))
    
    return participant_instance

#### PARTICIPANT 1 - Submit data

In [7]:
participant1_instance = participant_submit_data(
    participant1_username, 
    participant1_***REMOVED***, 
    admin_instance.id, 
    "./test-data/participant1_data.csv"
)

Data format:
categoriesColumns: "x_1"
categoriesColumns: "x_2"
categoriesColumns: "x_3"
categoriesColumns: "x_4"
valueColumn: "y"

Loaded data:

    x_1  x_2  x_3  x_4  y
0   5.1  3.5  1.4  0.2  0
1   4.9  3.0  1.4  0.2  0
2   4.7  3.2  1.3  0.2  0
3   4.6  3.1  1.5  0.2  0
4   5.0  3.6  1.4  0.2  0
..  ...  ...  ...  ... ..
70  5.9  3.2  4.8  1.8  1
71  6.1  2.8  4.0  1.3  1
72  6.3  2.5  4.9  1.5  1
73  6.1  2.8  4.7  1.2  1
74  6.4  2.9  4.3  1.3  1

[75 rows x 5 columns]
Number of successfully ingested rows: 75, number of failed rows: []


#### PARTICIPANT 2 - Submit Data

In [8]:
participant2_instance = participant_submit_data(
    participant2_username, 
    participant2_***REMOVED***, 
    admin_instance.id, 
    "./test-data/participant2_data.csv"
)

Data format:
categoriesColumns: "x_1"
categoriesColumns: "x_2"
categoriesColumns: "x_3"
categoriesColumns: "x_4"
valueColumn: "y"

Loaded data:

    x_1  x_2  x_3  x_4  y
0   6.6  3.0  4.4  1.4  1
1   6.8  2.8  4.8  1.4  1
2   6.7  3.0  5.0  1.7  1
3   6.0  2.9  4.5  1.5  1
4   5.7  2.6  3.5  1.0  1
..  ...  ...  ...  ... ..
70  6.7  3.0  5.2  2.3  2
71  6.3  2.5  5.0  1.9  2
72  6.5  3.0  5.2  2.0  2
73  6.2  3.4  5.4  2.3  2
74  5.9  3.0  5.1  1.8  2

[75 rows x 5 columns]
Number of successfully ingested rows: 75, number of failed rows: []


### ADMIN USER - Start Execution

In [9]:
admin_instance.start_execution(admin_***REMOVED***)

### PARTICIPANT USERS - Get Results

#### PARTICIPANT 1 - Get results

In [16]:
results1 = participant1_instance.get_results()
print("Result participant1\n{}:".format(json.dumps(results1, indent=2)))

Result participant1
{
  "model 0 - param (0,0)": 0.9491353034973145,
  "model 0 - param (0,1)": 2.4495279788970947,
  "model 0 - param (0,2)": -3.6021242141723633,
  "model 0 - param (0,3)": -3.77264666557312,
  "model 1 - param (0,0)": 0.24731068313121796,
  "model 1 - param (0,1)": -1.132555365562439,
  "model 1 - param (0,2)": 0.8474775552749634,
  "model 1 - param (0,3)": -1.7091010808944702,
  "model 2 - param (0,0)": -1.8889198303222656,
  "model 2 - param (0,1)": -2.451262950897217,
  "model 2 - param (0,2)": 2.4787139892578125,
  "model 2 - param (0,3)": 4.074766635894775
}:


#### PARTICIPANT 2 - Get results

In [17]:
results2 = participant2_instance.get_results()
print("Result participant2:")
print("Result participant2\n{}:".format(json.dumps(results2, indent=2)))

Result participant2:
Result participant2
{
  "model 0 - param (0,0)": 0.9491353034973145,
  "model 0 - param (0,1)": 2.4495279788970947,
  "model 0 - param (0,2)": -3.6021242141723633,
  "model 0 - param (0,3)": -3.77264666557312,
  "model 1 - param (0,0)": 0.24731068313121796,
  "model 1 - param (0,1)": -1.132555365562439,
  "model 1 - param (0,2)": 0.8474775552749634,
  "model 1 - param (0,3)": -1.7091010808944702,
  "model 2 - param (0,0)": -1.8889198303222656,
  "model 2 - param (0,1)": -2.451262950897217,
  "model 2 - param (0,2)": 2.4787139892578125,
  "model 2 - param (0,3)": 4.074766635894775
}:


### ADMIN USER - Clean Up

In [None]:
admin_instance.shutdown()
admin_instance.delete()
assert admin_instance.id not in admin_client.get_instances()