In [1]:

import nest_asyncio
nest_asyncio.apply()

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import tensorflow as tf
import tensorflow_federated as tff


In [2]:
tff.__version__

'0.50.0'

## Create Binary Classification data with sklearn

In [3]:

from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split

n = 100000
d = 20
noise_factor = 0.05
test_size = 0.1 # % of n

# Create (noisy) testing data for binary classification.
X, y = make_classification(
    n_samples=n, 
    n_features=d,
    n_informative=d,
    n_redundant=0, 
    n_classes=2,
    class_sep=-1,
    flip_y=noise_factor
)

# We will work with label values -1, +1 and not 0, +1 (convert)
y[y == 0] = -1

# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size)


## Convert to Tensors

In [4]:

# Convert the data to TensorFlow tensors
X_train_tensor = tf.constant(X_train, dtype=tf.float32)
y_train_tensor = tf.constant(y_train, dtype=tf.float32)
X_test_tensor = tf.constant(X_test, dtype=tf.float32)
y_test_tensor = tf.constant(y_test, dtype=tf.float32)

## Prepare data for Tensorflow Federated

We have the training and testing Tensors holding our data. TFF expects for each client an `OrderedDict` containing `y` and `x` data. Hence, we preprocess our Tensors to follow this convention.

In [5]:

NUM_CLIENTS = 8
BATCH_SIZE = 16
SHUFFLE_BUFFER = 96
BATCHES_PER_STEP = 1 # How many batches until we check RTC

In [78]:
print(f"Total number of batches per client: {int(n / (NUM_CLIENTS*BATCH_SIZE))}")

Total number of batches per client: 781


In [6]:

import collections

# Create a dictionary with the slices for each client
client_slices_train = {}
slices_test = {}

n_test = int(n - n*test_size)

for i in range(NUM_CLIENTS):
    # Compute the indices for this client's slice
    start_idx = int(i * n_test / NUM_CLIENTS)
    end_idx = int((i + 1) * n_test / NUM_CLIENTS)

    # Get the slice for this client
    X_client_train = X_train_tensor[start_idx:end_idx]
    y_client_train = y_train_tensor[start_idx:end_idx]
    
    client_data_train = collections.OrderedDict([('y', y_client_train), ('x', X_client_train)])
    
    # Combine the slices into a single dataset
    client_slices_train[f'client_{i}'] = client_data_train

slices_test = collections.OrderedDict([('y', y_test_tensor), ('x', X_test_tensor)])

For a sanity check let's see inside `client_slices_train` for the first x,y tuple of the 'first' client

In [7]:
client_slices_train['client_0']['x'][0]

<tf.Tensor: shape=(20,), dtype=float32, numpy=
array([ 1.9104669 , -2.1262722 , -0.57036614,  0.55053246,  1.074606  ,
       -1.7338465 , -1.2754428 , -2.7260678 ,  0.40247935,  1.2480997 ,
       -0.05747155,  2.175786  , -1.0639805 ,  3.198423  , -1.7250559 ,
       -0.5827713 ,  2.8746974 , -1.0082754 , -0.3457826 ,  1.5670983 ],
      dtype=float32)>

In [8]:
client_slices_train['client_0']['y'][0]

<tf.Tensor: shape=(), dtype=float32, numpy=1.0>

Now, a client with `client_id` has it's single Tensor holding instances in`client_slices_train[client_id]['x']` and labels in `client_slices_train[client_id]['y']`. Let's take a step back from TFF. Having this data scheme, we can create a client's Tensorflow dataset using `from_tensor_slices` function passing the client's id as follows

In [9]:

def create_tf_dataset_for_client(client_id):
    return tf.data.Dataset.from_tensor_slices(client_slices_train[client_id]).shuffle(SHUFFLE_BUFFER).batch(BATCH_SIZE).take(BATCHES_PER_STEP)

def create_tf_dataset_for_test():
    return tf.data.Dataset.from_tensor_slices(slices_test).batch(BATCH_SIZE)

