In [1]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import tensorflow as tf

from tensorflow import keras
from tensorflow.keras import layers

import numpy as np
import pandas as pd

## Import EMNIST data

In [2]:
(X_train, y_train), (X_test, y_test) = keras.datasets.mnist.load_data()

X_train, X_test = X_train / 255.0, X_test / 255.0

In [3]:
n_train = len(X_train)

In [4]:
CNN_BATCH_INPUT = (None, 28, 28) # EMNIST dataset (None is used for batch size, as it varies)
CNN_INPUT_RESHAPE = (28, 28, 1)

## Prepare data for Federated Learning

In [5]:
# Convert to TensorFlow Datasets
train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train))
test_dataset = tf.data.Dataset.from_tensor_slices((X_test, y_test)).batch(256)

del X_train, y_train, X_test, y_test

### Slice the Tensors for each Client

In [6]:
def create_federated_data_for_clients(num_clients):
    
    # Shard the data across clients CLIENT LEVEL
    client_datasets = [
        train_dataset.shard(num_clients, i)
        for i in range(num_clients)
    ]
    
    return client_datasets

### Prepare (and restart) Client Dataset - shuffling, batching, prefetching

Proper use of `.prefetch` [explained](https://stackoverflow.com/questions/63796936/what-is-the-proper-use-of-tensorflow-dataset-prefetch-and-cache-options).

Proper ordering `.shuffle` and `.batch` and `.repeat` [explained](https://stackoverflow.com/questions/50437234/tensorflow-dataset-shuffle-then-batch-or-batch-then-shuffle)

In [7]:
def prepare_federated_data_for_test(federated_data, batch_size, num_steps_until_rtc_check, seed=None):
    
    def process_client_dataset(client_dataset, batch_size, num_steps_until_rtc_check, seed, shuffle_size=512):
        return client_dataset.shuffle(shuffle_size, seed=seed).repeat().batch(batch_size)\
            .take(num_steps_until_rtc_check).prefetch(tf.data.AUTOTUNE)
        
    federated_dataset_prepared = [
        process_client_dataset(client_dataset, batch_size, num_steps_until_rtc_check, seed)
        for client_dataset in federated_data
    ]
    return federated_dataset_prepared

# Miscallenious

## Variance

In [8]:
@tf.function
def variance(cnn_list, cnn_sync):
    
    squared_distances = [
        tf.reduce_sum(tf.square(cnn.trainable_vars_as_vector() - cnn_sync.trainable_vars_as_vector())) 
        for cnn in cnn_list
    ]
    
    var = tf.reduce_mean(squared_distances)
    
    return var

## Metrics

In [9]:
def create_metrics_dict(fda_name, n_train, dataset_name, input_pixels, seed, epochs, num_clients, 
                        batch_size, steps_in_one_fda_step, theta, total_fda_steps, num_weights,
                        total_rounds, final_accuracy, sketch_width=-1, sketch_depth=-1):
    metrics = {
            "fda_name" : fda_name,
            "theta" : theta,
            "dataset_name" : dataset_name, # new
            "input_pixels" : input_pixels, # new
            "n_train" : n_train, # new
            "num_weights" : num_weights, # new
            "seed" : seed,
            "epochs" : epochs,
            "num_clients" : num_clients,
            "batch_size" : batch_size,
            "steps_in_one_fda_step" : steps_in_one_fda_step,
            "sketch_width" : sketch_width,
            "sketch_depth" : sketch_depth
        }
    
    # one batch bytes
    metrics["one_sample_bytes"] = 4 * (metrics["input_pixels"] + 1)  # 4 bytes float32
    
    # training dataset size
    metrics["training_dataset_bytes"] = metrics["one_sample_bytes"] * metrics["n_train"]
    
    # model bytes
    metrics["model_bytes"] = metrics["num_weights"] * 4
    
    
    # local state bytes (i.e. S_i), for one client
    if fda_name == "naive":
        metrics["local_state_bytes"] = 4
    elif fda_name == "linear":
        metrics["local_state_bytes"] = 8
    else:
        metrics["local_state_bytes"] = sketch_width * sketch_depth * 4 + 4
        
    # accuracy (already computed in parameter)
    metrics["final_accuracy"] = final_accuracy
    
    # total fda steps from algo
    metrics["total_fda_steps"] = total_fda_steps
    
    # total steps (a single fda step might have many normal SGD steps, batch steps)
    metrics["total_steps"] = metrics["total_fda_steps"] * metrics["steps_in_one_fda_step"]
    
    # total rounds in algo. Reason why we differentiate from the hardcoded NUM_ROUNDS
    # is because we might run less rounds in the future (i.e. stop on 10^7 samples idk)
    metrics["total_rounds"] = total_rounds
    
    # bytes exchanged for synchronizing weights (x2 because server sends back)
    metrics["model_bytes_exchanged"] = metrics["total_rounds"] * metrics["model_bytes"] \
        * metrics["num_clients"] * 2
    
    # bytes exchanged for monitoring the variance (communication)
    metrics["monitoring_bytes_exchanged"] = metrics["local_state_bytes"] * metrics["total_fda_steps"] \
        * metrics["num_clients"]
    
    # total communication bytes (for both monitoring and model synchronization)
    metrics["total_communication_bytes"] = metrics["model_bytes_exchanged"] + metrics["monitoring_bytes_exchanged"]
    
    # total seen dataset bytes (across all learning, i.e., all clients)
    metrics["trained_in_bytes"] = metrics["batch_size"] * metrics["one_sample_bytes"] \
        * metrics["total_steps"] * metrics["num_clients"]
    
    return metrics

## Time-Series Metrics

In [10]:
def create_time_series_metrics(time_series_data, dataset_name, fda_name, epochs, num_clients, batch_size, 
                               steps_in_one_fda_step, theta, num_weights, seed=None, sketch_width=-1, 
                               sketch_depth=-1):
    
    time_series_metrics = []
    
    id_tuple = (
        dataset_name, fda_name, epochs, num_clients, batch_size,
        steps_in_one_fda_step, theta, num_weights, seed, sketch_width, sketch_depth
    )
    
    # tf.cast(total_fda_steps, dtype=tf.float32), estimated_var, actual_var
    for round_num, (total_fda_steps, estimated_var, actual_var) in enumerate(time_series_data):
        
        total_fda_steps = int(total_fda_steps)
        
        met = {
            "round" : round_num,
            "total_fda_steps" : total_fda_steps,
            "total_steps" : total_fda_steps*steps_in_one_fda_step,
            "actual_var" : actual_var,
            "estimated_var" : estimated_var,
        }
        
        time_series_metrics.append(met)
        
    return {id_tuple: time_series_metrics}

## Neural Net weights

In [11]:
def count_weights(model):
    total_params = 0
    for layer in model.layers:
        total_params += np.sum([np.prod(weight.shape) for weight in layer.trainable_weights])
    return int(total_params)

## Reseting NN weights for Server-Clients

In [12]:
def reset_trainable_variables(server_cnn, client_cnns, starting_trainable_variables):
    
    server_cnn.set_trainable_variables(starting_trainable_variables)
    
    for client_cnn in client_cnns:
        client_cnn.set_trainable_variables(starting_trainable_variables)

# Simple Convolutional Neural Net (CNN) - Medium Size

A simple Convolutional Neural Network with a single convolutional layer, followed by a max-pooling layer, and two dense layers for classification. Designed for 28x28 grayscale images. It has 692,352 weights.

In [13]:
class CNN(tf.keras.Model):
    def __init__(self, num_classes=10):
        super(CNN, self).__init__()
        self.reshape = layers.Reshape(CNN_INPUT_RESHAPE)
        self.conv1 = layers.Conv2D(32, 3, activation='relu')
        self.max_pool = layers.MaxPooling2D(pool_size=(2, 2))
        self.flatten = layers.Flatten()
        self.dense1 = layers.Dense(128, activation='relu')
        self.dense2 = layers.Dense(num_classes, activation='softmax')

        
    # Defines the computation from inputs to outputs
    def call(self, inputs, training=None):
        x = self.reshape(inputs)  # Add a channel dimension
        x = self.conv1(x)
        x = self.max_pool(x)
        x = self.flatten(x)
        x = self.dense1(x)
        x = self.dense2(x)
        return x
    
    
    @tf.function
    def step(self, batch):
        
        x_batch, y_batch = batch

        with tf.GradientTape() as tape:
            # Forward pass: Compute predictions
            y_batch_pred = self(x_batch, training=True)

            # Compute the loss value
            # (the loss function is configured in `compile()`)
            loss = self.compiled_loss(
                y_true=y_batch,
                y_pred=y_batch_pred,
                regularization_losses=self.losses
            )

        # Compute gradients
        gradients = tape.gradient(loss, self.trainable_variables)
        
        # Apply gradients to the model's trainable variables (update weights)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
        
        # Update metrics (includes the metric that tracks the loss)
        #self.compiled_metrics.update_state(y_batch, y_batch_pred)
    
    
    @tf.function
    def train(self, dataset):

        for batch in dataset:
            self.step(batch)
            
    
    def set_trainable_variables(self, trainable_vars):
        """ Given `trainable_vars` set our `self.trainable_vars` """
        for model_var, var in zip(self.trainable_variables, trainable_vars):
            model_var.assign(var)

            
    def trainable_vars_as_vector(self):
        return tf.concat([tf.reshape(var, [-1]) for var in self.trainable_variables], axis=0)

### Helper function to compile and return the CNN

In [14]:
def get_compiled_and_built_cnn():
    cnn = CNN()
    
    cnn.compile(
        optimizer=keras.optimizers.Adam(),
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=False), # we have softmax
        metrics=[keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')]
    )
    
    cnn.build(CNN_BATCH_INPUT)  # EMNIST dataset (None is used for batch size, as it varies)
    
    return cnn

# Advanced Convolutional Neural Net (CNN) - Large Size

A more complex Convolutional Neural Network with three sets of two convolutional layers, each followed by a max-pooling layer, and two dense layers with dropout for classification. Designed for 28x28 grayscale images. It has 2,592,202 weights.

In [15]:
class AdvancedCNN(tf.keras.Model):
    def __init__(self, num_classes=10):
        super(AdvancedCNN, self).__init__()
        
        self.reshape = layers.Reshape(CNN_INPUT_RESHAPE)
        
        self.conv1 = layers.Conv2D(64, kernel_size=3, activation='relu', padding='same')
        self.conv2 = layers.Conv2D(64, kernel_size=3, activation='relu', padding='same')
        self.max_pool1 = layers.MaxPooling2D(pool_size=(2, 2))
        
        self.conv3 = layers.Conv2D(128, kernel_size=3, activation='relu', padding='same')
        self.conv4 = layers.Conv2D(128, kernel_size=3, activation='relu', padding='same')
        self.max_pool2 = layers.MaxPooling2D(pool_size=(2, 2))
        
        self.conv5 = layers.Conv2D(256, kernel_size=3, activation='relu', padding='same')
        self.conv6 = layers.Conv2D(256, kernel_size=3, activation='relu', padding='same')
        self.max_pool3 = layers.MaxPooling2D(pool_size=(2, 2))

        self.flatten = layers.Flatten()
        self.dense1 = layers.Dense(512, activation='relu')
        self.dropout1 = layers.Dropout(0.5)
        self.dense2 = layers.Dense(512, activation='relu')
        self.dropout2 = layers.Dropout(0.5)
        self.dense3 = layers.Dense(num_classes, activation='softmax')

    def call(self, inputs, training=None):
        x = self.reshape(inputs)  # Add a channel dimension
        
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.max_pool1(x)

        x = self.conv3(x)
        x = self.conv4(x)
        x = self.max_pool2(x)

        x = self.conv5(x)
        x = self.conv6(x)
        x = self.max_pool3(x)

        x = self.flatten(x)
        x = self.dense1(x)
        x = self.dropout1(x, training=training)
        x = self.dense2(x)
        x = self.dropout2(x, training=training)
        x = self.dense3(x)
        return x
    
    
    @tf.function
    def step(self, batch):
        
        x_batch, y_batch = batch

        with tf.GradientTape() as tape:
            # Forward pass: Compute predictions
            y_batch_pred = self(x_batch, training=True)

            # Compute the loss value
            # (the loss function is configured in `compile()`)
            loss = self.compiled_loss(
                y_true=y_batch,
                y_pred=y_batch_pred,
                regularization_losses=self.losses
            )

        # Compute gradients
        gradients = tape.gradient(loss, self.trainable_variables)
        
        # Apply gradients to the model's trainable variables (update weights)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
        
        # Update metrics (includes the metric that tracks the loss)
        #self.compiled_metrics.update_state(y_batch, y_batch_pred)
    
    
    @tf.function
    def train(self, dataset):

        for batch in dataset:
            self.step(batch)
            
    
    def set_trainable_variables(self, trainable_vars):
        """ Given `trainable_vars` set our `self.trainable_vars` """
        for model_var, var in zip(self.trainable_variables, trainable_vars):
            model_var.assign(var)

            
    def trainable_vars_as_vector(self):
        return tf.concat([tf.reshape(var, [-1]) for var in self.trainable_variables], axis=0)
    

### Helper function to compile and return the CNN

In [16]:
def get_compiled_and_built_advanced_cnn():
    advanced_cnn = AdvancedCNN()
    
    advanced_cnn.compile(
        optimizer=keras.optimizers.Adam(),
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=False), # we have softmax
        metrics=[keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')]
    )
    
    advanced_cnn.build(CNN_BATCH_INPUT)  # EMNIST dataset (None is used for batch size, as it varies)
    
    return advanced_cnn

### Average NN weights

In [17]:
def average_client_weights(client_models):
    # client_weights[0] the trainable variables of Client 0 (a list of tf.Variable)
    client_weights = [model.trainable_variables for model in client_models]

    # concise solution. per layer. `layer_weight_tensors` corresponds to a list of tensors of a layer
    avg_weights = [
        tf.reduce_mean(layer_weight_tensors, axis=0)
        for layer_weight_tensors in zip(*client_weights)
    ]

    return avg_weights

### Server - Clients synchronization

In [18]:
@tf.function
def synchronize(server_cnn, client_cnns):
    # server average
    server_cnn.set_trainable_variables(average_client_weights(client_cnns))
    
    # synchronize clients
    for client_cnn in client_cnns:
        client_cnn.set_trainable_variables(server_cnn.trainable_variables)

# Functional Dynamic Averaging

We follow the Functional Dynamic Averaging (FDA) scheme. Let the mean model be

$$ \overline{w_t} = \frac{1}{k} \sum_{i=1}^{k} w_t^{(i)} $$

where $ w_t^{(i)} $ is the model at time $ t $ in some round in the $i$-th learner.

Local models are trained independently and cooperatively and we want to monitor the Round Terminating Conditon (**RTC**):

$$ \frac{1}{k} \sum_{i=1}^{k} \lVert w_t^{(i)} - \overline{w_t} \rVert_2^2  \leq \Theta $$

where the left-hand side is the **model variance**, and threshold $\Theta$ is a hyperparameter of the FDA, defined at the beginning of the round; it may change at each round. When the monitoring logic cannot guarantee the validity of RTC, the round terminates. All local models are pulled into `tff.SERVER`, and $\bar{w_t}$ is set to their average. Then, another round begins.

### Monitoring the RTC

FDA monitors the RTC by applying techniques from Functionary [Functional Geometric Averaging](http://users.softnet.tuc.gr/~minos/Papers/edbt19.pdf). We first restate the problem of monitoring RTC into the standard distributed stream monitoring formulation. Let

$$ S(t) =  \frac{1}{k} \sum_{i=1}^{k} S_i(t) $$

where $ S(t) \in \mathbb{R}^n $ be the "global state" of the system and $ S_i(t) \in \mathbb{R}^n $ the "local states". The goal is to monitor the threshold condition on the global state in the form $ F(S(t)) \leq \Theta $ where $ F : \mathbb{R}^n \to \mathbb{R} $ a non-linear function. Let

$$ \Delta_t^{(i)} = w_t^{(i)} - w_{t_0}^{(i)} $$

be the update at the $ i $-th learner, that is, the change to the local model at time $t$ since the beginning of the current round at time $t_0$. Let the average update be

$$ \overline{\Delta_t} = \frac{1}{k} \sum_{i=1}^{k} \Delta_t^{(i)} $$

it follows that the variance can be written as

$$ \frac{1}{k} \sum_{i=1}^{k} \lVert w_t^{(i)} - \overline{w_t} \rVert_2^2 = \Big( \frac{1}{k} \sum_{i=1}^{k} \lVert \Delta_t^{(i)} \rVert_2^2 \Big) - \lVert \overline{\Delta_t} \rVert_2^2 $$

So, conceptually, if we define
$$ S_i(t) = \begin{bmatrix}
           \lVert \Delta_t^{(i)} \rVert_2^2 \\
           \Delta_t^{(i)}
         \end{bmatrix} \quad \text{and} \quad
         F(\begin{bmatrix}
           v \\
           \bf{x}
         \end{bmatrix}) = v - \lVert \bf{x} \rVert_2^2 $$

The RTC is equivalent to condition $$ F(S(t)) \leq \Theta $$

## 1️⃣ Naive FDA

In the naive approach, we eliminate the update vector from the local state (i.e. recuce the dimension to 0). Define local state as

$$ S_i(t) = \lVert \Delta_t^{(i)} \rVert_2^2 \in \mathbb{R}$$ 

and the identity function

$$ F(v) = v $$

It is trivial that $ F(S(t)) \leq \Theta $ implies the RTC.

### Client Steps

The number of steps depends on the dataset, i.e., `.take(num)` call on `tf.data.Dataset` creation!

In [19]:
@tf.function
def steps_naive(last_sync_cnn, client_cnn, client_dataset):
    # number of steps depend on `.take()` from `dataset`
    client_cnn.train(client_dataset)
    
    Delta_i = client_cnn.trainable_vars_as_vector() - last_sync_cnn.trainable_vars_as_vector()
    
    Delta_i_euc_norm_squared = tf.reduce_sum(tf.square(Delta_i)) # ||D(t)_i||^2
    
    return Delta_i_euc_norm_squared

### Training Loop

In [20]:
def F_naive(S):
    return S

In [21]:
@tf.function
def run_federated_simulation_naive(server_cnn, client_cnns, federated_dataset,
                                   num_epochs, theta, epoch_fda_steps):
    
    print("retracing naive")
    
    total_rounds = 0
    total_fda_steps = 0
    
    round_fda_steps = tf.constant(0, shape=(), dtype=tf.int32)
    epoch_count = tf.constant(0, shape=(), dtype=tf.int32)
    
    S = tf.constant(0., shape=(), dtype=tf.float32)
    
    """------------------------------time-series metrics-------------------------"""
    estimated_var = tf.constant(0., shape=(), dtype=tf.float32)  # for time series data
    actual_var = tf.constant(0., shape=(), dtype=tf.float32)  # for time series data
    time_series_data = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True)  # for time series data
    """------------------------------time-series metrics-------------------------"""
    
    while epoch_count < num_epochs:
        
        while F_naive(S) <= theta:
            S_i_clients = []

            # client steps (number depends on `federated_dataset`, i.e., `.take(num)`)
            for client_cnn, client_dataset in zip(client_cnns, federated_dataset):
                Delta_i_euc_norm_squared = steps_naive(server_cnn, client_cnn, client_dataset)
                S_i_clients.append(Delta_i_euc_norm_squared)
                
            S = tf.reduce_mean(S_i_clients)
            
            round_fda_steps += 1
            total_fda_steps += 1
            
            if round_fda_steps == epoch_fda_steps:
                epoch_count += 1
                round_fda_steps = tf.constant(0, shape=(), dtype=tf.int32)
                
                if epoch_count == num_epochs:
                    break
                    
        
        # server average
        server_cnn.set_trainable_variables(average_client_weights(client_cnns))
        
        """------------------------------time-series metrics-------------------------"""
        estimated_var = F_naive(S)
        actual_var = variance(client_cnns, server_cnn)
        
        time_series_data = time_series_data.write(
            total_rounds, 
            (tf.cast(total_fda_steps, dtype=tf.float32), estimated_var, actual_var)
        )
        """------------------------------time-series metrics-------------------------"""
        
        # reset variance approx
        S = tf.constant(0., shape=(), dtype=tf.float32)

        # synchronize clients
        for client_cnn in client_cnns:
            client_cnn.set_trainable_variables(server_cnn.trainable_variables)
            
        total_rounds += 1
    
    return total_rounds, total_fda_steps, time_series_data.stack()

## 2️⃣ Linear FDA

In the linear case, we reduce the update vector to a scalar, $ \xi \Delta_t^{(i)} \in \mathbb{R}$, where $ \xi $ is any unit vector.

Define the local state to be 

$$ S_i(t) = \begin{bmatrix}
           \lVert \Delta_t^{(i)} \rVert_2^2 \\
           \xi \Delta_t^{(i)}
         \end{bmatrix} \in \mathbb{R}^2 $$

Also, define 

$$ F(v, x) = v - x^2 $$

The RTC is equivalent to condition 

$$ F(S(t)) \leq \Theta $$

A random choice of $ \xi $ is likely to perform poorly (terminate round prematurely), as it wil likely be close to orthogonal to $ \overline{\Delta_t} $. A good choice would be a vector $ \xi $ correlated to $ \overline{\Delta_t} $. A heuristic choice is to take $ \overline{\Delta_{t_0}} $ (after scaling it to norm 1), i.e., the update vector right before the current round started. All nodes can estimate this without communication, as $ \overline{w_{t_0}} - \overline{w_{t_{-1}}} $, the difference of the last two models pushed by the Server. Hence, 

$$ \xi = \frac{\overline{w_{t_0}} - \overline{w_{t_{-1}}}}{\lVert \overline{w_{t_0}} - \overline{w_{t_{-1}}} \rVert_2} $$

In [22]:
@tf.function
def ksi_unit_fn(w_t0, w_tminus1):
    
    if tf.reduce_all(tf.equal(w_t0, w_tminus1)):
        # if equal then ksi becomes a random vector (will only happen in round 1)
        ksi = tf.random.normal(shape=w_t0.shape)
    else:
        ksi = w_t0 - w_tminus1

    # Normalize and return
    return tf.divide(ksi, tf.norm(ksi))

### Client Steps

The number of steps depends on the dataset, i.e., `.take(num)` call on `tf.data.Dataset` creation!

In [23]:
@tf.function
def steps_linear(cnn_tminus, cnn_t0, client_cnn, client_dataset):
    # number of steps depend on `.take()` from `dataset`
    client_cnn.train(client_dataset)
    
    Delta_i = client_cnn.trainable_vars_as_vector() - cnn_t0.trainable_vars_as_vector()
    
    #||D(t)_i||^2 , shape = (1,) 
    Delta_i_euc_norm_squared = tf.reduce_sum(tf.square(Delta_i)) # ||D(t)_i||^2
    
    # heuristic unit vector ksi
    ksi = ksi_unit_fn(cnn_t0.trainable_vars_as_vector(), cnn_tminus.trainable_vars_as_vector())
    
    # ksi * Delta_i (* is dot) , shape = ()
    ksi_Delta_i = tf.reduce_sum(tf.multiply(ksi, Delta_i))
    
    return Delta_i_euc_norm_squared, ksi_Delta_i

### Training Loop

In [24]:
def F_linear(S_1, S_2):
    return S_1 - S_2**2

In [25]:
@tf.function
def run_federated_simulation_linear(previous_server_cnn, server_cnn, client_cnns, federated_dataset,
                                   num_epochs, theta, epoch_fda_steps):
    
    print("retracing linear")
    
    total_rounds = 0
    total_fda_steps = 0
    
    round_fda_steps = tf.constant(0, shape=(), dtype=tf.int32)
    epoch_count = tf.constant(0, shape=(), dtype=tf.int32)
    
    S_1 = tf.constant(0., shape=(), dtype=tf.float32)
    S_2 = tf.constant(0., shape=(), dtype=tf.float32)
    
    """------------------------------time-series metrics-------------------------"""
    estimated_var = tf.constant(0., shape=(), dtype=tf.float32)  # for time series data
    actual_var = tf.constant(0., shape=(), dtype=tf.float32)  # for time series data
    time_series_data = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True)  # for time series data
    """------------------------------time-series metrics-------------------------"""
    
    while epoch_count < num_epochs:
        
        while F_linear(S_1, S_2) <= theta:
            euc_norm_squared_clients = []
            ksi_delta_clients = []

            # client steps (number depends on `federated_dataset`, i.e., `.take(num)`)
            for client_cnn, client_dataset in zip(client_cnns, federated_dataset):
                Delta_i_euc_norm_squared, ksi_Delta_i = steps_linear(
                    previous_server_cnn, server_cnn, client_cnn, client_dataset
                )
                
                euc_norm_squared_clients.append(Delta_i_euc_norm_squared)
                ksi_delta_clients.append(ksi_Delta_i)
            
            S_1 = tf.reduce_mean(euc_norm_squared_clients)
            S_2 = tf.reduce_mean(ksi_delta_clients)
            
            round_fda_steps += 1
            total_fda_steps += 1
            
            if round_fda_steps == epoch_fda_steps:
                epoch_count += 1
                round_fda_steps = tf.constant(0, shape=(), dtype=tf.int32)
                
                if epoch_count == num_epochs:
                    break
                    
        
        # last server model (previous sync)
        previous_server_cnn.set_trainable_variables(server_cnn.trainable_variables)
        
        # server average
        server_cnn.set_trainable_variables(average_client_weights(client_cnns))
        
        
        """------------------------------time-series metrics-------------------------"""
        estimated_var = F_linear(S_1, S_2)
        actual_var = variance(client_cnns, server_cnn)
        
        time_series_data = time_series_data.write(
            total_rounds, 
            (tf.cast(total_fda_steps, dtype=tf.float32), estimated_var, actual_var)
        )
        """------------------------------time-series metrics-------------------------"""
        
        # reset variance approx
        S_1 = tf.constant(0., shape=(), dtype=tf.float32)
        S_2 = tf.constant(0., shape=(), dtype=tf.float32)

        # synchronize clients
        for client_cnn in client_cnns:
            client_cnn.set_trainable_variables(server_cnn.trainable_variables)
            
        total_rounds += 1
    
    return total_rounds, total_fda_steps, time_series_data.stack()

## 3️⃣ Sketch FDA

An optimal estimator for $ \lVert \overline{\Delta_t} \rVert_2^2  $ can be obtained by employing AMS sketches. An AMS sketch of a vector $ v \in \mathbb{R}^M $ is a $ d \times m $ real matrix

$$ \Xi = \text{sk}(v) = \begin{bmatrix}
           \Xi_1 \\
           \Xi_2 \\
           \vdots \\
           \Xi_d 
         \end{bmatrix} $$
         
where $ d \cdot m \ll M$. Operator sk($ \cdot $) is linear, i.e., let $a, b \in \mathbb{R}$ and $v_1, v_2 \in \mathbb{R}^N$ then 

$$ \text{sk}(a v_1 + b v_2) = a \; \text{sk}(v_1) + b \; \text{sk}(v_2)  $$

Also, sk($ v $) can be computed in $ \mathcal{O}(dN) $ steps.

The interesting property of AMS sketches is that the function 

$$ M(sk(\textbf{v})) = \underset{i=1,...,d}{\text{median}} \; \lVert \boldsymbol{\Xi}_i \rVert_2^2  $$ 

is an excellent estimator of the Euclidean norm of **v** (within relative $\epsilon$-error):

$$ M(sk(\textbf{v})) \; \in (1 \pm \epsilon) \lVert \textbf{v} \rVert_2^2 \; \; \text{with probability at least} \; (1-\delta) $$

where $m = \mathcal{O}(\frac{1}{\epsilon^2})$ and $d = \mathcal{O}(\log \frac{1}{\delta})$
            
Moreover, let $\boldsymbol{\Xi} \in \mathbb{R}^{d \times m}$ and $ k \in \mathbb{R}$. It can be proven that

$$ M( \frac{1}{k} \boldsymbol{\Xi}) = \frac{1}{k^2} M(\boldsymbol{\Xi}) $$

Let's investigate a little further on how this helps us. The $i$-th client computes $ sk(\Delta_t^{(i)}) $ and sends it to the server. Notice

$$ M\big(sk(\Delta_t^{(1)}) + sk(\Delta_t^{(2)}) + ... + sk(\Delta_t^{(k)}) \big) = M\Big( \text{sk}\big( \sum_{i=1}^{k} \Delta_t^{(i)} \big) \Big)$$

Remember that

$$ \overline{\boldsymbol{\Delta}}_t = \frac{1}{k} \sum_{i=1}^{k} \boldsymbol{\Delta}_t^{(i)} $$

Then
            
$$ M\Big( \text{sk}\big( \overline{\boldsymbol{\Delta}}_t \big) \Big) = M\Big( \text{sk}\big( \frac{1}{k} \sum_{i=1}^{k} \boldsymbol{\Delta}_t^{(i)} \big) \Big) = \frac{1}{k^2} M\Big( \text{sk}\big( \sum_{i=1}^{k} \boldsymbol{\Delta}_t^{(i)} \big) \Big) $$


Which means that 

$$ \frac{1}{k^2} M\Big( \text{sk}\big( \sum_{i=1}^{k} \boldsymbol{\Delta}_t^{(i)} \big) \Big) \in (1 \pm \epsilon) \lVert \overline{\boldsymbol{\Delta}}_t \rVert_2^2 \; \; \text{w.p. at least} \; (1-\delta) $$

In the monitoring process it is essential that we do not overestimate $ \lVert \overline{\Delta_t} \rVert_2^2 $ because we would then underestimate the variance which would potentially result in actual varience exceeding $ \Theta$ without us noticing it. With this in mind,

$$ \frac{1}{k^2} M\Big( \text{sk}\big( \sum_{i=1}^{k} \Delta_t^{(i)} \big) \Big) \leq (1+\epsilon) \lVert \overline{\Delta_t} \rVert_2^2 \quad \text{with probability at least} \; (1-\delta)$$

Which means

$$ \frac{1}{(1+\epsilon)} \frac{1}{k^2} M\Big( \text{sk}\big( \sum_{i=1}^{k} \Delta_t^{(i)} \big) \Big) \leq \lVert \overline{\Delta_t} \rVert_2^2 \quad \text{with probability at least} \; (1-\delta)$$

Hence, the Server's estimation of $ \lVert \overline{\Delta_t} \rVert_2^2 $ is

$$ \frac{1}{(1+\epsilon)} \frac{1}{k^2} M\Big( sk(\Delta_t^{(1)}) + sk(\Delta_t^{(2)}) + ... + sk(\Delta_t^{(k)}) \big) \Big) $$

