In [1]:

import nest_asyncio
nest_asyncio.apply()

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import tensorflow as tf
import tensorflow_federated as tff


In [2]:
tff.__version__

'0.50.0'

## Create Binary Classification data with sklearn

In [3]:

from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split


n = 100_000
d = 100


noise_factor = 0.01 # % of the labels are randomly flipped, DEFAULT=0.01
test_size = 0.1 # % of n
# The factor multiplying the hypercube size. Larger values spread out the 
# clusters/classes and make the classification task easier. DEFAULT=1
class_sep = -1
seed = 7

# Create (noisy) testing data for binary classification.
X, y = make_classification(
    n_samples=n, 
    n_features=d,
    n_informative=d,
    n_redundant=0, 
    n_classes=2,
    class_sep=class_sep,
    flip_y=noise_factor,
    random_state=seed
)

# We will work with label values -1, +1 and not 0, +1 (convert)
y[y == 0] = -1

# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=seed)


In [4]:
from sklearn.linear_model import PassiveAggressiveClassifier
from sklearn.metrics import accuracy_score

# PA-I regressor from sklearn
pa1 = PassiveAggressiveClassifier(C=0.01, loss="hinge", n_jobs=-1)
pa1.fit(X_train, y_train)

accuracy_score(y_test, pa1.predict(X_test))

0.7658

## Convert to Tensors

In [5]:

# Convert the data to TensorFlow tensors
X_train_tensor = tf.constant(X_train, dtype=tf.float32)
y_train_tensor = tf.constant(y_train, dtype=tf.float32)
X_test_tensor = tf.constant(X_test, dtype=tf.float32)
y_test_tensor = tf.constant(y_test, dtype=tf.float32)

In [6]:
del X, y, X_train, X_test, y_train, y_test

## Prepare data for Tensorflow Federated

We have the training and testing Tensors holding our data. TFF expects for each client an `OrderedDict` containing `y` and `x` data. Hence, we preprocess our Tensors to follow this convention.

In [7]:

NUM_CLIENTS = 20

# https://www.tensorflow.org/api_docs/python/tf/data/Dataset#batch
BATCH_SIZE = 32
SHUFFLE_BUFFER = int(n / NUM_CLIENTS)
NUM_STEPS_UNTIL_SYNC_CHECK = 1 # Steps until sync check

In [8]:
print(f"Total number of batches per client: {int((1-test_size)*n / (NUM_CLIENTS*BATCH_SIZE))}")

Total number of batches per client: 156


In [9]:
import collections

# Create a dictionary with the slices for each client
client_slices_train = {}
slices_test = {}

n_test = int(n - n*test_size)

for i in range(NUM_CLIENTS):
    # Compute the indices for this client's slice
    start_idx = int(i * n_test / NUM_CLIENTS)
    end_idx = int((i + 1) * n_test / NUM_CLIENTS)

    # Get the slice for this client
    X_client_train = X_train_tensor[start_idx:end_idx]
    y_client_train = y_train_tensor[start_idx:end_idx]
    
    client_data_train = collections.OrderedDict([('y', y_client_train), ('x', X_client_train)])
    
    # Combine the slices into a single dataset
    client_slices_train[f'client_{i}'] = client_data_train

slices_test = collections.OrderedDict([('y', y_test_tensor), ('x', X_test_tensor)])

For a sanity check let's see inside `client_slices_train` for the first x,y tuple of the 'first' client

In [10]:
client_slices_train['client_0']['x'][0]

