# FedJAX Datasets

This tutorial introduces datasets in FedJAX and how to work with them. By completing this tutorial, you'll learn how to write clear and efficient code when working with datasets that follows best practices.

**NOTE: For datasets, everything is done with NumPy NOT JAX**

In [None]:
import functools
import itertools
import fedjax
import numpy as np

## What are datasets in federated learning?

In the context of federated learning (FL), data is decentralized across clients, with each client having their own local set of examples. In light of this, we refer to two levels of organization for datasets:

- Federated dataset: A collection of clients, each with their own local datasets and metadata
- Client dataset: The set of local examples for a particular client

You can think of federated data as a mapping from client ids to client datasets and client datasets as a list of examples.

```
federated_data = {
  'client0': ['a', 'b', 'c'],
  'client1': ['d', 'e'],
}
```

### Federated datasets structure

FedJAX defines a `fedjax.FederatedData` interface for all federated datasets. 

FedJAX comes packaged with multiple federated datasets, but we will look specifically at the Shakespeare dataset. The Shakespeare dataset is based on [The Complete Works of Shakespeare](https://www.gutenberg.org/files/100/100-0.txt), where each character in the play is a "client" and their dialogue lines are the examples.

In [None]:
train_fd, test_fd = fedjax.datasets.shakespeare.load_data()

`train_fd` and `test_fd` are the train and test federated datasets, respectively.

We can look at some of the metadata about the federated dataset, like the total number of clients, client ids, and number of examples for each client.

In [None]:
print('num_clients =', train_fd.num_clients())

# train_fd.client_ids() is a generator of client ids.
# itertools has efficient and convenient functions for working with generators.
for client_id in itertools.islice(train_fd.client_ids(), 3):
  print('client_id =', client_id)
  print('# examples =', train_fd.client_size(client_id))

num_clients = 715
client_id = b'00192e4d5c9c3a5b:ALL_S_WELL_THAT_ENDS_WELL_CENTURION'
# examples = 5
client_id = b'004309f15562402e:THE_FIRST_PART_OF_KING_HENRY_THE_FOURTH_CAMPEIUS'
# examples = 13
client_id = b'00b20765b748920d:THE_FIRST_PART_OF_KING_HENRY_THE_FOURTH_ALL'
# examples = 15


As seen in the output, there are 715 total clients in the Shakespeare dataset.
Each client has a unique client ID that can be used to query metadata about that client such as the number of examples that client has.

We can also query the dataset for a client using their client ID and `fedjax.FederatedData.get_client()`.

In [None]:
client_id = b'00192e4d5c9c3a5b:ALL_S_WELL_THAT_ENDS_WELL_CENTURION'
client_dataset = train_fd.get_client(client_id)
print(client_dataset)

<fedjax.core.client_datasets.ClientDataset object at 0x7f357713fc88>


The output of `fedjax.FederatedData.get_client()` is a `fedjax.ClientDataset` object that stores all the examples for a given client as well as built-in methods for batching, shuffling, and iterating over the data. Later, we will go more deeply into `fedjax.ClientDataset` and its structure and built-in methods, but for now, just consider `fedjax.ClientDataset` as all the examples for a given client.

### Client datasets structure

`fedjax.ClientDataset` is the interface for client datasets. We make the assumption that individual client datasets are small and can easily fit in memory. This assumption is also reflected in many of our design decisions.

**ClientDataset = examples + preprocessor**

In [None]:
# We cap max sentence length to 8.
train_fd, test_fd = fedjax.datasets.shakespeare.load_data(sequence_length=8)
cid = b'105f96df763d4ddf:ALL_S_WELL_THAT_ENDS_WELL_GUIDERIUS_AND_ARVIRAGUS'
cds = train_fd.get_client(cid)

The examples in a client dataset can be viewed as a table, where the rows are
the individual examples, and the columns are the features (labels are viewed as
a feature in this context).

We use a column based representation when loading a dataset into memory.

-   Each column is a NumPy array `x` of rank at least 1, where `x[i, ...]` is
    the value of this feature for the `i`-th example.
-   The complete set of examples is a dict-like object, from `str` feature
    names, to the corresponding column values.

Traditionally, a row based representation is used for representing the entire
dataset, and a column based representation is used for a single batch.

**In the context of federated learning, an individual client dataset is small
enough to easily fit into memory so the same representation is used for the
entire dataset and a batch.**

In [None]:
cds.all_examples()

{'x': array([[ 1, 55, 67, 84, 67, 47,  7, 67],
        [48, 16, 13, 32, 33, 14, 11, 78],
        [76, 78, 33, 19, 16, 66, 47,  3],
        [16, 27, 67, 23, 26, 47,  3, 27],
        [16,  7,  4, 67, 16, 51, 48, 68],
        [ 7, 26, 47, 27, 42, 16,  7,  4],
        [67, 72, 16, 48, 67, 27, 23, 71],
        [67, 65, 29, 79, 76, 51, 74, 12],
        [75, 54, 74, 19, 16, 66, 47,  3],
        [16, 67,  8, 67, 71, 47,  7, 61],
        [16, 14,  4, 67, 47, 16, 48, 67],
        [84, 67, 47,  7, 67, 48, 16, 12],
        [78, 29, 75, 78, 33, 16, 66, 47],
        [ 3, 16, 75, 73, 29, 11, 75, 76],
        [32, 19, 65, 16, 16, 16, 16, 16],
        [16, 16, 16, 16, 16, 16, 16, 16],
        [16, 16, 16, 16, 16, 16, 16, 16],
        [28, 68,  7,  4, 16, 75, 76, 32],
        [30, 74, 54, 65,  0,  0,  0,  0]], dtype=int32),
 'y': array([[55, 67, 84, 67, 47,  7, 67, 48],
        [16, 13, 32, 33, 14, 11, 78, 76],
        [78, 33, 19, 16, 66, 47,  3, 16],
        [27, 67, 23, 26, 47,  3, 27, 16],
        [

For Shakespeare, we are training a character-level language model, where the task is next character prediction, so the features are:

- `x` is a list of right-shifted sentences, e.g. `sentence[:-1]`
- `y` is a list of left-shifted sentences, e.g. `sentence[1:]`

This way, the pair `x[i][j]` and `y[i][j]` corresponds to the previous and next characters, respectively.

In [None]:
examples = cds.all_examples()
print('x', examples['x'][0])
print('y', examples['y'][0])

x [ 1 55 67 84 67 47  7 67]
y [55 67 84 67 47  7 67 48]


However, you probably noticed that `x` and `y` are arrays of integers not text. This is because some minimal preprocessing was done as part of `fedjax.datasets.shakespeare.load_data()` that did simple character look up that mapped characters to integer IDs. Later, we'll go over how this preprocessing was applied and how to add your own custom preprocessing.

We can also view the unprocessed version of the data

In [None]:
raw_fd = fedjax.datasets.shakespeare.load_split('train')
raw_cds = raw_fd.get_client(client_id)
raw_cds.all_examples()

Reusing cached file '/tmp/.cache/fedjax/shakespeare_train.sqlite'


{'snippets': array([b"If we be not reliev'd within this hour,",
        b"Let's hear him, for the things he speaks",
        b'May concern Caesar.\nSwoons rather; for so bad a prayer as his',
        b"We must return to th' court of guard. The night\nIs shiny, and they say we shall embattle\nBy th' second hour i' th' morn.\nEnobarbus?",
        b"[Drums afar off ] Hark! the drums\nDemurely wake the sleepers. Let us bear him\nTo th' court of guard; he is of note. Our hour\nIs fully out.\n"],
       dtype=object)}

## Accessing federated datasets

The previous methods work well for querying data for a *single* client for exploring the dataset. However, we often want to query for multiple client datasets at the same time. In most FL algorithms, tens to hundreds of clients particpate in each training round, and for large federated datasets, it is not feasible to load all client datasets into memory at once (whereas loading a single client dataset is assumed to be feasible).

In light of this, we offer more efficient methods for querying multiple client datasets that we **STRONGLY** recommend you use that leverage the fact that sequential read is much faster than random read for most storage technologies.

**We'll go through each acces method from the MOST efficient to the LEAST efficient.**

In [None]:
train_fd, test_fd = fedjax.datasets.shakespeare.load_data()

### `clients()` and `shuffled_clients()`

**Fastest** sequential read friendly access.



In [None]:
# clients() and shuffled_clients() are sequential read friendly.
clients = train_fd.clients()
shuffled_clients = train_fd.shuffled_clients(buffer_size=100, seed=0)
print('clients =', clients)
print('shuffled_clients =', shuffled_clients)

clients = <generator object SQLiteFederatedData.clients at 0x7f357710e258>
shuffled_clients = <generator object SQLiteFederatedData.shuffled_clients at 0x7f357710e150>


They are generators, so in order to use them, we need to iterator over them.

`clients()` returns clients in a deterministic order, where each "client" is a tuple of (client_id, client_dataset).

In [None]:
for client_id, client_dataset in itertools.islice(clients, 3):
  print('client_id =', client_id)
  print('# examples =', len(client_dataset))

client_id = b'00192e4d5c9c3a5b:ALL_S_WELL_THAT_ENDS_WELL_CENTURION'
# examples = 6
client_id = b'004309f15562402e:THE_FIRST_PART_OF_KING_HENRY_THE_FOURTH_CAMPEIUS'
# examples = 24
client_id = b'00b20765b748920d:THE_FIRST_PART_OF_KING_HENRY_THE_FOURTH_ALL'
# examples = 8


`shuffled_clients()` is like `clients()` but allows repeated buffered shuffling.


In [None]:
print('shuffled_clients()')
for client_id, client_dataset in itertools.islice(shuffled_clients, 3):
  print('client_id =', client_id)
  print('# examples =', len(client_dataset))

shuffled_clients()
client_id = b'0a18c2501d441fef:THE_TRAGEDY_OF_KING_LEAR_FLUTE'
# examples = 12
client_id = b'136c5586b7271525:THE_FIRST_PART_OF_KING_HENRY_THE_FOURTH_GLOUCESTER'
# examples = 381
client_id = b'0d642a9b4bb27187:THE_FIRST_PART_OF_KING_HENRY_THE_FOURTH_MESSENGER'
# examples = 78


### `get_clients()`

**Slower** than `clients()` since it requires random read but uses prefetching to ameliorate the cost of random read access. This will return a generator of tuples of (client_id, client_dataset).

In [None]:
client_ids = [
    b'1db830204507458e:THE_TAMING_OF_THE_SHREW_SEBASTIAN',
    b'140784b36d08efbc:PERICLES__PRINCE_OF_TYRE_GHOST_OF_VAUGHAN',
    b'105f96df763d4ddf:ALL_S_WELL_THAT_ENDS_WELL_GUIDERIUS_AND_ARVIRAGUS'
]
for client_id, client_dataset in train_fd.get_clients(client_ids):
  print('client_id =', client_id)
  print('# examples =', len(client_dataset))

client_id = b'1db830204507458e:THE_TAMING_OF_THE_SHREW_SEBASTIAN'
# examples = 49
client_id = b'140784b36d08efbc:PERICLES__PRINCE_OF_TYRE_GHOST_OF_VAUGHAN'
# examples = 1
client_id = b'105f96df763d4ddf:ALL_S_WELL_THAT_ENDS_WELL_GUIDERIUS_AND_ARVIRAGUS'
# examples = 2


### `get_client()`

**Slowest** way of accessing client datasets. We usually reserve this method only for interactive exploration of a small number of clients.

In [None]:
client_id = b'1db830204507458e:THE_TAMING_OF_THE_SHREW_SEBASTIAN'
print('client_id =', client_id)
print('# examples =', len(train_fd.get_client(client_id)))

client_id = b'1db830204507458e:THE_TAMING_OF_THE_SHREW_SEBASTIAN'
# examples = 49


## Batching client datasets

Next we'll go over different methods of batching and iterating over the client dataset. All the following methods can be invoked in 2 ways:

1. Using a hyperparams object. This is the recommended way in library code. `batch_fn(hparams)`.
2. Using keyword arguments. The keyword arguments are used to construct a new hyperparams object, or override an existing one. `batch_fn(batch_size=2)` or `batch_fn(hparams, batch_size=2)` to override `batch_size`.

For the most part, we'll use method 2 for this colab, but we highly recommend using method 1 for writing library code.

In [None]:
train_fd, test_fd = fedjax.datasets.shakespeare.load_data(sequence_length=8)
cid = b'105f96df763d4ddf:ALL_S_WELL_THAT_ENDS_WELL_GUIDERIUS_AND_ARVIRAGUS'
cds = train_fd.get_client(cid)

### `padded_batch()`

Produces preprocessed padded batches in a fixed sequential order **for evaluation**.

When the number of examples in the dataset is not a multiple of `batch_size`,
the final batch may be smaller than `batch_size`. This may lead to [a large
number of JIT recompilations](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html). This can be circumvented by padding the final
batch to a small number of fixed sizes controlled by `num_batch_size_buckets`.

In [None]:
# use list() to consume generator and store in memory.
padded_batches = list(cds.padded_batch(batch_size=8, num_batch_size_buckets=3))
print('# batches =', len(padded_batches))
padded_batches[0]

# batches = 3


{'x': array([[ 1, 55, 67, 84, 67, 47,  7, 67],
        [48, 16, 13, 32, 33, 14, 11, 78],
        [76, 78, 33, 19, 16, 66, 47,  3],
        [16, 27, 67, 23, 26, 47,  3, 27],
        [16,  7,  4, 67, 16, 51, 48, 68],
        [ 7, 26, 47, 27, 42, 16,  7,  4],
        [67, 72, 16, 48, 67, 27, 23, 71],
        [67, 65, 29, 79, 76, 51, 74, 12]], dtype=int32),
 'y': array([[55, 67, 84, 67, 47,  7, 67, 48],
        [16, 13, 32, 33, 14, 11, 78, 76],
        [78, 33, 19, 16, 66, 47,  3, 16],
        [27, 67, 23, 26, 47,  3, 27, 16],
        [ 7,  4, 67, 16, 51, 48, 68,  7],
        [26, 47, 27, 42, 16,  7,  4, 67],
        [72, 16, 48, 67, 27, 23, 71, 67],
        [65, 29, 79, 76, 51, 74, 12, 75]], dtype=int32),
 '__mask__': array([ True,  True,  True,  True,  True,  True,  True,  True])}

All batches contain an extra bool feature keyed by `__mask__`.
`batch[__mask__][i]` tells us whether the `i`-th example in this batch
is an actual example (`batch[__mask__][i] == True`), or a padding
example (`batch[__mask__][i] == False`).

We repeatedly halve the batch size up to `num_batch_size_buckets - 1` times, until
we find the smallest one that is also >= the size of the final batch. Therefore
if `batch_size < 2^num_batch_size_buckets`, fewer bucket sizes will be actually
used. This will be seen when we look at the final batch that only has 4 examples when the original batch size was 8.

In [None]:
padded_batches[-1]

{'__mask__': array([ True,  True,  True, False]),
 'x': array([[16, 16, 16, 16, 16, 16, 16, 16],
        [28, 68,  7,  4, 16, 75, 76, 32],
        [30, 74, 54, 65,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0,  0,  0,  0]], dtype=int32),
 'y': array([[16, 16, 16, 16, 16, 16, 16, 28],
        [68,  7,  4, 16, 75, 76, 32, 30],
        [74, 54, 65,  2,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0,  0,  0,  0]], dtype=int32)}

### `shuffle_repeat_batch()`

Produces preprocessed batches in a shuffled and repeated order **for training**.

Shuffling is done without replacement, therefore for a dataset of N examples,
the first `ceil(N/batch_size)` batches are guarranteed to cover the entire
dataset. Unlike `batch()` or `padded_batch()`, batches from
`shuffle_repeat_batch()` always contain exactly `batch_size` examples. Also
unlike TensorFlow, that holds even when `drop_remainder=False`.

By default the iteration stops after the first epoch.

In [None]:
print('# batches')
len(list(cds.shuffle_repeat_batch(batch_size=8)))

# batches


3

The number of batches produced from the iteration can be controlled by the `(num_epochs, num_steps,
drop_remainder)` combination:

If both `num_epochs` and `num_steps` are None, the shuffle-repeat process continues forever.


In [None]:
infinite_bs = cds.shuffle_repeat_batch(
    batch_size=8, num_epochs=None, num_steps=None)
for i, b in zip(range(6), infinite_bs):
  print(i)

0
1
2
3
4
5


If `num_epochs` is set and `num_steps` is None, as few batches as needed to go
over the dataset this many passes are produced. Further,

-   If `drop_remainder` is False (the default), the final batch is filled with
    additionally sampled examples to contain `batch_size` examples.
-   If `drop_remainder` is True, the final batch is dropped if it contains fewer
    than `batch_size` examples. This may result in examples being skipped when
    `num_epochs=1`.

In [None]:
print('# batches w/ drop_remainder=False')
print(len(list(cds.shuffle_repeat_batch(batch_size=8, num_epochs=1, num_steps=None))))
print('# batches w/ drop_remainder=True')
print(len(list(cds.shuffle_repeat_batch(batch_size=8, num_epochs=1, num_steps=None, drop_remainder=True))))

# batches w/ drop_remainder=False
3
# batches w/ drop_remainder=True
2


If `num_steps` is set and `num_steps` is None, exactly this many batches are
produced. `drop_remainder` has no effect in this case.

In [None]:
print('# batches w/ num_steps set and drop_remainder=True')
print(len(list(cds.shuffle_repeat_batch(batch_size=8, num_epochs=None, num_steps=3, drop_remainder=True))))

# batches w/ num_steps set and drop_remainder=True
3


If both `num_epochs` and `num_steps` are set, the fewer number of batches
between the two conditions are produced.

In [None]:
print('# batches w/ num_epochs and num_steps set')
print(len(list(cds.shuffle_repeat_batch(batch_size=8, num_epochs=1, num_steps=6))))

# batches w/ num_epochs and num_steps set
3


If reproducible iteration order is desired, a fixed `seed` can be used. When
`seed` is None, repeated iteration over the same object may produce batches in a
different order.

In [None]:
# Random shuffling.
print(list(cds.shuffle_repeat_batch(batch_size=2, seed=None))[0])
# Fixed shuffling.
print(list(cds.shuffle_repeat_batch(batch_size=2, seed=0))[0])

{'x': array([[16, 16, 16, 16, 16, 16, 16, 16],
       [16, 27, 67, 23, 26, 47,  3, 27]], dtype=int32), 'y': array([[16, 16, 16, 16, 16, 16, 16, 28],
       [27, 67, 23, 26, 47,  3, 27, 16]], dtype=int32)}
{'x': array([[16, 14,  4, 67, 47, 16, 48, 67],
       [48, 16, 13, 32, 33, 14, 11, 78]], dtype=int32), 'y': array([[14,  4, 67, 47, 16, 48, 67, 84],
       [16, 13, 32, 33, 14, 11, 78, 76]], dtype=int32)}


### `batch()`

Produces preprocessed batches in a fixed sequential order.

The final batch may contain fewer than `batch_size` examples. If used directly,
that may result in a large number of JIT recompilations. **Therefore we
recommended using `padded_batch()` or `shuffle_repeat_batch()` instead in most scenarios.**

In [None]:
batches = list(cds.batch(batch_size=8))
batches[-1]

{'x': array([[16, 16, 16, 16, 16, 16, 16, 16],
        [28, 68,  7,  4, 16, 75, 76, 32],
        [30, 74, 54, 65,  0,  0,  0,  0]], dtype=int32),
 'y': array([[16, 16, 16, 16, 16, 16, 16, 28],
        [68,  7,  4, 16, 75, 76, 32, 30],
        [74, 54, 65,  2,  0,  0,  0,  0]], dtype=int32)}

## Preprocessing

Preprocessing can be done at two levels

- The batch level with `fedjax.BatchPreprocessor`
- The client dataset level with `fedjax.ClientPreprocessor`

**Examples of preprocessing possible at either the client dataset level, or
the batch level**

Such preprocessing is deterministic, and strictly per-example.

- Casting a feature from `int8` to `float32`.
- Adding a new feature derived from existing features.
- Remove a feature (although the better place to do so is at the dataset
  level).

A simple rule for deciding where to carry out the preprocessing in this case
is the following,

- Does this make batching cheaper (e.g. removing features)? If so, do it at
  the dataset level.
- Otherwise, do it at the batch level.

Assuming preprocessing time is linear in the number of examples, preprocessing
at the batch level has the benefit of evenly distributing host compute work,
which may overlap better with asynchronous JAX compute work on GPU/TPU.

**Examples of preprocessing only possible at the batch level**

- Data augmentation (e.g. random cropping).
- Padding at the batch size dimension.

**Examples of preprocessing only possible at the dataset level**

- Those that require knowing the client id.
- Capping the number of examples.
- Altering what it means to be an example: e.g. in certain language model
  setups, sentences are concatenated and then split into equal sized chunks.

In [None]:
# Load unpreprocessed data.
raw_fd = fedjax.datasets.shakespeare.load_split('train')

Reusing cached file '/tmp/.cache/fedjax/shakespeare_train.sqlite'


### Applying preprocessing

Actually applying the preprocessing is usally done on the `fedjax.FederatedData`
using `preprocess_client()` and `preprocess_batch()` for the client dataset
level and batch level, respectively.

Below, we will walk through an example preprocessing pipeline for Shakespeare
that turns text into sequences of integer labels.

In [None]:
def _build_look_up_table(vocab, num_reserved):
  """Builds a look-up table from a byte to its integer label."""
  oov = num_reserved + len(vocab)
  vocab_size = oov + 1
  table = np.full([256], oov, dtype=np.int32)
  for i, c in enumerate(vocab):
    table[c] = num_reserved + i
  return table, vocab_size


# Vocabulary re-used from the Federated Learning for Text Generation tutorial.
# https://www.tensorflow.org/federated/tutorials/federated_learning_for_text_generation
TABLE, VOCAB_SIZE = _build_look_up_table(
    b'dhlptx@DHLPTX $(,048cgkoswCGKOSW[_#\'/37;?bfjnrvzBFJNRVZ"&*.26:\naeimquyAEIMQUY]!%)-159\r',
    num_reserved=3)
OOV = VOCAB_SIZE - 1
# Reserved labels.
PAD = 0
BOS = 1
EOS = 2

All snippets in a client dataset are first joined into a single sequence (with
BOS/EOS added), and then split into pairs of `sequence_length` chunks for
language model training. For example, with sequence_length=3, `[b'ABCD', b'E']`
becomes

```
Input sequences:  [[BOS, A, B], [C, D, EOS],   [BOS, E, PAD]]
Output seqeunces: [[A, B, C],   [D, EOS, BOS], [E, EOS, PAD]]
```

In [None]:
def preprocess_client(client_id, examples, sequence_length):
  """Turns snippets into sequences of integer labels."""
  del client_id
  snippets = examples['snippets']
  # Join all snippets into a single label sequence.
  joined_length = sum(len(i) + 2 for i in snippets)
  joined = np.zeros([joined_length], dtype=np.int32)
  offset = 0
  for i in snippets:
    joined[offset] = BOS
    joined[offset + 1:offset + 1 + len(i)] = TABLE[list(i)]
    joined[offset + 1 + len(i)] = EOS
    offset += len(i) + 2
  # Split into input/output sequences of size `sequence_length`.
  padded_length = ((joined_length - 1 + sequence_length - 1) //
                   sequence_length * sequence_length)
  input_labels = np.full([padded_length], PAD, dtype=np.int32)
  input_labels[:joined_length - 1] = joined[:-1]
  output_labels = np.full([padded_length], PAD, dtype=np.int32)
  output_labels[:joined_length - 1] = joined[1:]
  return {
      'x': input_labels.reshape([-1, sequence_length]),
      'y': output_labels.reshape([-1, sequence_length])
  }

The output features will be (M below is possibly different from N in
load_split):

-   `x`: [M, sequence_length] int32 input labels, in the range of [0,
    shakespeare.VOCAB_SIZE)
-   `y`: [M, sequence_length] int32 output labels, in the range of [0,
    shakespeare.VOCAB_SIZE)

In [None]:
cid = b'105f96df763d4ddf:ALL_S_WELL_THAT_ENDS_WELL_GUIDERIUS_AND_ARVIRAGUS'
raw_cds = raw_fd.get_client(cid)
print('Raw unprocessed client dataset')
print(raw_cds.all_examples())

preprocess = functools.partial(preprocess_client, sequence_length=10)
cds = raw_fd.preprocess_client(preprocess).get_client(cid)
print('Preprocessed client dataset')
print(cds.all_examples())

Raw unprocessed client dataset
{'snippets': array([b'Re-enter POSTHUMUS, and seconds the Britons; they rescue\nCYMBELINE, and exeunt. Then re-enter LUCIUS and IACHIMO,\n                     with IMOGEN\n'],
      dtype=object)}
Preprocessed client dataset
{'x': array([[ 1, 55, 67, 84, 67, 47,  7, 67, 48, 16],
       [13, 32, 33, 14, 11, 78, 76, 78, 33, 19],
       [16, 66, 47,  3, 16, 27, 67, 23, 26, 47],
       [ 3, 27, 16,  7,  4, 67, 16, 51, 48, 68],
       [ 7, 26, 47, 27, 42, 16,  7,  4, 67, 72],
       [16, 48, 67, 27, 23, 71, 67, 65, 29, 79],
       [76, 51, 74, 12, 75, 54, 74, 19, 16, 66],
       [47,  3, 16, 67,  8, 67, 71, 47,  7, 61],
       [16, 14,  4, 67, 47, 16, 48, 67, 84, 67],
       [47,  7, 67, 48, 16, 12, 78, 29, 75, 78],
       [33, 16, 66, 47,  3, 16, 75, 73, 29, 11],
       [75, 76, 32, 19, 65, 16, 16, 16, 16, 16],
       [16, 16, 16, 16, 16, 16, 16, 16, 16, 16],
       [16, 16, 16, 16, 16, 16, 28, 68,  7,  4],
       [16, 75, 76, 32, 30, 74, 54, 65,  0,  0]], dt

### `BatchPreprocessor`

Preprocessing on a batch of examples can be easily done via a chain of
functions. A `Preprocessor` object holds the chain of functions, and applies
the transformation on a batch of examples.

`fedjax.BatchPreprocessor` holds a chain of preprocessing functions, and applies
them in order on batched examples. Each individual preprocessing function
operates over multiple examples, instead of just 1 example.

In [None]:
preprocessor = fedjax.BatchPreprocessor([
  # Flattens `pixels`.
  lambda x: {**x, 'pixels': x['pixels'].reshape([-1, 28 * 28])},
  # Introduce `binary_label`.
  lambda x: {**x, 'binary_label': x['label'] % 2},
])
fake_emnist = {
  'pixels': np.random.uniform(size=(2, 28, 28)),
  'label': np.random.randint(10, size=(2,))
}
preprocessor(fake_emnist)
# Produces a dict of [2, 28*28] "pixels", [2,] "label" and "binary_label".

{'pixels': array([[0.45102459, 0.76206138, 0.06551379, ..., 0.45446418, 0.89527442,
         0.56221184],
        [0.67933712, 0.16994017, 0.57733109, ..., 0.7888319 , 0.70922089,
         0.80357344]]), 'label': array([6, 8]), 'binary_label': array([0, 0])}

Given a `fedjax.BatchPreprocessor`, a new `fedjax.BatchPreprocessor` can be
created with an additional preprocessing function appended to the chain

In [None]:
# Continuing from the previous example.
new_preprocessor = preprocessor.append(
  lambda x: {**x, 'sum_pixels': np.sum(x['pixels'], axis=1)})
new_preprocessor(fake_emnist)
# Produces a dict of [2, 28*28] "pixels", [2,] "sum_pixels", "label" and
# "binary_label".

{'pixels': array([[0.45102459, 0.76206138, 0.06551379, ..., 0.45446418, 0.89527442,
         0.56221184],
        [0.67933712, 0.16994017, 0.57733109, ..., 0.7888319 , 0.70922089,
         0.80357344]]),
 'label': array([6, 8]),
 'binary_label': array([0, 0]),
 'sum_pixels': array([390.24969764, 384.08759272])}

The main difference of this preprocessor and `fedjax.ClientPreprocessor` is that
`fedjax.ClientPreprocessor` also takes `client_id` as input. Because of the
identical representation between batched examples and all examples in a client
dataset, certain preprocessing can be done with either
`fedjax.BatchPreprocessor` or `fedjax.ClientPreprocessor`.

### `ClientPreprocessor`

A chain of preprocessing functions on all examples of a client dataset.

This is very similar to `fedjax.BatchPreprocessor`, with the main difference
being that `ClientPreprocessor` also takes `client_id` as input.

In [None]:
preprocessor = fedjax.ClientPreprocessor([
  # Adds `client_id_length`.
  lambda cid, x: {**x, 'client_id_length': np.ones_like(x['label']) * len(cid)}
])
fake_emnist = {
  'label': np.random.randint(10, size=(2,))
}
client_id = b'123456'
preprocessor(client_id, fake_emnist)
# Produces a dict of [2,] "label" and "client_id_length".

{'label': array([0, 2]), 'client_id_length': array([6, 6])}