# Model Owner

## Model Training

### Import dependencies

In [None]:
# TensorFlow and tf.keras
import tensorflow as tf
from tensorflow import keras

# Helper libraries
import numpy as np
import matplotlib.pyplot as plt
import os
import struct

### Directory to save machine learning model

In [None]:
current_dir = os.path.abspath('')
model_dir = os.path.join(current_dir, "data")
if not os.path.exists(model_dir):
    os.makedirs(model_dir)

### Load and prepare dataset

In [None]:
# Import the Fashion MNIST dataset.
fashion_mnist = keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

# Preprocess the data: Data normalization.
train_images = train_images / 255.0
test_images = test_images / 255.0

### Check dataset

In [None]:
# Explore the data.
len(test_labels)
# 10000
train_images.shape
# (60000, 28, 28)

In [None]:
#  Plot one input image.
plt.figure()
plt.imshow(train_images[0])
plt.colorbar()
plt.grid(False)
plt.show()

### Compile and train the model

In [None]:
def train_model(train_images, train_labels):
    # Define Keras deep-learning model
    model = keras.Sequential(
        [
            keras.layers.Flatten(input_shape=(28, 28)),
            keras.layers.Dense(128, activation="relu"),
            keras.layers.Dense(10),
            keras.layers.Activation('softmax')
        ]
    )

    # Setup the model for training:
    # Loss function — This measures how accurate the model is during
    #                 training. You want to minimize this function to
    #                 "steer" the model in the right direction.
    # Optimizer — This is how the model is updated based on the data it
    #             sees and its loss function.
    # Metrics — Used to monitor the training and testing steps.
    #           The following example uses accuracy, the fraction of the
    #           images that are correctly classified.
    model.compile(
        optimizer="adam",
        loss=tf.keras.losses.SparseCategoricalCrossentropy(
            from_logits=True),
        metrics=["accuracy"],
    )

    # Model training.
    model.fit(train_images, train_labels, epochs=2)


    return model


# Execute training.
model = train_model(train_images, train_labels)

# Save the model in SavedModel format, if needed.
model.save(os.path.join(model_dir,"MNIST_model_TF"), save_format="tf")

### Load a saved model (in case the model has already been trained and stored)

In [None]:
model = keras.models.load_model(os.path.join(model_dir,"MNIST_model_TF"))

### Evaluate model accuracy

In [None]:
test_loss, test_acc = model.evaluate(
    test_images, test_labels, verbose=2
)
print("\nTest accuracy:", test_acc)

### Make a local single prediction

In [None]:
test_image = test_images[0:1, :, :]
prediction = model.predict(test_image)
prediction

### Convert model to TFLite format

In [None]:
# In order to run the model in the Avato enclave it first needs
# to be converted into a simpler format called `TFLite`,
# also provided and maintained by Google.
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()

# Save the TFLite model in flatbuffer format, if necessary
with open(os.path.join(model_dir, "MNIST.fb"), "wb") as f:
    f.write(tflite_model)

### Load a TFLite model (in case the model has already been converted and stored)

In [None]:
with open(os.path.join(model_dir, "MNIST.fb"), "rb") as f:
    tflite_model = f.read()

## Model Upload

### Import dependencies

In [None]:
!python --version

In [None]:
from avato import Client
from avato import Secret
from avato_tflite_dynamic import TFLITEDYNAMIC_Instance

### Login to avato

In [None]:
model_owner_client = Client(
    api_token=os.environ["MODEL_OWNER_API_TOKEN"],
    instance_types=[TFLITEDYNAMIC_Instance],
    backend_host="api.decentriq.ch",
    backend_port="15005",
    use_ssl=True
)

### Create instance

In [None]:
inference_user_id = os.environ["INFERENCE_USER_ID"]
model_owner_instance = model_owner_client.create_instance(
    "Inference Demo",
    TFLITEDYNAMIC_Instance.type,
    [inference_user_id],
)
print(model_owner_instance.id)

### Check security guarantees

In [None]:
# Validating the fatquote. This step is crucial for all security
# guarantees.
# It gets and validates the cryptographic proof from the enclave:
#
# i)   It proves it is a valid SGX enclave (by checking a certificate).
# ii)  It compares the hash of the enclave code provided by the user to
#      an expected value (to verify what code is running in the enclave).
# iii) As part of the proof also a public key is transmitted that allows
#      establishing a secure connection into the enclave (as the private
#      key is only known to the enclave).
#
# As we are using a non-production environment, we whitelist debug and
# out_of_data flags

model_owner_instance.validate_fatquote(
    expected_measurement="6a2c1e90d79f09b9435b57301d388af8eaefde7e0b8feef345481a2a0527cfd2",
    accept_debug=True,
    accept_group_out_of_date=True,
)

print(model_owner_instance.fatquote)

In [None]:
#  The quote is part of the fatquote and provides a detailed fingerprint
#  of the program and state of the remote machine. For example:
#  * using `flags` we can detect if the CPU is running in un-trusted
#    debug mode
#  * using `*_snv` we can verify if all security patches have been
#    deployed to the infrastructure
#  * using `mrenclave` we can attest to the exact program being
#    executed on the remote machine
print(model_owner_instance.quote)

### Creating (randomly) a public-private keypair and setting it

In [None]:
model_owner_secret = Secret()
model_owner_instance.set_secret(model_owner_secret)

### Uploading the model

In [None]:
# Before uploading, the model is encrypted using the enclave
# public key extracted from the fatquote.
# The model_owner public key also sent together with the encrypted data.
model_owner_instance.upload_model(tflite_model)

### Make a local single prediction

In [None]:
prediction_remote = model_owner_instance.predict(test_image)

### Comparison of results

In [None]:
# Notice the little difference due to the fact that the TFLite model
# uses 32-bit precision, (like most models), while the local prediction
# is done on 64-bit precision.
print(f"Local - Label: {prediction.argmax()} with weight: {prediction.max()}")
print(
    f"Remote - Label: {prediction_remote.argmax()} with weight: {prediction_remote.max()}"
)

### Cleanup the enclave

In [None]:
model_owner_instance.shutdown()
model_owner_instance.delete()
assert model_owner_instance.id not in model_owner_client.get_instances()