<tf.Tensor: shape=(100,), dtype=float32, numpy=
array([ -0.43840492,  -8.602735  ,  -2.5706942 ,  -4.9137297 ,
        -6.5534873 ,  -1.493107  ,   1.5496199 ,  -3.7241213 ,
        -1.3533349 ,   6.419472  ,   9.5305    ,  -4.3068776 ,
         1.8209226 ,   3.843456  ,  -6.099927  ,  -2.0994277 ,
         5.0526834 ,   5.215126  ,   0.31975892,  -3.7441716 ,
         6.497558  ,   1.8366643 ,  -2.1913083 ,   9.370149  ,
        -3.4765773 ,  -1.4791905 ,  -6.209484  ,  -9.619827  ,
        12.635862  ,   2.6724894 ,   7.8316813 ,  -4.6290493 ,
         2.1394951 ,   4.2733474 ,   2.9170232 ,   2.5974233 ,
        -0.99408895,   3.4114075 ,   2.2466993 ,   4.0714283 ,
        -3.4346006 , -10.980129  ,   9.790514  ,   4.8795867 ,
        -5.8626986 ,   6.1965513 ,   3.0575798 ,   9.065236  ,
         1.9486036 ,  -9.105302  ,  -0.06869748,  -1.3184999 ,
         4.211022  ,  -3.5095856 ,  -1.2642521 ,  -7.6088433 ,
         4.582711  , -11.008443  ,   0.5270276 ,   3.9419043 ,
       

In [11]:
client_slices_train['client_0']['y'][0]

<tf.Tensor: shape=(), dtype=float32, numpy=1.0>

Now, a client with `client_id` has it's single Tensor holding instances in`client_slices_train[client_id]['x']` and labels in `client_slices_train[client_id]['y']`. Let's take a step back from TFF. Having this data scheme, we can create a client's Tensorflow dataset using `from_tensor_slices` function passing the client's id as follows

In [12]:
# https://www.tensorflow.org/api_docs/python/tf/data/Dataset#batch

def create_tf_dataset_for_client(client_id):
    return tf.data.Dataset.from_tensor_slices(client_slices_train[client_id]) \
        .shuffle(SHUFFLE_BUFFER).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE) \
        .take(NUM_STEPS_UNTIL_SYNC_CHECK)

def create_tf_dataset_for_test():
    return tf.data.Dataset.from_tensor_slices(slices_test).batch(BATCH_SIZE)

For TFF we need to construct Federated data for clients, i.e., `tff.simulation.datasets.ClientData`. We can use the `from_clients_and_tf_fn` function that takes as argument the `client_ids` : a list of strings corresponding to client ids, and a `serializable_dataset_fn` : a function that takes a `client_id` from the above list, and returns a `tf.data.Dataset`. It's obvious how we proceed with the code (using the above function)

In [13]:

preprocessed_train_federated_dataset = tff.simulation.datasets.ClientData.from_clients_and_tf_fn(
    client_ids=list(client_slices_train.keys()),
    serializable_dataset_fn=lambda client_id: create_tf_dataset_for_client(client_id)
)

In [14]:
preprocessed_train_federated_dataset.client_ids

['client_0',
 'client_1',
 'client_2',
 'client_3',
 'client_4',
 'client_5',
 'client_6',
 'client_7',
 'client_8',
 'client_9',
 'client_10',
 'client_11',
 'client_12',
 'client_13',
 'client_14',
 'client_15',
 'client_16',
 'client_17',
 'client_18',
 'client_19']

**Note**: Cross-device federated learning does not use client IDs or perform any tracking of clients. However in simulation experiments using centralized test data the experimenter may select specific clients to be processed per round. The concept of a client ID is only available at the preprocessing stage when preparing input data for the simulation and is not part of the TensorFlow Federated core APIs.

Now, `preprocessed_train_federated_dataset` holds logic on how each client constructs its dataset. Note that so `client_slices_train` has already been materialized and lies in this context's memory.

One way (the simplest) to feed federated data to TFF in a simulation is simply as a Python list, with each element of the list holds the data of an individual client, whether as a list or preferably as a `tf.data.Dataset`. Since we already created an interface that provides the latter we will use it. Here is a helper function that will construct a list of datasets from the set of users.

In [15]:

def create_federated_data():    
    return [
        preprocessed_train_federated_dataset.create_tf_dataset_for_client(client)
        for client in preprocessed_train_federated_dataset.client_ids
    ]

**Important Note**: Firstly, we used `sklearn` to create the binary classification data eagerly, i.e., we were forced to materialize it into memory. In simulation, in general it is more sound to push preprocessing logic into each client, i.e., each client constructs its own dataset (from the same underlying distribution) or reads from a file or something else and he, himself processes the data as needed. This is the best approach and uses the TFF distributed engine the best way. But in our case this was illogical to happen since we are forced to construct the dataset in memory anyway. For example, we could have stored each client's data inside some serialized file (`client_0.tfrecord` for the first client and so on) and push logic where each clients diserializes and processes its own data but this would be silly and slower when testing. For a small example that showcases this scenario see *TFF - Introduction - Federated Core API - Part 3(examples).ipynb*.

## TFF Types

Let's start with a simple float32 type.

In [16]:
FLOAT32_TYPE = tff.TensorType(dtype=tf.float32, shape=())

In [17]:
str(FLOAT32_TYPE)

'float32'

In [18]:
CLIENT_FLOAT32 = tff.FederatedType(FLOAT32_TYPE, tff.CLIENTS)

In [19]:
str(CLIENT_FLOAT32)

'{float32}@CLIENTS'

In [20]:
SERVER_FLOAT32 = tff.FederatedType(FLOAT32_TYPE, tff.SERVER)

In [21]:
str(SERVER_FLOAT32)

'float32@SERVER'

In [22]:
INT32_TYPE = tff.TensorType(dtype=tf.int32, shape=())

In [23]:
str(INT32_TYPE)

'int32'

In [24]:
SERVER_INT32_TYPE = tff.type_at_server(INT32_TYPE)

In [25]:
str(SERVER_INT32_TYPE)

'int32@SERVER'

Create a TFF type representing a float32 tensor of shape (1,)

In [26]:
FLOAT32_1_TYPE = tff.TensorType(dtype=tf.float32, shape=(1,))

In [27]:
str(FLOAT32_1_TYPE)

'float32[1]'

In [28]:
CLIENT_FLOAT32_1_TYPE = tff.type_at_clients(FLOAT32_1_TYPE)

In [29]:
str(CLIENT_FLOAT32_1_TYPE)

'{float32[1]}@CLIENTS'

1-dimensional tensor (vector) of length 1 with elements of type float32

In [30]:
STATE_TYPE = tff.TensorType(dtype=tf.float32, shape=(2,1))

In [31]:
str(STATE_TYPE)

'float32[2,1]'

The local client state $ S_i(t) $ as defined in the unpublished paper for Linear FDA (read theoretical analysis bellow).

In [32]:
CLIENT_STATE = tff.FederatedType(STATE_TYPE, tff.CLIENTS)

In [33]:
str(CLIENT_STATE)

'{float32[2,1]}@CLIENTS'

First, let's define the type of input as a TFF named tuple. Since the size of data batches may vary, we set the batch dimension to None to indicate that the size of this dimension is unknown.

In [34]:

BATCH_SPEC = collections.OrderedDict(
    y=tf.TensorSpec(shape=[None], dtype=tf.float32),
    x=tf.TensorSpec(shape=[None, d], dtype=tf.float32)
)
BATCH_TYPE = tff.to_type(BATCH_SPEC)

In [35]:
str(BATCH_TYPE)

'<y=float32[?],x=float32[?,100]>'

Every client holds a sequence of batches so the we define the client data type as follows

In [36]:

LOCAL_DATA_TYPE = tff.SequenceType(BATCH_TYPE)

In [37]:
str(LOCAL_DATA_TYPE)

'<y=float32[?],x=float32[?,100]>*'

Let's now define the TFF type of the model which is simply a `tf.Variable` with shape (d, 1)

In [38]:

MODEL_TYPE = tff.TensorType(dtype=tf.float32, shape=(d, 1))

In [39]:
str(MODEL_TYPE)

'float32[100,1]'

Since the server holds the 'global' model we need to create the Federated Type, defined as the tuple of a member: An instance of `tff.Type`, and a placement: The specification of placement of the member comonents (where this type is hosted at, for example, at `tff.SERVER` or `tff.CLIENTS`).

In [40]:

SERVER_MODEL_TYPE = tff.type_at_server(MODEL_TYPE)

In [41]:
str(SERVER_MODEL_TYPE)

'float32[100,1]@SERVER'

Following, the same logic, we create the Federated Type of each client's data.

In [42]:

CLIENT_DATA_TYPE = tff.type_at_clients(LOCAL_DATA_TYPE)

In [43]:
str(CLIENT_DATA_TYPE)

'{<y=float32[?],x=float32[?,100]>*}@CLIENTS'

We will also need to define the client models at the CLIENTS (for FDA later, to be cont...)

In [44]:
CLIENT_MODEL_TYPE = tff.type_at_clients(MODEL_TYPE)

In [45]:
str(CLIENT_MODEL_TYPE)

'{float32[100,1]}@CLIENTS'

## Accuracy Testing

In [46]:

@tf.function
def accuracy(model, dataset):
    
    @tf.function
    def _batch_accuracy(model, batch):
        x_batch, y_batch = batch['x'], tf.expand_dims(batch['y'], axis=1)

        # dot(w, x) for the batch (each instance of x in x_batch) with with shape=(batchsize, 1)
        weights_dot_x_batch = tf.matmul(x_batch, model)

        # Prediction batch with shape=(batchsize, 1)
        y_pred_batch = tf.sign(weights_dot_x_batch)

        accuracy = tf.reduce_mean(tf.cast(tf.equal(y_pred_batch, y_batch), tf.float32))

        return accuracy
    
    # We take advantage of AutoGraph (convert Python code to TensorFlow-compatible graph code automatically)
    acc, num_batches = 0., 0.
    for batch in dataset:
        acc += _batch_accuracy(model, batch)
        num_batches += 1
        
    acc = acc / num_batches
    
    return acc

In [47]:

@tff.tf_computation(MODEL_TYPE, LOCAL_DATA_TYPE)
def accuracy_fn(model, dataset):
    model = tf.Variable(initial_value=model)
    return accuracy(model, dataset)

Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089


Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089


In [48]:
str(accuracy_fn.type_signature)

'(<model=float32[100,1],dataset=<y=float32[?],x=float32[?,100]>*> -> float32)'

# Federated Learning

### Server Update

In [49]:

@tff.tf_computation(MODEL_TYPE)
def server_update_fn(clients_aggr_model):
    model = tf.Variable(initial_value=clients_aggr_model)
    return model

**Note**: This abstraction for this simple jupyter (where the model is a `tf.Variable`) is not necessary. We create this abstraction since it is common practice generally.

### Client train - PA-I Classifier

![PA](images/PA_binary_classifiers.png)

Each client trains on its own dataset (which is a sequence of batches). Hence, we create the training process, currently a PA-1 Classifier. The input of `client_train` is the client model materialized inside its client and its dataset.

In [50]:

@tf.function
def client_train(model, C, dataset):
    
    @tf.function
    def _train_on_batch(model, C, batch):

        x_batch, y_batch = batch['x'], tf.expand_dims(batch['y'], axis=1)

        # dot(w, x) for the batch (each instance of x in x_batch) with with shape=(batchsize, 1)
        weights_dot_x_batch = tf.matmul(x_batch, model)

        # Prediction batch with shape=(batchsize, 1)
        y_pred_batch = tf.sign(weights_dot_x_batch)

        # Suffer loss for each prediction (of instance) in the batch with shape=(batchsize,1)
        loss_batch = tf.maximum(0., 1. - tf.multiply(y_batch, weights_dot_x_batch))

        # shape=(batchsize,1) where each instance is ||x||^2, x in x_batch
        norm_batch = tf.expand_dims(tf.reduce_sum(tf.square(x_batch), axis=1), axis=1)
        
        # PA-1 : Learning rate t for each instance x, with shape=(batchsize,1)
        t_batch = tf.maximum(C, tf.divide(loss_batch, norm_batch))

        # each instance is y*t*x, where y,t scalars and x in x_batch. shape=(batchsize,d)
        t_y_x_batch = tf.multiply(t_batch, tf.multiply(y_batch, x_batch))

        # !!!! Update with mean t*y*x
        t_y_x_update = tf.expand_dims(tf.reduce_mean(t_y_x_batch, axis=0) ,axis=1)

        # Update
        model.assign_add(t_y_x_update)
    
    for batch in dataset:
        _train_on_batch(model, C, batch)
        
    return model

# Functional Dynamic Averaging

We follow the Functional Dynamic Averaging (FDA) scheme. Let the mean model be

$$ \overline{w_t} = \frac{1}{k} \sum_{i=1}^{k} w_t^{(i)} $$

where $ w_t^{(i)} $ is the model at time $ t $ in some round in the $i$-th learner.

Local models are trained independently and cooperatively and we want to monitor the Round Terminating Conditon (**RTC**):

$$ \frac{1}{k} \sum_{i=1}^{k} \lVert w_t^{(i)} - \overline{w_t} \rVert_2^2  \leq \Theta $$

where the left-hand side is the **model variance**, and threshold $\Theta$ is a hyperparameter of the FDA, defined at the beginning of the round; it may change at each round. When the monitoring logic cannot guarantee the validity of RTC, the round terminates. All local models are pulled into `tff.SERVER`, and $\bar{w_t}$ is set to their average. Then, another round begins.


### Monitoring the RTC

FDA monitors the RTC by applying techniques from Functionary [Functional Geometric Averaging](http://users.softnet.tuc.gr/~minos/Papers/edbt19.pdf). We first restate the problem of monitoring RTC into the standard distributed stream monitoring formulation. Let

$$ S(t) =  \frac{1}{k} \sum_{i=1}^{k} S_i(t) $$

where $ S(t) \in \mathbb{R}^n $ be the "global state", **variance approximation**, of the system and $ S_i(t) \in \mathbb{R}^n $ the "local states". The goal is to monitor the threshold condition on the global state in the form $ F(S(t)) \leq \Theta $ where $ F : \mathbb{R}^n \to \mathbb{R} $ a non-linear function. Let

$$ \Delta_t^{(i)} = w_t^{(i)} - w_{t_0}^{(i)} $$

be the update at the $ i $-th learner, that is, the change to the local model at time $t$ since the beginning of the current round at time $t_0$. Let the average update be

$$ \overline{\Delta_t} = \frac{1}{k} \sum_{i=1}^{k} \Delta_t^{(i)} $$

it follows that the variance can be written as

$$ \frac{1}{k} \sum_{i=1}^{k} \lVert w_t^{(i)} - \overline{w_t} \rVert_2^2 = \Big( \frac{1}{k} \sum_{i=1}^{k} \lVert \Delta_t^{(i)} \rVert_2^2 \Big) - \lVert \overline{\Delta_t} \rVert_2^2 $$

So, conceptually, if we define
$$ S_i(t) = \begin{bmatrix}
           \lVert \Delta_t^{(i)} \rVert_2^2 \\
           \Delta_t^{(i)}
         \end{bmatrix} \quad \text{and} \quad
         F(\begin{bmatrix}
           v \\
           \bf{x}
         \end{bmatrix}) = v - \lVert \bf{x} \rVert_2^2 $$

The RTC is equivalent to condition $$ F(S(t)) \leq \Theta $$

## 3️⃣ Sketch FDA

An optimal estimator for $ \lVert \overline{\Delta_t} \rVert_2^2  $ can be obtained by employing AMS sketches. An AMS sketch of a vector $ v \in \mathbb{R}^M $ is a $ d \times m $ real matrix

$$ \Xi = \text{sk}(v) = \begin{bmatrix}
           \Xi_1 \\
           \Xi_2 \\
           \vdots \\
           \Xi_d 
         \end{bmatrix} $$
         
where $ d \cdot m \ll M$. Operator sk($ \cdot $) is linear, i.e., let $a, b \in \mathbb{R}$ and $v_1, v_2 \in \mathbb{R}^N$ then 

$$ \text{sk}(a v_1 + b v_2) = a \; \text{sk}(v_1) + b \; \text{sk}(v_2)  $$

Also, sk($ v $) can be computed in $ \mathcal{O}(dN) $ steps.

The interesting property of AMS sketches is that

$$ M(sk(v)) = \underset{i=1,...,d}{\text{median}} \; \lVert \Xi_i \rVert_2^2 \; \in (1 \pm \epsilon) \lVert v \rVert_2^2 \; \; \text{with probability at least} \; (1-\delta)$$ 

Let's investigate a little further on how this helps us. The $i$-th client computes $ sk(\Delta_t^{(i)}) $ and sends it to the server. Notice

$$ M\big(sk(\Delta_t^{(1)}) + sk(\Delta_t^{(2)}) + ... + sk(\Delta_t^{(k)}) \big) = M\Big( \text{sk}\big( \sum_{i=1}^{k} \Delta_t^{(i)} \big) \Big)$$

Moreover, we want to approximate

$$ \lVert \overline{\Delta_t} \rVert_2^2 = \frac{1}{k^2} \Big\lVert \sum_{i=1}^{k} \Delta_t^{(i)} \Big\rVert_2^2 $$

Which means that 

$$ \frac{1}{k^2} M\Big( \text{sk}\big( \sum_{i=1}^{k} \Delta_t^{(i)} \big) \Big) \in (1 \pm \epsilon) \lVert \overline{\Delta_t} \rVert_2^2 \; \; \text{with probability at least} \; (1-\delta) $$ 

In the monitoring process it is essential that we do not overestimate $ \lVert \overline{\Delta_t} \rVert_2^2 $ because we would then underestimate the variance which would potentially result in actual varience exceeding $ \Theta$ without us noticing it. With this in mind,

$$ \frac{1}{k^2} M\Big( \text{sk}\big( \sum_{i=1}^{k} \Delta_t^{(i)} \big) \Big) \leq (1+\epsilon) \lVert \overline{\Delta_t} \rVert_2^2 \quad \text{with probability at least} \; (1-\delta)$$

Which means

$$ \frac{1}{(1+\epsilon)} \frac{1}{k^2} M\Big( \text{sk}\big( \sum_{i=1}^{k} \Delta_t^{(i)} \big) \Big) \leq \lVert \overline{\Delta_t} \rVert_2^2 \quad \text{with probability at least} \; (1-\delta)$$

Hence, the Server's estimation of $ \lVert \overline{\Delta_t} \rVert_2^2 $ is

$$ \frac{1}{(1+\epsilon)} \frac{1}{k^2} M\Big( sk(\Delta_t^{(1)}) + sk(\Delta_t^{(2)}) + ... + sk(\Delta_t^{(k)}) \big) \Big) $$

Define the local state to be 

$$ S_i(t) = \begin{bmatrix}
           \lVert \Delta_t^{(i)} \rVert_2^2 \\
           sk(\Delta_t^{(i)})
         \end{bmatrix} \in \mathbb{R}^{1+d \times m} \quad \text{and} \quad
         F(\begin{bmatrix}
           v \\
           \Xi
         \end{bmatrix}) = v - \frac{1}{(1+\epsilon)} \frac{1}{k^2} M(\Xi) \quad \text{where} \quad \Xi = \sum_{i=1}^{k} sk(\Delta_t^{(i)}) $$

It follows that $ F(S(t)) \leq \Theta $ implies that the variance is less or equal to $ \Theta $ with probability at least $ 1-\delta $.


## AMS sketch

First we define the `depth` ($ d $) and `width` ($ m $) of the sketch.

In [51]:
depth = 7  # number of hash functions
width = 1500  # specifies hash31 : N -> {0, 1, ..., `width`} uniformly.

In [52]:
import numpy as np

c = 1. # big O constant. Just an estimation (potentially way off)

epsilon = c/np.sqrt(width)
delta = c/np.exp(depth)

In [53]:
print(f"ε = {epsilon:<.5}  ,  δ = {delta:<.5}")

ε = 0.02582  ,  δ = 0.00091188


In [54]:
print(f"Promise: M(sk(v)) in [{1-epsilon:<.5}*||v||^2, {1+epsilon:<.5}*||v||^2]  w.p. at least {1-delta:<.5}")

Promise: M(sk(v)) in [0.97418*||v||^2, 1.0258*||v||^2]  w.p. at least 0.99909


### Define TFF types 

In [55]:
SKETCH_TYPE = tff.TensorType(tf.float32, shape=[depth, width])

In [56]:
print(SKETCH_TYPE)

float32[7,1500]


In [57]:
CLIENT_SKETCH_TYPE = tff.type_at_clients(SKETCH_TYPE)

In [58]:
str(CLIENT_SKETCH_TYPE)

'{float32[7,1500]}@CLIENTS'

### Initialize hash functions and Pre-serialize

In [59]:
tf_width = tf.constant(width, dtype=tf.int32)
tf_depth = tf.constant(depth, dtype=tf.int32)

# Pool of three random tuples (A, B) corresponding to a different hash function parameters
# We provide information about pair (F[0], F[1]) , the rest follow this 
# F[0] : shape(depth,) random `a` parameters for each row of the sketch. One row <-> One hash func <-> One `a`
# F[1] : shape(depth,) random `b` parameters for each row of the sketch. One row <-> One hash func <-> One `b`
tf_F = tf.random.uniform(shape=(6, depth), minval=0, maxval=(1 << 31) - 1, dtype=tf.int32)

### Sketch (Deprecated). See next Jupyters TODO: fix this

In [60]:
@tf.function
def sketch_for_vector(v):
    """ Returns AGMS sketch for `v` (vector shape=(n,)). 
    Note: We serialize `F`, `width`, `depth` for efficiency """
    
    F = tf.constant(tf_F)
    width = tf.constant(tf_width)
    depth = tf.constant(tf_depth)

    @tf.function
    def _hash31(x, a, b):
        """ _hash31 : N -> {0, 1, ..., width} uniformly """
        r = a * x + b
        fold = tf.bitwise.bitwise_xor(tf.bitwise.right_shift(r, 31), r)
        return tf.bitwise.bitwise_and(fold, 2147483647)
    
    @tf.function
    def _fourwise(x):
        """ Fourwise independent hash of `x` (int) to {+1, -1}. """
        result = 2 * (tf.bitwise.right_shift(tf.bitwise.bitwise_and(_hash31(_hash31(_hash31(x, F[2], F[3]), x, F[4]), x, F[5]), 32768), 15)) - 1
        return result

    sketch = tf.zeros(shape=(depth, width), dtype=tf.float32)
    indices = tf.range(tf.shape(v)[0], dtype=tf.int32)

    for i in indices:
        pos = _hash31(i, F[0], F[1]) % width
        delta = tf.cast(_fourwise(i), dtype=tf.float32) * v[i]
        indices_to_update = tf.stack([tf.range(depth, dtype=tf.int32), pos], axis=1)
        sketch = tf.tensor_scatter_nd_add(sketch, indices_to_update, delta)

    return sketch


@tff.tf_computation(MODEL_TYPE)
def sketch_for_vector_fn(v):
    # we use `.squeeze` to reshape from (n,1) to (n,)
    return sketch_for_vector(tf.squeeze(v, axis=1))

In [61]:
str(sketch_for_vector_fn.type_signature)

'(float32[100,1] -> float32[7,1500])'

### Euclidean Norm Squared estimation given Sketch

In [62]:
@tf.function
def estimate_euc_norm_squared(sketch):
    
    @tf.function
    def _median(v):
        """ Median of tensor `v` with shape=(n,). Note: Suboptimal O(nlogn) but it's ok bcz n = `depth`"""
        length = tf.shape(v)[0]
        sorted_v = tf.sort(v)
        middle = length // 2

        return tf.cond(
            tf.equal(length % 2, 0),
            lambda: (sorted_v[middle - 1] + sorted_v[middle]) / 2.0,
            lambda: sorted_v[middle]
        )
    
    return _median(tf.reduce_sum(tf.square(sketch), axis=1))

### Client train and Local State

As explained clients return

$$ S_i(t) = \begin{bmatrix}
           \lVert \Delta_t^{(i)} \rVert_2^2 \\
           sk(\Delta_t^{(i)})
         \end{bmatrix} \in \mathbb{R}^{1+d \times m} \quad \text{and} $$

Using the functions decorated with `tf.function` (context inside Tensorflow) we create the `client_train_fn` with context inside TFF. 

`initial_model` is the model currently inside each `tff.CLIENT`. This model is different in each CLIENT with the exception in the first step after synchronization, and `last_sync_model` is the synchronized model at the start of the current round. 

In [63]:

@tff.tf_computation(MODEL_TYPE, MODEL_TYPE, FLOAT32_TYPE, LOCAL_DATA_TYPE)
def client_train_fn(last_sync_model, initial_model, C, dataset):
    
    model = client_train(
        tf.Variable(initial_value=initial_model), C, dataset
    )
    
    Delta_i = model - last_sync_model # AutoGraph
    
    #||D(t)_i||^2 , shape = (1,) 
    Delta_i_norm_squared = tf.reduce_sum(tf.square(Delta_i), axis=0) 
    
    # sk(D(t)_i) 
    sketch = sketch_for_vector_fn(Delta_i)

    return model, Delta_i_norm_squared, sketch

In [64]:
str(client_train_fn.type_signature)

'(<last_sync_model=float32[100,1],initial_model=float32[100,1],C=float32,dataset=<y=float32[?],x=float32[?,100]>*> -> <float32[100,1],float32[1],float32[7,1500]>)'

### Server Average Client Models

When it is time to synchronize the Clients, the Server averages the Client weights and computes the global model. This is what this function does. Moreover, remember that the Client updates its model using `server_update_fn`.

In [65]:
@tff.federated_computation(CLIENT_MODEL_TYPE)
def server_update(client_models):
    # 4. Compute the mean of the client weights
    mean_client_model = tff.federated_mean(client_models)
    
    # 4. Update the server model
    server_model = tff.federated_map(server_update_fn, mean_client_model)
    
    return server_model

In [66]:
str(server_update.type_signature)

'({float32[100,1]}@CLIENTS -> float32[100,1]@SERVER)'

### Client Learning Step

A Client learning step is an update of the model based on one or more batches. The following `federated_computation` describes the step operation over all the clients. We pass a `federated_dataset` which could be thought as a stream of one or more batches for each client and some more parameters, namely, the last synchronized global model, the current client models, and the parameter `C` of the **PA-I** classifier.

Notice that each of those parameters is placed in `tff.CLIENTS`. 

Moreover, we return the updated Client models `client_models` (think of this as the new state of the distributed system) placed in `tff.CLIENTS` aswell which is logical, each Client updates its own model. Lastly, we return the $S_i(t)$ as described in the unpublished manuscript ('local state') of each Client, again placed in `tff.CLIENTS`.

Note: Do not think of the `return` as a normal programming `return` statement. Here, we describe the change in state in the distributed system, in this case, solely in `tff.CLIENTS`.

In [67]:
@tff.federated_computation(CLIENT_MODEL_TYPE, CLIENT_MODEL_TYPE, CLIENT_FLOAT32, CLIENT_DATA_TYPE)
def step(last_sync_client_models, client_models, client_C, federated_dataset):
    # 2. 3. Train the client models on their respective datasets
    client_models, client_Delta_i_norm_squared, client_sketches = tff.federated_map(
        client_train_fn, 
        (last_sync_client_models, client_models, client_C, federated_dataset)
    )
    
    return client_models, client_Delta_i_norm_squared, client_sketches

In [68]:
str(step.type_signature)

'(<last_sync_client_models={float32[100,1]}@CLIENTS,client_models={float32[100,1]}@CLIENTS,client_C={float32}@CLIENTS,federated_dataset={<y=float32[?],x=float32[?,100]>*}@CLIENTS> -> <{float32[100,1]}@CLIENTS,{float32[1]}@CLIENTS,{float32[7,1500]}@CLIENTS>)'

### Server Computation of 'Global State', i.e, Variance Approximation

For more simplicity in the computations we differ a bit from the theoretical analysis.

Let `server_S_1` = $ \frac{1}{k} \sum_{i=1}^{k} \lVert \Delta_t^{(i)} \rVert_2^2 $ , and `server_M_S_2` = $ \frac{1}{(1+\epsilon)} \frac{1}{k^2} M\Big( sk(\Delta_t^{(1)}) + sk(\Delta_t^{(2)}) + ... + sk(\Delta_t^{(k)}) \big) \Big) $