Define the local state to be 

$$ S_i(t) = \begin{bmatrix}
           \lVert \Delta_t^{(i)} \rVert_2^2 \\
           sk(\Delta_t^{(i)})
         \end{bmatrix} \in \mathbb{R}^{1+d \times m} \quad \text{and} \quad
         F(\begin{bmatrix}
           v \\
           \Xi
         \end{bmatrix}) = v - \frac{1}{(1+\epsilon)}  M(\Xi) \quad \text{where} \quad \Xi = \frac{1}{k} \sum_{i=1}^{k} sk(\Delta_t^{(i)}) $$

It follows that $ F(S(t)) \leq \Theta $ implies that the variance is less or equal to $ \Theta $ with probability at least $ 1-\delta $.


## AMS sketch

We use `ExtensionType` which is the way to go in order to avoid unecessary graph retracing when passing around `AmsSketch` type 'objects'.

In [26]:
from tensorflow.experimental import ExtensionType

class AmsSketch(ExtensionType):
    depth: int
    width: int
    F: tf.Tensor
        
        
    def __init__(self, depth=7, width=1500):
        self.depth = depth
        self.width = width
        self.F = tf.random.uniform(shape=(6, depth), minval=0, maxval=(1 << 31) - 1, dtype=tf.int32)

        
    @tf.function
    def hash31(self, x, a, b):

        r = a * x + b
        fold = tf.bitwise.bitwise_xor(tf.bitwise.right_shift(r, 31), r)
        return tf.bitwise.bitwise_and(fold, 2147483647)
    
    
    @tf.function
    def tensor_hash31(self, x, a, b): # GOOD
        """ Assumed that x is tensor shaped (d,) , i.e., a vector (for example, indices, i.e., tf.range(d)) """

        # Reshape x to have an extra dimension, resulting in a shape of (k, 1)
        x_reshaped = tf.expand_dims(x, axis=-1)

        # shape=(`v_dim`, 7)
        r = tf.multiply(a, x_reshaped) + b

        fold = tf.bitwise.bitwise_xor(tf.bitwise.right_shift(r, 31), r)
        
        return tf.bitwise.bitwise_and(fold, 2147483647)
    
    
    @tf.function
    def tensor_fourwise(self, x):
        """ Assumed that x is tensor shaped (d,) , i.e., a vector (for example, indices, i.e., tf.range(d)) """
        # 1st use the tensor hash31
        in1 = self.tensor_hash31(x, self.F[2], self.F[3])  # (`x_dim`, 7)
        
        # 2nd (notice we swap the first two params, no change really)
        in2 = self.tensor_hash31(x, in1, self.F[4])  # (`x_dim`, 7)
        
        in3 = self.tensor_hash31(x, in2, self.F[5])  # (`x_dim`, 7)
        
        in4 = tf.bitwise.bitwise_and(in3, 32768)  # (`x_dim`, 7)
        
        return 2 * (tf.bitwise.right_shift(in4, 15)) - 1  # (`x_dim`, 7)
        
        
    @tf.function
    def fourwise(self, x):

        result = 2 * (tf.bitwise.right_shift(tf.bitwise.bitwise_and(self.hash31(self.hash31(self.hash31(x, self.F[2], self.F[3]), x, self.F[4]), x, self.F[5]), 32768), 15)) - 1
        return result
    
    
    @tf.function
    def sketch_for_vector(self, v):
        """ Extremely efficient computation of sketch with only using tensors. """
        
        sketch = tf.zeros(shape=(self.depth, self.width), dtype=tf.float32)
        
        len_v = v.shape[0]
        
        pos_tensor = self.tensor_hash31(tf.range(len_v), self.F[0], self.F[1]) % self.width
        
        v_expand = tf.expand_dims(v, axis=-1)
        
        deltas_tensor = tf.multiply(tf.cast(self.tensor_fourwise(tf.range(len_v)), dtype=tf.float32), v_expand)
        
        range_tensor = tf.range(self.depth)
        
        # Expand dimensions to create a 2D tensor with shape (1, depth)
        range_tensor_expanded = tf.expand_dims(range_tensor, 0)

        # Use tf.tile to repeat the range `len_v` times
        repeated_range_tensor = tf.tile(range_tensor_expanded, [len_v, 1])
        
        # shape=(`len_v`, 7, 2)
        indices = tf.stack([repeated_range_tensor, pos_tensor], axis=-1)
        
        sketch = tf.tensor_scatter_nd_add(sketch, indices, deltas_tensor)
        
        return sketch
    
    
    @tf.function
    def sketch_for_vector2(self, v):
        """ Bad implementation for tensorflow. """

        sketch = tf.zeros(shape=(self.depth, self.width), dtype=tf.float32)

        for i in tf.range(tf.shape(v)[0], dtype=tf.int32):
            pos = self.hash31(i, self.F[0], self.F[1]) % self.width
            delta = tf.cast(self.fourwise(i), dtype=tf.float32) * v[i]
            indices_to_update = tf.stack([tf.range(self.depth, dtype=tf.int32), pos], axis=1)
            sketch = tf.tensor_scatter_nd_add(sketch, indices_to_update, delta)

        return sketch
        
    
    @staticmethod
    @tf.function
    def estimate_euc_norm_squared(sketch):

        @tf.function
        def _median(v):
            """ Median of tensor `v` with shape=(n,). Note: Suboptimal O(nlogn) but it's ok bcz n = `depth`"""
            length = tf.shape(v)[0]
            sorted_v = tf.sort(v)
            middle = length // 2

            return tf.cond(
                tf.equal(length % 2, 0),
                lambda: (sorted_v[middle - 1] + sorted_v[middle]) / 2.0,
                lambda: sorted_v[middle]
            )

        return _median(tf.reduce_sum(tf.square(sketch), axis=1))

