In [13]:
from typing import Callable, Dict, Iterable, Self
from numpy.typing import NDArray
import datasets
import tensorflow as tf
import numpy as np
import einops

import flagon

In [2]:
def load_mnist() -> datasets.Dataset:
    """
    Load the Fashion MNIST dataset http://arxiv.org/abs/1708.07747

    Arguments:
    - seed: seed value for the rng used in the dataset
    """
    ds = datasets.load_dataset("fashion_mnist")
    ds = ds.map(
        lambda e: {
            'X': einops.rearrange(np.array(e['image'], dtype=np.float32) / 255, "h (w c) -> h w c", c=1),
            'Y': e['label']
        },
        remove_columns=['image', 'label']
    )
    features = ds['train'].features
    features['X'] = datasets.Array3D(shape=(28, 28, 1), dtype='float32')
    ds['train'] = ds['train'].cast(features)
    ds['test'] = ds['test'].cast(features)
    ds.set_format('numpy')
    return ds

In [3]:
def create_model() -> tf.keras.Model:
    inputs = tf.keras.Input((28, 28, 1))
    x = tf.keras.layers.Flatten()(inputs)
    x = tf.keras.layers.Dense(100, activation="relu")(x)
    x = tf.keras.layers.Dense(50, activation="relu")(x)
    x = tf.keras.layers.Dense(10, activation="softmax")(x)
    model = tf.keras.Model(inputs=inputs, outputs=x)
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=0.01),
        loss="sparse_categorical_crossentropy",
        metrics=["accuracy"]
    )
    return model

In [11]:
class Client(flagon.Client):
    def __init__(self, data, create_model_fn):
        self.data = data
        self.model = create_model_fn()

    def fit(self, parameters, config):
        self.model.set_weights(parameters)
        history = self.model.fit(self.data['train']['X'], self.data['train']['Y'], epochs=config['num_epochs'], steps_per_epoch=config.get("num_steps"), verbose=0)
        return self.model.get_weights(), len(self.data['train']), {k: v[-1] for k, v in history.history.items()}

    def evaluate(self, parameters, config):
        self.model.set_weights(parameters)
        loss, accuracy = self.model.evaluate(self.data['test']['X'], self.data['test']['Y'], verbose=0)
        return len(self.data['test']), {'loss': loss, 'accuracy': accuracy}

In [5]:
def lda(labels, nclients, rng, alpha=0.5):
    """
    Latent Dirichlet allocation defined in https://arxiv.org/abs/1909.06335
    default value from https://arxiv.org/abs/2002.06440
    Optional arguments:
    - alpha: the alpha parameter of the Dirichlet function,
    the distribution is more i.i.d. as alpha approaches infinity and less i.i.d. as alpha approaches 0
    """
    distribution = [[] for _ in range(nclients)]
    nclasses = len(np.unique(labels))
    proportions = rng.dirichlet(np.repeat(alpha, nclients), size=nclasses)
    for c in range(nclasses):
        idx_c = np.where(labels == c)[0]
        rng.shuffle(idx_c)
        dists_c = np.split(idx_c, np.round(np.cumsum(proportions[c]) * len(idx_c)).astype(int)[:-1])
        distribution = [distribution[i] + d.tolist() for i, d in enumerate(dists_c)]
    return distribution

def create_clients(data, create_model_fn, network_arch, seed=None):
    Y = data['train']['Y']
    rng = np.random.default_rng(seed)
    idx = iter(lda(Y, flagon.common.count_clients(network_arch), rng, alpha=1000))

    def create_client(client_id: str) -> Client:
        return Client(datasets.DatasetDict(train=data['train'].select(next(idx)), test=data['test']), create_model_fn)
    return create_client

In [10]:
tf.random.set_seed(42)
server = flagon.Server(create_model().get_weights(), {"num_rounds": 5, "num_epochs": 1, "eval_every": 1})
network_arch = {"clients": 10}
history = flagon.start_simulation(
    server,
    create_clients(load_mnist(), create_model, network_arch, seed=42),
    network_arch
)

Found cached dataset fashion_mnist (/home/cody/.cache/huggingface/datasets/fashion_mnist/fashion_mnist/1.0.0/0a671f063342996f19779d38c0ab4abef9c64f757b35af8134b331c294d7ba48)
100%|██████████| 2/2 [00:00<00:00, 871.54it/s]
Loading cached processed dataset at /home/cody/.cache/huggingface/datasets/fashion_mnist/fashion_mnist/1.0.0/0a671f063342996f19779d38c0ab4abef9c64f757b35af8134b331c294d7ba48/cache-e93fd02beb204798.arrow
Loading cached processed dataset at /home/cody/.cache/huggingface/datasets/fashion_mnist/fashion_mnist/1.0.0/0a671f063342996f19779d38c0ab4abef9c64f757b35af8134b331c294d7ba48/cache-d78c8302bd9de403.arrow
Loading cached processed dataset at /home/cody/.cache/huggingface/datasets/fashion_mnist/fashion_mnist/1.0.0/0a671f063342996f19779d38c0ab4abef9c64f757b35af8134b331c294d7ba48/cache-31c809cd0ba5b588.arrow
Loading cached processed dataset at /home/cody/.cache/huggingface/datasets/fashion_mnist/fashion_mnist/1.0.0/0a671f063342996f19779d38c0ab4abef9c64f757b35af8134b331c294d7