Then `server_S_1` - `server_M_S_2` $ \leq \Theta$ implies that the variance is $ \leq \Theta $ with probability at least $ 1-\delta $ (already proven in **FDA Sketch** section).

In [69]:

@tff.federated_computation(CLIENT_FLOAT32_1_TYPE, CLIENT_SKETCH_TYPE, SERVER_INT32_TYPE, SERVER_FLOAT32)
def server_global_state(client_Delta_i_norm_squared, client_sketches, num_clients, epsilon):
    
    @tff.tf_computation(SKETCH_TYPE, INT32_TYPE, FLOAT32_TYPE)
    def var_est(sketch_sum, num_clients, epsilon):
        return (1/(1+epsilon))*tf.cast((1/num_clients**2), dtype=tf.float32) * estimate_euc_norm_squared(sketch_sum)
    
    server_S_1 = tff.federated_mean(client_Delta_i_norm_squared)
    
    server_M_S_2 = tff.federated_map(
        var_est,
        (tff.federated_sum(client_sketches), num_clients, epsilon)
    )
    
    return server_S_1, server_M_S_2

In [70]:
str(server_global_state.type_signature)

'(<client_Delta_i_norm_squared={float32[1]}@CLIENTS,client_sketches={float32[7,1500]}@CLIENTS,num_clients=int32@SERVER,epsilon=float32@SERVER> -> <float32[1]@SERVER,float32@SERVER>)'

