# Importing packages

In [1]:
import sys
sys.path.append('/home/gasanoe/fedmix_on_fedjax/src/')

In [2]:
from absl import app
import pickle
from matplotlib import pyplot as plt

import fedjax
from fedjax.core import tree_util

import jax
import jax.numpy as jnp

import PLM
import itertools
import FedMix
from custom_utils import emnist_load_gd_data

from IPython.core.debugger import set_trace

# How to make FedMix work - ideas

<ul>
    <li> Find the best hyperparameters by running the whole experiment on 3k rounds </li>
    <li> Good starting point </li>
</ul>

# Setting up the experiment

In [3]:
train_fd, test_fd = fedjax.datasets.emnist.load_data(only_digits=False, cache_dir='../data/')

Reusing cached file '../data/federated_emnist_train.sqlite'
Reusing cached file '../data/federated_emnist_test.sqlite'


In [4]:
model = fedjax.models.emnist.create_conv_model(only_digits=False)

In [5]:
train_gd_fd, val_gd_fd = emnist_load_gd_data(only_digits=False, cache_dir='../data/')

Reusing cached file '../data/federated_emnist_train.sqlite'


In [None]:
# fedjax.training.set_tf_cpu_only()

In [None]:
# fedjax.set_for_each_client_backend('pmap')

In [8]:
def loss(params, batch, rng):
    # `rng` used with `apply_for_train` to apply dropout during training.
    preds = model.apply_for_train(params, batch, rng)
    # Per example loss of shape [batch_size].
    example_loss = model.train_loss(batch, preds)
    return jnp.mean(example_loss)

In [9]:
grad_fn = jax.jit(jax.grad(loss))

In [10]:
def loss_for_eval(params, batch):
    preds = model.apply_for_eval(params, batch)
    example_loss = model.train_loss(batch, preds)
    return jnp.mean(example_loss)

In [11]:
grad_fn_eval = jax.jit(jax.grad(loss_for_eval))

In [12]:
client_ids = train_fd.client_ids()

In [13]:
PLM_dict = {}
alphas_dict = {}
alpha = 0.7
for client_id in client_ids:
    alphas_dict[client_id] = jnp.asarray(alpha)

In [14]:
save_plm_file = '../results/PLM_EMNIST.pickle'

In [15]:
with open(save_plm_file, 'rb') as handle:
    PLM_dict = pickle.load(handle)

In [16]:
init_params = model.init(jax.random.PRNGKey(11))
init_params = tree_util.tree_zeros_like(init_params)

In [17]:
def add_new(params, new_params, num):
    """"Recompute the running mean."""
    return tree_util.tree_add(tree_util.tree_weight(params, float(num / (num + 1))), tree_util.tree_weight(new_params, float(1 / (num + 1))))

In [18]:
# we compute the average of the first 3000 clients' plms
num = 0
for client_id, client_plm in PLM_dict.items():
    if num < 3000:
        init_params = add_new(init_params, client_plm, num)
        num += 1
    else:
        break

# GridSearch

In [26]:
server_lrs = 10**jnp.arange(-3, 0.5, 0.5)
fedmix_batch_sizes = [10, 20, 50]

In [27]:
client_lrs = 10**jnp.arange(-3, 0.5, 0.5)
plm_batch_sizes = [10, 20, 50]

In [29]:
num_clients_per_round=10
max_rounds = 2000

In [None]:
GridSearch_table_fedmix = jnp.zeros(shape=(len(lrs), len(batch_sizes)))

In [None]:
client_batch_hparams_eval = fedjax.BatchHParams(batch_size=256)

In [None]:
num_clients_for_validation = 400

In [None]:
train_data = itertools.islice(train_fd.clients(), train_fd.num_clients() - num_clients_for_validation)
validation_data = itertools.islice(train_fd.clients(), train_fd.num_clients() - num_clients_for_validation, train_fd.num_clients())