### Client Steps

The number of steps depends on the dataset, i.e., `.take(num)` call on `tf.data.Dataset` creation!

In [27]:
@tf.function
def steps_sketch(last_sync_cnn, client_cnn, client_dataset, ams_sketch):
    # number of steps depend on `.take()` from `dataset`
    client_cnn.train(client_dataset)
    
    Delta_i = client_cnn.trainable_vars_as_vector() - last_sync_cnn.trainable_vars_as_vector()
    
    #||D(t)_i||^2 , shape = (1,) 
    Delta_i_euc_norm_squared = tf.reduce_sum(tf.square(Delta_i)) # ||D(t)_i||^2
    
    # sketch approx
    sketch = ams_sketch.sketch_for_vector(Delta_i)
    
    return Delta_i_euc_norm_squared, sketch

### Training Loop

In [28]:
def F_sketch(S_1, S_2, epsilon):
    """ `S_1` is mean || ||^2 as usual, S_2 is the `Ξ` as defined in the theoretical analysis above """
    return S_1 - (1. / (1. + epsilon)) * AmsSketch.estimate_euc_norm_squared(S_2)

In [29]:
@tf.function
def run_federated_simulation_sketch(server_cnn, client_cnns, federated_dataset, num_epochs, 
                                    theta, epoch_fda_steps, ams_sketch, epsilon):

    print("retracing sketch")
    
    total_rounds = 0
    total_fda_steps = 0
    
    round_fda_steps = tf.constant(0, shape=(), dtype=tf.int32)
    epoch_count = tf.constant(0, shape=(), dtype=tf.int32)
    
    S_1 = tf.constant(0., shape=(), dtype=tf.float32)
    S_2 = tf.zeros(shape=(ams_sketch.depth, ams_sketch.width), dtype=tf.float32)
    
    """------------------------------time-series metrics-------------------------"""
    estimated_var = tf.constant(0., shape=(), dtype=tf.float32)  # for time series data
    actual_var = tf.constant(0., shape=(), dtype=tf.float32)  # for time series data
    time_series_data = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True)  # for time series data
    """------------------------------time-series metrics-------------------------"""
    
    while epoch_count < num_epochs:
        
        while F_sketch(S_1, S_2, epsilon) <= theta:
            euc_norm_squared_clients = []
            sketch_clients = []

            # client steps (number depends on `federated_dataset`, i.e., `.take(num)`)
            for client_cnn, client_dataset in zip(client_cnns, federated_dataset):
                Delta_i_euc_norm_squared, sketch = steps_sketch(
                    server_cnn, client_cnn, client_dataset, ams_sketch
                )
                
                euc_norm_squared_clients.append(Delta_i_euc_norm_squared)
                sketch_clients.append(sketch)
            
            S_1 = tf.reduce_mean(euc_norm_squared_clients)
            S_2 = tf.reduce_mean(sketch_clients, axis=0)  # shape=(`depth`, width`). See `Ξ` in theoretical analysis
            
            #del euc_norm_squared_clients, sketch_clients
            
            round_fda_steps += 1
            total_fda_steps += 1
            
            if round_fda_steps == epoch_fda_steps:
                epoch_count += 1
                round_fda_steps = tf.constant(0, shape=(), dtype=tf.int32)
                
                if epoch_count == num_epochs:
                    break
        
        # server average
        server_cnn.set_trainable_variables(average_client_weights(client_cnns))
        
        """------------------------------time-series metrics-------------------------"""
        estimated_var = F_sketch(S_1, S_2, epsilon)
        actual_var = variance(client_cnns, server_cnn)
        
        time_series_data = time_series_data.write(
            total_rounds, 
            (tf.cast(total_fda_steps, dtype=tf.float32), estimated_var, actual_var)
        )
        """------------------------------time-series metrics-------------------------"""
        
        # reset variance approx
        S_1 = tf.constant(0., shape=(), dtype=tf.float32)
        S_2 = tf.zeros(shape=(ams_sketch.depth, ams_sketch.width), dtype=tf.float32)

        # synchronize clients
        for client_cnn in client_cnns:
            client_cnn.set_trainable_variables(server_cnn.trainable_variables)
            
        total_rounds += 1
    
    return total_rounds, total_fda_steps, time_series_data.stack()