### Round Terminating Condition

We check whether we guarantee that the **RTC** holds (`True`) or not (`False`).

Follows from above...

In [71]:
# Same for all FDA. (bool)
@tff.tf_computation(FLOAT32_1_TYPE, FLOAT32_TYPE, FLOAT32_TYPE)
def RTC_holds(S_1, M_S_2, THETA):
    """ Returns True if RTC holds (has not been defied). False otherwise (sync must happen)"""
    
    @tf.function
    def _F(S_1, M_S_2, THETA):
        """ Sketch FDA """
        return S_1 - M_S_2 <= THETA
    
    return _F(S_1, M_S_2, THETA)


In [72]:
str(RTC_holds.type_signature)

'(<S_1=float32[1],M_S_2=float32,THETA=float32> -> bool[1])'

### Help variance function

A helper function that computes the **Actual Variance** when the approximate **RTC** condition does not hold. We want to see how far off the approximation is from reality.

In [73]:
w_spec = tf.TensorSpec(shape=(NUM_CLIENTS, d, 1), dtype=tf.float32)

@tf.function(input_signature=[w_spec, w_spec])
def variance(w_t, w_sync):
    # w_t , w_sync tensors with shape=(NUM_CLIENTS, d, 1)
    
    # tensor with shape=(NUM_CLIENTS, d, 1)
    diff = w_t - w_sync
    
    # tensor with shape=(NUM_CLIENTS, 1) , For each client ||w_i_t - w_t||^2
    dot = tf.reduce_sum(tf.square(diff), axis=1)
    
    # Variance shape=() , scalar
    var = tf.reduce_mean(dot)
    
    return var

