In [None]:
import torch
import syft
hook = syft.TorchHook(torch)
alice = syft.VirtualWorker(hook, id="alice")  # alice computes on behalf of data owner
bob = syft.VirtualWorker(hook, id="bob")      # bob computes on behalf of model owner
carol = syft.VirtualWorker(hook, id="carol")  # carol acts as crypto producer

In [None]:
syft.local_worker

In [None]:
import tf_encrypted as tfe
from syft.keras.model import Sequential
from syft.keras.layers import AveragePooling2D, Conv2D, Dense, ReLU

In [None]:
task_classes = 10
task_shape = [None, 1, 28, 28]
weights_path = "./short-conv-mnist/"  # SavedModel, or Keras weights file

In [None]:
model = Sequential()  # wrap this in a function, call in SecureNN context

model.add(Conv2D(10, (3, 3), input_shape=task_shape[1:]))
model.add(AveragePooling2D((2, 2)))
model.add(ReLU())
model.add(Conv2D(32, (3, 3)))
model.add(AveragePooling2D((2, 2)))
model.add(ReLU())
model.add(Conv2D(64, (3, 3)))
model.add(AveragePooling2D((2, 2)))
model.add(ReLU())
model.add(Flatten())
model.add(Dense(10, name="logit"))

In [None]:
model = build_model()
model.load_weights(weights_path)  # can extend this to load the savedmodel if we don't want to define
                                  # the architecture in this notebook

prot = tfe.SecureNN(player_0=alice,
                    player_1=bob,
                    comparison_helper=carol,
                    crypto_producer=carol)
model.share(prot)
# with prot:
#     new_model = model.rebuild()

x = tfe.define_private_input(
    prediction_client.player_name,
    prediction_client.provide_input,  # this can be converted to a function that waits to
)                                     # receive inputs from other syft workers and then passes them on
y = model(x)

reveal_output = tfe.define_output(
    prediction_client.player_name,
    y,
    prediction_client.receive_output,  # this can pass output result to the prediction_client's worker
)

# not sure if/how we'll be able to do away with the following for now
with tfe.Session() as sess:
    sess.run(tfe.global_variables_initializer(), tag='init')
    sess.run(reveal_output, tag='predict')