# Section: Encrypted Deep Learning

- Lesson: Reviewing Additive Secret Sharing
- Lesson: Encrypted Subtraction and Public/Scalar Multiplication
- Lesson: Encrypted Computation in PySyft
- Project: Build an Encrypted Database
- Lesson: Encrypted Deep Learning in PyTorch
- Lesson: Encrypted Deep Learning in Keras
- Final Project

# Lesson: Reviewing Additive Secret Sharing

_For more great information about SMPC protocols like this one, visit https://mortendahl.github.io. With permission, Morten's work directly inspired this first teaching segment._

In [1]:
import random
import numpy as np

BASE = 10

PRECISION_INTEGRAL = 8
PRECISION_FRACTIONAL = 8
Q = 293973345475167247070445277780365744413

PRECISION = PRECISION_INTEGRAL + PRECISION_FRACTIONAL

assert (Q > BASE ** PRECISION)


def encode(rational):
    upscaled = int(rational * BASE ** PRECISION_FRACTIONAL)
    field_element = upscaled % Q
    return field_element


def decode(field_element):
    upscaled = field_element if field_element <= Q / 2 else field_element - Q
    rational = upscaled / BASE ** PRECISION_FRACTIONAL
    return rational


def encrypt(secret):
    first = random.randrange(Q)
    second = random.randrange(Q)
    third = (secret - first - second) % Q
    return [first, second, third]


def decrypt(sharing):
    return sum(sharing) % Q


def add(a, b):
    c = list()
    for i in range(len(a)):
        c.append((a[i] + b[i]) % Q)
    return tuple(c)


In [2]:
x = encrypt(encode(5.5))
x

[89946055937878781269751054855960126947,
 96080289752138375846914300224941442615,
 107946999785150089953779922700014174851]

In [3]:
y = encrypt(encode(2.3))
y

[63133476195661899575913847675787781237,
 168230754106348661946563174170437021227,
 62609115173156685547968255934370941948]

In [4]:
z = add(x, y)
z


(153079532133540680845664902531747908184,
 264311043858487037793477474395378463842,
 170556114958306775501748178634385116799)

In [5]:
decode(decrypt(z))

7.79999999

# Lesson: Encrypted Subtraction and Public/Scalar Multiplication

In [6]:
field = 23740629843760239486723


In [7]:
x = 5

bob_x_share = 2372385723  # random number
alices_x_share = field - bob_x_share + x


In [8]:
(bob_x_share + alices_x_share) % field


5

In [9]:
field = 10

x = 5

bob_x_share = 8
alice_x_share = field - bob_x_share + x

y = 1

bob_y_share = 9
alice_y_share = field - bob_y_share + y


In [10]:
((bob_x_share + alice_x_share) - (bob_y_share + alice_y_share)) % field


4

In [11]:
((bob_x_share - bob_y_share) + (alice_x_share - alice_y_share)) % field


4

In [12]:
bob_x_share + alice_x_share + bob_y_share + alice_y_share


26

In [13]:
bob_z_share = (bob_x_share - bob_y_share)
alice_z_share = (alice_x_share - alice_y_share)


In [14]:
(bob_z_share + alice_z_share) % field


4

In [15]:
def sub(a, b):
    c = list()
    for i in range(len(a)):
        c.append((a[i] - b[i]) % Q)
    return tuple(c)


In [16]:
field = 10

x = 5

bob_x_share = 8
alice_x_share = field - bob_x_share + x

y = 1

bob_y_share = 9
alice_y_share = field - bob_y_share + y


In [17]:
bob_x_share + alice_x_share

15

In [18]:
bob_y_share + alice_y_share

11

In [19]:
((bob_y_share * 3) + (alice_y_share * 3)) % field


3

In [20]:
def imul(a, scalar):
    # logic here which can multiply by a public scalar
    c = list()
    for i in range(len(a)):
        c.append((a[i] * scalar) % Q)
    return tuple(c)


In [21]:
x = encrypt(encode(5.5))
x


[174093083252555128024126505646678583291,
 226796350751845359218433328732927991832,
 187057256945934006898330721181674913703]

In [22]:
z = imul(x, 3)


In [23]:
decode(decrypt(z))

16.5

# Lesson: Encrypted Computation in PySyft

In [25]:
import syft as sy
import torch as th
hook = sy.TorchHook(th)
from torch import nn, optim




In [26]:
bob = sy.VirtualWorker(hook, id="bob").add_worker(sy.local_worker)
alice = sy.VirtualWorker(hook, id="alice").add_worker(sy.local_worker)
secure_worker = sy.VirtualWorker(hook, id="secure_worker").add_worker(sy.local_worker)


