In [1]:

import nest_asyncio
nest_asyncio.apply()

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import tensorflow as tf
import tensorflow_federated as tff


In [2]:
tff.__version__

'0.50.0'

## Create Binary Classification data with sklearn

In [3]:

from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split

n = 100000 
d = 20 
noise_factor = 0.05
test_size = 0.1 # % of n

seed = 2

# Create (noisy) testing data for binary classification.
X, y = make_classification(
    n_samples=n, 
    n_features=d,
    n_informative=d,
    n_redundant=0, 
    n_classes=2,
    class_sep=-1,
    flip_y=noise_factor,
    random_state=seed
)

# We will work with label values -1, +1 and not 0, +1 (convert)
y[y == 0] = -1

# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size)


## Convert to Tensors

In [4]:

# Convert the data to TensorFlow tensors
X_train_tensor = tf.constant(X_train, dtype=tf.float32)
y_train_tensor = tf.constant(y_train, dtype=tf.float32)
X_test_tensor = tf.constant(X_test, dtype=tf.float32)
y_test_tensor = tf.constant(y_test, dtype=tf.float32)

## Prepare data for Tensorflow Federated

We have the training and testing Tensors holding our data. TFF expects for each client an `OrderedDict` containing `y` and `x` data. Hence, we preprocess our Tensors to follow this convention.

In [5]:

NUM_CLIENTS = 3

# https://www.tensorflow.org/api_docs/python/tf/data/Dataset#batch
BATCH_SIZE = 32
SHUFFLE_BUFFER = int(n / NUM_CLIENTS)
BATCHES_PER_STEP = 1 # Batches per Step, i.e, How many batches until we check RTC

In [6]:
print(f"Total number of batches per client: {int(n / (NUM_CLIENTS*BATCH_SIZE))}")

Total number of batches per client: 1041


In [7]:

import collections

# Create a dictionary with the slices for each client
client_slices_train = {}
slices_test = {}

n_test = int(n - n*test_size)

for i in range(NUM_CLIENTS):
    # Compute the indices for this client's slice
    start_idx = int(i * n_test / NUM_CLIENTS)
    end_idx = int((i + 1) * n_test / NUM_CLIENTS)

    # Get the slice for this client
    X_client_train = X_train_tensor[start_idx:end_idx]
    y_client_train = y_train_tensor[start_idx:end_idx]
    
    client_data_train = collections.OrderedDict([('y', y_client_train), ('x', X_client_train)])
    
    # Combine the slices into a single dataset
    client_slices_train[f'client_{i}'] = client_data_train

slices_test = collections.OrderedDict([('y', y_test_tensor), ('x', X_test_tensor)])

For a sanity check let's see inside `client_slices_train` for the first x,y tuple of the 'first' client

In [8]:
client_slices_train['client_0']['x'][0]

<tf.Tensor: shape=(20,), dtype=float32, numpy=
array([ 2.8520453 ,  4.1043606 ,  0.72331977,  0.630671  ,  1.0472858 ,
       -2.4497042 , -2.0924962 , -2.9751914 , -2.8719494 ,  0.63433236,
        1.8827676 ,  3.525443  , -1.7605263 , -2.8830485 ,  2.4992065 ,
        1.4187863 ,  5.275925  ,  0.7632305 , -3.8827038 , -0.09312054],
      dtype=float32)>

In [9]:
client_slices_train['client_0']['y'][0]

<tf.Tensor: shape=(), dtype=float32, numpy=1.0>

Now, a client with `client_id` has it's single Tensor holding instances in`client_slices_train[client_id]['x']` and labels in `client_slices_train[client_id]['y']`. Let's take a step back from TFF. Having this data scheme, we can create a client's Tensorflow dataset using `from_tensor_slices` function passing the client's id as follows

In [10]:
# https://www.tensorflow.org/api_docs/python/tf/data/Dataset#batch

def create_tf_dataset_for_client(client_id):
    return tf.data.Dataset.from_tensor_slices(client_slices_train[client_id]) \
        .shuffle(SHUFFLE_BUFFER).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE) \
        .take(BATCHES_PER_STEP)

def create_tf_dataset_for_test():
    return tf.data.Dataset.from_tensor_slices(slices_test).batch(BATCH_SIZE)

