A small self-contained training example built directly on top of Core API.

This example contains a simpler form of federated averaging logic, similar to that which one might find in `../../learning`, but optimized for simplicity and compactness and illustrating the use of basic mechanisms provided by the Core API. The user is encouraged to study the simplified structure of this example first as a stepping stone towards the more general implementation in `learning`.

In [0]:
import collections

import tensorflow as tf

from tensorflow_federated import python as tff

The model we will train is a simple linear classifier that given a discrete input class named `X` predicts a discrete output class named `Y`. The input class is an integer from 0 to `NUM_X_CLASSES` - 1, and the output is an integer from 0 to `NUM_Y_CLASSES`.


In [0]:
NUM_X_CLASSES = 7
NUM_Y_CLASSES = 3

The samples of data used to train or evaluate the model will consists of the values of the 2 features 'X' and 'Y'. The data will arrive in batches (e.g., as is typicallyh the case for the output of 'tf.parse_example'). Every batch of samples is represented as a TFF named tuple ewith 2 elements 'X' and 'Y', each of which is a tf.int32 tensor with a single dimension (so conceptually, a vector). The two tensors should have the same number of elements (this is not expressed in the type below). The number of elements is unspecified (the shape of both scalars is [None]), since the individual batches of data might contain unequal numbers of samples not known in advance.


In [21]:
BATCH_TYPE = tff.NamedTupleType([
    ('X', tff.TensorType(tf.int32, shape=[None])),
    ('Y', tff.TensorType(tf.int32, shape=[None])),
])
BATCH_TYPE

NamedTupleType([('X', TensorType(tf.int32, [None])), ('Y', TensorType(tf.int32, [None]))])

The input to training/evaluation is simply a sequence of such batches.

In [22]:
INPUT_TYPE = tff.SequenceType(BATCH_TYPE)
INPUT_TYPE

SequenceType(NamedTupleType([('X', TensorType(tf.int32, [None])), ('Y', TensorType(tf.int32, [None]))]))

The parameters of the model consist of a weight matrix and a bias vector, to be applied to a one-hot encoding of the inputs.


In [23]:
MODEL_TYPE = tff.NamedTupleType([
    ('weights', tff.TensorType(tf.float32, [NUM_X_CLASSES, NUM_Y_CLASSES])),
    ('bias', tff.TensorType(tf.float32, NUM_Y_CLASSES)),
])
MODEL_TYPE

NamedTupleType([('weights', TensorType(tf.float32, [7, 3])), ('bias', TensorType(tf.float32, [3]))])

A simple TensorFlow computation that computes loss and accuracy metrics on a batch of features.

The function decorator `tff.tf_computation` transforms the Python function into a `tff.Computation`, a unit of composition in TFF. When a Python function is wrapped as a computation, one can think of it as conceptually consuming and returning TFF values that have TFF types. The TFF type of the parameter of computation `tff.extract_features` is a tensor type declared as the argument to the decorator, see above for the definition of `BATCH_TYPE`. The TFF type of the value returned by the `tff.extract_features` computation is determined automatically. In this case, it is a TFF named tuple with 2 named elements constructed from elements of the returned Python dictionary.


In [0]:
@tff.tf_computation([BATCH_TYPE, MODEL_TYPE])
def forward_pass(features, model):
  encoded_x = tf.one_hot(features.X, NUM_X_CLASSES)
  encoded_y = tf.one_hot(features.Y, NUM_Y_CLASSES)
  softmax_y = tf.nn.softmax(tf.matmul(encoded_x, model.weights) + model.bias)
  loss = tf.reduce_mean(
      -tf.reduce_sum(encoded_y * tf.log(softmax_y), reduction_indices=1))
  prediction = tf.cast(tf.argmax(softmax_y, 1), tf.int32)
  is_correct = tf.equal(prediction, features.Y)
  accuracy = tf.reduce_mean(tf.cast(is_correct, tf.float32))
  count = tf.size(features.X)
  return collections.OrderedDict([
      ('loss', loss),
      ('accuracy', accuracy),
      ('count', count),
  ])

Capturing type signature of the named tuple of statistics (metrics, counters) computed by the model.

In [25]:
STATS_TYPE = forward_pass.type_signature.result
STATS_TYPE

NamedTupleType([('loss', TensorType(tf.float32)), ('accuracy', TensorType(tf.float32)), ('count', TensorType(tf.int32))])

In [0]:
# Since execution is not yet supported, using `tf_computation` temporarily
# to allow `forward_pass` to get stamped, and manually driving Session.run
# to execute it.
# TODO(b/113116813): Use the interfaces for executing computations as soon
# as they're ready rather than manually driving TensorFlow graphs in tests,
# which should help to shorten and simplify this code.
@tff.tf_computation
def _():
  # When creating the batch, we have to specify a concrete shape, since by
  # default 'BATCH_TYPE' leaves batch size undefined.
  batch = tff.utils.get_variables(
      name='batch',
      type_spec=[('X', (tf.int32, [5])), ('Y', (tf.int32, [5]))],
      initializer=tf.zeros_initializer())

  model = tff.utils.get_variables(
      name='model',
      type_spec=MODEL_TYPE,
      initializer=tf.zeros_initializer())

  loss = forward_pass(batch, model).loss

  # TODO(b/113116813): Replace this temporary workaround with a proper call
  # that gets plumbed through the execution API, once it materializes.
  # For now, just testing here that the graph has been stitched correctly,
  # and that something gets computed at all.
  with tf.Session(graph=tf.get_default_graph()) as sess:
    sess.run(tf.global_variables_initializer())
    sess.run(loss)

  return loss