# Using the Library

In this document, we will look at using the library for a few standard federated learning environments.

In [None]:
%pip install -U git+https://github.com/codymlewis/ymir.git git+https://github.com/codymlewis/tenjin.git tqdm

import tensorflow as tf
import tenjin
from tqdm.notebook import trange

import ymir

Lets first look standard federated learning. We will write a function to create a keras model as normal.

In [None]:
def create_model(input_shape, output_shape, lr=0.1):
    inputs = tf.keras.layers.Input(shape=input_shape)
    x = tf.keras.layers.Flatten()(inputs)
    x = tf.keras.layers.Dense(300, activation="relu")(x)
    x = tf.keras.layers.Dense(100, activation="relu")(x)
    outputs = tf.keras.layers.Dense(output_shape, activation="softmax")(x)
    model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
    opt = tf.keras.optimizers.SGD(learning_rate=lr)
    loss_fn = tf.keras.losses.SparseCategoricalCrossentropy()
    model.compile(loss=loss_fn, optimizer=opt, metrics=['accuracy'])
    return model

Next, we will load the MNIST dataset, define the per-client batch sizes and perform a latent Dirichlet allocation (LDA) on the dataset.

Finally, we will create separate validation and test datasets to evaluate the global model.

In [None]:
num_clients = 10
dataset = ymir.mp.datasets.Dataset(*tenjin.load('mnist'))
batch_sizes = [32 for _ in range(num_clients)]
data = dataset.fed_split(batch_sizes, ymir.mp.distributions.lda)
train_eval = dataset.get_iter("train", 10_000)
test_eval = dataset.get_iter("test", 10_000)

Next, we create the network and clients, adding each client to the network.

In [None]:
network = ymir.mp.network.Network()
for d in data:
    network.add_client(ymir.regiment.Scout(create_model(dataset.input_shape, dataset.classes), d, 1, test_data=test_eval))

Finally, we create the federated learning global model and controller.

In [None]:
learner = ymir.garrison.fedavg.Captain(create_model(dataset.input_shape, dataset.classes, lr=1), network)

We perform federated learning by repeatedly calling the `step` method on the controller. There will likely be retracing warnings,
these arise due to calling training steps on each client independently, cause a tracing step for each one, this does not impact
performance.

In the following we also, periodically evaluate the global model on the test dataset.

In [None]:
for r in (pbar := trange(500)):
    loss = learner.step()
    if r % 10 == 0:
        metrics = learner.model.test_on_batch(*next(test_eval), return_dict=True)
        pbar.set_postfix(metrics)

# Alternative Learning Methods

In this library, we include a number of alternative methods for federated learning. In the following, we will cover the most notable.

## Different Aggregators

Using a different aggregator is as simple as using a different Captain object either from the `garrison` module or by a class that
inherits from `Captain`.

In [None]:
learner = ymir.garrison.median.Captain(create_model(dataset.input_shape, dataset.classes, lr=1), network)

Then we can do the learning loop as normal.

In [None]:
for r in (pbar := trange(500)):
    loss = learner.step()
    if r % 10 == 0:
        metrics = learner.model.test_on_batch(*next(test_eval), return_dict=True)
        pbar.set_postfix(metrics)

## Regularized/Proximal Learning

Learning using clients with proximal or regularized terms amounts to adding a different client type to the network, these clients
have a different local step function that adds the term as a penalty to the loss.

In [None]:
network = ymir.mp.network.Network()
for d in data:
    network.add_client(ymir.regiment.fedmax.Scout(create_model(dataset.input_shape, dataset.classes), d, 1, test_data=test_eval))
learner = ymir.garrison.fedavg.Captain(create_model(dataset.input_shape, dataset.classes, lr=1), network)
for r in (pbar := trange(500)):
    loss = learner.step()
    if r % 10 == 0:
        metrics = learner.model.test_on_batch(*next(test_eval), return_dict=True)
        pbar.set_postfix(metrics)

## Personalized Learning

Personalized learning methods require the construction of a different client within the network, one that does not overwrite
the local model weights with the global model weights.

In the following example we will construct a network of ditto personalized learners and apply federated averaging for aggregation.

In [None]:
network = ymir.mp.network.Network()
for d in data:
    network.add_client(ymir.regiment.ditto.Scout(create_model(dataset.input_shape, dataset.classes), d, 1, test_data=test_eval))
learner = ymir.garrison.fedavg.Captain(create_model(dataset.input_shape, dataset.classes, lr=1), network)
for r in (pbar := trange(500)):
    loss = learner.step()
    if r % 10 == 0:
        metrics = learner.model.test_on_batch(*next(test_eval), return_dict=True)
        pbar.set_postfix(metrics)