In [50]:
import collections
import functools
import os

import matplotlib.pyplot as plt
import numpy as np

import tensorflow as tf
import tensorflow.keras as keras
import tensorflow_federated as tff

Load a pre-trained model

We load a model that was pre-trained following the TensorFlow tutorial Text generation using a RNN with eager execution. However, rather than training on The Complete Works of Shakespeare, we pre-trained the model on the text from the Charles Dickens' A Tale of Two Cities and A Christmas Carol.

Other than expanding the vocabularly, we didn't modify the original tutorial, so this initial model isn't state-of-the-art, but it produces reasonable predictions and is sufficient for our tutorial purposes. The final model was saved with tf.keras.models.save_model(include_optimizer=False).

We will use federated learning to fine-tune this model for Shakespeare in this tutorial, using a federated version of the data provided by TFF.

# Generate the vocabulary

In [2]:
# A fixed vocabularly of ASCII chars that occur in the works of Shakespeare and Dickens:
vocabulary = list('dhlptx@DHLPTX $(,048cgkoswCGKOSW[_#\'/37;?bfjnrvzBFJNRVZ"&*.26:\naeimquyAEIMQUY]!%)-159\r')

# Creating a mapping from unique characters to indices
characters_to_idx = {u:i for i, u in enumerate(vocabulary)}
idx_to_characters = {i:u for u, i in characters_to_idx.items()}


# Load the pre-trained models

In [3]:
keras.utils.get_file?