For TFF we need to construct Federated data for clients, i.e., `tff.simulation.datasets.ClientData`. We can use the `from_clients_and_tf_fn` function that takes as argument the `client_ids` : a list of strings corresponding to client ids, and a `serializable_dataset_fn` : a function that takes a `client_id` from the above list, and returns a `tf.data.Dataset`. It's obvious how we proceed with the code (using the above function)

In [11]:

preprocessed_train_federated_dataset = tff.simulation.datasets.ClientData.from_clients_and_tf_fn(
    client_ids=list(client_slices_train.keys()),
    serializable_dataset_fn=lambda client_id: create_tf_dataset_for_client(client_id)
)

In [12]:
preprocessed_train_federated_dataset.client_ids

['client_0', 'client_1', 'client_2']

**Note**: Cross-device federated learning does not use client IDs or perform any tracking of clients. However in simulation experiments using centralized test data the experimenter may select specific clients to be processed per round. The concept of a client ID is only available at the preprocessing stage when preparing input data for the simulation and is not part of the TensorFlow Federated core APIs.

Now, `preprocessed_train_federated_dataset` holds logic on how each client constructs its dataset. Note that so `client_slices_train` has already been materialized and lies in this context's memory.

One way (the simplest) to feed federated data to TFF in a simulation is simply as a Python list, with each element of the list holds the data of an individual client, whether as a list or preferably as a `tf.data.Dataset`. Since we already created an interface that provides the latter we will use it. Here is a helper function that will construct a list of datasets from the set of users.

In [13]:

def create_federated_data():    
    return [
        preprocessed_train_federated_dataset.create_tf_dataset_for_client(client)
        for client in preprocessed_train_federated_dataset.client_ids
    ]

**Important Note**: Firstly, we used `sklearn` to create the binary classification data eagerly, i.e., we were forced to materialize it into memory. In simulation, in general it is more sound to push preprocessing logic into each client, i.e., each client constructs its own dataset (from the same underlying distribution) or reads from a file or something else and he, himself processes the data as needed. This is the best approach and uses the TFF distributed engine the best way. But in our case this was illogical to happen since we are forced to construct the dataset in memory anyway. For example, we could have stored each client's data inside some serialized file (`client_0.tfrecord` for the first client and so on) and push logic where each clients diserializes and processes its own data but this would be silly and slower when testing. For a small example that showcases this scenario see *TFF - Introduction - Federated Core API - Part 3(examples).ipynb*.

## TFF Types

Let's start with a simple float32 type.

In [14]:
FLOAT32_TYPE = tff.TensorType(dtype=tf.float32, shape=())

In [15]:
str(FLOAT32_TYPE)

'float32'

In [16]:
CLIENT_FLOAT32 = tff.FederatedType(FLOAT32_TYPE, tff.CLIENTS)

In [17]:
str(CLIENT_FLOAT32)

'{float32}@CLIENTS'

1-dimensional tensor (vector) of length 1 with elements of type float32

In [18]:
FLOAT32_VECTOR_TYPE = tff.TensorType(dtype=tf.float32, shape=(1,))

The local client state $ S_i(t) $ as defined in the unpublished paper.

In [19]:
CLIENT_STATE = tff.FederatedType(FLOAT32_VECTOR_TYPE, tff.CLIENTS)

In [20]:
str(CLIENT_STATE)

'{float32[1]}@CLIENTS'

First, let's define the type of input as a TFF named tuple. Since the size of data batches may vary, we set the batch dimension to None to indicate that the size of this dimension is unknown.

In [21]:

BATCH_SPEC = collections.OrderedDict(
    y=tf.TensorSpec(shape=[None], dtype=tf.float32),
    x=tf.TensorSpec(shape=[None, d], dtype=tf.float32)
)
BATCH_TYPE = tff.to_type(BATCH_SPEC)

In [22]:
str(BATCH_TYPE)

'<y=float32[?],x=float32[?,20]>'

Every client holds a sequence of batches so the we define the client data type as follows

In [23]:

LOCAL_DATA_TYPE = tff.SequenceType(BATCH_TYPE)

In [24]:
str(LOCAL_DATA_TYPE)

'<y=float32[?],x=float32[?,20]>*'

