In [2]:
from absl import app

import fedjax
from fedjax.core import tree_util

import jax
import jax.numpy as jnp

import PLM
import itertools
import FedMix

In [3]:
train_fd, test_fd = fedjax.datasets.emnist.load_data(only_digits=False)

Reusing cached file '/home/gasanoe/.cache/fedjax/federated_emnist_train.sqlite'
Reusing cached file '/home/gasanoe/.cache/fedjax/federated_emnist_test.sqlite'


In [4]:
model = fedjax.models.emnist.create_conv_model(only_digits=False)

In [5]:
def loss(params, batch, rng):
    # `rng` used with `apply_for_train` to apply dropout during training.
    preds = model.apply_for_train(params, batch, rng)
    # Per example loss of shape [batch_size].
    example_loss = model.train_loss(batch, preds)
    return jnp.mean(example_loss)

In [6]:
grad_fn = jax.jit(jax.grad(loss))

In [7]:
def loss_for_eval(params, batch):
    preds = model.apply_for_eval(params, batch)
    example_loss = model.train_loss(batch, preds)
    return jnp.mean(example_loss)    

In [8]:
grad_fn_eval = jax.jit(jax.grad(loss_for_eval))

In [None]:
client_optimizer = fedjax.optimizers.sgd(learning_rate=10**(-1.5))
# Hyperparameters for client local traing dataset preparation.
client_batch_hparams = fedjax.ShuffleRepeatBatchHParams(batch_size=20, num_epochs=1000, drop_remainder=True)

algorithm = PLM.computing_plm(grad_fn, client_optimizer, client_batch_hparams)

In [None]:
fedjax.ShuffleRepeatBatchHParams?

In [None]:
# Initialize model parameters and algorithm server state.
init_params = model.init(jax.random.PRNGKey(17))

In [None]:
server_state = algorithm.init(init_params)

In [None]:
rng = jax.random.PRNGKey(10)
# num_clients_per_round = 10
num_client = 0
for cid, cds in itertools.islice(train_fd.clients(), 5):
    rng, use_rng = jax.random.split(rng)
    server_state, _ = algorithm.apply(server_state, [(cid, cds, use_rng)])
    num_client += 1
    print('Client {} out of {} is processed.'.format(num_client, train_fd.num_clients()), end='\r')

In [None]:
client_id = b'005fdad281234bc0:f0151_02'
grad_ = grad_fn_eval(server_state.PLM[client_id], train_fd.get_client(client_id).all_examples())

In [None]:
tree_util.tree_l2_norm(grad_)