# Simulation tests

###  Basic test

In [30]:
def basic_test(server_cnn, client_cnns, previous_server_cnn, starting_trainable_variables, 
               NUM_EPOCHS, NUM_STEPS_UNTIL_RTC_CHECK, NUM_CLIENTS, BATCH_SIZE, 
               THETA, EPSILON, ams_sketch, clients_federated_data, seed):
    
    """ One test for Naive,Linear,Sketch. Returns metrics """
    
    num_epochs = tf.constant(NUM_EPOCHS, shape=(), dtype=tf.int32)
    theta = tf.constant(THETA, shape=(), dtype=tf.float32)
    
    # for sketch
    epsilon = tf.constant(EPSILON, shape=(), dtype=tf.float32) # new
    
    epoch_client_batches = (n_train / BATCH_SIZE) / NUM_CLIENTS
    epoch_max_fda_steps = epoch_client_batches / NUM_STEPS_UNTIL_RTC_CHECK
    epoch_max_fda_steps = tf.constant(round(epoch_max_fda_steps), shape=(), dtype=tf.int32)
    
    basic_test_metrics = []
    basic_test_time_series_metrics = []
    
    """ --------------- Naive ----------------------------------"""
    
    # 1. tf.data.Dataset (we create it again because we want determinism)
    
    federated_dataset = prepare_federated_data_for_test(
        federated_data=clients_federated_data, 
        batch_size=BATCH_SIZE,
        num_steps_until_rtc_check=NUM_STEPS_UNTIL_RTC_CHECK,
        seed=seed
    )

    # 2. Run 

    total_rounds, total_fda_steps, time_series_data = run_federated_simulation_naive(
        server_cnn, 
        client_cnns, 
        federated_dataset, 
        num_epochs, 
        theta,
        epoch_max_fda_steps
    )
    
    # 3. compute metrics
    
    _, acc = server_cnn.evaluate(test_dataset, verbose=0)

    metrics = create_metrics_dict(
        fda_name="naive", 
        n_train=n_train, 
        dataset_name="EMNIST", 
        input_pixels=784, 
        seed=seed, 
        epochs=NUM_EPOCHS, 
        num_clients=NUM_CLIENTS, 
        batch_size=BATCH_SIZE, 
        steps_in_one_fda_step=NUM_STEPS_UNTIL_RTC_CHECK, 
        theta=THETA, 
        total_fda_steps=total_fda_steps.numpy(),
        num_weights=count_weights(server_cnn),
        total_rounds=total_rounds.numpy(), 
        final_accuracy=acc
    )
    
    basic_test_metrics.append(metrics)
    
    time_series_metrics = create_time_series_metrics(
        time_series_data=time_series_data.numpy(),
        dataset_name="EMNIST",
        fda_name="naive", 
        epochs=NUM_EPOCHS,
        num_clients=NUM_CLIENTS,
        batch_size=BATCH_SIZE, 
        steps_in_one_fda_step=NUM_STEPS_UNTIL_RTC_CHECK,
        theta=THETA, 
        num_weights=count_weights(server_cnn), 
        seed=seed
    )
    
    basic_test_time_series_metrics.append(time_series_metrics)

    del federated_dataset, total_rounds, time_series_data, time_series_metrics, total_fda_steps, acc
    
    # 4. IMPORTAND: Reset to the starting state all models
    reset_trainable_variables(server_cnn, client_cnns, starting_trainable_variables)
    
    
    """ --------------- Linear ----------------------------------"""

    # 1. tf.data.Dataset (we create it again because we want determinism)

    federated_dataset = prepare_federated_data_for_test(
        federated_data=clients_federated_data, 
        batch_size=BATCH_SIZE,
        num_steps_until_rtc_check=NUM_STEPS_UNTIL_RTC_CHECK,
        seed=seed
    )

    # 3. Run 

    total_rounds, total_fda_steps, time_series_data = run_federated_simulation_linear(
        previous_server_cnn,
        server_cnn, 
        client_cnns, 
        federated_dataset, 
        num_epochs, 
        theta,
        epoch_max_fda_steps
    )
    
    
    # 4. compute metrics
    
    loss, acc = server_cnn.evaluate(test_dataset, verbose=0)

    metrics = create_metrics_dict(
        fda_name="linear", 
        n_train=n_train, 
        dataset_name="EMNIST", 
        input_pixels=784, 
        seed=seed, 
        epochs=NUM_EPOCHS, 
        num_clients=NUM_CLIENTS, 
        batch_size=BATCH_SIZE, 
        steps_in_one_fda_step=NUM_STEPS_UNTIL_RTC_CHECK, 
        theta=THETA, 
        total_fda_steps=total_fda_steps.numpy(),
        num_weights=count_weights(server_cnn),
        total_rounds=total_rounds.numpy(), 
        final_accuracy=acc
    )
    
    basic_test_metrics.append(metrics)
    
    time_series_metrics = create_time_series_metrics(
        time_series_data=time_series_data.numpy(),
        dataset_name="EMNIST",
        fda_name="linear", 
        epochs=NUM_EPOCHS,
        num_clients=NUM_CLIENTS,
        batch_size=BATCH_SIZE, 
        steps_in_one_fda_step=NUM_STEPS_UNTIL_RTC_CHECK,
        theta=THETA, 
        num_weights=count_weights(server_cnn), 
        seed=seed
    )
    
    basic_test_time_series_metrics.append(time_series_metrics)

    del federated_dataset, total_rounds, time_series_data, time_series_metrics, total_fda_steps, acc
    
    # 4. IMPORTAND: Reset to the starting state all models
    reset_trainable_variables(server_cnn, client_cnns, starting_trainable_variables)
    
    previous_server_cnn.set_trainable_variables(starting_trainable_variables)  # +

    
    """ ------------------------ Sketch ----------------------"""

    
    # 1. tf.data.Dataset (we create it again because we want determinism)
    
    federated_dataset = prepare_federated_data_for_test(
        federated_data=clients_federated_data, 
        batch_size=BATCH_SIZE,
        num_steps_until_rtc_check=NUM_STEPS_UNTIL_RTC_CHECK,
        seed=seed
    )

    # 2. Run 

    total_rounds, total_fda_steps, time_series_data = run_federated_simulation_sketch(
        server_cnn=server_cnn, 
        client_cnns=client_cnns, 
        federated_dataset=federated_dataset,
        num_epochs=num_epochs, 
        theta=theta, 
        epoch_fda_steps=epoch_max_fda_steps, 
        ams_sketch=ams_sketch, 
        epsilon=epsilon
    )
    
    
    # 3. compute metrics
    
    loss, acc = server_cnn.evaluate(test_dataset, verbose=0)

    metrics = create_metrics_dict(
        fda_name="sketch", 
        n_train=n_train, 
        dataset_name="EMNIST", 
        input_pixels=784, 
        seed=seed, 
        epochs=NUM_EPOCHS, 
        num_clients=NUM_CLIENTS, 
        batch_size=BATCH_SIZE, 
        steps_in_one_fda_step=NUM_STEPS_UNTIL_RTC_CHECK, 
        theta=THETA, 
        total_fda_steps=total_fda_steps.numpy(),
        num_weights=count_weights(server_cnn),
        total_rounds=total_rounds.numpy(), 
        final_accuracy=acc, 
        sketch_width=ams_sketch.width, 
        sketch_depth=ams_sketch.depth
    )
    
    basic_test_metrics.append(metrics)
    
    time_series_metrics = create_time_series_metrics(
        time_series_data=time_series_data.numpy(),
        dataset_name="EMNIST",
        fda_name="sketch", 
        epochs=NUM_EPOCHS,
        num_clients=NUM_CLIENTS,
        batch_size=BATCH_SIZE, 
        steps_in_one_fda_step=NUM_STEPS_UNTIL_RTC_CHECK,
        theta=THETA, 
        num_weights=count_weights(server_cnn), 
        seed=seed, 
        sketch_width=ams_sketch.width, 
        sketch_depth=ams_sketch.depth
    )
    
    basic_test_time_series_metrics.append(time_series_metrics)

    del federated_dataset, total_rounds, time_series_data, time_series_metrics, total_fda_steps, acc
    
    # 4. IMPORTAND: Reset to the starting state all models
    reset_trainable_variables(server_cnn, client_cnns, starting_trainable_variables)
    
    return basic_test_metrics, basic_test_time_series_metrics