## Metrics

A very ugly looking `Metrics` class that helps us with printing and storing the round metrics.

Skip.

In [74]:
# Unentangle num_rounds

class Metrics:
    def __init__(self, name, num_rounds, theta,
                 num_clients, num_steps_until_sync_check, batch_size,
                 n, d, test_size, width, depth):
        
        self.num_rounds = num_rounds
        self.theta = theta
        self.num_clients = num_clients
        self.num_steps_until_sync_check = num_steps_until_sync_check
        self.batch_size = batch_size
        self.n = n
        self.d = d
        self.test_size = test_size
        self.sketch_bytes = width*depth*4
        self.S_i_bytes = self.sketch_bytes + 4
        self.name = name
        
        self.model_bytes = self.d * 4
        
        self.total_batches_per_client = int((1-self.test_size)*self.n / (self.num_clients*self.batch_size))
        
        self.one_sample_size_b = (self.d+1)*4 # bytes
        
        self.training_dataset_size_mb = self.one_sample_size_b * (n * (1-test_size)) / 1_000_000 # In mb
        
        
        self.samples = int(self.n * (1-self.test_size))
        
        self.all_metrics = []
        
    
    def print_initial_information(self):
        print("FEDERATED SETTING INFO:")
        print("------------------------------------------------------------")

        print("CLIENTS:")
        print(f'{"Clients":<10} {"Batches per Client":<20} {"Number of steps until Sync check":<25}')
        print(f'{self.num_clients:<10} {self.total_batches_per_client:<20} {self.num_steps_until_sync_check:<20}')
        print()

        print("TRAIN DATASET:")
        print(f'{"x-Dim":<6} {"y-classes":<10} {"Samples":<12} {"Dataset size (MB)":<20} {"Samples per Batch":<20}')
        print(f'{d:<6} {2:<10} {self.samples:<12} {self.training_dataset_size_mb:<20} {self.batch_size:<20}')
        print()

        print("ALGORITHM:")
        print(f'{"Name":<5} {"Model Bytes":<10}')
        print(f'{"PA-I":<5} {self.model_bytes:<10}')
        print()

        print("SYNCHRONIZATION:")
        print(f'{self.name} FDA monitoring model Variance bellow {self.theta}')
        print(f'Sketch bytes: {self.sketch_bytes}')
        print("------------------------------------------------------------")
        print()
        print()   
    
    def store_and_print_metrics(self, num_round, num_steps, global_state, accuracy, C, var):
        metrics = {}

        metrics['Round'] = num_round
        metrics['Steps'] = num_steps
        metrics['Accuracy'] = accuracy
        metrics['Global State'] = global_state[0]
        metrics['C'] = C
        metrics['Actual Variance'] = var

        # Total samples seen by all clients. batch_size = samples per batch
        metrics['Samples'] = self.batch_size * num_steps * self.num_clients

        assert num_steps % self.num_steps_until_sync_check == 0
        # FDA Every `NUM_STEPS_UNTIL_SYNC_CHECK` clients return their S_i (`width`*`depth`*4+4 bytes)
        local_states_bytes = int(num_steps/self.num_steps_until_sync_check) * self.num_clients * self.S_i_bytes
        # Synchronization: Receive models from clients AND Send model to all clients
        sync_bytes = 2 * self.model_bytes * self.num_clients

        metrics['Bytes Exchanged'] = local_states_bytes + sync_bytes

        self.all_metrics.append(metrics)

        # Print the metrics for the current round
        self.print_round_metrics()
        
    def print_round_metrics(self):
        metrics = self.all_metrics[-1]
        # Print the metric values in a nicely formatted table
        print((
            f'{"Round":<6} {"Steps":<6} {"C":<5} {"Accuracy":<13} '
            f'{"Bytes Exchanged":<20} {"Samples":<15} {"Var Approx":<15} '
            f'{"Var (Actual)":<15}'
        ))

        print((
            f"{metrics['Round']:<6} {metrics['Steps']:<6} {metrics['C']:<5} "
            f"{metrics['Accuracy']:<13.5f} {metrics['Bytes Exchanged']:<20} "
            f"{metrics['Samples']:<15} {metrics['Global State']:<15.6f} "
            f"{metrics['Actual Variance']:<15.6f}"
        ))

        print()
    
    def print_aggregate_metrics(self):

        total_bytes_exchanged = sum(metrics['Bytes Exchanged'] for metrics in self.all_metrics)
        total_mb_exchanged = total_bytes_exchanged/1_000_000
        
        total_steps = sum(metrics['Steps'] for metrics in self.all_metrics)
        total_samples = sum(metrics['Samples'] for metrics in self.all_metrics)
        final_accuracy = self.all_metrics[-1]['Accuracy']

        # Remember we pass the dataset many times at random (random batches)
        trained_in_bytes = self.one_sample_size_b * total_samples 
        trained_in_mb = trained_in_bytes / 1_000_000 # MB
        
        # total communication due to model synchronization
        model_sync_bytes_exchanged = self.num_rounds * (2 * self.model_bytes * self.num_clients)
        model_sync_mb_exchanged = model_sync_bytes_exchanged / 1_000_000
        
        # total communication due to monitoring
        monitoring_bytes_exchanged = total_bytes_exchanged - model_sync_bytes_exchanged
        monitoring_mb_exchanged = monitoring_bytes_exchanged / 1_000_000
        
        print("------------------------------------------------------------")

        print()
        print('FINAL METRICS:')
        print("------------------------------------------------------------")
        print()

        print('TRAINING:')
        print((
            f'{"Rounds":<7} {"Steps":<9} {"Samples":<13} {"Trained MB":<20}'
            f'{"Final Accuracy":<25}'
        ))
        print((
            f'{self.num_rounds:<7} {total_steps:<9} {total_samples:<13} {trained_in_mb:<20.5f}'
            f'{final_accuracy:<25.6}'
        ))
        print()
        
        print('COMMUNICATION:')
        print(f'{"Total MB Exchanged":<25} {"Model Sync MB Exchanged":<25} {"Monitoring MB Exchanged":<25}')
        print(f'{total_mb_exchanged:<25} {model_sync_mb_exchanged:<25} {monitoring_mb_exchanged:<25}')
            
        

