# Importing packages

In [1]:
import fedjax
import jax
import jax.numpy as jnp
import PLM_computation
import FedMix_computation
from grid_search import FedMixGrid, grid_search
from EMNIST_custom import emnist_load_gd_data
import itertools

# Model setup

In [2]:
model = fedjax.models.emnist.create_conv_model(only_digits=False)

In [3]:
def loss(params, batch, rng):
    # `rng` used with `apply_for_train` to apply dropout during training.
    preds = model.apply_for_train(params, batch, rng)
    # Per example loss of shape [batch_size].
    example_loss = model.train_loss(batch, preds)
    return jnp.mean(example_loss)

In [4]:
def loss_for_eval(params, batch):
    preds = model.apply_for_eval(params, batch)
    example_loss = model.train_loss(batch, preds)
    return jnp.mean(example_loss)

In [5]:
grad_fn = jax.jit(jax.grad(loss))

In [6]:
grad_fn_eval = jax.jit(jax.grad(loss_for_eval))

# Grid search setup

## Constants

In [8]:
CACHE_DIR = '../data/'
NUM_CLIENTS_GRID_SEARCH = 200
TRAIN_VALIDATION_SPLIT = 0.8
NUM_CLIENTS_PER_PLM_ROUND = 5
NUM_CLIENTS_PER_FEDMIX_ROUND = 10
FEDMIX_ALGORITHM = 'adam'
FEDMIX_NUM_ROUNDS = 500
PLM_NUM_EPOCHS = 100

## Datasets and parameters

In [9]:
train_fd, validation_fd = emnist_load_gd_data(
    train_val_split=TRAIN_VALIDATION_SPLIT,
    only_digits=False,
    cache_dir=CACHE_DIR
)

Reusing cached file '../data/federated_emnist_train.sqlite'


In [10]:
client_ids = set([cid for cid in itertools.islice(
    train_fd.client_ids(), NUM_CLIENTS_GRID_SEARCH)])

In [11]:
train_fd = fedjax.SubsetFederatedData(train_fd, client_ids)
validation_fd = fedjax.SubsetFederatedData(validation_fd, client_ids)

In [14]:
plm_init_params = model.init(jax.random.PRNGKey(0))

In [15]:
plm_comp_params = PLM_computation.PLMComputationProcessParams(
    plm_init_params, NUM_CLIENTS_PER_PLM_ROUND)

In [16]:
fedmix_init_params = model.init(jax.random.PRNGKey(20))

In [17]:
fedmix_comp_params = FedMix_computation.FedMixComputationParams(
    FEDMIX_ALGORITHM, fedmix_init_params, FEDMIX_NUM_ROUNDS)

In [18]:
alpha = 0.7

## Grid

In [22]:
fedmix_lrs = 10**jnp.arange(-5., 0.5, 1)
fedmix_batch_sizes = [20, 50, 100, 200]
plm_lrs = 10**jnp.arange(-5., 0.5, 1)
plm_batch_sizes = [10, 20, 50, 100]

In [23]:
grid = FedMixGrid(fedmix_lrs, plm_lrs, fedmix_batch_sizes, plm_batch_sizes)

# Grid search

In [28]:
SAVE_FILE = '../results/EMNIST_{}_gd.npy'.format(int(10 * alpha))

In [30]:
SAVE_FILE

'../results/EMNIST_7_gd.npy'

In [None]:
table = grid_search(
    train_fd, validation_fd, grad_fn, grad_fn_eval, model, alpha,
    plm_comp_params, fedmix_comp_params, grid, PLM_NUM_EPOCHS,
    NUM_CLIENTS_PER_FEDMIX_ROUND, SAVE_FILE
)

PLM computation: num_epochs = 100, lr = 9.999999747378752e-06, b_size = 10
Round 5 / 40

In [None]:
1

In [None]:
jnp.unravel_index(jnp.argmax(table), table.shape)