In [31]:
def print_info_current_test(NUM_EPOCHS, NUM_STEPS_UNTIL_RTC_CHECK, NUM_CLIENTS, THETA, BATCH_SIZE):
    print()
    print(f"----------- Current Test --------------")
    print(f"Num Clients : {NUM_CLIENTS}")
    print(f"Num Epochs : {NUM_EPOCHS}")
    print(f"Number of steps until we check RTC : {NUM_STEPS_UNTIL_RTC_CHECK}")
    print(f"Batch size : {BATCH_SIZE}")
    print(f"Theta : {THETA}")
    print("----------------------------------------")
    print()

In [32]:
from math import sqrt # new
from copy import deepcopy

def run_tests(NUM_CLIENTS_LIST, NUM_EPOCHS_LIST, NUM_STEPS_UNTIL_RTC_CHECK_LIST,
              BATCH_SIZE_LIST, THETA_LIST, SKETCH_DEPTH, SKETCH_WIDTH, SEED=None):
    
    """ --------------- Fixed configurations -------------------"""

    ams_sketch = AmsSketch(
        depth=SKETCH_DEPTH,
        width=SKETCH_WIDTH
    )

    EPSILON = 1. / sqrt(SKETCH_WIDTH)
    
    
    """ --------------- Metrics list ----------------------"""
    
    all_metrics = []
    
    all_time_series_metrics = []
        
    """ --------------- Run tests -------------------"""
    for NUM_CLIENTS in NUM_CLIENTS_LIST:

        # 1. Dataset for the same number of `NUM_CLIENTS`

        clients_federated_data = create_federated_data_for_clients(NUM_CLIENTS)  # new sliced dataset (diff NUM_CLIENTS)

        # 2. CNNs for the same number of `NUM_CLIENTS` 

        # we will create the CNNs here to avoid graph retracing (we will keep the same starting variables)
        server_cnn = get_compiled_and_built_advanced_cnn()
        client_cnns = [get_compiled_and_built_advanced_cnn() for _ in range(NUM_CLIENTS)]

        previous_server_cnn = get_compiled_and_built_advanced_cnn()  # For linear

        # synchronize
        synchronize(server_cnn, client_cnns)

        # keep the same starting variables in each test corresponding to the same `NUM_CLIENTS`
        starting_trainable_variables = deepcopy(server_cnn.trainable_variables)

        previous_server_cnn.set_trainable_variables(starting_trainable_variables)  # For linear

        for NUM_EPOCHS in NUM_EPOCHS_LIST:

            for NUM_STEPS_UNTIL_RTC_CHECK in NUM_STEPS_UNTIL_RTC_CHECK_LIST:

                for BATCH_SIZE in BATCH_SIZE_LIST:

                    for THETA in THETA_LIST:

                        print_info_current_test(NUM_EPOCHS, NUM_STEPS_UNTIL_RTC_CHECK, NUM_CLIENTS, THETA, BATCH_SIZE)

                        basic_test_metrics, basic_test_time_series_metrics = basic_test(
                            server_cnn=server_cnn,
                            client_cnns=client_cnns,
                            previous_server_cnn=previous_server_cnn,
                            starting_trainable_variables=starting_trainable_variables,
                            NUM_EPOCHS=NUM_EPOCHS, 
                            NUM_STEPS_UNTIL_RTC_CHECK=NUM_STEPS_UNTIL_RTC_CHECK,
                            NUM_CLIENTS=NUM_CLIENTS,
                            BATCH_SIZE=BATCH_SIZE, 
                            THETA=THETA, 
                            EPSILON=EPSILON,
                            ams_sketch=ams_sketch,
                            clients_federated_data=clients_federated_data,
                            seed=SEED
                        )

                        all_metrics.extend(basic_test_metrics)
                        all_time_series_metrics.extend(basic_test_time_series_metrics)

        # Delete previous stuff because we will encounter a different `NUM_CLIENTS`
        del clients_federated_data, server_cnn, client_cnns, previous_server_cnn, starting_trainable_variables

    return all_metrics, all_time_series_metrics