In [45]:
class DataDict:
    def __init__(self, data: Dict[str, NDArray]):
        self.__data = data
        length = 0
        for k, v in data.items():
            if length == 0:
                length = len(v)
            elif length != len(v):
                raise AttributeError(f"Data should be composed of equal length arrays, column {k} has length {len(v)} should be {length}")
        self.length = length

    def select(self, idx: int | Iterable[int | bool]):
        return DataDict({k: v[idx] for k, v in self.__data.items()})

    def map(self, mapping_fn: Callable[[Dict[str, NDArray]], Dict[str, NDArray]]) -> Self:
        self.__data = mapping_fn(self.__data)
        return self

    def __getitem__(self, i: str) -> NDArray:
        return self.__data[i]
    
    def __len__(self) -> int:
        return len(self.__data['X'])
    
    def __str__(self) -> str:
        return str(self.__data)
    
    def short_details(self) -> str:
        details = "{"
        for k, v in self.__data.items():
            details += f"{k}: type {v.dtype}, shape {v.shape}, range [{v.min()}, {v.max()}], "
        details = details[:-2] + "}"
        return details


class Data:
    def __init__(self, data: Dict[str, Dict[str, NDArray]]):
        self.__data = {k: DataDict(v) for k, v in data.items()}
    
    def __getitem__(self, i: str) -> DataDict:
        return self.__data[i]
    
    def map(self, mapping_fn: Callable[[Dict[str, NDArray]], Dict[str, NDArray]]) -> Self:
        for v in self.__data.values():
            v.map(mapping_fn)
        return self
    
    def keys(self) -> Iterable[str]:
        return self.__data.keys()

    def __str__(self) -> str:
        string = "{\n"
        for k, v in self.__data.items():
            string += f"\t{k}: {v.short_details()}\n"
        string += "}"
        return string

In [46]:
data = load_mnist()
data = {t: {'X': data[t]['X'], 'Y': data[t]['Y']} for t in ['train', 'test']}
data = Data(data)

Found cached dataset fashion_mnist (/home/cody/.cache/huggingface/datasets/fashion_mnist/fashion_mnist/1.0.0/0a671f063342996f19779d38c0ab4abef9c64f757b35af8134b331c294d7ba48)
100%|██████████| 2/2 [00:00<00:00, 1135.13it/s]
Loading cached processed dataset at /home/cody/.cache/huggingface/datasets/fashion_mnist/fashion_mnist/1.0.0/0a671f063342996f19779d38c0ab4abef9c64f757b35af8134b331c294d7ba48/cache-e93fd02beb204798.arrow
Loading cached processed dataset at /home/cody/.cache/huggingface/datasets/fashion_mnist/fashion_mnist/1.0.0/0a671f063342996f19779d38c0ab4abef9c64f757b35af8134b331c294d7ba48/cache-d78c8302bd9de403.arrow
Loading cached processed dataset at /home/cody/.cache/huggingface/datasets/fashion_mnist/fashion_mnist/1.0.0/0a671f063342996f19779d38c0ab4abef9c64f757b35af8134b331c294d7ba48/cache-31c809cd0ba5b588.arrow
Loading cached processed dataset at /home/cody/.cache/huggingface/datasets/fashion_mnist/fashion_mnist/1.0.0/0a671f063342996f19779d38c0ab4abef9c64f757b35af8134b331c294d

In [47]:
trigger_X = np.zeros((28, 28, 1))
trigger_X[:5, :5] = 1
data.map(lambda e: {'X': e['X'], 'attacked X': np.minimum(e['X'] + trigger_X, 1.0), 'Y': e['Y']})
print(data)

{
	train: {X: type float32, shape (60000, 28, 28, 1), range [0.0, 1.0], attacked X: type float64, shape (60000, 28, 28, 1), range [0.0, 1.0], Y: type int64, shape (60000,), range [0, 9]}
	test: {X: type float32, shape (10000, 28, 28, 1), range [0.0, 1.0], attacked X: type float64, shape (10000, 28, 28, 1), range [0.0, 1.0], Y: type int64, shape (10000,), range [0, 9]}
}


In [48]:
tf.random.set_seed(42)
server = flagon.Server(create_model().get_weights(), {"num_rounds": 5, "num_epochs": 1, "eval_every": 1})
network_arch = {"clients": 10}
history = flagon.start_simulation(
    server,
    create_clients(data, create_model, network_arch, seed=42),
    network_arch
)

| flagon INFO @ 2023-07-12 11:09:46,009 in server.py:54 | Registering 10 clients to the server
| flagon INFO @ 2023-07-12 11:09:46,009 in server.py:64 | Starting training on the server for 5 rounds
| flagon INFO @ 2023-07-12 11:09:49,945 in server.py:83 | Aggregated training metrics at round 1: {'loss': 0.8344945247262716, 'accuracy': 0.6963999981274207}
| flagon INFO @ 2023-07-12 11:09:49,946 in server.py:84 | Finding test metrics
| flagon INFO @ 2023-07-12 11:09:49,946 in server.py:103 | Performing analytics on the server
| flagon INFO @ 2023-07-12 11:09:51,966 in server.py:112 | Completed server analytics in 2.019432783126831s
| flagon INFO @ 2023-07-12 11:09:51,966 in server.py:114 | Aggregated final metrics {'loss': 1.3312294483184814, 'accuracy': 0.5892000198364258}
| flagon INFO @ 2023-07-12 11:09:53,429 in server.py:83 | Aggregated training metrics at round 2: {'loss': 0.6982301035801569, 'accuracy': 0.7426499946465095}
| flagon INFO @ 2023-07-12 11:09:53,430 in server.py:84 | 