For TFF we need to construct Federated data for clients, i.e., `tff.simulation.datasets.ClientData`. We can use the `from_clients_and_tf_fn` function that takes as argument the `client_ids` : a list of strings corresponding to client ids, and a `serializable_dataset_fn` : a function that takes a `client_id` from the above list, and returns a `tf.data.Dataset`. It's obvious how we proceed with the code (using the above function)

In [10]:

preprocessed_train_federated_dataset = tff.simulation.datasets.ClientData.from_clients_and_tf_fn(
    client_ids=list(client_slices_train.keys()),
    serializable_dataset_fn=lambda client_id: create_tf_dataset_for_client(client_id)
)

In [11]:
preprocessed_train_federated_dataset.client_ids

['client_0',
 'client_1',
 'client_2',
 'client_3',
 'client_4',
 'client_5',
 'client_6',
 'client_7']

**Note**: Cross-device federated learning does not use client IDs or perform any tracking of clients. However in simulation experiments using centralized test data the experimenter may select specific clients to be processed per round. The concept of a client ID is only available at the preprocessing stage when preparing input data for the simulation and is not part of the TensorFlow Federated core APIs.

Now, `preprocessed_train_federated_dataset` holds logic on how each client constructs its dataset. Note that so `client_slices_train` has already been materialized and lies in this context's memory.

One way (the simplest) to feed federated data to TFF in a simulation is simply as a Python list, with each element of the list holds the data of an individual client, whether as a list or preferably as a `tf.data.Dataset`. Since we already created an interface that provides the latter we will use it. Here is a helper function that will construct a list of datasets from the set of users.

In [12]:

def create_federated_data():    
    return [
        preprocessed_train_federated_dataset.create_tf_dataset_for_client(client)
        for client in preprocessed_train_federated_dataset.client_ids
    ]

**Important Note**: Firstly, we used `sklearn` to create the binary classification data eagerly, i.e., we were forced to materialize it into memory. In simulation, in general it is more sound to push preprocessing logic into each client, i.e., each client constructs its own dataset (from the same underlying distribution) or reads from a file or something else and he, himself processes the data as needed. This is the best approach and uses the TFF distributed engine the best way. But in our case this was illogical to happen since we are forced to construct the dataset in memory anyway. For example, we could have stored each client's data inside some serialized file (`client_0.tfrecord` for the first client and so on) and push logic where each clients diserializes and processes its own data but this would be silly and slower when testing. For a small example that showcases this scenario see *TFF - Introduction - Federated Core API - Part 3(examples).ipynb*.

In [13]:
#https://stackoverflow.com/questions/60265798/tff-how-define-tff-simulation-clientdata-from-clients-and-fn-function

## TFF Types

Let's start with a simple float32 type.

In [14]:
FLOAT32_TYPE = tff.TensorType(dtype=tf.float32, shape=())

In [15]:
str(FLOAT32_TYPE)

'float32'

1-dimensional tensor (vector) of length 1 with elements of type float32

In [16]:
FLOAT32_VECTOR_TYPE = tff.TensorType(dtype=tf.float32, shape=(1,))

The local client state $ S_i(t) $ as defined in the unpublished paper.

In [17]:
CLIENT_STATE = tff.FederatedType(FLOAT32_VECTOR_TYPE, tff.CLIENTS)

In [18]:
str(CLIENT_STATE)

'{float32[1]}@CLIENTS'

First, let's define the type of input as a TFF named tuple. Since the size of data batches may vary, we set the batch dimension to None to indicate that the size of this dimension is unknown.

In [19]:

BATCH_SPEC = collections.OrderedDict(
    y=tf.TensorSpec(shape=[None], dtype=tf.float32),
    x=tf.TensorSpec(shape=[None, d], dtype=tf.float32)
)
BATCH_TYPE = tff.to_type(BATCH_SPEC)

In [20]:
str(BATCH_TYPE)

'<y=float32[?],x=float32[?,20]>'

Every client holds a sequence of batches so the we define the client data type as follows

In [21]:

LOCAL_DATA_TYPE = tff.SequenceType(BATCH_TYPE)

In [22]:
str(LOCAL_DATA_TYPE)

'<y=float32[?],x=float32[?,20]>*'

Let's now define the TFF type of the model which is simply a `tf.Variable` with shape (d, 1)

In [23]:

MODEL_TYPE = tff.TensorType(dtype=tf.float32, shape=(d, 1))

In [24]:
str(MODEL_TYPE)

'float32[20,1]'

Since the server holds the 'global' model we need to create the Federated Type, defined as the tuple of a member: An instance of `tff.Type`, and a placement: The specification of placement of the member comonents (where this type is hosted at, for example, at `tff.SERVER` or `tff.CLIENTS`).

In [25]:

SERVER_MODEL_TYPE = tff.type_at_server(MODEL_TYPE)

In [26]:
str(SERVER_MODEL_TYPE)

'float32[20,1]@SERVER'

Following, the same logic, we create the Federated Type of each client's data.

In [27]:

CLIENT_DATA_TYPE = tff.type_at_clients(LOCAL_DATA_TYPE)

In [28]:
str(CLIENT_DATA_TYPE)

'{<y=float32[?],x=float32[?,20]>*}@CLIENTS'

We will also need to define the client models at the CLIENTS (for FDA later, to be cont...)

In [29]:
CLIENT_MODEL_TYPE = tff.type_at_clients(MODEL_TYPE)

In [30]:
str(CLIENT_MODEL_TYPE)

'{float32[20,1]}@CLIENTS'

## Accuracy Testing

In [31]:

@tf.function
def accuracy(model, dataset):
    
    @tf.function
    def _batch_accuracy(model, batch):
        x_batch, y_batch = batch['x'], tf.expand_dims(batch['y'], axis=1)

        # dot(w, x) for the batch (each instance of x in x_batch) with with shape=(batchsize, 1)
        weights_dot_x_batch = tf.matmul(x_batch, model)

        # Prediction batch with shape=(batchsize, 1)
        y_pred_batch = tf.sign(weights_dot_x_batch)

        accuracy = tf.reduce_mean(tf.cast(tf.equal(y_pred_batch, y_batch), tf.float32))

        return accuracy
    
    # We take advantage of AutoGraph (convert Python code to TensorFlow-compatible graph code automatically)
    acc, num_batches = 0., 0.
    for batch in dataset:
        acc += _batch_accuracy(model, batch)
        num_batches += 1
        
    acc = acc / num_batches
    
    return acc

In [32]:

@tff.tf_computation(MODEL_TYPE, LOCAL_DATA_TYPE)
def accuracy_fn(model, dataset):
    model = tf.Variable(initial_value=model)
    return accuracy(model, dataset)

Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089


Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089


In [33]:
str(accuracy_fn.type_signature)

'(<model=float32[20,1],dataset=<y=float32[?],x=float32[?,20]>*> -> float32)'

# Federated Learning

### Server Update

The server update takes as input the *average* of the client's models and creates its model as follows

In [34]:

@tff.tf_computation(MODEL_TYPE)
def server_update_fn(mean_client_model):
    model = tf.Variable(initial_value=mean_client_model)
    return model

**Note**: This abstraction for this simple jupyter (where the model is a `tf.Variable`) is not necessary. We create this abstraction since it is common practice generally.

### Client train

Each client trains on its own dataset (which is a sequence of batches). Hence, we create the training process, currently a PA-1 Classifier. The input of `client_train` is the client model materialized inside its client and its dataset.

In [35]:

@tf.function
def client_train(model, dataset):
    
    @tf.function
    def _train_on_batch(model, batch, C=0.01):

        x_batch, y_batch = batch['x'], tf.expand_dims(batch['y'], axis=1)

        # dot(w, x) for the batch (each instance of x in x_batch) with with shape=(batchsize, 1)
        weights_dot_x_batch = tf.matmul(x_batch, model)

        # Prediction batch with shape=(batchsize, 1)
        y_pred_batch = tf.sign(weights_dot_x_batch)

        # Suffer loss for each prediction (of instance) in the batch with shape=(batchsize,1)
        loss_batch = tf.maximum(0., 1. - tf.multiply(y_batch, weights_dot_x_batch))

        # shape=(batchsize,1) where each instance is ||x||^2, x in x_batch
        norm_batch = tf.expand_dims(tf.reduce_sum(tf.square(x_batch), axis=1), axis=1)

        # PA-1 : Learning rate t for each instance x, with shape=(batchsize,1)
        t_batch = tf.maximum(C, tf.divide(loss_batch, norm_batch))

        # each instance is y*t*x, where y,t scalars and x in x_batch. shape=(batchsize,d)
        t_y_x_batch = tf.multiply(t_batch, tf.multiply(y_batch, x_batch))

        # !!!! Update with mean t*y*x
        t_y_x_update = tf.expand_dims(tf.reduce_mean(t_y_x_batch, axis=0) ,axis=1)

        # Update
        model.assign_add(t_y_x_update)
    
    for batch in dataset:
        _train_on_batch(model, batch)
        
    return model

# Functional Dynamic Averaging

We follow the Functional Dynamic Averaging (FDA) scheme. Let the mean model be

$$ \overline{w_t} = \frac{1}{k} \sum_{i=1}^{k} w_t^{(i)} $$

where $ w_t^{(i)} $ is the model at time $ t $ in some round in the $i$-th learner.

Local models are trained independently and cooperatively and we want to monitor the Round Terminating Conditon (**RTC**):

$$ \frac{1}{k} \sum_{i=1}^{k} \lVert w_t^{(i)} - \overline{w_t} \rVert_2^2  \leq \Theta $$

where the left-hand side is the **model variance**, and threshold $\Theta$ is a hyperparameter of the FDA, defined at the beginning of the round; it may change at each round. When the monitoring logic cannot guarantee the validity of RTC, the round terminates. All local models are pulled into `tff.SERVER`, and $\bar{w_t}$ is set to their average. Then, another round begins.


### Monitoring the RTC