In [30]:
x = th.tensor([1, 2, 3, 4])
y = th.tensor([2, -1, 1, 0])


In [31]:
x = x.share(bob, alice, crypto_provider=secure_worker)
x


(Wrapper)>[AdditiveSharingTensor]
	-> (Wrapper)>[PointerTensor | me:53488425189 -> bob:22500053968]
	-> (Wrapper)>[PointerTensor | me:74904402757 -> alice:10318919918]
	*crypto provider: secure_worker*

In [32]:
y = y.share(bob, alice, crypto_provider=secure_worker)
y


(Wrapper)>[AdditiveSharingTensor]
	-> (Wrapper)>[PointerTensor | me:37875131198 -> bob:8720012974]
	-> (Wrapper)>[PointerTensor | me:5288806080 -> alice:11538024511]
	*crypto provider: secure_worker*

In [33]:
z = x + y
z.get()


tensor([3, 1, 4, 4])

In [34]:
z = x - y
z.get()


tensor([-1,  3,  2,  4])

In [35]:
z = x * y
z.get()


tensor([ 2, -2,  3,  0])

In [36]:
z = x > y
z.get()


tensor([0, 1, 1, 1])

In [37]:
z = x < y
z.get()


tensor([1, 0, 0, 0])

In [38]:
z = x == y
z.get()


tensor([0, 0, 0, 0])

In [40]:
x = th.tensor([1, 2, 3, 4])
y = th.tensor([2, -1, 1, 0])

x = x.fix_precision().share(bob, alice, crypto_provider=secure_worker)
y = y.fix_precision().share(bob, alice, crypto_provider=secure_worker)
x


(Wrapper)>FixedPrecisionTensor>(Wrapper)>[AdditiveSharingTensor]
	-> (Wrapper)>[PointerTensor | me:15258064215 -> bob:58821851563]
	-> (Wrapper)>[PointerTensor | me:19233042876 -> alice:45941336518]
	*crypto provider: secure_worker*

In [41]:
z = x + y
z.get().float_precision()


tensor([3., 1., 4., 4.])

In [42]:
z = x - y
z.get().float_precision()


tensor([-1.,  3.,  2.,  4.])

In [43]:
z = x * y
z.get().float_precision()


tensor([ 2., -2.,  3.,  0.])

In [44]:
z = x > y
z.get().float_precision()


tensor([0., 1., 1., 1.])

In [45]:
z = x < y
z.get().float_precision()


tensor([1., 0., 0., 0.])

In [46]:
z = x == y
z.get().float_precision()


tensor([0., 0., 0., 0.])

# Project: Build an Encrypted Database

In [47]:
import string

In [48]:
char2index = {}
index2char = {}

In [49]:
for i, char in enumerate(' ' + string.ascii_lowercase + '0123456789' + string.punctuation):
    char2index[char] = i
    index2char[i] = char


In [50]:
str_input = "Hello"
max_len = 8


In [51]:
def string2values(str_input, max_len=8):
    str_input = str_input[:max_len].lower()
    if len(str_input) < max_len:
        str_input = str_input + "." * (max_len - len(str_input))
    values = list()
    for char in str_input:
        values.append(char2index[char])
    return th.tensor(values).long()


In [52]:
def values2string(input_values):
    s = ""
    for value in input_values:
        s += index2char[int(value)]
    return s


In [71]:
def strings_equal(str_a, str_b):
    vect = (str_a * str_b).sum(1)
    x = vect[0]
    for i in range(vect.shape[0] - 1):
        x = x * vect[i + 1]
    return x


In [54]:
def one_hot(index, length):
    vect = th.zeros(length).long()
    vect[index] = 1
    return vect


In [60]:
def string2one_hot_matrix(str_input, max_len=8):
    str_input = str_input[:max_len].lower()
    if len(str_input) < max_len:
        str_input = str_input + "." * (max_len - len(str_input))
    char_vectors = list()
    for char in str_input:
        char_v = one_hot(char2index[char], len(char2index)).unsqueeze(0)
        char_vectors.append(char_v)
    return th.cat(char_vectors, dim=0)


In [67]:
class EncryptedDB:
    def __init__(self, *owners, max_key_len=8, max_val_len=8):
        self.max_key_len = max_key_len
        self.max_val_len = max_val_len
        self.keys = list()
        self.values = list()
        self.owners = owners

    def add_entry(self, key, value):
        key = string2one_hot_matrix(key)
        key = key.share(*self.owners)
        self.keys.append(key)

        value = string2values(value, max_len=self.max_val_len)
        value = value.share(*self.owners)
        self.values.append(value)

    def query(self, query_str):
        query_matrix = string2one_hot_matrix(query_str)
        query_matrix = query_matrix.share(*self.owners)
        key_matches = list()
        for key in self.keys:
            key_match = strings_equal(key, query_matrix)
            key_matches.append(key_match)
        result = self.values[0] * key_matches[0]
        for i in range(len(self.values) - 1):
            result += self.values[i + 1] * key_matches[i + 1]
        result = result.get()
        return values2string(result).replace(".", "")


