# Importing packages

In [1]:
import fedjax
import jax
import jax.numpy as jnp
import PLM_computation
import FedMix_computation_general
from grid_search_general import FedMixGrid, grid_search
from Shakespeare_custom import shakespeare_load_gd_data
import itertools

In [2]:
from matplotlib import pyplot as plt

In [3]:
import pickle

# Model setup

In [4]:
model = fedjax.models.shakespeare.create_lstm_model()

In [5]:
def loss(params, batch, rng):
    # `rng` used with `apply_for_train` to apply dropout during training.
    preds = model.apply_for_train(params, batch, rng)
    # Per example loss of shape [batch_size].
    example_loss = model.train_loss(batch, preds)
    return jnp.mean(example_loss)

In [6]:
def loss_for_eval(params, batch):
    preds = model.apply_for_eval(params, batch)
    example_loss = model.train_loss(batch, preds)
    return jnp.mean(example_loss)

In [7]:
grad_fn = jax.jit(jax.grad(loss))

In [8]:
grad_fn_eval = jax.jit(jax.grad(loss_for_eval))

# Grid search setup

## Constants

In [9]:
CACHE_DIR = '../data/'
NUM_CLIENTS_GRID_SEARCH = 715
TRAIN_VALIDATION_SPLIT = 0.8
NUM_CLIENTS_PER_PLM_ROUND = 5
NUM_CLIENTS_PER_FEDMIX_ROUND = 10
FEDMIX_ALGORITHM = 'sgd'
CLIENT_ALGORITHM = 'sgd'
FEDMIX_NUM_ROUNDS = 500
PLM_NUM_EPOCHS = 25

## Datasets and parameters

In [10]:
train_fd, validation_fd = shakespeare_load_gd_data(
    train_val_split=TRAIN_VALIDATION_SPLIT,
    cache_dir=CACHE_DIR
)

Reusing cached file '../data/shakespeare_train.sqlite'


In [11]:
client_ids = set([cid for cid in itertools.islice(
    train_fd.client_ids(), NUM_CLIENTS_GRID_SEARCH)])

In [12]:
train_fd = fedjax.SubsetFederatedData(train_fd, client_ids)
validation_fd = fedjax.SubsetFederatedData(validation_fd, client_ids)

In [13]:
plm_init_params = model.init(jax.random.PRNGKey(0))

In [14]:
plm_comp_params = PLM_computation.PLMComputationProcessParams(
    plm_init_params, NUM_CLIENTS_PER_PLM_ROUND)

In [15]:
fedmix_init_params = model.init(jax.random.PRNGKey(20))

In [16]:
fedmix_comp_params = FedMix_computation_general.FedMixComputationParams(
    FEDMIX_ALGORITHM, CLIENT_ALGORITHM, fedmix_init_params, FEDMIX_NUM_ROUNDS)

In [17]:
alpha = 0.7

## Grid

In [18]:
# fedmix_lrs = 10**jnp.arange(-5., 0.5, 1)
# fedmix_batch_sizes = [20, 50, 100, 200]
# plm_lrs = 10**jnp.arange(-5., 0.5, 1)
# plm_batch_sizes = [10, 20, 50, 100]
# client_lrs = [0.01]

In [19]:
# fedmix_lrs = 10**jnp.arange(-5., 0.5, 1)
# fedmix_batch_sizes = [20, 50, 100, 200]
# plm_lrs = 10**jnp.arange(-5., -1.5, 1)
# plm_batch_sizes = [10, 20, 50, 100]
# client_lrs = 10**jnp.arange(-5., 0.5, 1)

In [20]:
fedmix_lrs = 10 ** jnp.arange(-1, 1.1, 0.5)
fedmix_batch_sizes = [1, 4, 10, 20]
plm_lrs = 10 ** jnp.arange(-1, 1.1, 0.5)
plm_batch_sizes = [1, 4, 10, 20]
client_lrs = 10 ** jnp.arange(-1, 1.1, 0.5)

In [21]:
grid = FedMixGrid(fedmix_lrs,
                  plm_lrs, client_lrs,
                  fedmix_batch_sizes,
                  plm_batch_sizes
                 )

# Grid search

In [22]:
SAVE_FILE = '../results/fedavg_fedmix_Shakespeare_{}_gd.npy'.format(
    int(10 * alpha))

In [23]:
SAVE_FILE

'../results/fedavg_fedmix_Shakespeare_7_gd.npy'

In [None]:
table = grid_search(
    train_fd, validation_fd, grad_fn, grad_fn_eval, model, alpha,
    plm_comp_params, fedmix_comp_params, grid, PLM_NUM_EPOCHS,
    NUM_CLIENTS_PER_FEDMIX_ROUND, SAVE_FILE, grid_metrics='accuracy_in_vocab'
)

PLM computation: num_epochs = 25, lr = 0.10000000149011612, b_size = 1
Round 13 / 143

In [None]:
table

In [None]:
table = jnp.load(SAVE_FILE)

In [None]:
best_ind = jnp.unravel_index(jnp.argmax(table), table.shape)

In [None]:
table[best_ind]

In [None]:
plm_batch_size = plm_batch_sizes[best_ind[0]]
plm_lr = plm_lrs[best_ind[1]]
fedmix_batch_size = fedmix_batch_sizes[best_ind[2]]
fedmix_lr = fedmix_lrs[best_ind[3]]
client_lr = client_lrs[best_ind[4]]