## Training Initialization

Create the federated datasets.

In [75]:

train_federated_data = create_federated_data()
test_dataset = create_tf_dataset_for_test()

Initial conditions

In [76]:
NUM_ROUNDS = 15
THETA = 2.
C = 0.01

We assume that all Clients start in synchronization, i.e., Server model is zeros and Client models are also zeros.

Moreover, notice that `client_models`, `last_sync_client_models`, `client_C` are all defined as lists containing `NUM_CLIENTS` elements. This is the simulation approach of TFF and following the already defined functions above, each element in those lists is assumed to lie in one `tff.CLIENT`. For example, each `tff.CLIENT` has a `C` hyperparameter to be used by **PA-I** classifier on its own model etc.

In [77]:
# Initial model of zeros
model = tf.Variable(tf.zeros(shape=(d, 1)), trainable=True, name='weights', dtype=tf.float32)

# Assume client models are synchronized at the start (Obviously S_t = 0)
client_models = [model]*NUM_CLIENTS
last_sync_client_models = [model]*NUM_CLIENTS
client_C = [C]*NUM_CLIENTS
S_1 = tf.zeros(shape=(1,), dtype=tf.float32)
S_2 = tf.zeros(shape=(), dtype=tf.float32)

## Training Loop

In [78]:
metrics = Metrics('Sketch', NUM_ROUNDS, THETA, NUM_CLIENTS, NUM_STEPS_UNTIL_SYNC_CHECK, BATCH_SIZE, n, d, test_size, width, depth)
metrics.print_initial_information()