Let's now define the TFF type of the model which is simply a `tf.Variable` with shape (d, 1)

In [25]:

MODEL_TYPE = tff.TensorType(dtype=tf.float32, shape=(d, 1))

In [26]:
str(MODEL_TYPE)

'float32[20,1]'

Since the server holds the 'global' model we need to create the Federated Type, defined as the tuple of a member: An instance of `tff.Type`, and a placement: The specification of placement of the member comonents (where this type is hosted at, for example, at `tff.SERVER` or `tff.CLIENTS`).

In [27]:

SERVER_MODEL_TYPE = tff.type_at_server(MODEL_TYPE)

In [28]:
str(SERVER_MODEL_TYPE)

'float32[20,1]@SERVER'

Following, the same logic, we create the Federated Type of each client's data.

In [29]:

CLIENT_DATA_TYPE = tff.type_at_clients(LOCAL_DATA_TYPE)

In [30]:
str(CLIENT_DATA_TYPE)

'{<y=float32[?],x=float32[?,20]>*}@CLIENTS'

We will also need to define the client models at the CLIENTS (for FDA later, to be cont...)

In [31]:
CLIENT_MODEL_TYPE = tff.type_at_clients(MODEL_TYPE)

In [32]:
str(CLIENT_MODEL_TYPE)

'{float32[20,1]}@CLIENTS'

## Accuracy Testing

In [33]:

@tf.function
def accuracy(model, dataset):
    
    @tf.function
    def _batch_accuracy(model, batch):
        x_batch, y_batch = batch['x'], tf.expand_dims(batch['y'], axis=1)

        # dot(w, x) for the batch (each instance of x in x_batch) with with shape=(batchsize, 1)
        weights_dot_x_batch = tf.matmul(x_batch, model)

        # Prediction batch with shape=(batchsize, 1)
        y_pred_batch = tf.sign(weights_dot_x_batch)

        accuracy = tf.reduce_mean(tf.cast(tf.equal(y_pred_batch, y_batch), tf.float32))

        return accuracy
    
    # We take advantage of AutoGraph (convert Python code to TensorFlow-compatible graph code automatically)
    acc, num_batches = 0., 0.
    for batch in dataset:
        acc += _batch_accuracy(model, batch)
        num_batches += 1
        
    acc = acc / num_batches
    
    return acc

In [34]:

@tff.tf_computation(MODEL_TYPE, LOCAL_DATA_TYPE)
def accuracy_fn(model, dataset):
    model = tf.Variable(initial_value=model)
    return accuracy(model, dataset)

Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089


Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089


In [35]:
str(accuracy_fn.type_signature)

'(<model=float32[20,1],dataset=<y=float32[?],x=float32[?,20]>*> -> float32)'

# Federated Learning

### Server Update

In [36]:

@tff.tf_computation(MODEL_TYPE)
def server_update_fn(clients_aggr_model):
    model = tf.Variable(initial_value=clients_aggr_model)
    return model

**Note**: This abstraction for this simple jupyter (where the model is a `tf.Variable`) is not necessary. We create this abstraction since it is common practice generally.

### Client train - PA-I Classifier

![PA](images/PA_binary_classifiers.png)

Each client trains on its own dataset (which is a sequence of batches). Hence, we create the training process, currently a PA-1 Classifier. The input of `client_train` is the client model materialized inside its client and its dataset.

In [37]:

@tf.function
def client_train(model, C, dataset):
    
    @tf.function
    def _train_on_batch(model, C, batch):

        x_batch, y_batch = batch['x'], tf.expand_dims(batch['y'], axis=1)

        # dot(w, x) for the batch (each instance of x in x_batch) with with shape=(batchsize, 1)
        weights_dot_x_batch = tf.matmul(x_batch, model)

        # Prediction batch with shape=(batchsize, 1)
        y_pred_batch = tf.sign(weights_dot_x_batch)

        # Suffer loss for each prediction (of instance) in the batch with shape=(batchsize,1)
        loss_batch = tf.maximum(0., 1. - tf.multiply(y_batch, weights_dot_x_batch))

        # shape=(batchsize,1) where each instance is ||x||^2, x in x_batch
        norm_batch = tf.expand_dims(tf.reduce_sum(tf.square(x_batch), axis=1), axis=1)
        
        # PA-1 : Learning rate t for each instance x, with shape=(batchsize,1)
        t_batch = tf.maximum(C, tf.divide(loss_batch, norm_batch))

        # each instance is y*t*x, where y,t scalars and x in x_batch. shape=(batchsize,d)
        t_y_x_batch = tf.multiply(t_batch, tf.multiply(y_batch, x_batch))

        # !!!! Update with mean t*y*x
        t_y_x_update = tf.expand_dims(tf.reduce_mean(t_y_x_batch, axis=0) ,axis=1)

        # Update
        model.assign_add(t_y_x_update)
    
    for batch in dataset:
        _train_on_batch(model, C, batch)
        
    return model