[0;31mSignature:[0m
[0mkeras[0m[0;34m.[0m[0mutils[0m[0;34m.[0m[0mget_file[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mfname[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0morigin[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0muntar[0m[0;34m=[0m[0;32mFalse[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mmd5_hash[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mfile_hash[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mcache_subdir[0m[0;34m=[0m[0;34m'datasets'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mhash_algorithm[0m[0;34m=[0m[0;34m'auto'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mextract[0m[0;34m=[0m[0;32mFalse[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0marchive_format[0m[0;34m=[0m[0;34m'auto'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mcache_dir[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m
Downloads a file from a URL i

In [5]:
os.path.basename("https://storage.googleapis.com/tff-models-public/dickens_rnn.batch8.kerasmodel")

'dickens_rnn.batch8.kerasmodel'

In [11]:
def load_pretrained_model(batch_size):
    origin = f"https://storage.googleapis.com/tff-models-public/dickens_rnn.batch{batch_size}.kerasmodel"
    fname = (os.path
               .basename(origin))
    path = (keras.utils
                 .get_file(fname, origin))
    keras_model = (keras.models
                        .load_model(path, compile=False))
    return keras_model


In [13]:
def generate_text(model, start_string):
    num_generate = 200
    input_eval = [characters_to_idx[s] for s in start_string]
    input_eval = tf.expand_dims(input_eval, 0)
    text_generated = []
    temperature = 1.0

    model.reset_states()
    for i in range(num_generate):
        predictions = model(input_eval)
        predictions = tf.squeeze(predictions, 0)
        predictions = predictions / temperature
        predicted_id = tf.random.categorical(predictions, num_samples=1)[-1, 0].numpy()
        input_eval = tf.expand_dims([predicted_id], 0)
        text_generated.append(idx_to_characters[predicted_id])

    return (start_string + ''.join(text_generated))


In [14]:
# Text generation requires a batch_size=1 model.
pretrained_model_1 = load_pretrained_model(batch_size=1)
print(generate_text(pretrained_model_1, 'What of TensorFlow Federated, you ask? '))


What of TensorFlow Federated, you ask? Then, what a corner of the d
gentlemen," said Stryver, looking at the crowd; all
was to warm as if he said it with his awith the hungred man
with him, and because he had a light and time--

The b


# Load and Preprocess the Federated Shakespeare Data

The [`tff.simulation.datasets`](https://www.tensorflow.org/federated/api_docs/python/tff/simulation/datasets) package provides a variety of datasets that are split into "clients", where each client corresponds to a dataset on a particular device that might participate in federated learning.

These datasets provide realistic non-IID data distributions that replicate in simulation the challenges of training on real decentralized data. Some of the pre-processing of this data was done using tools from the [Leaf project](https://github.com/TalwalkarLab/leaf).

In [15]:
training_data, testing_data = (tff.simulation
                                  .datasets
                                  .shakespeare
                                  .load_data())


Downloading data from https://storage.googleapis.com/tff-datasets-public/shakespeare.tar.bz2


  collections.OrderedDict((name, ds.value) for name, ds in sorted(


The datasets provided consist of a sequence of string Tensors, one for each line spoken by a particular character in a Shakespeare play. The client keys consist of the name of the play joined with the name of the character, so for example `MUCH_ADO_ABOUT_NOTHING_OTHELLO` corresponds to the lines for the character Othello in the play *Much Ado About Nothing*. 

Note that in a real federated learning scenario clients are never identified or tracked by ids but for simulation it is useful to work with keyed datasets.

Here, for example, we can look at some data from *King Lear*.

In [16]:
# Here the play is "The Tragedy of King Lear" and the character is "King".
client_dataset = training_data.create_tf_dataset_for_client("THE_TRAGEDY_OF_KING_LEAR_KING")

In [17]:
# To allow for future extensions, each entry x
# is an OrderedDict with a single key 'snippets' which contains the text.
for x in client_dataset.take(2):
    print(x['snippets'])


tf.Tensor(b"Live regist'red upon our brazen tombs,\nAnd then grace us in the disgrace of death;\nWhen, spite of cormorant devouring Time,\nTh' endeavour of this present breath may buy\nThat honour which shall bate his scythe's keen edge,\nAnd make us heirs of all eternity.\nTherefore, brave conquerors- for so you are\nThat war against your own affections\nAnd the huge army of the world's desires-\nOur late edict shall strongly stand in force:\nNavarre shall be the wonder of the world;\nOur court shall be a little Academe,\nStill and contemplative in living art.\nYou three, Berowne, Dumain, and Longaville,\nHave sworn for three years' term to live with me\nMy fellow-scholars, and to keep those statutes\nThat are recorded in this schedule here.\nYour oaths are pass'd; and now subscribe your names,\nThat his own hand may strike his honour down\nThat violates the smallest branch herein.\nIf you are arm'd to do as sworn to do,\nSubscribe to your deep oaths, and keep it too.\nYour oath is pa

We now use [`tf.data.Dataset`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset) transformations to prepare this data for training the char RNN loaded above. 

In [66]:
AUTOTUNE = (tf.data
              .experimental
              .AUTOTUNE)
SEQUENCE_LENGTH = 100
SHUFFLE_BUFFER_SIZE = 10000
NUMBER_TRAINING_EPOCHS = 3
TRAINING_BATCH_SIZE = 8
TESTING_BATCH_SIZE = 8

# create the look-up table based on vocabulary
_indices = tf.constant([n for n in range(len(vocabulary))], dtype=tf.int64) 
_initializer = (tf.lookup
                  .KeyValueTensorInitializer(vocabulary, values=_indices))
TABLE = (tf.lookup
           .StaticHashTable(_initializer, default_value=-1))


def _to_indices(entry):
    snippets = tf.reshape(entry["snippets"], shape=[1])
    characters = (tf.strings
                    .bytes_split(snippets)
                    .values)
    indices = TABLE.lookup(characters)
    return indices


def _to_input_target(batch):
    input_sequence = tf.map_fn(lambda x: x[:-1], batch)
    target_sequence = tf.map_fn(lambda x: x[1:], batch)
    return input_sequence, target_sequence


def create_training_dataset(client_dataset,
                            seed=None,
                            num_parallel_calls=AUTOTUNE,
                            shuffle_buffer_size=SHUFFLE_BUFFER_SIZE,
                            number_epochs=NUMBER_TRAINING_EPOCHS,
                            batch_size=TRAINING_BATCH_SIZE,
                            prefetch_buffer_size=AUTOTUNE):
    """Create a training dataset from raw client dataset."""
    _dataset = (client_dataset.map(_to_indices, num_parallel_calls)
                              .unbatch()
                              .batch(SEQUENCE_LENGTH + 1, drop_remainder=True)
                              .shuffle(shuffle_buffer_size, seed, reshuffle_each_iteration=True)
                              .repeat(number_epochs)
                              .batch(batch_size, drop_remainder=True)
                              .map(_to_input_target)
                              .prefetch(prefetch_buffer_size))
    return _dataset


def create_training_datasets(client_ids,
                             seed=None,
                             num_parallel_calls=AUTOTUNE,
                             shuffle_buffer_size=SHUFFLE_BUFFER_SIZE,
                             number_epochs=NUMBER_TRAINING_EPOCHS,
                             batch_size=TRAINING_BATCH_SIZE,
                             prefetch_buffer_size=AUTOTUNE):
    """Creates a TF Dataset for each client id."""
    training_datasets = []
    for client_id in client_ids:
        client_dataset = training_data.create_tf_dataset_for_client(client_id)
        training_dataset = (create_training_dataset(client_dataset,
                                                    seed,
                                                    num_parallel_calls,
                                                    shuffle_buffer_size,
                                                    number_epochs,
                                                    batch_size,
                                                    prefetch_buffer_size))
        training_datasets.append(training_dataset)
    return training_datasets


Note that in the formation of the original sequences and in the formation of batches above, we use `drop_remainder=True` for simplicity. This means that any characters (clients) that don't have at least `(SEQ_LENGTH + 1) * BATCH_SIZE` chars of text will have empty datasets. There are many standard approaches to dealing with this issue, note however, that in the federated setting this issue is more significant because many users might have small datasets.

Now we can preprocess our `client_dataset`, and check the types.

In [58]:
preprocessed_client_dataset = create_training_dataset(client_dataset)

In [59]:
print(preprocessed_client_dataset.element_spec)

(TensorSpec(shape=(8, 100), dtype=tf.int64, name=None), TensorSpec(shape=(8, 100), dtype=tf.int64, name=None))


## Compile the model and test on the preprocessed data

We loaded an uncompiled keras model, but in order to run evaluate the model, we need to compile it with a loss and metrics. We will also compile in an optimizer, which will be used as the on-device optimizer in Federated Learning.

The original tutorial didn't have char-level accuracy (the fraction of predictions where the highest probability was put on the correct next char). This is a useful metric, so we add it. However, we need to define a new metric class for this because our predictions have rank 3 (a vector of logits for each of the BATCH_SIZE * SEQ_LENGTH predictions), and SparseCategoricalAccuracy expects only rank 2 predictions.

In [26]:
class FlattenedSparseCategoricalAccuracy(keras.metrics.SparseCategoricalAccuracy):

    def __init__(self, name='accuracy', dtype=None):
        super().__init__(name, dtype=dtype)

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true = tf.reshape(y_true, [-1, 1])
        y_pred = tf.reshape(y_pred, [-1, len(vocabulary), 1])
        return super().update_state(y_true, y_pred, sample_weight)


In [29]:
def compile_keras_model(model, optimizer):
    _loss_fn = (keras.losses
                     .SparseCategoricalCrossentropy(from_logits=True))
    _metrics = [
        FlattenedSparseCategoricalAccuracy()
    ]
    model.compile(
        optimizer,
        loss=_loss_fn,
        metrics=_metrics
    )
    return model


In [32]:
TRAINING_BATCH_SIZE = 8  # The training and eval batch size for the rest of this tutorial.

model_fn = load_pretrained_model(batch_size=TRAINING_BATCH_SIZE)
optimizer = keras.optimizers.SGD(lr=0.5)
compile_keras_model(model_fn, optimizer)

# Confirm that loss is much lower on Shakespeare than on random data
print('Evaluating on an example Shakespeare character:')
model_fn.evaluate(preprocessed_client_dataset.take(1))

# As a sanity check, we can construct some completely random data, where we expect
# the accuracy to be essentially random:
random_indexes = np.random.randint(
    low=0, high=len(vocabulary), size=1 * TRAINING_BATCH_SIZE * (SEQUENCE_LENGTH + 1))
data = {
    'snippets':
        tf.constant(''.join(np.array(vocabulary)[random_indexes]), shape=[1, 1])
}
random_dataset = preprocess(tf.data.Dataset.from_tensor_slices(data))
print('\nExpected accuracy for random guessing: {:.3f}'.format(1.0 / len(vocabulary)))
print('Evaluating on completely random data:')
model_fn.evaluate(random_dataset, steps=1)


Evaluating on an example Shakespeare character:

Expected accuracy for random guessing: 0.012
Evaluating on completely random data:


[11.473602294921875, 0.01125]

# Fine-tune the model with Federated Learning

TFF serializes all TensorFlow computations so they can potentially be run in a non-Python environment (even though at the moment, only a simulation runtime implemented in Python is available). Even though we are running in eager mode, (TF 2.0), currently TFF serializes TensorFlow computations by constructing the necessary ops inside the context of a "with tf.Graph.as_default()" statement. Thus, we need to provide a function that TFF can use to introduce our model into a graph it controls. We do this as follows:

In [47]:
# Clone the keras_model inside `create_tff_model()`, which TFF will
# call to produce a new copy of the model inside the graph that it will serialize.
def create_tff_model():
    # TFF uses a `dummy_batch` so it knows the types and shapes
    # that your model expects.
    x = tf.constant(np.random.randint(1, len(vocabulary), size=[TRAINING_BATCH_SIZE, SEQUENCE_LENGTH]))
    dummy_batch = collections.OrderedDict([('x', x), ('y', x)])
    optimizer = keras.optimizers.SGD(lr=0.5)
    cloned_model_fn = compile_keras_model(keras.models.clone_model(model_fn), optimizer)
    return tff.learning.from_compiled_keras_model(cloned_model_fn, dummy_batch=dummy_batch)


Now we are ready to construct a Federated Averaging iterative process, which we will use to improve the model (for details on the Federated Averaging algorithm, see the paper Communication-Efficient Learning of Deep Networks from Decentralized Data).

We use a compiled Keras model to perform standard (non-federated) evaluation after each round of federated training. This is useful for research purposes when doing simulated federated learning and there is a standard test dataset.

In a realistic production setting this same technique might be used to take models trained with federated learning and evaluate them on a centralized benchmark dataset for testing or quality assurance purposes.

In [48]:
# This command builds all the TensorFlow graphs and serializes them: 
fed_avg = tff.learning.build_federated_averaging_process(model_fn=create_tff_model)

Now let's write a slightly more interesting training and evaluation loop.

So that this simulation still runs relatively quickly, we train on the same three clients each round, only considering two minibatches for each.

In [52]:
def data(client, source=training_data):
    return preprocess(source.create_tf_dataset_for_client(client)).take(2)

clients = ['ALL_S_WELL_THAT_ENDS_WELL_CELIA',
           'MUCH_ADO_ABOUT_NOTHING_OTHELLO',
           'THE_TRAGEDY_OF_KING_LEAR_KING']

train_datasets = [data(client) for client in clients]

# We concatenate the test datasets for evaluation with Keras.
test_dataset = functools.reduce(
    lambda d1, d2: d1.concatenate(d2),
    [data(client, testing_data) for client in clients]
)


  collections.OrderedDict((name, ds.value) for name, ds in sorted(


In [54]:
NUM_ROUNDS = 3

# The state of the FL server, containing the model and optimization state.
state = fed_avg.initialize()

state = tff.learning.state_with_new_model_weights(
    state,
    trainable_weights=[v.numpy() for v in model_fn.trainable_weights],
    non_trainable_weights=[
        v.numpy() for v in model_fn.non_trainable_weights
    ])


def keras_evaluate(state, round_num):
    tff.learning.assign_weights_to_keras_model(model_fn, state.model)
    print('Evaluating before training round', round_num)
    model_fn.evaluate(preprocessed_client_dataset, steps=2)


for round_num in range(NUM_ROUNDS):
    keras_evaluate(state, round_num)
    # N.B. The TFF runtime is currently fairly slow,
    # expect this to get significantly faster in future releases.
    state, metrics = fed_avg.next(state, train_datasets)
    print('Training metrics: ', metrics)

keras_evaluate(state, NUM_ROUNDS + 1)


Evaluating before training round 0
Training metrics:  <accuracy=0.4193750023841858,loss=3.228424072265625>
Evaluating before training round 1
Training metrics:  <accuracy=0.4397916793823242,loss=2.9016449451446533>
Evaluating before training round 2
Training metrics:  <accuracy=0.4660416543483734,loss=2.6546707153320312>
Evaluating before training round 4


In [67]:
NUM_ROUNDS = 5
RANDOM_STATE = np.random.RandomState(42)

federated_averaging_process = (tff.learning
                                  .build_federated_averaging_process(create_tff_model))
state = federated_averaging_process.initialize()


def sample_client_ids(client_ids: list,
                      sample_size: float,
                      random_state: np.random.RandomState) -> list:
    """Randomly selects a subset of clients ids."""
    n_clients = len(client_ids)
    n_clients_per_sample = int(sample_size * n_clients)
    random_indices = random_state.randint(n_clients, size=n_clients_per_sample)
    return [client_ids[i] for i in random_indices]


for n in range(NUM_ROUNDS):
    
    # resample 1% of all clients at each round
    client_ids = sample_client_ids(training_data.client_ids, 0.01, RANDOM_STATE)
    federated_training_data = create_training_datasets(client_ids)

    # perform the federated computation
    state, metrics = federated_averaging_process.next(state, federated_training_data)
    print(f"round:{n}, metrics:{metrics}")




  collections.OrderedDict((name, ds.value) for name, ds in sorted(


KeyboardInterrupt: 

# Suggested extensions

This tutorial is just the first step! Here are some ideas for how you might try extending this notebook: 

* Write a more realistic training loop where you sample clients to train on randomly.
* Use ".repeat(NUM_EPOCHS)" on the client datasets to try multiple epochs of local training (e.g., as in McMahan et. al.). See also Federated Learning for Image Classification which does this. 
* Change the compile() command to experiment with using different optimization algorithms on the client. 
* Try the server_optimizer argument to build_federated_averaging_process to try different algorithms for applying the model updates on the server. 
* Try the client_weight_fn argument to to build_federated_averaging_process to try different weightings of the clients. The default weights client updates by the number of examples on the client, but you can do e.g. client_weight_fn=lambda _: tf.constant(1.0).