In [None]:
for lr_id, lr in enumerate(lrs):
    print('Learning rate = {}'.format(lr))
    for b_id, batch_size in enumerate(batch_sizes):
        train_data = itertools.islice(train_fd.clients(), train_fd.num_clients() - num_clients_for_validation)
        validation_data = itertools.islice(train_fd.clients(), train_fd.num_clients() - num_clients_for_validation, train_fd.num_clients())
        print('Batch size = {}'.format(batch_size))
        server_optimizer = fedjax.optimizers.adam(learning_rate=lr, b1=0.9, b2=0.999, eps=10**(-4))
        train_client_sampler = fedjax.client_samplers.UniformShuffledClientSampler(itertools.cycle(train_data), num_clients_per_round)
        validation_client_sampler = fedjax.client_samplers.UniformShuffledClientSampler(itertools.cycle(validation_data), num_clients_for_validation)
        client_batch_hparams = fedjax.ShuffleRepeatBatchHParams(batch_size=batch_size, num_steps=1)
        algorithm = FedMix.fedmix(grad_fn, server_optimizer, client_batch_hparams, PLM_dict, alphas_dict)
        server_state = algorithm.init(init_params)
        for round_num in range(max_rounds):
            print('Round {} / {}'.format(round_num + 1, max_rounds), end='\r')
            clients = train_client_sampler.sample()
            server_state, _ = algorithm.apply(server_state, clients)
        clients = validation_client_sampler.sample()
        client_data_for_evaluation = [(alphas_dict[cid], PLM_dict[cid], cds) for cid, cds, _ in clients]
        grid_search_metrics = FedMix.evaluate_model(model, server_state.params, client_data_for_evaluation, client_batch_hparams_eval)
        print('\n Accuracy on testing clients is {}'.format(grid_search_metrics['accuracy']))
        GridSearch_table_fedmix = GridSearch_table_fedmix.at[lr_id, b_id].set(grid_search_metrics['accuracy'])

In [None]:
import pandas as pd

In [None]:
df = pd.DataFrame(GridSearch_table_fedmix, columns=batch_sizes, index=lrs)

In [None]:
df

# FedMix

In [None]:
# batch_size = 100
# lr = 0.01

In [None]:
batch_size = 100
lr = 10**(-2.5)

In [None]:
lr

In [None]:
num_clients_per_round=10
max_rounds = 3000

In [None]:
client_batch_hparams_eval = fedjax.BatchHParams(batch_size=256)

In [None]:
test_client_sampler = fedjax.client_samplers.UniformGetClientSampler(fd=test_fd, num_clients=test_fd.num_clients(), seed=0)

In [None]:
alphas = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]

In [None]:
test_acc_progress = []

In [None]:
server_optimizer = fedjax.optimizers.adam(learning_rate=lr, b1=0.9, b2=0.999, eps=10**(-4))
client_batch_hparams = fedjax.ShuffleRepeatBatchHParams(batch_size=batch_size, num_steps=1)
for alpha in alphas:
    print('alpha = {}'.format(alpha))
    train_client_sampler = fedjax.client_samplers.UniformGetClientSampler(fd=train_fd, num_clients=num_clients_per_round, seed=0)
    alphas_dict = {}
    for client_id in train_fd.client_ids():
        alphas_dict[client_id] = jnp.asarray(alpha)
    algorithm = FedMix.fedmix(grad_fn, server_optimizer, client_batch_hparams, PLM_dict, alphas_dict)
    init_params = model.init(jax.random.PRNGKey(3))
    server_state = algorithm.init(init_params)
    curr_test_acc_progress = []
    for round_num in range(1, max_rounds + 1):
        print('Round {} / {}'.format(round_num, max_rounds), end='\r')
        clients = train_client_sampler.sample()
        server_state, _ = algorithm.apply(server_state, clients)
        if round_num % 100 == 0:        
            clients = test_client_sampler.sample()
            client_data_for_evaluation = [(alphas_dict[cid], PLM_dict[cid], cds) for cid, cds, _ in clients]
            test_metrics = FedMix.evaluate_model(model, server_state.params, client_data_for_evaluation, client_batch_hparams_eval)
            curr_test_acc_progress.append(test_metrics['accuracy'])
            print('Test accuracy = {}'.format(test_metrics['accuracy']))
    test_acc_progress.append(curr_test_acc_progress)

In [None]:
len(test_acc_progress)

In [None]:
save_file = '../results/test_acc_FedMix.pickle'

In [None]:
with open(save_file, 'wb') as handle:
    pickle.dump(test_acc_progress, handle)

In [None]:
with open(save_file, 'rb') as handle:
    test_acc_progress = pickle.load(handle)

In [None]:
for alpha_id, alpha in enumerate(alphas):
    plt.plot(jnp.arange(100, 5001, 100), test_acc_progress[alpha_id], label='FedMix {}'.format(alpha))
plt.ylabel('accuracy')
plt.xlabel('rounds')
plt.title('EMNIST')
plt.xlim(left=0)
plt.tight_layout()
plt.legend()
# plt.savefig('../results/first_plot.pdf')

