# Non-Federated EMNIST Baseline Training

This colab has three main parts:

*   It trains a non-federated model on a flattened and shuffled (that is,
    non-federated) view of the the
    [Federated EMNIST](https://www.tensorflow.org/federated/api_docs/python/tff/simulation/datasets/emnist/load_data)
    dataset. The model architecture matches the simple CNN from the paper
    [Communication-Efficient Learning of Deep Networks from Decentralized Data](https://arxiv.org/abs/1602.05629).
    This (currently untuned) training reaches an accuracy of about 97% with
    vanilla SGD. (Note these accuracy numbers are not directly comparable to
    MNIST results, as the train and test datasets are different). This is
    intended to serve as a baseline for simulated federated training on the Fed
    EMNIST dataset.

*   It uses this model to examine the Fed EMNIST dataset, showing it has
    interesting variation across users.

*   As a sanity check, it shows an equivalent model can be trained using the
    `Federated Averaging` implementation from `tff.learning` applied to a
    non-federated (that is, flattened and shuffled) view of the data.


**Note:** This notebook will probably take ~25 minutes to fully execute.

In [0]:
!pip install tensorflow_federated
!pip install tensorflow_gan

In [0]:
from __future__ import absolute_import, division, print_function

import collections
import functools
import numpy as np
import time

import matplotlib.pyplot as plt
import pandas as pd

import tensorflow as tf
import tensorflow_federated as tff
import tensorflow_gan as tfgan

tf.compat.v1.enable_v2_behavior()

# Training a baseline model with Keras

## Data

Download the data, and lightly reformat for use in Keras.

In [0]:
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

In [0]:
Example = collections.namedtuple('Example', ['x', 'y'])

BATCH_SIZE = 10
SHUFFLE_BUFFER = 10000


def element_fn(element):
  return Example(
      x=tf.reshape(element['pixels'], [-1]),
      y=tf.reshape(element['label'], [1]))


def preprocess_train(dataset, batch_size=BATCH_SIZE):
  return dataset.map(element_fn).apply(
      tf.data.experimental.shuffle_and_repeat(
          buffer_size=SHUFFLE_BUFFER, count=-1)).batch(batch_size)


def preprocess_test(dataset):
  return dataset.map(element_fn).batch(100, drop_remainder=False)


flat_train_data = preprocess_train(
    emnist_train.create_tf_dataset_from_all_clients(seed=739613565))
flat_test_data = preprocess_test(
    emnist_test.create_tf_dataset_from_all_clients(seed=686991103))

## Model

In [0]:
def build_cnn():
  """The CNN model used in https://arxiv.org/abs/1602.05629.

  The number of parameters (1,663,370) matches what is reported in the paper.
  """
  data_format = 'channels_last'
  input_shape = [28, 28, 1]

  # Alternatively:
  # data_format = 'channels_first'
  # input_shape = [1, 28, 28]

  max_pool = lambda: tf.keras.layers.MaxPooling2D(
      pool_size=(2, 2), padding='same', data_format=data_format)
  conv2d = functools.partial(
      tf.keras.layers.Conv2D,
      kernel_size=5,
      padding='same',
      data_format=data_format,
      activation=tf.nn.relu)

  model = tf.keras.models.Sequential([
      tf.keras.layers.Reshape(target_shape=input_shape, input_shape=(28 * 28,)),
      conv2d(filters=32),
      max_pool(),
      conv2d(filters=64),
      max_pool(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(512, activation=tf.nn.relu),
      tf.keras.layers.Dense(10, activation=tf.nn.softmax),
  ])

  model.compile(
      loss=tf.keras.losses.sparse_categorical_crossentropy,
      # This learning rate has not been tuned.
      optimizer=tf.keras.optimizers.SGD(learning_rate=0.02),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

  return model


build_cnn().summary()

## Training and evaluation

In [0]:
NUM_ROUNDS = 10
BATCHES_PER_ROUND = 1000

In [0]:
model = build_cnn()
# We set steps_per_epoch and epochs just to break training up in reasonable "chunks".
# These aren't really epochs over the full dataset.
# Training could take about 8 minutes.
model.fit(flat_train_data, steps_per_epoch=BATCHES_PER_ROUND, epochs=NUM_ROUNDS)

There are 40,832 test examples, so we take 409 eval steps with a test batch size
of 100:

In [0]:
_ = model.evaluate(flat_test_data)

# Aside: Do some users have hard-to-classify data?

Since this is a public dataset intended for research, one interesting thing we
can do with the model is to use it to see if some users
have hard-to-classify data. This is way of verifying that the Federated EMNIST
dataset has interesting variation across users.

In [0]:
def display_raw_emnist(data, grid_width=25):
  """A helper function to display images from Fed EMNIST datasets."""
  # List of numpy images
  img_data = np.array([x['pixels'].numpy() for x in data])
  img_data = np.reshape(img_data, (-1, 28, 28, 1))
  num_rows = int(np.ceil(len(img_data) / grid_width))
  
  # Pad to rectangular since tfgan.eval.python_image_grid
  # expects this.
  needed_images = num_rows * grid_width
  tmp = np.zeros((needed_images, 28, 28, 1))
  s = img_data.shape
  tmp[:s[0], :s[1], :s[2]] = img_data
  img_data = tmp
  
  img_grid = tfgan.eval.python_image_grid(
      img_data, grid_shape=(num_rows, grid_width))

  h = 20
  w = h * (grid_width / num_rows)
  plt.figure(figsize=(h, w))
  plt.axis('off')
  plt.imshow(np.squeeze(img_grid), cmap='binary')
  plt.show()

Display the data from clients with accuracy below a threshold.

In [0]:
accuracy_by_client_id = {}
THRESHOLD = 0.82

for i, client_id in enumerate(emnist_train.client_ids):
  raw_data = emnist_train.create_tf_dataset_for_client(client_id)
  num_examples = sum([1 for _ in raw_data])
  loss, accuracy = model.evaluate(preprocess_test(raw_data), verbose=0)
  accuracy_by_client_id[client_id] = accuracy
  if accuracy < THRESHOLD:
    print('client {} ({}) with {:3d} examples has accuracy {:6.2f}%'.format(
        i, client_id, num_examples, 100 * accuracy))
    display_raw_emnist(raw_data)

## Accuracy vs client_id
Now, let's plot accuracy versus the (sorted) `client_id`s. We are interested in the general trend, so we use a moving average over clients.

In [0]:
client_ids = sorted(list(emnist_train.client_ids))
y = [accuracy_by_client_id[client_id] for client_id in client_ids]
y = pd.Series(y).rolling(window=50).mean(center=True)
plt.figure(figsize=(15, 3))
plt.plot(range(len(y)), y)
plt.title('Rolling mean accuracy vs client_id')
plt.xlabel('client_id')
plt.ylabel('Accuracy')
plt.ylim(0.9, 1.0)
s1 = client_ids.index('f2100_97')
s2 = client_ids.index('f3100_44')
x_loc = [500, 1000, 1500, s1, s2, 3000]
plt.xticks(x_loc, [str(client_ids[x]) for x in x_loc])
plt.vlines([s1, s2], 0.9, 1.0)
plt.show()

There appears to be a correlation between the (sorted) `client_id`s and accuracy. This is likely due to the client_ids indicating the source; see Table 2 in the [User's Guide](https://s3.amazonaws.com/nist-srd/SD19/1stEditionUserGuide.pdf) for the NIST Special Database 19. Writers `f0000` - `f2099` were Census Bureau field personal, `f2100` - `f3099`  were high school students, and `f3100` - `f4099` were Census Bureau employees in Maryland.

This finding implies that for centralized baseline training (as we did above), sufficient shuffling of the data is important (see also the TODO(b/135021147) above to improve this). 

For federated training, randomly sampling users is important; alternatively, this data could be used to simulate three different "blocks" of users to test the behavior of [Semi-Cyclic Stochastic Gradient Descent](https://arxiv.org/abs/1904.10120) as well as the mitigations suggested in the linked paper; non-federated experiments can be found [here](https://github.com/tensorflow/federated/tree/master/tensorflow_federated/python/research/semi_cyclic_sgd).

# Replicating the baseline with `tff.learning`

Here we show how to use `tff.learning` to replicate the non-federated
baseline. However, critically *the training is still essentially
non-federated* --- that is, this is a santiy check, not a proper simulation of
federated learning.

The approach is based on the fact that if each "client" has IID shuffled data
from a centralized training set, and we use the `FederatedAveraging` algorithm
with one client per round (so there is no actual averaging), then this is
algorithmically equivalent to running SGD centrally.



## Construct the `federated_averaging_process`

In [0]:
dummy_batch = tf.nest.map_structure(lambda x: x.numpy(),
                                    next(iter(flat_train_data.take(1))))


def create_tff_model():
  keras_model = build_cnn()
  return tff.learning.from_compiled_keras_model(
      build_cnn(), dummy_batch=dummy_batch)


fed_avg_process = tff.learning.build_federated_averaging_process(
    model_fn=create_tff_model)

## Helper for selecting datasets for each "round"
We work around a dataset issue to construct a sequence of Datasets each containing
`BATCHES_PER_ROUND` batches of size `BATCH_SIZE` from the flat shuffled
training data.

TODO(b/134945216): Once supported, use `tf.data.Dataset.window()` instead.

In [0]:
# Dataset of "big" batches to work around window issue (b/134945216)
tff_train_data = preprocess_train(
    emnist_train.create_tf_dataset_from_all_clients(),
    batch_size=BATCH_SIZE * BATCHES_PER_ROUND)

tff_train_data_iter = iter(tff_train_data)


def next_client_dataset():
  # Grab the next "big" batch, create a dataset, and split into regular batches.
  client_data = tf.data.Dataset.from_tensor_slices(next(tff_train_data_iter))
  return client_data.batch(BATCH_SIZE)

## Training and evaluation

Now we are ready to do some training. We do 10 rounds of 1000 batches per round,
but the split between rounds doesn't really matter.

In [0]:
state = fed_avg_process.initialize()

print('Running Federated Averaging')
start_time = time.time()
for i in range(NUM_ROUNDS):
  # Run one round of FederatedAveraging, on a single client.
  round_start_time = time.time()
  state, metrics = fed_avg_process.next(state, [next_client_dataset()])
  finish_time = time.time()
  print('Round {:3d} took {:6.2f} seconds (total {:4.0f} seconds). '
        'Training metrics: {}'.format(i, finish_time - round_start_time,
                                      finish_time - start_time, metrics))

In [0]:
print('Final model evaluation on test data')
keras_model = build_cnn()
tff.learning.assign_weights_to_keras_model(keras_model, state.model)
_ = keras_model.evaluate(flat_test_data)