# FedJAX Advanced Usage

[Open In Colab](https://colab.research.google.com/github/google/fedjax/blob/main/notebooks/fedjax_advanced.ipynb)

This notebook introduces more advanced usages of FedJAX and walks through:
* Definining a custom model
* Writing a custom federated algorithm

This notebook is meant to go deeper into the concepts introduced in [FedJAX Intro](./fedjax_intro.ipynb).


In [None]:
!pip install --upgrade -q fedjax==0.0.1

In [None]:
import collections
import functools

import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np

import fedjax

## Defining a custom model

In this section, we will cover how to use define custom models in a format suitable for use in FedJAX.

Below, we use `haiku` as the neural net library of choice. For `haiku`, the following pointers should help:
* https://dm-haiku.readthedocs.io/en/latest/api.html#haiku.transform
* https://dm-haiku.readthedocs.io/en/latest/api.html#haiku.Module 

In [None]:
def forward_pass(batch):
  """Runs forward pass to produce unnormalized logits."""
  network = hk.Sequential([
      hk.Flatten(),
      hk.Linear(100),
      jax.nn.relu,
      hk.Linear(100),
      jax.nn.relu,
      hk.Linear(62),
  ])
  return network(batch['x'])


def cross_entropy_loss(batch, preds):
  targets = batch['y']
  num_classes = preds.shape[-1]
  log_preds = jax.nn.log_softmax(preds)
  one_hot_targets = jax.nn.one_hot(targets, num_classes)
  return -jnp.mean(jnp.sum(one_hot_targets * log_preds, axis=-1))


def accuracy(batch, preds):
  targets = batch['y']
  pred_class = jnp.argmax(preds, axis=-1)
  return jnp.mean(pred_class == targets)


# Transform forward_pass function which uses hk.Module into pure functions.
transformed_forward_pass = hk.transform(forward_pass)
# Sample batch used to initialize model parameter shapes.
sample_batch = collections.OrderedDict(
    x=np.ones((1, 28, 28)), y=np.ones((1, 1)))
model = fedjax.create_model_from_haiku(
    transformed_forward_pass=transformed_forward_pass,
    sample_batch=sample_batch,
    loss_fn=cross_entropy_loss,
    metrics_fn_map=collections.OrderedDict(accuracy=accuracy))

rng = next(fedjax.PRNGSequence(0))
params = model.init_params(rng)
backward_pass_output = model.backward_pass(params, sample_batch, rng)
metrics = model.evaluate(params, sample_batch)

print('# parameters =', hk.data_structures.tree_size(params))
print('# grads =', hk.data_structures.tree_size(backward_pass_output.grads))
print('backward_pass_output.weight =', backward_pass_output.weight)
print('metrics[loss] =', metrics['loss'])
print('metrics[weight] =', metrics['weight'])
print('metrics[accuracy] =', metrics['accuracy'])



# parameters = 94862
# grads = 94862
backward_pass_output.weight = 1.0
metrics[loss] = 4.011884
metrics[weight] = 1.0
metrics[accuracy] = 0.0


## Writing a custom federated algorithm

In this section, we'll go over how to implement your own custom federated algorithm in FedJAX.

In order to do this, we will be implementing the Federated Averaging algorithm from scratch.

As a refresher, recall that federated algorithms typically consist of:
* Client training: How to train across clients on their local data (analogous to the "map" in "mapreduce").
* Server aggregation: How to aggregate multiple client outputs into a single server output (analogous to the "reduce" in "mapreduce").

### Client training

In FedJAX, we introduce the `fedjax.ClientTrainer` interface that defines how to conduct training for a **single** client.

Below, `SimpleClientTrainer` is an example simple implementation of `fedjax.ClientTrainer`. As you can see, `SimpleClientTrainerState` keeps track of model parameters, optimizer state, and weight at each step, where weight is typically number of examples seen during training. `one_step` simply trains the model parameters on the input batch and updates optimizer state according to the input `client_optimizer`.

**NOTE**: `SimpleClientTrainerState` is different from the server state mentioned in federated algorithms. You can think of `SimpleClientTrainerState` as a sort of client state.


In [None]:
SimpleClientTrainerState = collections.namedtuple(
    'SimpleClientTrainerState', ['params', 'opt_state', 'weight'])


class SimpleClientTrainer(fedjax.ClientTrainer):
  """Simple client trainer."""

  def __init__(self, model, client_optimizer):
    super().__init__()
    self._model = model
    self._client_optimizer = client_optimizer

  def init_state(self, params, weight=0.):
    opt_state = self._client_optimizer.init_fn(params)
    return SimpleClientTrainerState(params, opt_state, weight)

  # https://jax.readthedocs.io/en/latest/notebooks/quickstart.html#Using-jit-to-speed-up-functions
  # https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html 
  @functools.partial(jax.jit, static_argnums=0)
  def one_step(self, client_trainer_state, batch, rng):
    backward_pass_output = self._model.backward_pass(
        client_trainer_state.params, batch, rng)
    params_updates, opt_state = self._client_optimizer.update_fn(
        backward_pass_output.grads, client_trainer_state.opt_state)
    params = self._client_optimizer.apply_updates(client_trainer_state.params,
                                                  params_updates)
    weight = client_trainer_state.weight + backward_pass_output.weight
    return SimpleClientTrainerState(params, opt_state, weight)

As stated, `fedjax.ClientTrainer` defines how to conduct training for a **single** client. However, we can easily map our `fedjax.ClientTrainer` across multiple clients using `fedjax.train_multiple_clients`.

In [None]:
federated_train, federated_test = fedjax.datasets.emnist.load_data(
    only_digits=False)
model = fedjax.models.emnist.create_dense_model(
    only_digits=False, hidden_units=100)
init_params = model.init_params(rng)
optimizer = fedjax.get_optimizer(fedjax.OptimizerName.SGD, learning_rate=0.1)

client_trainer = SimpleClientTrainer(model, optimizer)
init_client_trainer_state = client_trainer.init_state(init_params)

# client_outputs contains updated parameters and weight per client.
# In this case, client_outputs is SimpleClientTrainerState, but this is
# not always guaranteed to be true. It depends on the implementation of
# ClientTrainer.
client_outputs = fedjax.train_multiple_clients(
    federated_train,
    federated_train.client_ids[:3],
    client_trainer,
    init_client_trainer_state,
    fedjax.PRNGSequence(0),
    fedjax.ClientDataHParams(batch_size=10))

client_outputs

Downloading data from https://storage.googleapis.com/tff-datasets-public/fed_emnist.tar.bz2




<generator object train_multiple_clients at 0x7fbb5316bc50>

Above, `client_outputs` is a Python generator purposefully. This is to avoid issues of trying to fit `len(client_ids)` copies of the model parameters and optimizer state in memory. For larger models and experiments with larger numbers of clients per federated training rounds, this can be very problematic. Generators give us a nice built-in solution to this.

If you are unfamiliar with Python generators, please see https://docs.python.org/3/reference/expressions.html#yieldexpr. To apply any post processing on these client outputs, we recommend using Python's built-in [`map`](https://docs.python.org/3/library/functions.html#map).

In [None]:
client_weights = map(lambda co: co.weight, client_outputs)
# We call list() to consume the generator.
print(list(client_weights))
# client_outputs is now an empty generator.
print(list(client_weights))

[DeviceArray(344., dtype=float32), DeviceArray(372., dtype=float32), DeviceArray(316., dtype=float32)]
[]


### Server aggregation

This section describes how to aggregate multiple client outputs into a single output on server. A helpful analogy is to think of client training as the "map" in "mapreduce" and server aggregation as the "reduce".

FedJAX provides a few useful utilties for common aggregation strategies. For example, `fedjax.tree_mean` takes an iterator of pytrees and associated weights and returns a weighted average of the pytrees with the same structure.

We also use a few JAX utilities for working with [pytrees](https://jax.readthedocs.io/en/latest/pytrees.html) like [`jax.tree_util.tree_multimap`](https://jax.readthedocs.io/en/latest/jax.tree_util.html#jax.tree_util.tree_multimap).

In [None]:
client_outputs = fedjax.train_multiple_clients(
    federated_train, federated_train.client_ids[:3], client_trainer,
    init_client_trainer_state, fedjax.PRNGSequence(0),
    fedjax.ClientDataHParams(batch_size=10))


# Weighted average of param delta across clients.
def get_delta_params_and_weight(client_output):
  delta_params = fedjax.tree_multimap(lambda a, b: a - b, init_params,
                                      client_output.params)
  return delta_params, client_output.weight


delta_params_and_weight = map(get_delta_params_and_weight, client_outputs)
delta_params = fedjax.tree_mean(delta_params_and_weight)
delta_params.keys()

KeysOnlyKeysView(['linear', 'linear_1', 'linear_2'])

### Putting it all together

Now that we've defined how to coordinate training across clients and how to aggregate multiple client outputs, we can put all of the pieces together in a single `fedjax.FederatedAlgorithm`.


In [None]:
FedAvgState = collections.namedtuple('FedAvgState',
                                     ['params', 'server_opt_state'])


class FedAvg(fedjax.FederatedAlgorithm):
  """Simple federated averaging algorithm."""

  def __init__(self, federated_data, model, client_optimizer, server_optimizer,
               data_hparams, rng_seq):
    self._federated_data = federated_data
    self._model = model
    self._client_optimizer = client_optimizer
    self._server_optimizer = server_optimizer
    self._client_data_hparams = client_data_hparams
    self._rng_seq = rng_seq
    self._client_trainer = SimpleClientTrainer(model, client_optimizer)

  @property
  def federated_data(self):
    return self._federated_data

  @property
  def model(self):
    return self._model

  def init_state(self):
    params = self._model.init_params(next(self._rng_seq))
    server_opt_state = self._server_optimizer.init_fn(params)
    return FedAvgState(params, server_opt_state)

  def run_round(self, state, client_ids):
    """Runs one round of federated averaging."""
    # Train model per client.
    client_outputs = fedjax.train_multiple_clients(
        self.federated_data, client_ids, self._client_trainer,
        self._client_trainer.init_state(state.params), self._rng_seq,
        self._client_data_hparams)

    # Weighted average of param delta across clients.
    def get_delta_params_and_weight(client_output):
      delta_params = fedjax.tree_multimap(lambda a, b: a - b, state.params,
                                          client_output.params)
      return delta_params, client_output.weight

    delta_params_and_weight = map(get_delta_params_and_weight, client_outputs)
    delta_params = fedjax.tree_mean(delta_params_and_weight)

    # Server state update.
    updates, server_opt_state = self._server_optimizer.update_fn(
        delta_params, state.server_opt_state)
    params = self._server_optimizer.apply_updates(state.params, updates)
    return FedAvgState(params, server_opt_state)

To run our simulation, we can use `fedjax.training.run_federated_experiment`. However, for the sake of instruction, here, we'll write the experiment logic from scratch.

Running multiple federated training rounds is as simple as calling `run_one_round` inside of a for loop. However, for evaluation, we can make use of some FedJAX utilities, such as:
* `fedjax.evaluate_multiple_clients`: Produces generator of evaluation metrics per client
* `fedjax.aggregate_metrics`: Aggregates generator of evaluation metrics into a single summary

In [None]:
# Set up fedjax.FederatedAlgorithm.
client_optimizer = fedjax.get_optimizer(
    fedjax.OptimizerName.SGD, learning_rate=0.1)
server_optimizer = fedjax.get_optimizer(
    fedjax.OptimizerName.MOMENTUM, learning_rate=1.0, momentum=0.9)
client_data_hparams = fedjax.ClientDataHParams(batch_size=10)
rng_seq = fedjax.PRNGSequence(0)
federated_averaging = FedAvg(federated_train, model, client_optimizer,
                             server_optimizer, client_data_hparams, rng_seq)

state = federated_averaging.init_state()

for i in range(10):
  client_ids = federated_train.client_ids[:3]
  state = federated_averaging.run_round(state, client_ids)
  # Do any post processing or evaluation you'd like on output state.
  test_metrics = fedjax.aggregate_metrics(
      fedjax.evaluate_multiple_clients(federated_test, client_ids, model,
                                       state.params, client_data_hparams))
  loss = test_metrics['loss']
  accuracy = test_metrics['accuracy']
  print(f'round {i}: loss = {loss} accuracy = {accuracy}')



round 0: loss = 3.9508423805236816 accuracy = 0.05982906371355057
round 1: loss = 4.171614170074463 accuracy = 0.05982906371355057
round 2: loss = 5.593493938446045 accuracy = 0.08547008782625198
round 3: loss = 3.534285068511963 accuracy = 0.1367521435022354
round 4: loss = 3.030869483947754 accuracy = 0.29914531111717224
round 5: loss = 2.7541251182556152 accuracy = 0.3162393271923065
round 6: loss = 2.2167229652404785 accuracy = 0.4273504614830017
round 7: loss = 1.9401822090148926 accuracy = 0.5213675498962402
round 8: loss = 1.8677700757980347 accuracy = 0.5726495981216431
round 9: loss = 1.925694465637207 accuracy = 0.5128205418586731