# Functional Dynamic Averaging

We follow the Functional Dynamic Averaging (FDA) scheme. Let the mean model be

$$ \overline{w_t} = \frac{1}{k} \sum_{i=1}^{k} w_t^{(i)} $$

where $ w_t^{(i)} $ is the model at time $ t $ in some round in the $i$-th learner.

Local models are trained independently and cooperatively and we want to monitor the Round Terminating Conditon (**RTC**):

$$ \frac{1}{k} \sum_{i=1}^{k} \lVert w_t^{(i)} - \overline{w_t} \rVert_2^2  \leq \Theta $$

where the left-hand side is the **model variance**, and threshold $\Theta$ is a hyperparameter of the FDA, defined at the beginning of the round; it may change at each round. When the monitoring logic cannot guarantee the validity of RTC, the round terminates. All local models are pulled into `tff.SERVER`, and $\bar{w_t}$ is set to their average. Then, another round begins.


### Monitoring the RTC

FDA monitors the RTC by applying techniques from Functionary [Functional Geometric Averaging](http://users.softnet.tuc.gr/~minos/Papers/edbt19.pdf). We first restate the problem of monitoring RTC into the standard distributed stream monitoring formulation. Let

$$ S(t) =  \frac{1}{k} \sum_{i=1}^{k} S_i(t) $$

where $ S(t) \in \mathbb{R}^n $ be the "global state" of the system and $ S_i(t) \in \mathbb{R}^n $ the "local states". The goal is to monitor the threshold condition on the global state in the form $ F(S(t)) \leq \Theta $ where $ F : \mathbb{R}^n \to \mathbb{R} $ a non-linear function. Let

$$ \Delta_t^{(i)} = w_t^{(i)} - w_{t_0}^{(i)} $$

be the update at the $ i $-th learner, that is, the change to the local model at time $t$ since the beginning of the current round at time $t_0$. Let the average update be

$$ \overline{\Delta_t} = \frac{1}{k} \sum_{i=1}^{k} \Delta_t^{(i)} $$

it follows that the variance can be written as

$$ \frac{1}{k} \sum_{i=1}^{k} \lVert w_t^{(i)} - \overline{w_t} \rVert_2^2 = \Big( \frac{1}{k} \sum_{i=1}^{k} \lVert \Delta_t^{(i)} \rVert_2^2 \Big) - \lVert \overline{\Delta_t} \rVert_2^2 $$

So, conceptually, if we define
$$ S_i(t) = \begin{bmatrix}
           \lVert \Delta_t^{(i)} \rVert_2^2 \\
           \Delta_t^{(i)}
         \end{bmatrix} \quad \text{and} \quad
         F(\begin{bmatrix}
           v \\
           \bf{x}
         \end{bmatrix}) = v - \lVert \bf{x} \rVert_2^2 $$

The RTC is equivalent to condition $$ F(S(t)) \leq \Theta $$

## 2️⃣ Linear FDA

In the linear case, we reduce the update vector to a scalar, $ \xi \Delta_t^{(i)} \in \mathbb{R}$, where $ \xi $ is any unit vector.

Define the local state to be 

$$ S_i(t) = \begin{bmatrix}
           \lVert \Delta_t^{(i)} \rVert_2^2 \\
           \xi \Delta_t^{(i)}
         \end{bmatrix} \in \mathbb{R}^2 $$

Also, define 

$$ F(v, x) = v - x^2 $$

The RTC is equivalent to condition 

$$ F(S(t)) \leq \Theta $$

A random choice of $ \xi $ is likely to perform poorly (terminate round prematurely), as it wil likely be close to orthogonal to $ \overline{\Delta_t} $. A good choice would be a vector $ \xi $ correlated to $ \overline{\Delta_t} $. A heuristic choice is to take $ \overline{\Delta_{t_0}} $ (after scaling it to norm 1), i.e., the update vector right before the current round started. All nodes can estimate this without communication, as $ \overline{w_{t_0}} - \overline{w_{t_{-1}}} $, the difference of the last two models pushed by the Server. Hence, 

$$ \xi = \overline{w_{t_0}} - \overline{w_{t_{-1}}} $$

In [257]:
@tff.tf_computation(MODEL_TYPE, MODEL_TYPE)
def ksi_unit_fn(w_t, w_tminus1):
    
    @tf.function
    def _ksi_unit(w_t, w_tminus1):
        if tf.reduce_all(tf.equal(w_t, w_tminus1)):
            # if equal then ksi becomes a random vector (will only happen in round 1)
            ksi = tf.random.normal(shape=w_t.shape)
        else:
            ksi = w_t - w_tminus1

        # Normalize and return
        return tf.divide(ksi, tf.norm(ksi))
    
    return _ksi_unit(w_t, w_tminus1)

In [240]:
str(ksi_unit_fn.type_signature)

'(<w_t=float32[20,1],w_tminus1=float32[20,1]> -> float32[20,1])'

Using the functions decorated with `tf.function` (context inside Tensorflow) we create the `client_train_fn` with context inside TFF. 

`initial_model` is the model currently inside each `tff.CLIENT`. This model is different in each CLIENT with the exception in the first step after synchronization.

`last_sync_model` is the synchronized model at the start of the current round. 

`last_last_sync_model` is the synchronized model at the start of the previous round (used for the heuristic for $ \xi $).

In [258]:

@tff.tf_computation(MODEL_TYPE, MODEL_TYPE, MODEL_TYPE, FLOAT32_TYPE, LOCAL_DATA_TYPE)
def client_train_fn(last_last_sync_model, last_sync_model, initial_model, C, dataset):
    
    model = client_train(
        tf.Variable(initial_value=initial_model), C, dataset
    )
    
    Delta_i = model - last_sync_model # AutoGraph
    
    #||D(t)_i||^2 , shape = (1,) 
    Delta_i_norm_squared = tf.reduce_sum(tf.square(Delta_i), axis=0) 
    
    # heuristic unit vector ksi
    ksi = ksi_unit_fn(last_sync_model, last_last_sync_model)
    
    # ksi * Delta_i (* is dot) , shape = ()
    ksi_Delta_i = tf.reduce_sum(tf.multiply(ksi, Delta_i))
    # shape = (1,)
    ksi_Delta_i = tf.expand_dims(ksi_Delta_i, axis=0)
    
    # shape = (2,)
    S_i = tf.concat([Delta_i_norm_squared, ksi_Delta_i], axis=0)
    # shape = (2,1)
    S_i = tf.reshape(S_i, (2, 1))

    return model, S_i, tf.reduce_sum(tf.square(ksi))

In [213]:
str(client_train_fn.type_signature)

'(<last_last_sync_model=float32[20,1],last_sync_model=float32[20,1],initial_model=float32[20,1],C=float32,dataset=<y=float32[?],x=float32[?,20]>*> -> <float32[20,1],float32[2,1]>)'

### HEREEEEEEEEEEEEEEEEEEEE (CORRECT ABOVE ALL) Server Average Client Models

When it is time to synchronize the Clients, the Server averages the Client weights and computes the global model. This is what this function does. Moreover, remember that the Client updates its model using `server_update_fn`.

In [40]:
@tff.federated_computation(CLIENT_MODEL_TYPE)
def server_update(client_models):
    # 4. Compute the mean of the client weights
    mean_client_model = tff.federated_mean(client_models)
    
    # 4. Update the server model
    server_model = tff.federated_map(server_update_fn, mean_client_model)
    
    return server_model

In [41]:
str(server_update.type_signature)

'({float32[20,1]}@CLIENTS -> float32[20,1]@SERVER)'

### Client Learning Step

A Client learning step is an update of the model based on one or more batches. The following `federated_computation` describes the step operation over all the clients. We pass a `federated_dataset` which could be thought as a stream of one or more batches for each client and some more parameters, namely, the last synchronized global model, the current client models, and the parameter `C` of the **PA-I** classifier.

Notice that each of those parameters is placed in `tff.CLIENTS`. 

Moreover, we return the updated Client models `client_models` (think of this as the new state of the distributed system) placed in `tff.CLIENTS` aswell which is logical, each Client updates its own model. Lastly, we return the `client_S_i` as described in the unpublished manuscript ('local state') of each Client, again placed in `tff.CLIENTS`.

Note: Do not think of the `return` as a normal programming `return` statement. Here, we describe the change in state in the distributed system, in this case, solely in `tff.CLIENTS`.

In [42]:
@tff.federated_computation(CLIENT_MODEL_TYPE, CLIENT_MODEL_TYPE, CLIENT_FLOAT32, CLIENT_DATA_TYPE)
def step(last_sync_client_models, client_models, client_C, federated_dataset):
    # 2. 3. Train the client models on their respective datasets
    client_models, client_S_i = tff.federated_map(
        client_train_fn, 
        (last_sync_client_models, client_models, client_C, federated_dataset)
    )
    
    return client_models, client_S_i

In [43]:
str(step.type_signature)

'(<last_sync_client_models={float32[20,1]}@CLIENTS,client_models={float32[20,1]}@CLIENTS,client_C={float32}@CLIENTS,federated_dataset={<y=float32[?],x=float32[?,20]>*}@CLIENTS> -> <{float32[20,1]}@CLIENTS,{float32[1]}@CLIENTS>)'

### Server Computation of 'Global State'

As you saw above, each Client has a property 'local state' `client_S_i`. When each round ends, the server should average those local states to compute the 'global state', i.e., the approximation of the **variance** using the *Naive FDA* scheme. The following `federated_computation` describes exactly that.

In [44]:

@tff.federated_computation(CLIENT_STATE)
def server_global_state(client_S_i):
    
    server_S = tff.federated_mean(client_S_i)
    
    return server_S

In [45]:
str(server_global_state.type_signature)

'({float32[1]}@CLIENTS -> float32[1]@SERVER)'

### Round Terminating Condition

As explained in the theoretical analysis of the *Naive FDA*, when the approximation of **RTC** does not hold, i.e., 

$$ F(S(t)) \gt \Theta $$

we are oblidged to synchronize the Client models since we can no longer guarantee that the variance is bellow the $\Theta$ threshold, i.e., 

$$ \frac{1}{k} \sum_{i=1}^{k} \lVert w_t^{(i)} - \overline{w_t} \rVert_2^2  \leq \Theta $$

The following function checks whether we guarantee that the **RTC** holds (`True`) or not (`False`).

In [46]:
# Same for all FDA. (bool)
@tff.tf_computation(FLOAT32_VECTOR_TYPE, FLOAT32_TYPE)
def RTC_holds(S_t, THETA):
    """ Returns True if RTC holds (has not been defied). False otherwise (sync must happen)"""
    
    @tf.function
    def _F(S_t, THETA):
        """ Naive FDA """
        return S_t <= THETA
    
    return _F(S_t, THETA)


In [47]:
str(RTC_holds.type_signature)

'(<S_t=float32[1],THETA=float32> -> bool[1])'

### Help variance function

A helper function that computes the **Actual Variance** when the approximate **RTC** condition does not hold. We want to see how far off the approximation is from reality.

In [48]:
w_spec = tf.TensorSpec(shape=(NUM_CLIENTS, d, 1), dtype=tf.float32)

@tf.function(input_signature=[w_spec, w_spec])
def variance(w_t, w_sync):
    # w_t , w_sync tensors with shape=(NUM_CLIENTS, d, 1)
    
    # tensor with shape=(NUM_CLIENTS, d, 1)
    diff = w_t - w_sync
    
    # tensor with shape=(NUM_CLIENTS, 1) , For each client ||w_i_t - w_t||^2
    dot = tf.reduce_sum(tf.square(diff), axis=1)
    
    # Variance shape=() , scalar
    var = tf.reduce_mean(dot)
    
    return var

## Metrics

A very ugly looking `Metrics` class that helps us with printing and storing the round metrics.

Skip.

In [49]:
class Metrics:
    def __init__(self, NUM_ROUNDS, THETA, NUM_CLIENTS, BATCHES_PER_STEP, BATCH_SIZE, n, d, test_size):
        self.NUM_ROUNDS = NUM_ROUNDS
        self.THETA = THETA
        self.NUM_CLIENTS = NUM_CLIENTS
        self.BATCHES_PER_STEP = BATCHES_PER_STEP
        self.BATCH_SIZE = BATCH_SIZE
        self.n = n
        self.d = d
        self.test_size = test_size
        
        self.total_batches_per_client = int(self.n / (self.NUM_CLIENTS*self.BATCH_SIZE))
        
        self.one_sample_size_b = (self.d+2)*4 # bytes
        
        self.training_dataset_size_mb = self.one_sample_size_b * (n * (1-test_size)) / 1_000_000 # In mb
        
        # Total batches for all clients for a single step
        self.total_batches_per_step = (self.BATCHES_PER_STEP * self.NUM_CLIENTS)
        
        self.samples = int(self.n * (1-self.test_size))
        
        self.all_metrics = []
        
    
    def print_initial_information(self):
        print("FEDERATED SETTING INFO:")
        print("------------------------------------------------------------")

        print("CLIENTS:")
        print(f'{"Clients":<10} {"Batches per Client":<20} {"Batches per Step":<20}')
        print(f'{self.NUM_CLIENTS:<10} {self.total_batches_per_client:<20} {self.BATCHES_PER_STEP:<20}')
        print()

        print("TRAIN DATASET:")
        print(f'{"x-Dim":<6} {"y-classes":<10} {"Samples":<12} {"Dataset size (MB)":<20} {"Samples per Batch":<20}')
        print(f'{d:<6} {2:<10} {int(n * (1-test_size)):<12} {self.training_dataset_size_mb:<20} {self.BATCH_SIZE:<20}')
        print()

        print("ALGORITHM:")
        print(f'{"Name":<5} {"Model Bytes":<10}')
        print(f'{"PA-I":<5} {self.d*4:<10}')
        print()

        print("SYNCHRONIZATION:")
        print('Naive FDA')
        print("------------------------------------------------------------")
        print()
        print()   
    
    def store_and_print_metrics(self, num_round, num_steps, global_state, accuracy, C, var):
        metrics = {}

        metrics['Round'] = num_round
        metrics['Steps'] = num_steps
        metrics['Accuracy'] = accuracy
        metrics['Global State'] = global_state[0]
        metrics['C'] = C
        metrics['Actual Variance'] = var

        # Total samples seen by all clients. BATCH_SIZE = samples per batch
        metrics['Samples'] = self.BATCH_SIZE * (num_steps * self.total_batches_per_step)

        # FDA In each step clients return their S_i (4 bytes)
        local_states_bytes = num_steps * (self.NUM_CLIENTS * 4)
        # Synchronization: Send model to all clients
        sync_bytes = self.d * self.NUM_CLIENTS * 4

        metrics['Bytes Exchanged'] = local_states_bytes + sync_bytes

        self.all_metrics.append(metrics)

        # Print the metrics for the current round
        self.print_round_metrics()
        
    def print_round_metrics(self):
        metrics = self.all_metrics[-1]
        
        # Print the metric values in a nicely formatted table
        print(f'{"Round":<6} {"Steps":<6} {"C":<5} {"Accuracy":<13} {"Bytes Exchanged":<20} {"Samples":<15} {"Var Approx":<15} {"Var (Actual)":<15}')

        print(f"{metrics['Round']:<6} {metrics['Steps']:<6} {metrics['C']:<5} {metrics['Accuracy']:<13.5f} {metrics['Bytes Exchanged']:<20} {metrics['Samples']:<15} {metrics['Global State']:<15.6f} {metrics['Actual Variance']:<15.6f}")
        print()
    
    def print_aggregate_metrics(self):

        total_bytes_exchanged = sum(metrics['Bytes Exchanged'] for metrics in self.all_metrics)
        total_steps = sum(metrics['Steps'] for metrics in self.all_metrics)
        total_samples = sum(metrics['Samples'] for metrics in self.all_metrics)
        final_accuracy = self.all_metrics[-1]['Accuracy']

        # Remember we pass the dataset many times at random (random batches)
        trained_in_size = self.one_sample_size_b * total_samples / 1_000_000 # MB

        print()
        print('FINAL METRICS:')
        print()

        # Print the metric values in a nicely formatted table
        print(f'{"Rounds":<6} {"Steps":<8} {"Samples":<13} {"MB Exchanged":<17} {"Accuracy":<13} {"Trained MB":<15}')

        print(f'{self.NUM_ROUNDS:<6} {total_steps:<8} {total_samples:<13} {total_bytes_exchanged/1_000_000:<17} {final_accuracy:<13.6} {trained_in_size:<20.5f}')
        

## Training Initialization

Create the federated datasets.

In [50]:

train_federated_data = create_federated_data()
test_dataset = create_tf_dataset_for_test()

Initial conditions

In [51]:
NUM_ROUNDS = 10
THETA = 1.
C = 1

We assume that all Clients start in synchronization, i.e., Server model is zeros and Client models are also zeros.

Moreover, notice that `client_models`, `last_sync_client_models`, `client_C` are all defined as lists containing `NUM_CLIENTS` elements. This is the simulation approach of TFF and following the already defined functions above, each element in those lists is assumed to lie in one `tff.CLIENT`. For example, each `tff.CLIENT` has a `C` hyperparameter to be used by **PA-I** classifier on its own model etc.

In [52]:
# Initial model of zeros
model = tf.Variable(tf.zeros(shape=(d, 1)), trainable=True, name='weights', dtype=tf.float32)

# Assume client models are synchronized at the start (Obviously S_t = 0)
client_models = [model]*NUM_CLIENTS
last_sync_client_models = [model]*NUM_CLIENTS
last_last_sync_client_models = [model]*NUM_CLIENTS
S_t = [0.]

client_C = [C]*NUM_CLIENTS

## Training Loop

In [53]:
metrics = Metrics(NUM_ROUNDS, THETA, NUM_CLIENTS, BATCHES_PER_STEP, BATCH_SIZE, n, d, test_size)
metrics.print_initial_information()

for r in range(1, NUM_ROUNDS+1):
    
    num_steps = 0 # Each step() invocation is a step
    
    while RTC_holds(S_t, THETA): # RTC holds, no sync needed
        
        # Perform a training step with the current client_models (no sync yet)
        client_models, client_S_i = step(
            last_sync_client_models, client_models, client_C, train_federated_data
        )
        
        # Compute 'global state' as defined in the manuscript
        S_t = server_global_state(client_S_i)
        
        num_steps += 1
    
    # RTC defied, sync must happen
    
        
    # Update the server model from the client models.
    model = server_update(client_models)
    
    metrics.store_and_print_metrics(r, num_steps, S_t, accuracy_fn(model, test_dataset), client_C[0], variance(client_models, [model]*NUM_CLIENTS))
    
    client_models, last_sync_client_models, last_last_sync_client_models, S_t = [model]*NUM_CLIENTS, [model]*NUM_CLIENTS, last_sync_client_models, [0.]

    
metrics.print_aggregate_metrics()

FEDERATED SETTING INFO:
------------------------------------------------------------
CLIENTS:
Clients    Batches per Client   Batches per Step    
3          1041                 1                   

TRAIN DATASET:
x-Dim  y-classes  Samples      Dataset size (MB)    Samples per Batch   
20     2          90000        7.92                 32                  

ALGORITHM:
Name  Model Bytes
PA-I  80        

SYNCHRONIZATION:
Naive FDA
------------------------------------------------------------


Round  Steps  C     Accuracy      Bytes Exchanged      Samples         Var Approx      Var (Actual)   
1      1      1     0.77935       252                  96              7.515466        2.153086       

Round  Steps  C     Accuracy      Bytes Exchanged      Samples         Var Approx      Var (Actual)   
2      1      1     0.78524       252                  96              13.302562       3.629052       

Round  Steps  C     Accuracy      Bytes Exchanged      Samples         Var Approx     


2. Comments + Check approach (maybe pass string "Naive FDA", unentangle functions)
5. Wrap in tff.tf_computation or federated. X nope
7. shuffle not random seed so we can compare FDA
8. fix awful looking metrics cell