In [75]:
db = EncryptedDB(bob, alice, secure_worker, max_val_len=256)


In [76]:
db.add_entry("Bob", "(123) 456 7890")
db.add_entry("Bill", "(234) 567 8901")
db.add_entry("Sam", "(345) 678 9012")
db.add_entry("Key", "really big json value")


In [77]:
db.query("Bob")

'(123) 456 7890'

In [None]:
# Lesson: Encrypted Deep Learning in PyTorch

In [1]:
### Train a Model

In [3]:
import syft as sy
import torch as th
hook = sy.TorchHook(th)



In [33]:
from torch import nn
from torch import optim
import torch.nn.functional as F

# A Toy Dataset
data = th.tensor([[0, 0], [0, 1], [1, 0], [1, 1.]], requires_grad=True)
target = th.tensor([[0], [0], [1], [1.]], requires_grad=True)


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(2, 20)
        self.fc2 = nn.Linear(20, 1)

    def forward(self, x):
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x


# A Toy Model
model = Net()


def train():
    # Training Logic
    opt = optim.SGD(params=model.parameters(), lr=0.1)
    for iter in range(20):
        # 1) erase previous gradients (if they exist)
        opt.zero_grad()

        # 2) make a prediction
        pred = model(data)

        # 3) calculate how much we missed
        loss = ((pred - target) ** 2).sum()

        # 4) figure out which weights caused us to miss
        loss.backward()

        # 5) change those weights
        opt.step()

        # 6) print our progress
        print(loss.data)


train()


tensor(0.8452)
tensor(2.9524)
tensor(16.8206)
tensor(4.1732)
tensor(0.9313)
tensor(0.8650)
tensor(0.7948)
tensor(0.7125)
tensor(0.6166)
tensor(0.5222)
tensor(0.4610)
tensor(0.3624)
tensor(0.2932)
tensor(0.2413)
tensor(0.1952)
tensor(0.1294)
tensor(0.1009)
tensor(0.0673)
tensor(0.0529)
tensor(0.0313)


In [34]:
model(data)

tensor([[0.0768],
        [0.0859],
        [0.9569],
        [0.8913]], grad_fn=<AddmmBackward>)

In [None]:
## Encrypt the Model and Data

In [35]:
bob = sy.VirtualWorker(hook, id="bob").add_worker(sy.local_worker)
alice = sy.VirtualWorker(hook, id="alice").add_worker(sy.local_worker)
secure_worker = sy.VirtualWorker(hook, id="secure_worker").add_worker(sy.local_worker)






In [36]:
encrypted_model = model.fix_precision().share(alice, bob, crypto_provider=secure_worker)


In [37]:
list(encrypted_model.parameters())

