In [53]:
from __future__ import absolute_import, division, print_function


import collections
import functools
import numpy as np
import os
import tensorflow as tf
from tensorflow.python.keras.optimizer_v2 import gradient_descent
import time

from tensorflow_federated import python as tff

nest = tf.contrib.framework.nest

tf.compat.v1.enable_v2_behavior()

np.random.seed(0)

# Test the TFF is working:
tff.federated_computation(lambda: 'Hello, World!')()

'Hello, World!'

In [54]:
# A fixed vocabularly of ASCII chars that occur in the works of Shakespeare and Dickens:
vocab = list('dhlptx@DHLPTX $(,048cgkoswCGKOSW[_#\'/37;?bfjnrvzBFJNRVZ"&*.26:\naeimquyAEIMQUY]!%)-159\r')

# Creating a mapping from unique characters to indices
char2idx = {u:i for i, u in enumerate(vocab)}
idx2char = np.array(vocab)

In [55]:
def load_model(batch_size):
    urls = {
      1: 'https://storage.googleapis.com/tff-models-public/dickens_rnn.batch1.kerasmodel',
      8: 'https://storage.googleapis.com/tff-models-public/dickens_rnn.batch8.kerasmodel'}
    assert batch_size in urls, 'batch_size must be in ' + str(urls.keys())
    url = urls[batch_size]
    local_file = tf.keras.utils.get_file(os.path.basename(url), origin=url)  
    return tf.keras.models.load_model(local_file, compile=False)

In [56]:
def generate_text(model, start_string):
    num_generate = 200
    input_eval = [char2idx[s] for s in start_string]
    input_eval = tf.expand_dims(input_eval, 0)
    text_generated = []
    temperature = 1.0

    model.reset_states()
    for i in range(num_generate):
        predictions = model(input_eval)
        predictions = tf.squeeze(predictions, 0)
        predictions = predictions / temperature
        predicted_id = tf.multinomial(predictions, num_samples=1)[-1,0].numpy()      
        input_eval = tf.expand_dims([predicted_id], 0)      
        text_generated.append(idx2char[predicted_id])

    return (start_string + ''.join(text_generated))

In [57]:
# Text generation requires a batch_size=1 model.
keras_model_batch1 = load_model(batch_size=1)
# print(generate_text(keras_model_batch1, 'What of TensorFlow Federated, you ask? '))
keras_model_batch1.summary()


_________________________________________________________________
Layer (type)                 Output Shape              Param #   
embedding_6 (Embedding)      (1, None, 256)            22016     
_________________________________________________________________
gru_6 (GRU)                  (1, None, 1024)           3935232   
_________________________________________________________________
dense_6 (Dense)              (1, None, 86)             88150     
Total params: 4,045,398
Trainable params: 4,045,398
Non-trainable params: 0
_________________________________________________________________


In [58]:
train_data, test_data = tff.simulation.datasets.shakespeare.load_data()
train_data.client_ids