# FedMix

In [None]:
num_rounds = 3000

In [None]:
train_fd, test_fd = fedjax.datasets.shakespeare.load_data(cache_dir='../data/')

In [None]:
plm_comp_hparams = PLM_computation.PLMComputationHParams(PLM_NUM_EPOCHS,
                                                         plm_lr,
                                                         plm_batch_size)

In [None]:
PLM_dict = PLM_computation.plm_computation(train_fd,
                                           grad_fn,
                                           plm_comp_hparams,
                                           plm_comp_params)

In [None]:
save_file = '../results/PLM_EMNIST_{}_{}.pickle'.format(best_ind[0], best_ind[1])

In [None]:
save_file

In [None]:
with open(save_file, 'wb') as file:
    pickle.dump(PLM_dict, file)

In [None]:
with open(save_file, 'rb') as file:
    PLM_dict = pickle.load(file)

In [None]:
alpha

In [None]:
alpha_dict = {}
for cid in train_fd.client_ids():
    alpha_dict[cid] = alpha

In [None]:
len(alpha_dict)

In [None]:
fedmix_hparams = FedMix_computation_general.FedMixHParams(
    fedmix_lr, client_lr, NUM_CLIENTS_PER_FEDMIX_ROUND, fedmix_batch_size)

In [None]:
fedmix_batch_size

In [None]:
fedmix_lr

In [None]:
client_lr

In [None]:
fedmix_comp_params = FedMix_computation_general.FedMixComputationParams(
    FEDMIX_ALGORITHM, CLIENT_ALGORITHM, fedmix_init_params, num_rounds)

In [None]:
_, stats = FedMix_computation_general.fedmix_computation_with_statistics(
    train_fd, test_fd, grad_fn, grad_fn_eval, model, PLM_dict, alpha_dict,
    fedmix_hparams, fedmix_comp_params, 100)

In [None]:
save_file = '../results/EMNIST_FLIX_fedavg_{}.pickle'.format(int(10 * alpha))

In [None]:
save_file

In [None]:
with open(save_file, 'wb') as file:
    pickle.dump(stats, file)

In [None]:
stats[-1]

In [None]:
with open(save_file, 'rb') as file:
    stats = pickle.load(file)

# FedAvg

In [None]:
client_optimizer = fedjax.optimizers.sgd(learning_rate=1)
server_optimizer = fedjax.optimizers.sgd(learning_rate=1)
# Hyperparameters for client local traing dataset preparation.
client_batch_hparams = fedjax.ShuffleRepeatBatchHParams(batch_size=4)
algorithm = fedjax.algorithms.fed_avg.federated_averaging(grad_fn, 
                                                          client_optimizer,
                                                          server_optimizer,
                                                          client_batch_hparams)
# Initialize model parameters and algorithm server state.
init_params = model.init(jax.random.PRNGKey(17))
server_state = algorithm.init(init_params)

In [None]:
train_client_sampler = fedjax.client_samplers.UniformGetClientSampler(fd=train_fd, num_clients=10, seed=0)

In [None]:
fedavg_test_acc_progress = []

In [None]:
max_rounds = 1200

In [None]:
fedjax.set_for_each_client_backend('pmap')

In [None]:
for round_num in range(1, max_rounds + 1):
    # Sample 10 clients per round without replacement for training.
    clients = train_client_sampler.sample()
    # Run one round of training on sampled clients.
    server_state, client_diagnostics = algorithm.apply(server_state, clients)
    print(f'[round {round_num}]', end='\r')
    # Optionally print client diagnostics if curious about each client's model
    # update's l2 norm.
    # print(f'[round {round_num}] client_diagnostics={client_diagnostics}')

    if round_num % 100 == 0:
        test_eval_datasets = [cds for _, cds in test_fd.clients()]
        test_eval_batches = fedjax.padded_batch_client_datasets(test_eval_datasets, batch_size=256)
        test_metrics = fedjax.evaluate_model(model, server_state.params, test_eval_batches)
        fedavg_test_acc_progress.append(test_metrics['accuracy_in_vocab'])
        print('Test accuracy = {}'.format(test_metrics['accuracy_in_vocab']))

In [None]:
1 + 1

In [None]:
save_file = '../results/test_acc_fedavg_shakespeare.pickle'

In [None]:
with open(save_file, 'wb') as handle:
    pickle.dump(fedavg_test_acc_progress, handle)

In [None]:
with open(save_file, 'rb') as handle:
    fedavg_test_acc_progress = pickle.load(handle)

In [None]:
fedavg_test_acc_progress[-1]

# Plots

In [None]:
# accs = [stat['accuracy'] for stat in stats]

In [None]:
round_nums = jnp.linspace(100, max_rounds, num=12, endpoint=True)
# plt.plot(round_nums, accs, label='FedMix, alpha={}'.format(alpha))
plt.plot(round_nums, fedavg_test_acc_progress, label='FedAvg')
plt.xlim(left=0)
plt.ylabel('accuracy')
plt.xlabel('rounds')
plt.grid()
plt.title('EMNIST')
plt.legend()
plt.tight_layout()
# plt.savefig('../results/plots/EMNIST_{}.pdf'.format(int(10 * alpha)))