In [1]:
from absl import app

import fedjax
from fedjax.core import tree_util

import jax
import jax.numpy as jnp

import PLM
import itertools
import FedMix

In [2]:
train_fd, test_fd = fedjax.datasets.shakespeare.load_data()

Reusing cached file '/home/gasanoe/.cache/fedjax/shakespeare_train.sqlite'
Reusing cached file '/home/gasanoe/.cache/fedjax/shakespeare_test.sqlite'


In [3]:
model = fedjax.models.shakespeare.create_lstm_model()

In [4]:
def loss(params, batch, rng):
    # `rng` used with `apply_for_train` to apply dropout during training.
    preds = model.apply_for_train(params, batch, rng)
    # Per example loss of shape [batch_size].
    example_loss = model.train_loss(batch, preds)
    return jnp.mean(example_loss)

In [5]:
grad_fn = jax.jit(jax.grad(loss))

In [6]:
def loss_for_eval(params, batch):
    preds = model.apply_for_eval(params, batch)
    example_loss = model.train_loss(batch, preds)
    return jnp.mean(example_loss)

In [7]:
grad_fn_eval = jax.jit(jax.grad(loss_for_eval))

In [8]:
init_params = model.init(jax.random.PRNGKey(11))

In [9]:
client_ids = train_fd.client_ids()

In [10]:
PLM_dict = {}
alphas_dict = {}
alpha = 0.5
for client_id in client_ids:
    PLM_dict[client_id] = tree_util.tree_zeros_like(init_params)
    alphas_dict[client_id] = alpha

In [11]:
server_optimizer = fedjax.optimizers.adam(learning_rate=10**(-2.5), b1=0.9, b2=0.999, eps=10**(-4))

In [12]:
client_batch_hparams = fedjax.ShuffleRepeatBatchHParams(batch_size=50, num_steps=1)

In [13]:
algorithm = FedMix.fedmix(grad_fn, server_optimizer, client_batch_hparams, PLM_dict, alphas_dict)

In [14]:
server_state = algorithm.init(init_params)

In [15]:
train_client_sampler = fedjax.client_samplers.UniformGetClientSampler(fd=train_fd, num_clients=10, seed=0)
for round_num in range(1, 3):
# Sample 10 clients per round without replacement for training.
    clients = train_client_sampler.sample()
    print('debug')
    # Run one round of training on sampled clients.
    server_state, client_diagnostics = algorithm.apply(server_state, clients);
    # print(f'[round {round_num}]')
    # Optionally print client diagnostics if curious about each client's model
    # update's l2 norm.
    print(f'[round {round_num}] client_diagnostics={client_diagnostics}')

debug
[round 1] client_diagnostics={b'fe62b57330b09ae2:THE_FIRST_PART_OF_KING_HENRY_THE_FOURTH_BOLINGBROKE': {'delta_l2_norm': DeviceArray(0.12142139, dtype=float32)}, b'be654813264020f9:THE_TAMING_OF_THE_SHREW_MARCUS': {'delta_l2_norm': DeviceArray(0.08075441, dtype=float32)}, b'e16bfa29226db143:ALL_S_WELL_THAT_ENDS_WELL_MARCIUS': {'delta_l2_norm': DeviceArray(0.12572527, dtype=float32)}, b'8874bb1a3270b6ae:THE_FIRST_PART_OF_KING_HENRY_THE_FOURTH_COMMONS': {'delta_l2_norm': DeviceArray(0.10309148, dtype=float32)}, b'93bae0966643b967:ALL_S_WELL_THAT_ENDS_WELL_GUARDSMAN': {'delta_l2_norm': DeviceArray(0.04885468, dtype=float32)}, b'a8636ae06a2e4d5f:ALL_S_WELL_THAT_ENDS_WELL_OCTAVIA': {'delta_l2_norm': DeviceArray(0.10725318, dtype=float32)}, b'9c4a003c160fbb25:ALL_S_WELL_THAT_ENDS_WELL_AGRIPPA': {'delta_l2_norm': DeviceArray(0.12040406, dtype=float32)}, b'adeabea0708893ba:THE_FIRST_PART_OF_KING_HENRY_THE_FOURTH_HERALD': {'delta_l2_norm': DeviceArray(0.07770044, dtype=float32)}, b'b6f359