# Run Simulation Tests

In [33]:
all_metrics, all_time_series_metrics = run_tests(
    NUM_CLIENTS_LIST=[15],
    NUM_EPOCHS_LIST=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
    NUM_STEPS_UNTIL_RTC_CHECK_LIST=[1],
    BATCH_SIZE_LIST=[32, 64, 128],
    THETA_LIST=[0.05, 0.5, 1., 3., 5., 10.],
    SKETCH_DEPTH=7,
    SKETCH_WIDTH=500,
    SEED=7
)


----------- Current Test --------------
Num Clients : 5
Num Epochs : 1
Number of steps until we check RTC : 1
Batch size : 32
Theta : 1.0
----------------------------------------

retracing naive
retracing linear
retracing sketch
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089


In [34]:
all_metrics_df = pd.DataFrame(all_metrics)

In [35]:
all_metrics_df.to_csv('test_results/results.csv', index=False)

In [52]:
all_metrics_df

Unnamed: 0,fda_name,theta,dataset_name,input_pixels,n_train,num_weights,seed,epochs,num_clients,batch_size,...,model_bytes,local_state_bytes,final_accuracy,total_fda_steps,total_steps,total_rounds,model_bytes_exchanged,monitoring_bytes_exchanged,total_communication_bytes,trained_in_bytes
0,naive,1.0,EMNIST,784,60000,2592202,7,1,5,32,...,10368808,4,0.9659,375,375,91,9435615280,7500,9435622780,188400000
1,linear,1.0,EMNIST,784,60000,2592202,7,1,5,32,...,10368808,8,0.9493,375,375,86,8917174880,15000,8917189880,188400000
2,sketch,1.0,EMNIST,784,60000,2592202,7,1,5,32,...,10368808,14004,0.9616,375,375,95,9850367600,26257500,9876625100,188400000