['ALL_S_WELL_THAT_ENDS_WELL_ADAM',
 'ALL_S_WELL_THAT_ENDS_WELL_AEDILE',
 'ALL_S_WELL_THAT_ENDS_WELL_AGRIPPA',
 'ALL_S_WELL_THAT_ENDS_WELL_ALEXAS',
 'ALL_S_WELL_THAT_ENDS_WELL_ALL',
 'ALL_S_WELL_THAT_ENDS_WELL_ALL_THE_PEOPLE',
 'ALL_S_WELL_THAT_ENDS_WELL_AMIENS',
 'ALL_S_WELL_THAT_ENDS_WELL_ANTIPHOLUS_OF_EPHESUS',
 'ALL_S_WELL_THAT_ENDS_WELL_ANTONY',
 'ALL_S_WELL_THAT_ENDS_WELL_ARVIRAGUS',
 'ALL_S_WELL_THAT_ENDS_WELL_AUDREY',
 'ALL_S_WELL_THAT_ENDS_WELL_AUFIDIUS',
 'ALL_S_WELL_THAT_ENDS_WELL_BELARIUS',
 'ALL_S_WELL_THAT_ENDS_WELL_BOTH',
 'ALL_S_WELL_THAT_ENDS_WELL_BOTH_TRIBUNES',
 'ALL_S_WELL_THAT_ENDS_WELL_BROTHERS',
 'ALL_S_WELL_THAT_ENDS_WELL_BRUTUS',
 'ALL_S_WELL_THAT_ENDS_WELL_CAESAR',
 'ALL_S_WELL_THAT_ENDS_WELL_CANIDIUS',
 'ALL_S_WELL_THAT_ENDS_WELL_CAPTAIN',
 'ALL_S_WELL_THAT_ENDS_WELL_CELIA',
 'ALL_S_WELL_THAT_ENDS_WELL_CENTURION',
 'ALL_S_WELL_THAT_ENDS_WELL_CHARLES',
 'ALL_S_WELL_THAT_ENDS_WELL_CHARMIAN',
 'ALL_S_WELL_THAT_ENDS_WELL_CITIZEN',
 'ALL_S_WELL_THAT_ENDS_WELL_CITIZ

In [59]:
example = train_data.create_tf_dataset_for_client('ALL_S_WELL_THAT_ENDS_WELL_ANTONY')
lenx = 0
for x in example:
    print(x['snippets'])
    lenx += 1
print(lenx)

tf.Tensor(b'Fulvia is dead.', shape=(), dtype=string)
tf.Tensor(b'What is his strength by land?', shape=(), dtype=string)
tf.Tensor(b"There's a great spirit gone! Thus did I desire it.\nWhat our contempts doth often hurl from us\nWe wish it ours again; the present pleasure,\nBy revolution low'ring, does become\nThe opposite of itself. She's good, being gone;\nThe hand could pluck her back that shov'd her on.\nI must from this enchanting queen break off.\nTen thousand harms, more than the ills I know,\nMy idleness doth hatch. How now, Enobarbus!\n                Re-enter ENOBARBUS\nI must with haste from hence.", shape=(), dtype=string)
tf.Tensor(b'Sit, sir.', shape=(), dtype=string)
tf.Tensor(b"To this good purpose, that so fairly shows,\nDream of impediment! Let me have thy hand.\nFurther this act of grace; and from this hour\nThe heart of brothers govern in our loves\nAnd sway our great designs!\nI did not think to draw my sword 'gainst Pompey;", shape=(), dtype=string)
tf.Tensor(b"T

In [22]:
raw_example_dataset = train_data.create_tf_dataset_for_client(
    'THE_TRAGEDY_OF_KING_LEAR_KING')
# To allow for future extensions, each entry x
# is an OrderedDict with a single key 'snippets' which contains the text.
for x in raw_example_dataset.take(2):
    print(x['snippets'])

tf.Tensor(b"Live regist'red upon our brazen tombs,\nAnd then grace us in the disgrace of death;\nWhen, spite of cormorant devouring Time,\nTh' endeavour of this present breath may buy\nThat honour which shall bate his scythe's keen edge,\nAnd make us heirs of all eternity.\nTherefore, brave conquerors- for so you are\nThat war against your own affections\nAnd the huge army of the world's desires-\nOur late edict shall strongly stand in force:\nNavarre shall be the wonder of the world;\nOur court shall be a little Academe,\nStill and contemplative in living art.\nYou three, Berowne, Dumain, and Longaville,\nHave sworn for three years' term to live with me\nMy fellow-scholars, and to keep those statutes\nThat are recorded in this schedule here.\nYour oaths are pass'd; and now subscribe your names,\nThat his own hand may strike his honour down\nThat violates the smallest branch herein.\nIf you are arm'd to do as sworn to do,\nSubscribe to your deep oaths, and keep it too.\nYour oath is pa

In [60]:
# Input pre-processing parameters
SEQ_LENGTH = 100
BATCH_SIZE = 8
BUFFER_SIZE = 10000  # For dataset shuffling

In [61]:
table = tf.contrib.lookup.index_table_from_tensor(
    mapping=vocab,
    num_oov_buckets=0,
    default_value=0)  
  
def to_ids(x):
    s = tf.reshape(x['snippets'], shape=[1])
    chars = tf.string_split(s, delimiter='').values
    ids = table.lookup(chars)
    return ids  

def split_input_target(chunk):
    input_text = tf.map_fn(lambda x: x[:-1], chunk)
    target_text = tf.map_fn(lambda x: x[1:], chunk)
    return BatchType(input_text, target_text)

  
def preprocess(dataset):  
    return (
      # Map ASCII chars to int64 indexes using the vocab
      dataset.map(to_ids)
      # Split into individual chars
      .apply(tf.data.experimental.unbatch())
      # Form example sequences of SEQ_LENGTH +1
      .batch(SEQ_LENGTH + 1,  drop_remainder=True)
      # Shuffle and form minibatches
      .shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)
      # And finally split into (input, target) tuples,
      # each of length SEQ_LENGTH.
      .map(split_input_target))


In [62]:
example_dataset = preprocess(raw_example_dataset)
# print(example_dataset.output_types, example_dataset.output_shapes)  
example_dataset

<DatasetV1Adapter shapes: BatchType(x=(8, 100), y=(8, 100)), types: BatchType(x=tf.int64, y=tf.int64)>

In [63]:
class FlattenedCategoricalAccuracy(tf.keras.metrics.SparseCategoricalAccuracy):

    def __init__(self, name='accuracy', dtype=None):
        super(FlattenedCategoricalAccuracy, self).__init__(name, dtype=dtype)

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true = tf.reshape(y_true, [-1, 1])
        y_pred = tf.reshape(y_pred, [-1, len(vocab), 1])
        return super(FlattenedCategoricalAccuracy, self).update_state(
            y_true, y_pred, sample_weight)

In [64]:
def loss_fn(y_true, y_pred):
    return tf.reduce_mean(
      tf.keras.metrics.sparse_categorical_crossentropy(
          y_true, y_pred, from_logits=True))

In [67]:
def compile(keras_model):
    keras_model.compile(
      optimizer=gradient_descent.SGD(lr=0.5),
      loss=loss_fn,
      metrics=[FlattenedCategoricalAccuracy()])
    return keras_model

In [68]:
#@test {"output": "ignore", "timeout": 120}
BATCH_SIZE = 8  # The training and eval batch size for the rest of this tutorial.
keras_model = load_model(batch_size=BATCH_SIZE)

compile(keras_model)

print('Evaluating on an example Shakespeare character:')
keras_model.evaluate(example_dataset.take(1), steps=1)


random_indexes = np.random.randint(
    low=0, high=len(vocab), size=1 * BATCH_SIZE * (SEQ_LENGTH + 1))
data = {
    'snippets':
        tf.constant(''.join(np.array(vocab)[random_indexes]), shape=[1, 1])
}
random_dataset = preprocess(tf.data.Dataset.from_tensor_slices(data))
print('Expected accuracy for random guessing: {:.3f}'.format(1.0 / len(vocab)))
print('Evaluating on completely random data:')
keras_model.evaluate(random_dataset, steps=1)

Evaluating on an example Shakespeare character:
Expected accuracy for random guessing: 0.012
Evaluating on completely random data:


[11.508935928344727, 0.01]

# Fine-tune the model with Federated Learning

In [18]:
def create_tff_model():
    x = tf.constant(np.random.randint(1, len(vocab), size=[BATCH_SIZE, SEQ_LENGTH]))
    dummy_batch = collections.OrderedDict([('x', x), ('y', x)]) 
    keras_model_clone = compile(tf.keras.models.clone_model(keras_model))
    return tff.learning.from_compiled_keras_model(
      keras_model_clone, dummy_batch=dummy_batch)

In [0]:
fed_avg = tff.learning.build_federated_averaging_process(model_fn=create_tff_model)

In [0]:
state = fed_avg.initialize()
state, metrics = fed_avg.next(state, [example_dataset.take(1)])
print(metrics)

<accuracy=0.005,loss=4.45424>


In [0]:
def data(client, source=train_data):
    return preprocess(
      source.create_tf_dataset_for_client(client)).take(2)

clients = ['ALL_S_WELL_THAT_ENDS_WELL_CELIA',
           'MUCH_ADO_ABOUT_NOTHING_OTHELLO',
           'THE_TRAGEDY_OF_KING_LEAR_KING']

train_datasets = [data(client) for client in clients]

# We concatenate the test datasets for evaluation with Keras.
test_dataset = functools.reduce(
    lambda d1, d2: d1.concatenate(d2),
    [data(client, test_data) for client in clients])

In [0]:
NUM_ROUNDS = 3

state = fed_avg.initialize()

state = tff.learning.state_with_new_model_weights(
    state,
    trainable_weights=[v.numpy() for v in keras_model.trainable_weights],
    non_trainable_weights=[
        v.numpy() for v in keras_model.non_trainable_weights
    ])


def keras_evaluate(state, round_num):
    keras_model.set_weights(
      tff.learning.keras_weights_from_tff_weights(state.model))
    print('Evaluating before training round', round_num)
    keras_model.evaluate(example_dataset, steps=2)


for round_num in range(NUM_ROUNDS):
    keras_evaluate(state, round_num)

keras_evaluate(state, NUM_ROUNDS + 1)

Evaluating before training round 0
Training metrics:  <accuracy=0.427083,loss=3.05846>
Evaluating before training round 1
Training metrics:  <accuracy=0.531458,loss=2.15261>
Evaluating before training round 2
Training metrics:  <accuracy=0.633125,loss=1.56697>
Evaluating before training round 4


In [0]:
keras_model_batch1.set_weights([v.numpy() for v in keras_model.weights])
print(generate_text(keras_model_batch1, 'What of TensorFlow Federated, you ask? '))

What of TensorFlow Federated, you ask? Says
in my cell in the husband and fiffeinne, the hands of the clock, struck him to the chocolate with lists with the rascal public himself.
What forehead was the golden head that lived it for music;
