# Using the Library

In this document we will look at the basics of using the aesir library.

First we will install and import the libraries in use.

In [1]:
%pip install datasets einops flax numpy optax tqdm git+https://github.com/codymlewis/ymir.git

import datasets
import einops
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
import optax
from tqdm.notebook import trange

import ymir

2022-06-28 17:12:15.187578: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2022-06-28 17:12:15.187595: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.


Next we define our neural network model, in the following we use the [flax library](https://github.com/google/flax) but many other JAX libraries should be compatible.

In [2]:
class LeNet(nn.Module):

    @nn.compact
    def __call__(self, x):
        return nn.Sequential(
            [
                lambda x: einops.rearrange(x, "b w h c -> b (w h c)"),
                nn.Dense(300), nn.relu,
                nn.Dense(100), nn.relu,
                nn.Dense(10), nn.softmax
            ]
        )(x)

Now we define out evaluation functions, first cross entropy loss, then accuracy.

In [3]:
def ce_loss(model):

    @jax.jit
    def _loss(params, X, y):
        logits = jnp.clip(model.apply(params, X), 1e-15, 1 - 1e-15)
        one_hot = jax.nn.one_hot(y, logits.shape[-1])
        return -jnp.mean(jnp.einsum("bl,bl -> b", one_hot, jnp.log(logits)))

    return _loss


def accuracy(model, params, X, y):
    return jnp.mean(jnp.argmax(model.apply(params, X), axis=-1) == y)

This next we preprocess our data. This library uses [huggingface datasets](https://huggingface.co/docs/datasets/) with the features in the `X` column and the labels in the `Y` column.

In [5]:
ds = datasets.load_dataset('mnist')

def preprocess_data(examples):
    result = {}
    result['X'] = einops.rearrange(np.array(examples['image'], dtype=np.float32) / 255, "h (w c) -> h w c", c=1)
    result['Y'] = examples['label']
    return result

ds = ds.map(preprocess_data, remove_columns=['image', 'label'])
features = ds['train'].features
features['X'] = datasets.Array3D(shape=(28, 28, 1), dtype='float32')
ds['train'] = ds['train'].cast(features)
ds['test'] = ds['test'].cast(features)
ds.set_format('numpy')

Reusing dataset mnist (/home/cody/.cache/huggingface/datasets/mnist/mnist/1.0.0/fda16c03c4ecfb13f165ba7e29cf38129ce035011519968cdaf74894ce91c9d4)


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/60000 [00:00<?, ?ex/s]

  0%|          | 0/10000 [00:00<?, ?ex/s]

Casting the dataset:   0%|          | 0/6 [00:00<?, ?ba/s]

Casting the dataset:   0%|          | 0/1 [00:00<?, ?ba/s]

Next we set up our data into a federated learning dataset and distribute across our clients according to latent Dirichlet Allocation

In [6]:
num_clients = 10
dataset = aesir.utils.datasets.Dataset(ds)
batch_sizes = [32 for _ in range(num_clients)]
data = dataset.fed_split(batch_sizes, aesir.utils.distributions.lda)
test_eval = dataset.get_iter("test", 10_000)

Then we set up the initial global model

In [8]:
model = LeNet()
params = model.init(jax.random.PRNGKey(42), np.zeros((32,) + dataset.input_shape))

Now we set up the network, we first construct the Network object add our clients, then place the network into a Server object

In [9]:
network = aesir.utils.network.Network()
for d in data:
    network.add_client(aesir.client.Client(params, optax.sgd(0.1), ce_loss(model.clone()), d))
server = aesir.server.fedavg.Server(network, params)

Finally, we perform our rounds of training simply by repeatedly calling the step function from the server

In [10]:
for r in (p := trange(3750)):
    loss_val = server.step()
    p.set_postfix_str(f"loss: {loss_val:.3f}")

  0%|          | 0/3750 [00:00<?, ?it/s]

We conclude by looking out how this final model performs

In [11]:
print(f"Test loss: {ce_loss(model)(server.params, *next(test_eval)):.3f}")
print(f"Test accuracy: {accuracy(model, server.params, *next(test_eval)):.3%}")

Test loss: 0.080
Test accuracy: 97.550%