FDA monitors the RTC by applying techniques from Functionary [Functional Geometric Averaging](http://users.softnet.tuc.gr/~minos/Papers/edbt19.pdf). We first restate the problem of monitoring RTC into the standard distributed stream monitoring formulation. Let

$$ S(t) =  \frac{1}{k} \sum_{i=1}^{k} S_i(t) $$

where $ S(t) \in \mathbb{R}^n $ be the "global state" of the system and $ S_i(t) \in \mathbb{R}^n $ the "local states". The goal is to monitor the threshold condition on the global state in the form $ F(S(t)) \leq \Theta $ where $ F : \mathbb{R}^n \to \mathbb{R} $ a non-linear function. Let

$$ \Delta_t^{(i)} = w_t^{(i)} - w_{t_0}^{(i)} $$

be the update at the $ i $-th learner, that is, the change to the local model at time $t$ since the beginning of the current round at time $t_0$. Let the average update be

$$ \overline{\Delta_t} = \frac{1}{k} \sum_{i=1}^{k} \Delta_t^{(i)} $$

it follows that the variance can be written as

$$ \frac{1}{k} \sum_{i=1}^{k} \lVert w_t^{(i)} - \overline{w_t} \rVert_2^2 = \Big( \frac{1}{k} \sum_{i=1}^{k} \lVert \Delta_t^{(i)} \rVert_2^2 \Big) - \lVert \overline{\Delta_t} \rVert_2^2 $$

So, conceptually, if we define
$$ S_i(t) = \begin{bmatrix}
           \lVert \Delta_t^{(i)} \rVert_2^2 \\
           \Delta_t^{(i)}
         \end{bmatrix} \quad \text{and} \quad
         F(\begin{bmatrix}
           v \\
           \bf{x}
         \end{bmatrix}) = v - \lVert \bf{x} \rVert_2^2 $$

The RTC is equivalent to condition $$ F(S(t)) \leq \Theta $$

## 1️⃣ Naive FDA

In the naive approach, we eliminate the update vector from the local state (i.e. recuce the dimension to 0). Define local state as

$$ S_i(t) = \lVert \Delta_t^{(i)} \rVert_2^2 \in \mathbb{R}$$ 

and the identity function

$$ F(v) = v $$

It is trivial that $ F(S(t)) \leq \Theta $ implies the RTC.

Using the functions decorated with `tf.function` (context inside Tensorflow) we create the `client_train_fn` with context inside TFF. `client_train_fn` takes as input the `initial_model` which is the model broadcasted from the server to each client and the client dataset. Notice that each client first creates it's own model using the server model.

In [36]:
# TODO: last synchronized model. Not initial model
@tff.tf_computation(MODEL_TYPE, MODEL_TYPE, LOCAL_DATA_TYPE)
def client_train_fn(last_sync_model, initial_model, dataset):
    
    model = client_train(
        tf.Variable(initial_value=initial_model), 
        dataset
    )
    
    Delta_i = model - last_sync_model # AutoGraph
    S_i = tf.reduce_sum(tf.square(Delta_i), axis=0) # ||D(t)_i||^2

    return model, S_i

In [37]:
str(client_train_fn.type_signature)

'(<last_sync_model=float32[20,1],initial_model=float32[20,1],dataset=<y=float32[?],x=float32[?,20]>*> -> <float32[20,1],float32[1]>)'

### Training Round

Remember the 4 elements of an FL round:

1. A server-to-client broadcast of the weights.
2. A local client training 'step' on its own data.
3. A client-to-server upload step (returning the trained weights).
4. A server update step.

In [38]:
@tff.federated_computation(CLIENT_MODEL_TYPE)
def server_update(client_models):
    # 4. Compute the mean of the client weights
    mean_client_model = tff.federated_mean(client_models)
    
    # 4. Update the server model
    server_model = tff.federated_map(server_update_fn, mean_client_model)
    
    return server_model

In [39]:
str(server_update.type_signature)

'({float32[20,1]}@CLIENTS -> float32[20,1]@SERVER)'

In [40]:
@tff.federated_computation(CLIENT_MODEL_TYPE, CLIENT_MODEL_TYPE, CLIENT_DATA_TYPE)
def step(last_sync_client_models, client_models, federated_dataset):
    # 2. 3. Train the client models on their respective datasets
    client_models, client_S_i = tff.federated_map(
        client_train_fn, 
        (last_sync_client_models, client_models, federated_dataset)
    )
    
    return client_models, client_S_i

In [41]:
str(step.type_signature)

'(<last_sync_client_models={float32[20,1]}@CLIENTS,client_models={float32[20,1]}@CLIENTS,federated_dataset={<y=float32[?],x=float32[?,20]>*}@CLIENTS> -> <{float32[20,1]}@CLIENTS,{float32[1]}@CLIENTS>)'

In [52]:
@tff.federated_computation(SERVER_MODEL_TYPE, CLIENT_DATA_TYPE)
def synchronize_and_step(server_model, federated_dataset):
    # 1. Broadcast the current server model to the clients
    server_model_at_client = tff.federated_broadcast(server_model)
    
    # 2. 3. Train the client models on their respective datasets
    client_models, client_S_i = tff.federated_map(
        client_train_fn, 
        (server_model_at_client, server_model_at_client, federated_dataset)
    )
    
    return client_models, client_S_i


In [53]:
str(synchronize_and_step.type_signature)

'(<server_model=float32[20,1]@SERVER,federated_dataset={<y=float32[?],x=float32[?,20]>*}@CLIENTS> -> <{float32[20,1]}@CLIENTS,{float32[1]}@CLIENTS>)'

In [44]:

@tff.federated_computation(CLIENT_STATE)
def server_global_state(client_S_i):
    
    server_S = tff.federated_mean(client_S_i)
    
    return server_S

In [45]:
str(server_global_state.type_signature)

'({float32[1]}@CLIENTS -> float32[1]@SERVER)'

In [46]:

@tff.tf_computation(FLOAT32_VECTOR_TYPE, FLOAT32_TYPE)
def RTC_holds(S_t, THETA):
    """ Returns True if RTC holds (has not been defied). False otherwise (sync must happen)"""
    
    @tf.function
    def _F(S_t, THETA):
        """ Naive FDA """
        return S_t <= THETA
    
    return _F(S_t, THETA)


In [47]:
str(RTC_holds.type_signature)

'(<S_t=float32[1],THETA=float32> -> bool[1])'

## Training.

In [79]:
def PRINT_INFO_BEFORE_SYNC(num_rounds, num_steps, global_state, accuracy):
    print(f"---------------------------------- Round {num_rounds} ----------------------------------------------")
    print(f"Steps: {num_steps} , Server Model Accuracy: {accuracy} ,  Global State (S_t): {S_t}")
    
    
def PRINT_INFO_AFTER_SYNC(global_state):    
    print(f"Global State after synchronization and one step: {global_state}")
    print()

In [101]:
# Initial model of zeros (in Python context, to be passed to server)
model = tf.Variable(tf.zeros(shape=(d, 1)), trainable=True, name='weights', dtype=tf.float32)

client_models = [model]*NUM_CLIENTS
last_sync_client_models = [model]*NUM_CLIENTS

In [102]:

train_federated_data = create_federated_data()

In [103]:

test_dataset = create_tf_dataset_for_test()

In [104]:
S_t = tf.constant([float('inf')], dtype=tf.float32) # Force synchronization at the start
THETA = 10

num_rounds = 0 
num_steps = 0 # Each step invoke is a step
        
while num_rounds < 10:
    
    if RTC_holds(S_t, THETA): # RTC holds, no sync needed
        
        # Perform a training step with the current client_models (no sync yet)
        client_models, client_S_i = step(last_sync_client_models, client_models, train_federated_data)
        
        # Compute 'global state' as defined in the manuscript
        S_t = server_global_state(client_S_i)
    
    else: # RTC defied, sync needed
        
        # Update the server model from the client models.
        model = server_update(client_models)
        
        PRINT_INFO_BEFORE_SYNC(num_rounds, num_steps, S_t, accuracy_fn(model, test_dataset))
        
        # Synchronize client models with server model, and perform a train step
        client_models, client_S_i = synchronize_and_step(model, train_federated_data)
        
        last_sync_client_models = [model]*NUM_CLIENTS
        
        # Compute 'global state' as defined in the manuscript
        S_t = server_global_state(client_S_i)
        
        PRINT_INFO_AFTER_SYNC(S_t)
        
        num_rounds += 1
    
    num_steps += 1

print()
print(f"Total number of steps: {num_steps}")
        

---------------------------------- Round 0 ----------------------------------------------
Steps: 0 , Server Model Accuracy: 0.0 ,  Global State (S_t): [inf]
Global State after synchronization: [0.00190214]

---------------------------------- Round 1 ----------------------------------------------
Steps: 131 , Server Model Accuracy: 0.8648999929428101 ,  Global State (S_t): [10.032373]
Global State after synchronization: [0.00338402]

---------------------------------- Round 2 ----------------------------------------------
Steps: 274 , Server Model Accuracy: 0.8583999872207642 ,  Global State (S_t): [10.082127]
Global State after synchronization: [0.00420818]

---------------------------------- Round 3 ----------------------------------------------
Steps: 411 , Server Model Accuracy: 0.85589998960495 ,  Global State (S_t): [10.002724]
Global State after synchronization: [0.0194706]

---------------------------------- Round 4 ----------------------------------------------
Steps: 541 , Ser

In [None]:
S_t

In [None]:
num_steps

1. Check correctness Delta_i etc.
1. Fix C = 0.01
2. Comments + Check approach (maybe pass string "Naive FDA" or deduplicate functions)
3. comments
4. THink about Whilte True.
5. Wrap in tff.tf_computation or federated.