In [None]:
plt.plot(jnp.arange(100, 5001, 100), test_acc_progress[7], label='FedMix')
plt.ylabel('accuracy')
plt.xlabel('rounds')
plt.title('EMNIST')
plt.xlim(left=0)
plt.tight_layout()
# plt.legend()

## Separate alpha training

In [None]:
alpha = 0.7

In [None]:
server_optimizer = fedjax.optimizers.adam(learning_rate=lr, b1=0.9, b2=0.999, eps=10**(-4))
client_batch_hparams = fedjax.ShuffleRepeatBatchHParams(batch_size=batch_size, num_steps=1)

train_client_sampler = fedjax.client_samplers.UniformGetClientSampler(fd=train_fd, num_clients=num_clients_per_round, seed=0)
alphas_dict = {}
for client_id in train_fd.client_ids():
    alphas_dict[client_id] = jnp.asarray(alpha)
algorithm = FedMix.fedmix(grad_fn, server_optimizer, client_batch_hparams, PLM_dict, alphas_dict)
# init_params = model.init(jax.random.PRNGKey(3))
server_state = algorithm.init(init_params)
curr_test_acc_progress = []
for round_num in range(1, max_rounds + 1):
    print('Round {} / {}'.format(round_num, max_rounds), end='\r')
    clients = train_client_sampler.sample()
    server_state, _ = algorithm.apply(server_state, clients)
    if round_num % 100 == 0:        
        clients = test_client_sampler.sample()
        client_data_for_evaluation = [(alphas_dict[cid], PLM_dict[cid], cds) for cid, cds, _ in clients]
        test_metrics = FedMix.evaluate_model(model, server_state.params, client_data_for_evaluation, client_batch_hparams_eval)
        curr_test_acc_progress.append(test_metrics['accuracy'])
        print('Test accuracy = {}'.format(test_metrics['accuracy']))

In [None]:
curr_test_acc_progress

# FedAvg

In [None]:
client_optimizer = fedjax.optimizers.sgd(learning_rate=10**(-1.5))
server_optimizer = fedjax.optimizers.adam(
learning_rate=10**(-2.5), b1=0.9, b2=0.999, eps=10**(-4))
# Hyperparameters for client local traing dataset preparation.
client_batch_hparams = fedjax.ShuffleRepeatBatchHParams(batch_size=20)
algorithm = fedjax.algorithms.fed_avg.federated_averaging(grad_fn, client_optimizer,
                                          server_optimizer,
                                          client_batch_hparams)
# Initialize model parameters and algorithm server state.
init_params = model.init(jax.random.PRNGKey(17))
server_state = algorithm.init(init_params)

In [None]:
train_client_sampler = fedjax.client_samplers.UniformGetClientSampler(fd=train_fd, num_clients=10, seed=0)

In [None]:
fedavg_test_acc_progress = []

In [None]:
for round_num in range(1, max_rounds + 1):
    # Sample 10 clients per round without replacement for training.
    clients = train_client_sampler.sample()
    # Run one round of training on sampled clients.
    server_state, client_diagnostics = algorithm.apply(server_state, clients)
    print(f'[round {round_num}]', end='\r')
    # Optionally print client diagnostics if curious about each client's model
    # update's l2 norm.
    # print(f'[round {round_num}] client_diagnostics={client_diagnostics}')

    if round_num % 100 == 0:
        test_eval_datasets = [cds for _, cds in test_fd.clients()]
        test_eval_batches = fedjax.padded_batch_client_datasets(test_eval_datasets, batch_size=256)
        test_metrics = fedjax.evaluate_model(model, server_state.params, test_eval_batches)
        fedavg_test_acc_progress.append(test_metrics['accuracy'])
        print('Test accuracy = {}'.format(test_metrics['accuracy']))

In [None]:
save_file = '../results/test_acc_fedavg.pickle'

In [None]:
with open(save_file, 'wb') as handle:
    pickle.dump(fedavg_test_acc_progress, handle)

In [None]:
with open(save_file, 'rb') as handle:
    fedavg_test_acc_progress = pickle.load(handle)

In [None]:
fedavg_test_acc_progress_up_to_3000 = fedavg_test_acc_progress[:30]

In [None]:
plt.plot(jnp.arange(100, 5001, 100), test_acc_progress[7], label='FedMix')
plt.plot(jnp.arange(100, 5001, 100), fedavg_test_acc_progress, label='FedAvg')
plt.ylabel('accuracy')
plt.xlabel('rounds')
plt.title('EMNIST')
plt.xlim(left=0)
plt.tight_layout()
plt.legend()
plt.savefig('../results/plots/tmp.png')