# Part 2: Secure Model Serving with TFE Keras

Now that you have a trained model with normal Keras, you are ready to serve some private predictions. We can do that using TFE Keras.

To secure and serve this model, we will need three TFE servers. This is because TF Encrypted under the hood uses an encryption technique called [multi-party computation (MPC)](https://en.wikipedia.org/wiki/Secure_multi-party_computation). The idea is to split the model weights and input data into shares, then send a share of each value to the different servers. The key property is that if you look at the share on one server, it reveals nothing about the original value (input data or model weights).

If you want to learn more about MPC, you can read this excellent [blog](https://mortendahl.github.io/2017/04/17/private-deep-learning-with-mpc/).

In this notebook, you will be able serve private predictions after a ssries of simple steps:
- Configure TFE Protocol to secure the model via secret sharing.
- Launch three TFE servers.
- Convert the TF Keras model into a TFE Keras model using `tfe.keras.models.clone_model`.
- Serve the secured model using `tfe.serving.QueueServer`.

Alright, let's do it!

In [1]:
from collections import OrderedDict

import numpy as np
import tensorflow as tf

import tf_encrypted as tfe
import tf_encrypted.keras.backend as KE

tf.compat.v1.disable_eager_execution()

Falling back to insecure randomness since the required custom op could not be found for the installed version of TensorFlow. Fix this by compiling custom ops. Missing file was '/Users/justinpatriquin/projects/tf-encrypted/tf_encrypted/operations/secure_random/secure_random_module_tf_2.0.0.so'


## Define Model

As you can see, we define almost the exact same model as before, except we provide a `batch_input_shape`. This allows TF Encrypted to better optimize the secure computations via predefined tensor shapes. For this MNIST demo, we'll send input data with the shape of (1, 28, 28, 1). 
We also return the logit instead of softmax because this operation is complex to perform using MPC, and we don't need it to serve prediction requests.

In [None]:
num_classes = 10
input_shape = (1, 28, 28, 1)

In [None]:
model = tf.keras.Sequential([
          tf.keras.layers.Conv2D(16, 8,
                                 strides=2,
                                 padding='same',
                                 activation='relu',
                                 batch_input_shape=input_shape),
          tf.keras.layers.AveragePooling2D(2, 1),
          tf.keras.layers.Conv2D(32, 4,
                                 strides=2,
                                 padding='valid',
                                 activation='relu'),
          tf.keras.layers.AveragePooling2D(2, 1),
          tf.keras.layers.Flatten(),
          tf.keras.layers.Dense(32, activation='relu'),
          tf.keras.layers.Dense(num_classes, name='logit')
  ])

With `load_weights` you can easily load the weights you have saved previously after training your model.

In [None]:
pre_trained_weights = 'short-dnn.h5'
model.load_weights(pre_trained_weights)

## Protocol

We first configure the protocol we will be using, as well as the servers on which we want to run it. We will be using the SecureNN protocol to secret share the model between each of the three TFE servers. Most importantly, this will add the capability of providing predictions on encrypted data.

Note that the configuration is saved to file as we will be needing it in the client as well.

In [2]:
players = OrderedDict([
    ('server0', 'localhost:4000'),
    ('server1', 'localhost:4001'),
    ('server2', 'localhost:4002'),
])

config = tfe.RemoteConfig(players)
config.save('/tmp/config.json')

In [None]:
tfe.set_config(config)
tfe.set_protocol(tfe.protocol.SecureNN())

## Launching servers

Before actually serving the computation below we need to launch TFE servers in new processes. Run the following in three different terminals. You may have to allow Python to accept incoming connections.

In [3]:
for player_name in players.keys():
    print("python -m tf_encrypted.player --config /tmp/config.json {}".format(player_name))

python -m tf_encrypted.player --config /tmp/tfe.config server0
python -m tf_encrypted.player --config /tmp/tfe.config server1
python -m tf_encrypted.player --config /tmp/tfe.config crypto-producer
python -m tf_encrypted.player --config /tmp/tfe.config model-owner
python -m tf_encrypted.player --config /tmp/tfe.config data-owner-0
python -m tf_encrypted.player --config /tmp/tfe.config data-owner-1
python -m tf_encrypted.player --config /tmp/tfe.config data-owner-2


## Convert TF Keras into TFE Keras

Thanks to `tfe.keras.models.clone_model` you can convert automatically the TF Keras model into a TFE Keras model.

In [None]:
with tfe.protocol.SecureNN():
    tfe_model = tfe.keras.models.clone_model(model)

## Set up a new `tfe.serving.QueueServer` 

`tfe.serving.QueueServer` will launch a serving queue, so that the TFE servers can accept prediction requests on the secured model from external clients.

In [None]:
# Set up a new tfe.serving.QueueServer for the shared TFE model
q_input_shape = (1, 28, 28, 1)
q_output_shape = (1, 10)

server = tfe.serving.QueueServer(
    input_shape=q_input_shape, output_shape=q_output_shape, computation_fn=tfe_model
)

## Start Server

Perfect! with all of the above in place we can finally connect to our servers, push our TensorFlow graph to them, and start serving the model. You can set `num_requests` to set a limit on the number of predictions requests served by the model; if not specified then the model will be served until interrupted.

In [None]:
sess = KE.get_session()

In [None]:
request_ix = 1

def step_fn():
    global request_ix
    print("Served encrypted prediction {i} to client.".format(i=request_ix))
    request_ix += 1

server.run(
    sess,
    num_steps=3,
    step_fn=step_fn)

You are ready to move to the **c - Private Prediction Client** notebook to request some private predictions. 

### Cleanup!

Once your request limit above, the model will no longer be available for serving requests, but it's still secret shared between the three workers above. You can kill the workers by executing the cell below.

**Congratulations** on finishing b - Secure Model Serving.

In [None]:
process_ids = !ps aux | grep '[p]ython -m tf_encrypted.player --config' | awk '{print $2}'
for process_id in process_ids:
    !kill {process_id}
    print("Process ID {id} has been killed.".format(id=process_id))