In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds
import syft

**(INTERNAL)**: Context here is that we've already convinced the attendees that we can do privacy-preserving EDA & data pre-processing, perhaps by demoing Cowbay or PySyft-Pandas. This means that we can load the data in more or less already processed form for this demo.

Build a remote tf.data pipeline _on each party_

In [None]:
# specify the players involved
model_owner = tfe.Player('model_owner')
data_owners = [
    tfe.Player('data_owner_0'),
    tfe.Player('data_owner_1'),
    tfe.Player('data_owner_2'),
]

In [None]:
train_data_sources = [data_owner.build_data_pipeline() for data_owner in data_owners]
val_data_sources = [data_owner.build_validation_pipeline() for data_owner in data_owners]

Demonstration of debugging on a particular data owner with syft-tf.

(This will likely be in a separate notebook.)

In [None]:
def build_model():
    model = tfe.keras.Sequential()
    model.add(tfe.keras.Dense(10))
    model.compile(loss='mse', optimizer='sgd')
    return model

In [None]:
model = build_model().send(data_owners[0])

# these are just checks to make sure the data works with the model :)
model.fit(train_data_sources[0], epochs=1, batch_size=128)
model.evaluate(val_data_source[0], batch_size=1000)

When data is distributed in the wild, it will take much more debugging to make sure the data is all properly formated across all machines. This is also a downside of FL, since data is often not identically distributed across different parties (e.g. different labeling, different generating distributions, etc.).

But since we're using a clean dataset, we can skip much of that work.

## Main event
Once we're convinced the model works with each of the data sources, we can switch over to using TFE's secure aggregation.

Specify aggregation op

In [None]:
# ... alternatively we could have instantiated it explicitly,
# resulting in exactly the same thing
aggregation = tfe.functionalities.AdditiveSecureAverage(
    compute_players=data_owners,
    output_receiver=model_owner)

Basic Keras `model.fit()` training

In [None]:
with tfe.protocol.FederatedLearning(model_owner, aggregation):
    model = build_model()
    
model.fit(data_sources, epochs=10, batch_size=128)

Remote validation & testing

In [None]:
model.evaluate(val_data_sources)

Custom training loop

In [None]:
# TODO

Model subclassing?

In [None]:
# TODO, maybe combine with the above

Private predictions with TFE or TF Trusted

In [None]:
# (can be in a separate notebook/script, or maybe we can launch/drive it from here if we want)