In [1]:
from absl import app

import fedjax

import jax
import jax.numpy as jnp

In [2]:
train_fd, test_fd = fedjax.datasets.emnist.load_data(only_digits=False)

Reusing cached file '/home/gasanoe/.cache/fedjax/federated_emnist_train.sqlite'
Reusing cached file '/home/gasanoe/.cache/fedjax/federated_emnist_test.sqlite'


In [3]:
model = fedjax.models.emnist.create_conv_model(only_digits=False)

In [4]:
def loss(params, batch, rng):
    # `rng` used with `apply_for_train` to apply dropout during training.
    preds = model.apply_for_train(params, batch, rng)
    # Per example loss of shape [batch_size].
    example_loss = model.train_loss(batch, preds)
    return jnp.mean(example_loss)

In [5]:
grad_fn = jax.jit(jax.grad(loss))

In [6]:
client_optimizer = fedjax.optimizers.sgd(learning_rate=10**(-1.5))
server_optimizer = fedjax.optimizers.adam(learning_rate=10**(-2.5), b1=0.9, b2=0.999, eps=10**(-4))
# Hyperparameters for client local traing dataset preparation.
client_batch_hparams = fedjax.ShuffleRepeatBatchHParams(batch_size=20)
algorithm = fedjax.algorithms.fed_avg.federated_averaging(grad_fn, client_optimizer,
                                          server_optimizer,
                                          client_batch_hparams)

In [7]:
# Initialize model parameters and algorithm server state.
init_params = model.init(jax.random.PRNGKey(17))

In [8]:
server_state = algorithm.init(init_params)

# Train and eval loop.
train_client_sampler = fedjax.client_samplers.UniformGetClientSampler(fd=train_fd, num_clients=10, seed=0)

In [15]:
for round_num in range(1, 1501):
    # Sample 10 clients per round without replacement for training.
    clients = train_client_sampler.sample()
    # Run one round of training on sampled clients.
    server_state, client_diagnostics = algorithm.apply(server_state, clients)
    print(f'[round {round_num}]', end='\r')
    # Optionally print client diagnostics if curious about each client's model
    # update's l2 norm.
    # print(f'[round {round_num}] client_diagnostics={client_diagnostics}')

    if round_num % 10 == 0:
      # Periodically evaluate the trained server model parameters.
      # Read and combine clients' train and test datasets for evaluation.
        client_ids = [cid for cid, _, _ in clients]
        train_eval_datasets = [cds for _, cds in train_fd.get_clients(client_ids)]
        test_eval_datasets = [cds for _, cds in test_fd.get_clients(client_ids)]
        train_eval_batches = fedjax.padded_batch_client_datasets(
          train_eval_datasets, batch_size=256)
        test_eval_batches = fedjax.padded_batch_client_datasets(
          test_eval_datasets, batch_size=256)

      # Run evaluation metrics defined in `model.eval_metrics`.
        train_metrics = fedjax.evaluate_model(model, server_state.params,
                                            train_eval_batches)
        test_metrics = fedjax.evaluate_model(model, server_state.params, test_eval_batches)
        print('')
        print(f'[round {round_num}] train_metrics={train_metrics}')
        print(f'[round {round_num}] test_metrics={test_metrics}')

# Save final trained model parameters to file.
fedjax.serialization.save_state(server_state.params, '/tmp/params')

[round 10]
[round 10] train_metrics={'accuracy': DeviceArray(0.8026859, dtype=float32), 'loss': DeviceArray(0.6623511, dtype=float32)}
[round 10] test_metrics={'accuracy': DeviceArray(0.7955556, dtype=float32), 'loss': DeviceArray(0.6655238, dtype=float32)}
[round 20]
[round 20] train_metrics={'accuracy': DeviceArray(0.8492724, dtype=float32), 'loss': DeviceArray(0.52056676, dtype=float32)}
[round 20] test_metrics={'accuracy': DeviceArray(0.8475337, dtype=float32), 'loss': DeviceArray(0.46154043, dtype=float32)}
[round 30]
[round 30] train_metrics={'accuracy': DeviceArray(0.78894204, dtype=float32), 'loss': DeviceArray(0.6405203, dtype=float32)}
[round 30] test_metrics={'accuracy': DeviceArray(0.7579909, dtype=float32), 'loss': DeviceArray(0.61889136, dtype=float32)}
[round 40]
[round 40] train_metrics={'accuracy': DeviceArray(0.8332388, dtype=float32), 'loss': DeviceArray(0.5279828, dtype=float32)}
[round 40] test_metrics={'accuracy': DeviceArray(0.8, dtype=float32), 'loss': DeviceArr