In [1]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import tensorflow as tf

from tensorflow import keras
from tensorflow.keras import layers

import numpy as np
import pandas as pd

## Import EMNIST data

In [2]:
(X_train, y_train), (X_test, y_test) = keras.datasets.mnist.load_data()

X_train, X_test = X_train / 255.0, X_test / 255.0

In [3]:
n_train = len(X_train)

In [4]:
CNN_BATCH_INPUT = (None, 28, 28) # EMNIST dataset (None is used for batch size, as it varies)
CNN_INPUT_RESHAPE = (28, 28, 1)

## Prepare data for Federated Learning

In [5]:
# Convert to TensorFlow Datasets
train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train))
test_dataset = tf.data.Dataset.from_tensor_slices((X_test, y_test)).batch(256)

del X_train, y_train, X_test, y_test

### Slice the Tensors for each Client

In [6]:
def create_federated_data_for_clients(num_clients):
    
    # Shard the data across clients CLIENT LEVEL
    client_datasets = [
        train_dataset.shard(num_clients, i)
        for i in range(num_clients)
    ]
    
    return client_datasets

### Prepare (and restart) Client Dataset - shuffling, batching, prefetching

Proper use of `.prefetch` [explained](https://stackoverflow.com/questions/63796936/what-is-the-proper-use-of-tensorflow-dataset-prefetch-and-cache-options).

Proper ordering `.shuffle` and `.batch` and `.repeat` [explained](https://stackoverflow.com/questions/50437234/tensorflow-dataset-shuffle-then-batch-or-batch-then-shuffle)

In [7]:
def prepare_federated_data_for_test(federated_data, batch_size, num_steps_until_rtc_check, seed=None):
    
    def process_client_dataset(client_dataset, batch_size, num_steps_until_rtc_check, seed, shuffle_size=512):
        return client_dataset.shuffle(shuffle_size, seed=seed).repeat().batch(batch_size)\
            .take(num_steps_until_rtc_check).prefetch(tf.data.AUTOTUNE)
        
    federated_dataset_prepared = [
        process_client_dataset(client_dataset, batch_size, num_steps_until_rtc_check, seed)
        for client_dataset in federated_data
    ]
    return federated_dataset_prepared

# Miscallenious

TODO: Metrics Time series + Regular

## Neural Net weights

In [8]:
def count_weights(model):
    total_params = 0
    for layer in model.layers:
        total_params += np.sum([np.prod(weight.shape) for weight in layer.trainable_weights])
    return int(total_params)

## Reseting NN weights for Server-Clients

In [9]:
def reset_trainable_variables(server_cnn, client_cnns, starting_trainable_variables):
    
    server_cnn.set_trainable_variables(starting_trainable_variables)
    
    for client_cnn in client_cnns:
        client_cnn.set_trainable_variables(starting_trainable_variables)

# Simple Convolutional Neural Net (CNN) - Medium Size

A simple Convolutional Neural Network with a single convolutional layer, followed by a max-pooling layer, and two dense layers for classification. Designed for 28x28 grayscale images. It has 692,352 weights.

In [10]:
class CNN(tf.keras.Model):
    def __init__(self, num_classes=10):
        super(CNN, self).__init__()
        self.reshape = layers.Reshape(CNN_INPUT_RESHAPE)
        self.conv1 = layers.Conv2D(32, 3, activation='relu')
        self.max_pool = layers.MaxPooling2D(pool_size=(2, 2))
        self.flatten = layers.Flatten()
        self.dense1 = layers.Dense(128, activation='relu')
        self.dense2 = layers.Dense(num_classes, activation='softmax')

        
    # Defines the computation from inputs to outputs
    def call(self, inputs, training=None):
        x = self.reshape(inputs)  # Add a channel dimension
        x = self.conv1(x)
        x = self.max_pool(x)
        x = self.flatten(x)
        x = self.dense1(x)
        x = self.dense2(x)
        return x
    
    
    def step(self, batch):
        
        x_batch, y_batch = batch

        with tf.GradientTape() as tape:
            # Forward pass: Compute predictions
            y_batch_pred = self(x_batch, training=True)

            # Compute the loss value
            # (the loss function is configured in `compile()`)
            loss = self.compiled_loss(
                y_true=y_batch,
                y_pred=y_batch_pred,
                regularization_losses=self.losses
            )

        # Compute gradients
        gradients = tape.gradient(loss, self.trainable_variables)
        
        # Apply gradients to the model's trainable variables (update weights)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
        
        # Update metrics (includes the metric that tracks the loss)
        #self.compiled_metrics.update_state(y_batch, y_batch_pred)
    
    
    def train(self, dataset):

        for batch in dataset:
            self.step(batch)
            
    
    def set_trainable_variables(self, trainable_vars):
        """ Given `trainable_vars` set our `self.trainable_vars` """
        for model_var, var in zip(self.trainable_variables, trainable_vars):
            model_var.assign(var)

            
    def trainable_vars_as_vector(self):
        return tf.concat([tf.reshape(var, [-1]) for var in self.trainable_variables], axis=0)

### Helper function to compile and return the CNN

In [11]:
def get_compiled_and_built_cnn():
    cnn = CNN()
    
    cnn.compile(
        optimizer=keras.optimizers.Adam(),
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=False), # we have softmax
        metrics=[keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')]
    )
    
    cnn.build(CNN_BATCH_INPUT)  # EMNIST dataset (None is used for batch size, as it varies)
    
    return cnn

# Advanced Convolutional Neural Net (CNN) - Large Size

A more complex Convolutional Neural Network with three sets of two convolutional layers, each followed by a max-pooling layer, and two dense layers with dropout for classification. Designed for 28x28 grayscale images. It has 2,592,202 weights.

In [12]:
class AdvancedCNN(tf.keras.Model):
    def __init__(self, num_classes=10):
        super(AdvancedCNN, self).__init__()
        
        self.reshape = layers.Reshape(CNN_INPUT_RESHAPE)
        
        self.conv1 = layers.Conv2D(64, kernel_size=3, activation='relu', padding='same')
        self.conv2 = layers.Conv2D(64, kernel_size=3, activation='relu', padding='same')
        self.max_pool1 = layers.MaxPooling2D(pool_size=(2, 2))
        
        self.conv3 = layers.Conv2D(128, kernel_size=3, activation='relu', padding='same')
        self.conv4 = layers.Conv2D(128, kernel_size=3, activation='relu', padding='same')
        self.max_pool2 = layers.MaxPooling2D(pool_size=(2, 2))
        
        self.conv5 = layers.Conv2D(256, kernel_size=3, activation='relu', padding='same')
        self.conv6 = layers.Conv2D(256, kernel_size=3, activation='relu', padding='same')
        self.max_pool3 = layers.MaxPooling2D(pool_size=(2, 2))

        self.flatten = layers.Flatten()
        self.dense1 = layers.Dense(512, activation='relu')
        self.dropout1 = layers.Dropout(0.5)
        self.dense2 = layers.Dense(512, activation='relu')
        self.dropout2 = layers.Dropout(0.5)
        self.dense3 = layers.Dense(num_classes, activation='softmax')

    def call(self, inputs, training=None):
        x = self.reshape(inputs)  # Add a channel dimension
        
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.max_pool1(x)

        x = self.conv3(x)
        x = self.conv4(x)
        x = self.max_pool2(x)

        x = self.conv5(x)
        x = self.conv6(x)
        x = self.max_pool3(x)

        x = self.flatten(x)
        x = self.dense1(x)
        x = self.dropout1(x, training=training)
        x = self.dense2(x)
        x = self.dropout2(x, training=training)
        x = self.dense3(x)
        return x
    
    
    def step(self, batch):
        
        x_batch, y_batch = batch

        with tf.GradientTape() as tape:
            # Forward pass: Compute predictions
            y_batch_pred = self(x_batch, training=True)

            # Compute the loss value
            # (the loss function is configured in `compile()`)
            loss = self.compiled_loss(
                y_true=y_batch,
                y_pred=y_batch_pred,
                regularization_losses=self.losses
            )

        # Compute gradients
        gradients = tape.gradient(loss, self.trainable_variables)
        
        # Apply gradients to the model's trainable variables (update weights)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
        
        # Update metrics (includes the metric that tracks the loss)
        #self.compiled_metrics.update_state(y_batch, y_batch_pred)
    
    
    def train(self, dataset):

        for batch in dataset:
            self.step(batch)
            
    
    def set_trainable_variables(self, trainable_vars):
        """ Given `trainable_vars` set our `self.trainable_vars` """
        for model_var, var in zip(self.trainable_variables, trainable_vars):
            model_var.assign(var)

            
    def trainable_vars_as_vector(self):
        return tf.concat([tf.reshape(var, [-1]) for var in self.trainable_variables], axis=0)
    

### Helper function to compile and return the CNN

In [13]:
def get_compiled_and_built_advanced_cnn():
    advanced_cnn = AdvancedCNN()
    
    advanced_cnn.compile(
        optimizer=keras.optimizers.Adam(),
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=False), # we have softmax
        metrics=[keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')]
    )
    
    advanced_cnn.build(CNN_BATCH_INPUT)  # EMNIST dataset (None is used for batch size, as it varies)
    
    return advanced_cnn

### Average NN weights

In [14]:
def average_client_weights(client_models):
    # client_weights[0] the trainable variables of Client 0 (a list of tf.Variable)
    client_weights = [model.trainable_variables for model in client_models]

    # concise solution. per layer. `layer_weight_tensors` corresponds to a list of tensors of a layer
    avg_weights = [
        tf.reduce_mean(layer_weight_tensors, axis=0)
        for layer_weight_tensors in zip(*client_weights)
    ]

    return avg_weights

### Server - Clients synchronization

In [15]:
def synchronize(server_cnn, client_cnns):
    # server average
    server_cnn.set_trainable_variables(average_client_weights(client_cnns))
    
    # synchronize clients
    for client_cnn in client_cnns:
        client_cnn.set_trainable_variables(server_cnn.trainable_variables)

# Functional Dynamic Averaging

We follow the Functional Dynamic Averaging (FDA) scheme. Let the mean model be

$$ \overline{w_t} = \frac{1}{k} \sum_{i=1}^{k} w_t^{(i)} $$

where $ w_t^{(i)} $ is the model at time $ t $ in some round in the $i$-th learner.

Local models are trained independently and cooperatively and we want to monitor the Round Terminating Conditon (**RTC**):

$$ \frac{1}{k} \sum_{i=1}^{k} \lVert w_t^{(i)} - \overline{w_t} \rVert_2^2  \leq \Theta $$

where the left-hand side is the **model variance**, and threshold $\Theta$ is a hyperparameter of the FDA, defined at the beginning of the round; it may change at each round. When the monitoring logic cannot guarantee the validity of RTC, the round terminates. All local models are pulled into `tff.SERVER`, and $\bar{w_t}$ is set to their average. Then, another round begins.

### Monitoring the RTC

FDA monitors the RTC by applying techniques from Functionary [Functional Geometric Averaging](http://users.softnet.tuc.gr/~minos/Papers/edbt19.pdf). We first restate the problem of monitoring RTC into the standard distributed stream monitoring formulation. Let

$$ S(t) =  \frac{1}{k} \sum_{i=1}^{k} S_i(t) $$

where $ S(t) \in \mathbb{R}^n $ be the "global state" of the system and $ S_i(t) \in \mathbb{R}^n $ the "local states". The goal is to monitor the threshold condition on the global state in the form $ F(S(t)) \leq \Theta $ where $ F : \mathbb{R}^n \to \mathbb{R} $ a non-linear function. Let

$$ \Delta_t^{(i)} = w_t^{(i)} - w_{t_0}^{(i)} $$

be the update at the $ i $-th learner, that is, the change to the local model at time $t$ since the beginning of the current round at time $t_0$. Let the average update be

$$ \overline{\Delta_t} = \frac{1}{k} \sum_{i=1}^{k} \Delta_t^{(i)} $$

it follows that the variance can be written as

$$ \frac{1}{k} \sum_{i=1}^{k} \lVert w_t^{(i)} - \overline{w_t} \rVert_2^2 = \Big( \frac{1}{k} \sum_{i=1}^{k} \lVert \Delta_t^{(i)} \rVert_2^2 \Big) - \lVert \overline{\Delta_t} \rVert_2^2 $$

So, conceptually, if we define
$$ S_i(t) = \begin{bmatrix}
           \lVert \Delta_t^{(i)} \rVert_2^2 \\
           \Delta_t^{(i)}
         \end{bmatrix} \quad \text{and} \quad
         F(\begin{bmatrix}
           v \\
           \bf{x}
         \end{bmatrix}) = v - \lVert \bf{x} \rVert_2^2 $$

The RTC is equivalent to condition $$ F(S(t)) \leq \Theta $$

## 1️⃣ Naive FDA

In the naive approach, we eliminate the update vector from the local state (i.e. recuce the dimension to 0). Define local state as

$$ S_i(t) = \lVert \Delta_t^{(i)} \rVert_2^2 \in \mathbb{R}$$ 

and the identity function

$$ F(v) = v $$

It is trivial that $ F(S(t)) \leq \Theta $ implies the RTC.

### Client Train

The number of steps depends on the dataset, i.e., `.take(num)` call on `tf.data.Dataset` creation!

In [54]:
def client_train_naive(last_sync_cnn, client_cnn, client_dataset):
    """ Returns Tensor shape=() dtype=tf.float32 """
    
    # number of steps depend on `.take()` from `dataset`
    client_cnn.train(client_dataset)
    
    Delta_i = client_cnn.trainable_vars_as_vector() - last_sync_cnn.trainable_vars_as_vector()
    
    Delta_i_euc_norm_squared = tf.reduce_sum(tf.square(Delta_i)) # ||D(t)_i||^2
    
    return Delta_i_euc_norm_squared

### Train all Clients

We will now create the training function given many client CNNs. It is essential to use `@tf.function` wrapper here to let Tensorflow create a Graph since each Client can be trained in parallel.

In [55]:
@tf.function
def clients_train_naive(last_sync_cnn, client_cnns, federated_dataset):
    """ Returns list of Tensors shape=() dtype=tf.float32 """
    
    print("retrace `clients_train_naive`")
    S_i_clients = []

    for client_cnn, client_dataset in zip(client_cnns, federated_dataset):
        Delta_i_euc_norm_squared = client_train_naive(last_sync_cnn, client_cnn, client_dataset)
        S_i_clients.append(Delta_i_euc_norm_squared)
    
    return S_i_clients

### Identity Function

In [56]:
def F_naive(S_i_clients):
    """ Returns Tensor shape=() dtype=tf.float32 """
    
    S = tf.reduce_mean(S_i_clients)
    
    return S

### Training Loop

In [32]:
def testing_stuff():
    NUM_CLIENTS = 5
    NUM_EPOCHS = 1
    BATCH_SIZE = 32
    NUM_STEPS_UNTIL_RTC_CHECK = 1
    seed = 7
    
    clients_federated_data = create_federated_data_for_clients(NUM_CLIENTS)
            
    server_cnn = get_compiled_and_built_advanced_cnn()
    client_cnns = [get_compiled_and_built_advanced_cnn() for _ in range(NUM_CLIENTS)]
    
    synchronize(server_cnn, client_cnns)
    
    federated_dataset = prepare_federated_data_for_test(
        federated_data=clients_federated_data, 
        batch_size=BATCH_SIZE,
        num_steps_until_rtc_check=NUM_STEPS_UNTIL_RTC_CHECK,
        seed=seed
    )
    
    return server_cnn, client_cnns, federated_dataset