In [37]:
# Combine the data to create a DataFrame
data = []
index_tuples = []

for time_series_dict in all_time_series_metrics:
    for id_tuple, time_series_metrics in time_series_dict.items():
        index_tuples.extend([id_tuple]*len(time_series_metrics))
        data.extend(time_series_metrics)

index = pd.MultiIndex.from_tuples(index_tuples, names=['dataset_name', 'fda_name', 'epochs', 'num_clients', 'batch_size', 'steps_in_one_fda_step', 'theta', 'num_weights', 'seed', 'sketch_width', 'sketch_depth'])
all_time_series_metrics_df = pd.DataFrame(data, index=index)

In [38]:
all_time_series_metrics_df

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,Unnamed: 6_level_0,Unnamed: 7_level_0,Unnamed: 8_level_0,Unnamed: 9_level_0,Unnamed: 10_level_0,round,total_fda_steps,total_steps,actual_var,estimated_var
dataset_name,fda_name,epochs,num_clients,batch_size,steps_in_one_fda_step,theta,num_weights,seed,sketch_width,sketch_depth,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1
EMNIST,naive,1,5,32,1,1.0,2592202,7,-1,-1,0,6,6,0.884019,1.098551
EMNIST,naive,1,5,32,1,1.0,2592202,7,-1,-1,1,12,12,1.103115,1.369943
EMNIST,naive,1,5,32,1,1.0,2592202,7,-1,-1,2,18,18,0.944327,1.159325
EMNIST,naive,1,5,32,1,1.0,2592202,7,-1,-1,3,24,24,0.945413,1.193771
EMNIST,naive,1,5,32,1,1.0,2592202,7,-1,-1,4,30,30,0.867770,1.120618
EMNIST,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
EMNIST,sketch,1,5,32,1,1.0,2592202,7,500,7,90,353,353,1.071295,1.084746
EMNIST,sketch,1,5,32,1,1.0,2592202,7,500,7,91,360,360,1.200325,1.224319
EMNIST,sketch,1,5,32,1,1.0,2592202,7,500,7,92,366,366,1.314449,1.328762
EMNIST,sketch,1,5,32,1,1.0,2592202,7,500,7,93,372,372,1.069446,1.084648


In [39]:
all_time_series_metrics_df.to_csv('test_results/time_series_results.csv', index=True)

In [40]:
#all_time_series_metrics_df = pd.read_csv('output_filename.csv', index_col=['dataset_name', 'fda_name', 'epochs', 'num_clients', 'batch_size', 'steps_in_one_fda_step', 'theta', 'num_weights', 'seed', 'sketch_width', 'sketch_depth'])

TODO:
    
4. DONE: `get_compiled_and_built_...()` retraces for `server_cnn` every time (ofc for `client_cnns` aswell).
   BUT: make sure once more that when we `reset` then afterwards `.evaluate` works correctly. Maybe weird shit with metrics. Check plz


5. Approach on sketch should be `reduce_mean`, change it in PA-I.
6. Approach on global tests `for` loop PA-I
7. remove `one` as a `tf.constant(1)` PA-I