FEDERATED SETTING INFO:
------------------------------------------------------------
CLIENTS:
Clients    Batches per Client   Number of steps until Sync check
20         156                  10                  

TRAIN DATASET:
x-Dim  y-classes  Samples      Dataset size (MB)    Samples per Batch   
100    2          90000        36.36                32                  

ALGORITHM:
Name  Model Bytes
PA-I  400       

SYNCHRONIZATION:
Sketch FDA monitoring model Variance bellow 2.0
Sketch bytes: 42000
------------------------------------------------------------




In [79]:
metrics.print_initial_information()

for r in range(1, NUM_ROUNDS+1):
    
    num_steps = 0 # Each step() invocation is `NUM_STEPS_UNTIL_SYNC_CHECK` number of theoretical steps
    
    while RTC_holds(S_1, S_2, THETA): # RTC holds, no sync needed
        
        # Perform a training step with the current client_models (no sync yet)
        # Note: We train for `NUM_STEPS_UNTIL_SYNC_CHECK` batches inside `step` in order to let TF optimize.
        #       It is the same to `step` for one batch `NUM_STEPS_UNTIL_SYNC_CHECK` number of times.
        client_models, client_Delta_i_norm_squared, client_sketches = step(
            last_sync_client_models, client_models, client_C, train_federated_data
        )
        
        # Compute 'global state' Approx Variance as defined in the manuscript
        S_1, S_2 = server_global_state(client_Delta_i_norm_squared, client_sketches, NUM_CLIENTS, epsilon)
        
        # because we train for `NUM_STEPS_UNTIL_SYNC_CHECK` batches inside `step`
        num_steps += 1*NUM_STEPS_UNTIL_SYNC_CHECK
    
    # RTC defied, sync must happen
    
    # Update the server model from the client models.
    model = server_update(client_models)
    
    metrics.store_and_print_metrics(r, num_steps, S_1 - S_2, accuracy_fn(model, test_dataset), client_C[0], variance(client_models, [model]*NUM_CLIENTS))
    
    client_models, last_sync_client_models, S_1, S_2 = [model]*NUM_CLIENTS, [model]*NUM_CLIENTS, tf.zeros(shape=(1,), dtype=tf.float32), tf.zeros(shape=(), dtype=tf.float32)
    
    