[Parameter containing:
 Parameter>FixedPrecisionTensor>(Wrapper)>[AdditiveSharingTensor]
 	-> (Wrapper)>[PointerTensor | me:61377541344 -> alice:73834852371]
 	-> (Wrapper)>[PointerTensor | me:84390337491 -> bob:20664967613]
 	*crypto provider: secure_worker*, Parameter containing:
 Parameter>FixedPrecisionTensor>(Wrapper)>[AdditiveSharingTensor]
 	-> (Wrapper)>[PointerTensor | me:24572759692 -> alice:24733019802]
 	-> (Wrapper)>[PointerTensor | me:91595158988 -> bob:12431410982]
 	*crypto provider: secure_worker*, Parameter containing:
 Parameter>FixedPrecisionTensor>(Wrapper)>[AdditiveSharingTensor]
 	-> (Wrapper)>[PointerTensor | me:99200322420 -> alice:45731440851]
 	-> (Wrapper)>[PointerTensor | me:51753870863 -> bob:94208355056]
 	*crypto provider: secure_worker*, Parameter containing:
 Parameter>FixedPrecisionTensor>(Wrapper)>[AdditiveSharingTensor]
 	-> (Wrapper)>[PointerTensor | me:12737141174 -> alice:54985009979]
 	-> (Wrapper)>[PointerTensor | me:92553414080 -> bob:32961698

In [38]:
encrypted_data = data.fix_precision().share(alice, bob, crypto_provider=secure_worker)


In [39]:
encrypted_data

(Wrapper)>FixedPrecisionTensor>(Wrapper)>[AdditiveSharingTensor]
	-> (Wrapper)>[PointerTensor | me:49104973751 -> alice:34926642919]
	-> (Wrapper)>[PointerTensor | me:74536473970 -> bob:20609174616]
	*crypto provider: secure_worker*

In [40]:
encrypted_prediction = encrypted_model(encrypted_data)

In [41]:
encrypted_prediction.get().float_precision()

tensor([[0.0760],
        [0.0860],
        [0.9560],
        [0.8900]])

# Lesson: Encrypted Deep Learning in Keras


## Step 1: Public Training

Welcome to this tutorial! In the following notebooks you will learn how to provide private predictions. By private predictions, we mean that the data is constantly encrypted throughout the entire process. At no point is the user sharing raw data, only encrypted (that is, secret shared) data. In order to provide these private predictions, Syft Keras uses a library called [TF Encrypted](https://github.com/tf-encrypted/tf-encrypted) under the hood. TF Encrypted combines cutting-edge cryptographic and machine learning techniques, but you don't have to worry about this and can focus on your machine learning application.

You can start serving private predictions with only three steps:
- **Step 1**: train your model with normal Keras.
- **Step 2**: secure and serve your machine learning model (server).
- **Step 3**: query the secured model to receive private predictions (client). 

Alright, let's go through these three steps so you can deploy impactful machine learning services without sacrificing user privacy or model security.

Huge shoutout to the Dropout Labs ([@dropoutlabs](https://twitter.com/dropoutlabs)) and TF Encrypted ([@tf_encrypted](https://twitter.com/tf_encrypted)) teams for their great work which makes this demo possible, especially: Jason Mancuso ([@jvmancuso](https://twitter.com/jvmancuso)), Yann Dupis ([@YannDupis](https://twitter.com/YannDupis)), and Morten Dahl ([@mortendahlcs](https://github.com/mortendahlcs)). 

_Demo Ref: https://github.com/OpenMined/PySyft/tree/dev/examples/tutorials_

## Train Your Model in Keras

To use privacy-preserving machine learning techniques for your projects you should not have to learn a new machine learning framework. If you have basic [Keras](https://keras.io/) knowledge, you can start using these techniques with Syft Keras. If you have never used Keras before, you can learn a bit more about it through the [Keras documentation](https://keras.io). 

Before serving private predictions, the first step is to train your model with normal Keras. As an example, we will train a model to classify handwritten digits. To train this model we will use the canonical [MNIST dataset](http://yann.lecun.com/exdb/mnist/).

We borrow [this example](https://github.com/keras-team/keras/blob/master/examples/mnist_cnn.py) from the reference Keras repository.  To train your classification model, you just run the cell below.

In [43]:
from __future__ import print_function
import tensorflow.keras as keras
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Flatten
from tensorflow.keras.layers import Conv2D, AveragePooling2D
from tensorflow.keras.layers import Activation

batch_size = 128
num_classes = 10
epochs = 2

# input image dimensions
img_rows, img_cols = 28, 28

# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
input_shape = (img_rows, img_cols, 1)

x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')

# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

model = Sequential()

model.add(Conv2D(10, (3, 3), input_shape=input_shape))
model.add(AveragePooling2D((2, 2)))
model.add(Activation('relu'))
model.add(Conv2D(32, (3, 3)))
model.add(AveragePooling2D((2, 2)))
model.add(Activation('relu'))
model.add(Conv2D(64, (3, 3)))
model.add(AveragePooling2D((2, 2)))
model.add(Activation('relu'))
model.add(Flatten())
model.add(Dense(num_classes, activation='softmax'))

model.compile(loss=keras.losses.categorical_crossentropy,
              optimizer=keras.optimizers.Adadelta(),
              metrics=['accuracy'])

model.fit(x_train, y_train,
          batch_size=batch_size,
          epochs=epochs,
          verbose=1,
          validation_data=(x_test, y_test))
score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])


x_train shape: (60000, 28, 28, 1)
60000 train samples
10000 test samples


Train on 60000 samples, validate on 10000 samples


Epoch 1/2


  128/60000 [..............................] - ETA: 2:01 - loss: 2.3048 - acc: 0.0469

  384/60000 [..............................] - ETA: 48s - loss: 2.3004 - acc: 0.1198 

  640/60000 [..............................] - ETA: 34s - loss: 2.2961 - acc: 0.1609

  896/60000 [..............................] - ETA: 28s - loss: 2.2943 - acc: 0.1674

 1152/60000 [..............................] - ETA: 24s - loss: 2.2910 - acc: 0.1780

 1408/60000 [..............................] - ETA: 22s - loss: 2.2868 - acc: 0.1783

 1664/60000 [..............................] - ETA: 21s - loss: 2.2826 - acc: 0.1911

 1920/60000 [..............................] - ETA: 19s - loss: 2.2764 - acc: 0.2130

 2176/60000 [>.............................] - ETA: 18s - loss: 2.2694 - acc: 0.2174

 2432/60000 [>.............................] - ETA: 18s - loss: 2.2578 - acc: 0.2459

 2688/60000 [>.............................] - ETA: 17s - loss: 2.2415 - acc: 0.2612



 2944/60000 [>.............................] - ETA: 17s - loss: 2.2204 - acc: 0.2717



 3200/60000 [>.............................] - ETA: 16s - loss: 2.2001 - acc: 0.2809



 3456/60000 [>.............................] - ETA: 16s - loss: 2.1755 - acc: 0.2894

 3712/60000 [>.............................] - ETA: 16s - loss: 2.1443 - acc: 0.3001

 3968/60000 [>.............................] - ETA: 15s - loss: 2.1302 - acc: 0.3047

 4224/60000 [=>............................] - ETA: 15s - loss: 2.1186 - acc: 0.3085

 4480/60000 [=>............................] - ETA: 15s - loss: 2.0888 - acc: 0.3194

 4736/60000 [=>............................] - ETA: 15s - loss: 2.0613 - acc: 0.3304

 4992/60000 [=>............................] - ETA: 14s - loss: 2.0549 - acc: 0.3339

 5248/60000 [=>............................] - ETA: 14s - loss: 2.0349 - acc: 0.3380

 5504/60000 [=>............................] - ETA: 14s - loss: 2.0079 - acc: 0.3472

 5760/60000 [=>............................] - ETA: 14s - loss: 1.9765 - acc: 0.3568

 6016/60000 [==>...........................] - ETA: 14s - loss: 1.9646 - acc: 0.3604

 6272/60000 [==>...........................] - ETA: 13s - loss: 1.9424 - acc: 0.3664

 6528/60000 [==>...........................] - ETA: 13s - loss: 1.9134 - acc: 0.3771

 6784/60000 [==>...........................]

 - ETA: 13s - loss: 1.8851 - acc: 0.3869

 7040/60000 [==>...........................] - ETA: 13s - loss: 1.8676 - acc: 0.3929

 7296/60000 [==>...........................] - ETA: 13s - loss: 1.8460 - acc: 0.3999



 7552/60000 [==>...........................] - ETA: 13s - loss: 1.8181 - acc: 0.4098



 7808/60000 [==>...........................] - ETA: 13s - loss: 1.7925 - acc: 0.4180



 8064/60000 [===>..........................] - ETA: 13s - loss: 1.7842 - acc: 0.4214



 8320/60000 [===>..........................] - ETA: 13s - loss: 1.7678 - acc: 0.4266



 8576/60000 [===>..........................] - ETA: 13s - loss: 1.7461 - acc: 0.4331



 8832/60000 [===>..........................] - ETA: 12s - loss: 1.7238 - acc: 0.4402



 9088/60000 [===>..........................] - ETA: 12s - loss: 1.7037 - acc: 0.4465



 9344/60000 [===>..........................] - ETA: 12s - loss: 1.6833 - acc: 0.4524



 9600/60000 [===>..........................] - ETA: 12s - loss: 1.6652 - acc: 0.4580



 9856/60000 [===>..........................] - ETA: 12s - loss: 1.6499 - acc: 0.4627



10112/60000 [====>.........................] - ETA: 12s - loss: 1.6297 - acc: 0.4700

10368/60000 [====>.........................] - ETA: 12s - loss: 1.6125 - acc: 0.4749



10624/60000 [====>.........................] - ETA: 12s - loss: 1.5964 - acc: 0.4801

10880/60000 [====>.........................] - ETA: 12s - loss: 1.5770 - acc: 0.4873

11136/60000 [====>.........................] - ETA: 12s - loss: 1.5587 - acc: 0.4942

11392/60000 [====>.........................] - ETA: 12s - loss: 1.5480 - acc: 0.4973

11648/60000 [====>.........................] - ETA: 12s - loss: 1.5336 - acc: 0.5025

11904/60000 [====>.........................] - ETA: 11s - loss: 1.5207 - acc: 0.5060

12160/60000 [=====>........................] - ETA: 11s - loss: 1.5065 - acc: 0.5108

12416/60000 [=====>........................] - ETA: 11s - loss: 1.4928 - acc: 0.5150

12672/60000 [=====>........................] - ETA: 11s - loss: 1.4803 - acc: 0.5191

12928/60000 [=====>........................] - ETA: 11s - loss: 1.4664 - acc: 0.5240

13184/60000 [=====>........................] - ETA: 11s - loss: 1.4520 - acc: 0.5292

13440/60000 [=====>........................] - ETA: 11s - loss: 1.4401 - acc: 0.5327

13696/60000 [=====>........................] - ETA: 11s - loss: 1.4298 - acc: 0.5362

13952/60000 [=====>........................] - ETA: 11s - loss: 1.4157 - acc: 0.5413



































 - ETA: 10s - loss: 1.2758 - acc: 0.5903











































































































































































































































































































































































































































Epoch 2/2
  128/60000 [..............................] - ETA: 14s - loss: 0.2136 - acc: 0.9375

  384/60000 [..............................] - ETA: 13s - loss: 0.2727 - acc: 0.9010

  640/60000 [..............................] - ETA: 13s - loss: 0.2785 - acc: 0.9062

  896/60000 [..............................] - ETA: 13s - loss: 0.2752 - acc: 0.9118

 1152/60000 [..............................] - ETA: 13s - loss: 0.2701 - acc: 0.9167

 1408/60000 [..............................] - ETA: 13s - loss: 0.2775 - acc: 0.9162

 1664/60000 [..............................] - ETA: 13s - loss: 0.2901 - acc: 0.9141

 1920/60000 [..............................] - ETA: 13s - loss: 0.2780 - acc: 0.9187

 2176/60000 [>.............................] - ETA: 12s - loss: 0.2766 - acc: 0.9223

 2432/60000 [>.............................] - ETA: 13s - loss: 0.2701 - acc: 0.9235

 2688/60000 [>.............................] - ETA: 12s - loss: 0.2619 - acc: 0.9256

 2944/60000 [>.............................] - ETA: 12s - loss: 0.2598 - acc: 0.9253

 3200/60000 [>.............................] - ETA: 12s - loss: 0.2574 - acc: 0.9256



 3456/60000 [>.............................] - ETA: 12s - loss: 0.2556 - acc: 0.9265



 3712/60000 [>.............................] - ETA: 12s - loss: 0.2503 - acc: 0.9281



 3968/60000 [>.............................] - ETA: 12s - loss: 0.2484 - acc: 0.9277

 4224/60000 [=>............................] - ETA: 13s - loss: 0.2461 - acc: 0.9285



 4480/60000 [=>............................] - ETA: 12s - loss: 0.2454 - acc: 0.9290



 4736/60000 [=>............................] - ETA: 12s - loss: 0.2467 - acc: 0.9288

 4992/60000 [=>............................] - ETA: 12s - loss: 0.2474 - acc: 0.9289

 5248/60000 [=>............................] - ETA: 12s - loss: 0.2495 - acc: 0.9274

 5504/60000 [=>............................] - ETA: 12s - loss: 0.2484 - acc: 0.9273

 5760/60000 [=>............................] - ETA: 12s - loss: 0.2461 - acc: 0.9276

 6016/60000 [==>...........................] - ETA: 12s - loss: 0.2470 - acc: 0.9270

 6272/60000 [==>...........................] - ETA: 12s - loss: 0.2458 - acc: 0.9276

 6528/60000 [==>...........................] - ETA: 12s - loss: 0.2461 - acc: 0.9282

 6784/60000 [==>...........................] - ETA: 12s - loss: 0.2472 - acc: 0.9273

 7040/60000 [==>...........................] - ETA: 12s - loss: 0.2503 - acc: 0.9266

 7296/60000 [==>...........................] - ETA: 12s - loss: 0.2495 - acc: 0.9267

 7552/60000 [==>...........................] - ETA: 12s - loss: 0.2504 - acc: 0.9265

 7808/60000 [==>...........................] - ETA: 11s - loss: 0.2497 - acc: 0.9275

 8064/60000 [===>..........................] - ETA: 11s - loss: 0.2520 - acc: 0.9267



 8320/60000 [===>..........................] - ETA: 11s - loss: 0.2504 - acc: 0.9273



 8576/60000 [===>..........................] - ETA: 11s - loss: 0.2496 - acc: 0.9282



 8832/60000 [===>..........................] - ETA: 11s - loss: 0.2517 - acc: 0.9281



 9088/60000 [===>..........................] - ETA: 11s - loss: 0.2509 - acc: 0.9280



 9344/60000 [===>..........................] - ETA: 11s - loss: 0.2491 - acc: 0.9286

 9600/60000 [===>..........................] - ETA: 11s - loss: 0.2485 - acc: 0.9287



 9856/60000 [===>..........................] - ETA: 11s - loss: 0.2464 - acc: 0.9295



10112/60000 [====>.........................] - ETA: 11s - loss: 0.2480 - acc: 0.9288



10368/60000 [====>.........................] - ETA: 11s - loss: 0.2487 - acc: 0.9286

10624/60000 [====>.........................] - ETA: 11s - loss: 0.2470 - acc: 0.9294

10880/60000 [====>.........................] - ETA: 11s - loss: 0.2499 - acc: 0.9287

11136/60000 [====>.........................] - ETA: 11s - loss: 0.2491 - acc: 0.9283

11392/60000 [====>.........................] - ETA: 11s - loss: 0.2484 - acc: 0.9286

11648/60000 [====>.........................] - ETA: 11s - loss: 0.2479 - acc: 0.9287

11904/60000 [====>.........................] - ETA: 11s - loss: 0.2494 - acc: 0.9286

12160/60000 [=====>........................] - ETA: 11s - loss: 0.2496 - acc: 0.9285

12416/60000 [=====>........................] - ETA: 11s - loss: 0.2491 - acc: 0.9285



12672/60000 [=====>........................] - ETA: 10s - loss: 0.2476 - acc: 0.9286



12928/60000 [=====>........................] - ETA: 10s - loss: 0.2484 - acc: 0.9284



13184/60000 [=====>........................] - ETA: 10s - loss: 0.2496 - acc: 0.9282

13440/60000 [=====>........................] - ETA: 10s - loss: 0.2483 - acc: 0.9284

13696/60000 [=====>........................] - ETA: 10s - loss: 0.2476 - acc: 0.9284

13952/60000 [=====>........................] - ETA: 10s - loss: 0.2466 - acc: 0.9288













































































































































































































































































































































































































































































































Test loss: 0.17428727200329303
Test accuracy: 0.9455


In [45]:
## Save your model's weights for future private prediction
model.save('short-conv-mnist.h5')


In [None]:
## Step 2: Load and Serve the Model

Now that you have a trained model with normal Keras, you are ready to serve some private predictions. We can do that using Syft Keras.

To secure and serve this model, we will need three TFEWorkers (servers). This is because TF Encrypted under the hood uses an encryption technique called [multi-party computation (MPC)](https://en.wikipedia.org/wiki/Secure_multi-party_computation). The idea is to split the model weights and input data into shares, then send a share of each value to the different servers. The key property is that if you look at the share on one server, it reveals nothing about the original value (input data or model weights).

We'll define a Syft Keras model like we did in the previous notebook. However, there is a trick: before instantiating this model, we'll run `hook = sy.KerasHook(tf.keras)`. This will add three important new methods to the Keras Sequential class:
 - `share`: will secure your model via secret sharing; by default, it will use the SecureNN protocol from TF Encrypted to secret share your model between each of the three TFEWorkers. Most importantly, this will add the capability of providing predictions on encrypted data.
 - `serve`: this function will launch a serving queue, so that the TFEWorkers can can accept prediction requests on the secured model from external clients.
 - `shutdown_workers`: once you are done providing private predictions, you can shut down your model by running this function. It will direct you to shutdown the server processes manually if you've opted to manually manage each worker.

If you want learn more about MPC, you can read this excellent [blog](https://mortendahl.github.io/2017/04/17/private-deep-learning-with-mpc/).

In [1]:
import numpy as np
import tensorflow as tf
from tensorflow.keras import Sequential
from tensorflow.keras.layers import AveragePooling2D, Conv2D, Dense, Activation, Flatten, ReLU, Activation

import syft as sy
hook = sy.KerasHook(tf.keras)



## Model

As you can see, we define almost the exact same model as before, except we provide a `batch_input_shape`. This allows TF Encrypted to better optimize the secure computations via predefined tensor shapes. For this MNIST demo, we'll send input data with the shape of (1, 28, 28, 1). 
We also return the logit instead of softmax because this operation is complex to perform using MPC, and we don't need it to serve prediction requests.

In [2]:
num_classes = 10
input_shape = (1, 28, 28, 1)

model = Sequential()

model.add(Conv2D(10, (3, 3), batch_input_shape=input_shape))
model.add(AveragePooling2D((2, 2)))
model.add(Activation('relu'))
model.add(Conv2D(32, (3, 3)))
model.add(AveragePooling2D((2, 2)))
model.add(Activation('relu'))
model.add(Conv2D(64, (3, 3)))
model.add(AveragePooling2D((2, 2)))
model.add(Activation('relu'))
model.add(Flatten())
model.add(Dense(num_classes, name="logit"))

Instructions for updating:
Colocations handled automatically by placer.


Instructions for updating:
Colocations handled automatically by placer.


### Load Pre-trained Weights

With `load_weights` you can easily load the weights you have saved previously after training your model.

In [3]:
pre_trained_weights = 'short-conv-mnist.h5'
model.load_weights(pre_trained_weights)

## Step 3: Setup Your Worker Connectors

Let's now connect to the TFEWorkers (`alice`, `bob`, and `carol`) required by TF Encrypted to perform private predictions. For each TFEWorker, you just have to specify a host.

These workers run a [TensorFlow server](https://www.tensorflow.org/api_docs/python/tf/distribute/Server), which you can either manage manually (`AUTO = False`) or ask the workers to manage for you (`AUTO = True`). If choosing to manually manage them, you will be instructed to execute a terminal command on each worker's host device after calling `model.share()` below.  If all workers are hosted on a single device (e.g. `localhost`), you can choose to have Syft automatically manage the worker's TensorFlow server.

In [4]:
AUTO = False

alice = sy.TFEWorker(host='localhost:4000', auto_managed=AUTO)
bob = sy.TFEWorker(host='localhost:4001', auto_managed=AUTO)
carol = sy.TFEWorker(host='localhost:4002', auto_managed=AUTO)

## Step 4: Split the Model Into Shares

Thanks to `sy.KerasHook(tf.keras)` you can call the `share` method to transform your model into a TF Encrypted Keras model.

If you have asked to manually manage servers above then this step will not complete until they have all been launched. Note that your firewall may ask for Python to accept incoming connection.

In [5]:
model.share(alice, bob, carol)

INFO:tf_encrypted:If not done already, please launch the following command in a terminal on host localhost:4000: 'python -m tf_encrypted.player --config C:\Users\kk\AppData\Local\Temp\tfe.config server0'
This can be done automatically in a local subprocess by setting `auto_managed=True` when instantiating a TFEWorker.



INFO:tf_encrypted:If not done already, please launch the following command in a terminal on host localhost:4001: 'python -m tf_encrypted.player --config C:\Users\kk\AppData\Local\Temp\tfe.config server1'
This can be done automatically in a local subprocess by setting `auto_managed=True` when instantiating a TFEWorker.



INFO:tf_encrypted:If not done already, please launch the following command in a terminal on host localhost:4002: 'python -m tf_encrypted.player --config C:\Users\kk\AppData\Local\Temp\tfe.config server2'
This can be done automatically in a local subprocess by setting `auto_managed=True` when instantiating a TFEWorker.



INFO:tf_encrypted:Starting session on target 'grpc://localhost:4000' using config graph_options {
}



## Step 5: Launch 3 Servers

```
python -m tf_encrypted.player --config /tmp/tfe.config server0
python -m tf_encrypted.player --config /tmp/tfe.config server1
python -m tf_encrypted.player --config /tmp/tfe.config server2```

## Step 6: Serve the Model

Perfect! Now by calling `model.serve`, your model is ready to provide some private predictions. You can set `num_requests` to set a limit on the number of predictions requests served by the model; if not specified then the model will be served until interrupted.

In [None]:
model.serve(num_requests=3)

Served encrypted prediction 1 to client.
Served encrypted prediction 2 to client.
Served encrypted prediction 3 to client.


## Step 7: Run the Client

At this point open up and run the companion notebook: Section 4b - Encrytped Keras Client

## Step 8: Shutdown the Servers

Once your request limit above, the model will no longer be available for serving requests, but it's still secret shared between the three workers above. You can kill the workers by executing the cell below.

**Congratulations** on finishing Part 12: Secure Classification with Syft Keras and TFE!

In [None]:
model.shutdown_workers()

if not AUTO:
    process_ids = !ps aux | grep '[p]ython -m tf_encrypted.player --config /tmp/tfe.config' | awk '{print $2}'
    for process_id in process_ids:
        !kill {process_id}
        print("Process ID {id} has been killed.".format(id=process_id))

# Keystone Project - Mix and Match What You've Learned

Description: Take two of the concepts you've learned about in this course (Encrypted Computation, Federated Learning, Differential Privacy) and combine them for a use case of your own design. Extra credit if you can get your demo working with [WebSocketWorkers](https://github.com/OpenMined/PySyft/tree/dev/examples/tutorials/advanced/websockets-example-MNIST) instead of VirtualWorkers! Then take your demo or example application, write a blogpost, and share that blogpost in #general-discussion on OpenMined's slack!!!

Inspiration:
- This Course's Code: https://github.com/Udacity/private-ai
- OpenMined's Tutorials: https://github.com/OpenMined/PySyft/tree/dev/examples/tutorials
- OpenMined's Blog: https://blog.openmined.org