In [33]:
def testing_stuff(THETA):
    
    NUM_CLIENTS = 5
    NUM_EPOCHS = 1
    BATCH_SIZE = 32
    NUM_STEPS_UNTIL_RTC_CHECK = 1
    seed = 7
    
    """ --------------- Fixed configurations -------------------"""
    
    SKETCH_DEPTH = 7
    SKETCH_WIDTH = 500

    ams_sketch = AmsSketch(
        depth=SKETCH_DEPTH,
        width=SKETCH_WIDTH
    )

    EPSILON = 1. / sqrt(SKETCH_WIDTH)
    
    
    """ --------------- Metrics list ----------------------"""
    
    all_metrics = []
    
    clients_federated_data = create_federated_data_for_clients(NUM_CLIENTS)  # new sliced dataset (diff NUM_CLIENTS)
            
    # 2. CNNs for the same number of `NUM_CLIENTS` 

    # we will create the CNNs here to avoid graph retracing (we will keep the same starting variables)
    server_cnn = get_compiled_and_built_advanced_cnn()
    client_cnns = [get_compiled_and_built_advanced_cnn() for _ in range(NUM_CLIENTS)]

    previous_server_cnn = get_compiled_and_built_advanced_cnn()  # For linear

    # synchronize
    synchronize(server_cnn, client_cnns)

    # keep the same starting variables in each test corresponding to the same `NUM_CLIENTS`
    starting_trainable_variables = deepcopy(server_cnn.trainable_variables)

    previous_server_cnn.set_trainable_variables(starting_trainable_variables)  # For linear
    
    
    # 3. Dataset
    
    
    """ One test for Naive,Linear,Sketch. Returns metrics """
    
    num_epochs = tf.constant(NUM_EPOCHS, shape=(), dtype=tf.int32)
    theta = tf.constant(THETA, shape=(), dtype=tf.float32)
    
    # for sketch
    epsilon = tf.constant(EPSILON, shape=(), dtype=tf.float32) # new

    epoch_max_fda_steps = tf.constant(20, shape=(), dtype=tf.int32)
    
    basic_test_metrics = []
    basic_test_time_series_metrics = []
    
    """ --------------- Naive ----------------------------------"""
    
    # 1. tf.data.Dataset (we create it again because we want determinism)
    
    federated_dataset = prepare_federated_data_for_test(
        federated_data=clients_federated_data, 
        batch_size=BATCH_SIZE,
        num_steps_until_rtc_check=NUM_STEPS_UNTIL_RTC_CHECK,
        seed=seed
    )

    # 2. Run 

    total_rounds, total_fda_steps, time_series_data = run_federated_simulation_naive(
        server_cnn, 
        client_cnns, 
        federated_dataset, 
        num_epochs, 
        theta,
        epoch_max_fda_steps
    )
    
    # 3. compute metrics
    
    _, acc = server_cnn.evaluate(test_dataset, verbose=0)

    metrics = create_metrics_dict(
        fda_name="naive", 
        n_train=n_train, 
        dataset_name="EMNIST", 
        input_pixels=784, 
        seed=seed, 
        epochs=NUM_EPOCHS, 
        num_clients=NUM_CLIENTS, 
        batch_size=BATCH_SIZE, 
        steps_in_one_fda_step=NUM_STEPS_UNTIL_RTC_CHECK, 
        theta=THETA, 
        total_fda_steps=total_fda_steps.numpy(),
        num_weights=count_weights(server_cnn),
        total_rounds=total_rounds.numpy(), 
        final_accuracy=acc
    )
    
    basic_test_metrics.append(metrics)
    
    time_series_metrics = create_time_series_metrics(
        time_series_data=time_series_data.numpy(),
        dataset_name="EMNIST",
        fda_name="naive", 
        epochs=NUM_EPOCHS,
        num_clients=NUM_CLIENTS,
        batch_size=BATCH_SIZE, 
        steps_in_one_fda_step=NUM_STEPS_UNTIL_RTC_CHECK,
        theta=THETA, 
        num_weights=count_weights(server_cnn), 
        seed=seed
    )
    
    basic_test_time_series_metrics.append(time_series_metrics)

    del federated_dataset, total_rounds, time_series_data, time_series_metrics, total_fda_steps, acc
    
    # 4. IMPORTAND: Reset to the starting state all models
    reset_trainable_variables(server_cnn, client_cnns, starting_trainable_variables)
    
    
    """ --------------- Linear ----------------------------------"""

    # 1. tf.data.Dataset (we create it again because we want determinism)

    federated_dataset = prepare_federated_data_for_test(
        federated_data=clients_federated_data, 
        batch_size=BATCH_SIZE,
        num_steps_until_rtc_check=NUM_STEPS_UNTIL_RTC_CHECK,
        seed=seed
    )

    # 3. Run 

    total_rounds, total_fda_steps, time_series_data = run_federated_simulation_linear(
        previous_server_cnn,
        server_cnn, 
        client_cnns, 
        federated_dataset, 
        num_epochs, 
        theta,
        epoch_max_fda_steps
    )
    
    
    # 4. compute metrics
    
    loss, acc = server_cnn.evaluate(test_dataset, verbose=0)

    metrics = create_metrics_dict(
        fda_name="linear", 
        n_train=n_train, 
        dataset_name="EMNIST", 
        input_pixels=784, 
        seed=seed, 
        epochs=NUM_EPOCHS, 
        num_clients=NUM_CLIENTS, 
        batch_size=BATCH_SIZE, 
        steps_in_one_fda_step=NUM_STEPS_UNTIL_RTC_CHECK, 
        theta=THETA, 
        total_fda_steps=total_fda_steps.numpy(),
        num_weights=count_weights(server_cnn),
        total_rounds=total_rounds.numpy(), 
        final_accuracy=acc
    )
    
    basic_test_metrics.append(metrics)
    
    time_series_metrics = create_time_series_metrics(
        time_series_data=time_series_data.numpy(),
        dataset_name="EMNIST",
        fda_name="linear", 
        epochs=NUM_EPOCHS,
        num_clients=NUM_CLIENTS,
        batch_size=BATCH_SIZE, 
        steps_in_one_fda_step=NUM_STEPS_UNTIL_RTC_CHECK,
        theta=THETA, 
        num_weights=count_weights(server_cnn), 
        seed=seed
    )
    
    basic_test_time_series_metrics.append(time_series_metrics)

    del federated_dataset, total_rounds, time_series_data, time_series_metrics, total_fda_steps, acc
    
    # 4. IMPORTAND: Reset to the starting state all models
    reset_trainable_variables(server_cnn, client_cnns, starting_trainable_variables)
    
    previous_server_cnn.set_trainable_variables(starting_trainable_variables)  # +

    
    """ ------------------------ Sketch ----------------------"""

    
    # 1. tf.data.Dataset (we create it again because we want determinism)
    
    federated_dataset = prepare_federated_data_for_test(
        federated_data=clients_federated_data, 
        batch_size=BATCH_SIZE,
        num_steps_until_rtc_check=NUM_STEPS_UNTIL_RTC_CHECK,
        seed=seed
    )

    # 2. Run 

    total_rounds, total_fda_steps, time_series_data = run_federated_simulation_sketch(
        server_cnn=server_cnn, 
        client_cnns=client_cnns, 
        federated_dataset=federated_dataset,
        num_epochs=num_epochs, 
        theta=theta, 
        epoch_fda_steps=epoch_max_fda_steps, 
        ams_sketch=ams_sketch, 
        epsilon=epsilon
    )
    
    
    # 3. compute metrics
    
    loss, acc = server_cnn.evaluate(test_dataset, verbose=0)

    metrics = create_metrics_dict(
        fda_name="sketch", 
        n_train=n_train, 
        dataset_name="EMNIST", 
        input_pixels=784, 
        seed=seed, 
        epochs=NUM_EPOCHS, 
        num_clients=NUM_CLIENTS, 
        batch_size=BATCH_SIZE, 
        steps_in_one_fda_step=NUM_STEPS_UNTIL_RTC_CHECK, 
        theta=THETA, 
        total_fda_steps=total_fda_steps.numpy(),
        num_weights=count_weights(server_cnn),
        total_rounds=total_rounds.numpy(), 
        final_accuracy=acc, 
        sketch_width=ams_sketch.width, 
        sketch_depth=ams_sketch.depth
    )
    
    basic_test_metrics.append(metrics)
    
    time_series_metrics = create_time_series_metrics(
        time_series_data=time_series_data.numpy(),
        dataset_name="EMNIST",
        fda_name="sketch", 
        epochs=NUM_EPOCHS,
        num_clients=NUM_CLIENTS,
        batch_size=BATCH_SIZE, 
        steps_in_one_fda_step=NUM_STEPS_UNTIL_RTC_CHECK,
        theta=THETA, 
        num_weights=count_weights(server_cnn), 
        seed=seed, 
        sketch_width=ams_sketch.width, 
        sketch_depth=ams_sketch.depth
    )
    
    basic_test_time_series_metrics.append(time_series_metrics)

    del federated_dataset, total_rounds, time_series_data, time_series_metrics, total_fda_steps, acc
    
    # 4. IMPORTAND: Reset to the starting state all models
    reset_trainable_variables(server_cnn, client_cnns, starting_trainable_variables)
    
    return basic_test_metrics, basic_test_time_series_metrics
    

In [34]:
basic_test_metrics, basic_test_time_series_metrics = testing_stuff(1.)

retracing naive
retracing linear
retracing sketch