metrics.print_aggregate_metrics()

FEDERATED SETTING INFO:
------------------------------------------------------------
CLIENTS:
Clients    Batches per Client   Number of steps until Sync check
20         156                  10                  

TRAIN DATASET:
x-Dim  y-classes  Samples      Dataset size (MB)    Samples per Batch   
100    2          90000        36.36                32                  

ALGORITHM:
Name  Model Bytes
PA-I  400       

SYNCHRONIZATION:
Sketch FDA monitoring model Variance bellow 2.0
Sketch bytes: 42000
------------------------------------------------------------


Round  Steps  C     Accuracy      Bytes Exchanged      Samples         Var Approx      Var (Actual)   
1      90     0.01  0.81540       7576720              57600           2.048988        1.425875       

Round  Steps  C     Accuracy      Bytes Exchanged      Samples         Var Approx      Var (Actual)   
2      90     0.01  0.81749       7576720              57600           2.022470        1.453180       

Round  Steps  C 

# TODO: 
1. Metrics class becomes .py imported
1. unused types
3. types consistency. *_TYPE @SERVER/CLIENTS
4. show initial metrics before training loop
5. beamer presentation
6. plots think about it. (maybe first create relationship between hyperparameters)

# training loop (testing)

In [84]:
"""
metrics = Metrics('Sketch', NUM_ROUNDS, THETA, NUM_CLIENTS, NUM_STEPS_UNTIL_SYNC_CHECK, BATCH_SIZE, n, d, test_size, width*depth*4)
metrics.print_initial_information()

for r in range(1, NUM_ROUNDS+1):
    
    num_steps = 0 # Each step() invocation is a step
    
    while RTC_holds(S_1, S_2, THETA): # RTC holds, no sync needed
        
        # Perform a training step with the current client_models (no sync yet)
        client_models, client_Delta_i_norm_squared, client_sketches, client_D_i = step(
            last_sync_client_models, client_models, client_C, train_federated_data
        )
        
        # Compute 'global state' Approx Variance as defined in the manuscript
        S_1, S_2 = server_global_state(client_Delta_i_norm_squared, client_sketches, NUM_CLIENTS, epsilon)
        
        client_D_i_tf = tf.constant(client_D_i, shape=(NUM_CLIENTS, d)) # CORR shape=(NUM_CLIENTS, d)
        D_i_mean_tf = tf.reduce_mean(client_D_i_tf, axis=0) # CORR shape=(d,)
        D_i_mean_euc_sq = tf.reduce_sum(tf.square(D_i_mean_tf)) # CORR (right-hand-side)
        
        D_i_mean_euc_sq_est = S_2
        
        sum_D_i_k = tf.reduce_mean(tf.reduce_sum(tf.square(client_D_i_tf), axis=1)) # CORR (left-hand-side)
        
        print(f"Actual sketch val: {D_i_mean_euc_sq} ,  Sketch est: {D_i_mean_euc_sq_est}")
        print(f"Overestimate: {D_i_mean_euc_sq <= D_i_mean_euc_sq_est}")
        print(f"Left-hand-side: {sum_D_i_k}")
        print(f"Actual Var: {sum_D_i_k-D_i_mean_euc_sq} , Sketch var: {sum_D_i_k-D_i_mean_euc_sq_est}")
        print()
        
        num_steps += 1
    
    # RTC defied, sync must happen
    
        
    # Update the server model from the client models.
    model = server_update(client_models)
    
    metrics.store_and_print_metrics(r, num_steps, S_1 - S_2, accuracy_fn(model, test_dataset), client_C[0], variance(client_models, [model]*NUM_CLIENTS))
    
    client_models, last_sync_client_models, S_1, S_2 = [model]*NUM_CLIENTS, [model]*NUM_CLIENTS, tf.zeros(shape=(1,), dtype=tf.float32), tf.zeros(shape=(), dtype=tf.float32)

    
metrics.print_aggregate_metrics()
"""
print("uncomment")

uncomment
