In [1]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import tensorflow as tf

from tensorflow import keras
from tensorflow.keras import layers

import numpy as np

## Load EMNIST data

In [2]:
def load_data():
    (X_train, y_train), (X_test, y_test) = keras.datasets.mnist.load_data()
    X_train, X_test = X_train / 255.0, X_test / 255.0

    return X_train, y_train, X_test, y_test

In [3]:
CNN_BATCH_INPUT = (None, 28, 28) # EMNIST dataset (None is used for batch size, as it varies)
CNN_INPUT_RESHAPE = (28, 28, 1)
n_train = 60_000

## Prepare data for Federated Learning

In [4]:
def convert_to_tf_dataset(X_train, y_train, X_test, y_test):
    # Convert to TensorFlow Datasets
    train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train))
    test_dataset = tf.data.Dataset.from_tensor_slices((X_test, y_test)).batch(256)

    return train_dataset, test_dataset

### Slice the Tensors for each Client

In [5]:
def create_federated_data_for_clients(num_clients, train_dataset):
    
    # Shard the data across clients CLIENT LEVEL
    client_datasets = [
        train_dataset.shard(num_clients, i)
        for i in range(num_clients)
    ]
    
    return client_datasets

In [6]:
def prepare_federated_data_for_test(federated_data, batch_size, num_steps_until_rtc_check=1, seed=None):
    
    def process_client_dataset(client_dataset, batch_size, num_steps_until_rtc_check, seed, shuffle_size=512):
        return client_dataset.shuffle(shuffle_size, seed=seed).repeat().batch(batch_size)\
            .take(num_steps_until_rtc_check).prefetch(tf.data.AUTOTUNE)
        
    federated_dataset_prepared = [
        process_client_dataset(client_dataset, batch_size, num_steps_until_rtc_check, seed)
        for client_dataset in federated_data
    ]
    return federated_dataset_prepared

# Simple Convolutional Neural Net (CNN) - Medium Size

A simple Convolutional Neural Network with a single convolutional layer, followed by a max-pooling layer, and two dense layers for classification. Designed for 28x28 grayscale images. It has 692,352 weights.

In [7]:
class SimpleCNN(tf.keras.Model):
    def __init__(self, num_classes=10):
        super(SimpleCNN, self).__init__()
        self.reshape = layers.Reshape(CNN_INPUT_RESHAPE)
        self.conv1 = layers.Conv2D(32, 3, activation='relu')
        self.max_pool = layers.MaxPooling2D(pool_size=(2, 2))
        self.flatten = layers.Flatten()
        self.dense1 = layers.Dense(128, activation='relu')
        self.dense2 = layers.Dense(num_classes, activation='softmax')

        
    # Defines the computation from inputs to outputs
    def call(self, inputs, training=None):
        x = self.reshape(inputs)  # Add a channel dimension
        x = self.conv1(x)
        x = self.max_pool(x)
        x = self.flatten(x)
        x = self.dense1(x)
        x = self.dense2(x)
        return x
    
    
    def step(self, batch):
        
        x_batch, y_batch = batch

        with tf.GradientTape() as tape:
            # Forward pass: Compute predictions
            y_batch_pred = self(x_batch, training=True)

            # Compute the loss value
            # (the loss function is configured in `compile()`)
            loss = self.compiled_loss(
                y_true=y_batch,
                y_pred=y_batch_pred,
                regularization_losses=self.losses
            )

        # Compute gradients
        gradients = tape.gradient(loss, self.trainable_variables)
        
        # Apply gradients to the model's trainable variables (update weights)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
        
        # Update metrics (includes the metric that tracks the loss)
        #self.compiled_metrics.update_state(y_batch, y_batch_pred)
    
    
    def train(self, dataset):

        for batch in dataset:
            self.step(batch)
            
    
    def set_trainable_variables(self, trainable_vars):
        """ Given `trainable_vars` set our `self.trainable_vars` """
        for model_var, var in zip(self.trainable_variables, trainable_vars):
            model_var.assign(var)

            
    def trainable_vars_as_vector(self):
        return tf.concat([tf.reshape(var, [-1]) for var in self.trainable_variables], axis=0)

### Helper function to compile and return the CNN

**Important** function that returns a compiled and built `SimpleCNN`.

In [40]:
def get_compiled_and_built_simple_cnn():
    cnn = SimpleCNN()
    
    cnn.compile(
        optimizer=keras.optimizers.Adam(),
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=False), # we have softmax
        metrics=[keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')]
    )
    
    cnn.build(CNN_BATCH_INPUT)  # EMNIST dataset (None is used for batch size, as it varies)
    
    return cnn

Helper function that creates a new `AdvancedCNN` just to compute and return accuracy. Will be used in-between epochs so we keep correct metrics without synchronizing and messing up the federated training.

In [9]:
def current_accuracy_simple_cnn(client_cnns, test_dataset):
    
    tmp_cnn = get_compiled_and_built_simple_cnn()
    tmp_cnn.set_trainable_variables(average_client_weights(client_cnns))
    _, acc = tmp_cnn.evaluate(test_dataset, verbose=0)
    
    return acc

### Client Train

The number of steps depends on the dataset, i.e., `.take(num)` call on `tf.data.Dataset` creation!

In [10]:
def client_train_naive(w_t0, client_cnn, client_dataset):
    """
    :param w_t0: Vector Tensor shape=(d,). Same shape with `client_cnn.trainable_vars_as_vector()`
    :return: Tensor shape=() dtype=tf.float32
    """
    
    # number of steps depend on `.take()` from `dataset`
    client_cnn.train(client_dataset)
    
    Delta_i = client_cnn.trainable_vars_as_vector() - w_t0
    
    #||D(t)_i||^2 , shape = () 
    Delta_i_euc_norm_squared = tf.reduce_sum(tf.square(Delta_i)) # ||D(t)_i||^2
    
    return Delta_i_euc_norm_squared

### Train all Clients

Notes about the following potentialy general `tf.function`:

1. Even though `clients_cnn` and `federated_dataset` contain `tf.Keras.Module` and `tf.data.Dataset` elements, they both are python lists (python side-effects). Take a look at [Looping Over Python data](https://www.tensorflow.org/guide/function#tracing) and afterwards [For example, the following loop is unrolled, even though the list contains ...](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/control_flow.md#while-statements) to get more insight.

2. TL;DR: It is very-very bad in terms of RAM. It produces an unrolled loop. The graph becomes consequent `Delta_i_... = ... ; S_i_clients.append(...) ;` commands `len(client_cnns)` number of times. This produces a huge graph (for instance, for `NUM_CLIENTS`=8, 4GB graph is produced). Notice that each sequence of the two commands has a big (unseen) underlying graph going to the bottom, that is, `.step` in the `tf.Keras.Module` class!

3. Even if we had endless RAM the usage of `tf.function` is still arguable. For instance, on testing for 16 clients the difference between the two is only 20-30ms with total execution time in the order of 200-250ms. Only if we had a huge amount of CPUs or GPU we could consider it, but still... there must be a better approach (`Dask` or a different implementation).

In [11]:
def clients_train_naive(w_t0, client_cnns, federated_dataset):
    """
    :param w_t0: Vector Tensor shape=(d,). Same shape with `client_cnns[i].trainable_vars_as_vector()`
    :return: List of `Tensor shape=() dtype=tf.float32`, one for each `client_cnn` in `client_cnns`.
    """
    
    S_i_clients = []

    for client_cnn, client_dataset in zip(client_cnns, federated_dataset):
        Delta_i_euc_norm_squared = client_train_naive(w_t0, client_cnn, client_dataset)
        S_i_clients.append(Delta_i_euc_norm_squared)
    
    return S_i_clients

### Identity F Function

In [12]:
def F_naive(S_i_clients):
    """ :return: Tensor shape=() dtype=tf.float32 , Naive variance approximation """
    
    S = tf.reduce_mean(S_i_clients)
    
    return S

### Testing Preparation

In [13]:
def prepare_for_federated_simulation(num_clients, train_dataset, batch_size, num_steps_until_rtc_check=1, seed=None, bench_test=False):
    
    # 1. Helper variable to count Epochs
    if bench_test:
        fda_steps_in_one_epoch = 10
    else:
        fda_steps_in_one_epoch = ((n_train / batch_size) / num_clients) / num_steps_until_rtc_check
    
    # 2. Federated Dataset creation
    clients_federated_data = create_federated_data_for_clients(num_clients, train_dataset)
    federated_dataset = prepare_federated_data_for_test(clients_federated_data, batch_size)
    
    # 3. Models creation
    server_cnn = get_compiled_and_built_simple_cnn()
    client_cnns = [get_compiled_and_built_simple_cnn() for _ in range(num_clients)]
    
    return server_cnn, client_cnns, federated_dataset, fda_steps_in_one_epoch

In [14]:
train_dataset, test_dataset = convert_to_tf_dataset(*load_data())

server_cnn, client_cnns, federated_dataset, fda_steps_in_one_epoch = prepare_for_federated_simulation(
    num_clients=5, 
    train_dataset=train_dataset, 
    batch_size=32
)


This function provides the most importand functionality to achieve high parallelism. Given a list of client NN weights
and a list of client datasets (1-1 correspondance) we store each client's information (the weights and the batch) inside
a unifying `tf.data.Dataset` and return it.
    
This is extremely importand because, quoting [here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/control_flow.md#while-statements), "`for` statements that iterate over a `tf.data.Dataset` and which do not contain `break` or `return` statements are executed as TensorFlow loops by converting them to `tf.data.Dataset.reduce` ops"!

In [15]:
def transform_to_single_ds(weights_of_clients, federated_dataset):
    """ 
    :param: weights_of_clients: A list where the i-th element is client_cnn[i].get_weights()
    :param: federated_dataset: A list where the i-th element is a `_PrefetchDataset` (`.take(1)`) of the sharded
        i-th client's dataset.
    :returns: A single `tf.data.Dataset` with k elements, where k the number of clients, where each element is
        (client_weights, client_batch). 
    """
    weights_batches_ds = None
    
    for client_weights, client_batch in zip(weights_of_clients, federated_dataset):
    
        if weights_batches_ds is None:

            client_weights_ds = tf.data.Dataset.from_tensors((*client_weights, ))
            weights_batches_ds = tf.data.Dataset.zip((client_weights_ds, client_batch))

        else:

            # Create a `Dataset` with a single element comprising of the client cnn layers' weights
            client_weights_ds = tf.data.Dataset.from_tensors((*client_weights, ))

            # Zip together the information of the client weights along with its dataset (_Prefetch)
            client_weights_batch_ds = tf.data.Dataset.zip((client_weights_ds, client_batch))

            # Concatenate the above with the rest of the clients
            weights_batches_ds = weights_batches_ds.concatenate(client_weights_batch_ds)
    
    return weights_batches_ds

In [16]:
weights_of_clients = [client_cnn.get_weights() for client_cnn in client_cnns]

In [17]:
weights_batches_ds = transform_to_single_ds(weights_of_clients, federated_dataset)

In [85]:
def client_step(client_weights, client_batch):
    
    cnn = get_compiled_and_built_simple_cnn()
    
    weights = [w for w in client_weights]
    
    cnn.set_weights(weights)
    
    
    return client_weights

In [87]:
updated_weights_dataset = weights_batches_ds.map(
    client_step
)

NotImplementedError: in user code:

    File "/tmp/ipykernel_47754/2618185590.py", line 7, in client_step  *
        cnn.set_weights(weights)
    File "/home/miketheologitis/anaconda3/envs/tf-2.12/lib/python3.9/site-packages/keras/engine/base_layer.py", line 1835, in set_weights  **
        backend.batch_set_value(weight_value_tuples)
    File "/home/miketheologitis/anaconda3/envs/tf-2.12/lib/python3.9/site-packages/keras/backend.py", line 4311, in batch_set_value
        value = np.asarray(value, dtype=dtype_numpy(x))

    NotImplementedError: Cannot convert a symbolic tf.Tensor (args_0:0) to a numpy array. This error may indicate that you're trying to pass a Tensor to a NumPy call, which is not supported.


In [77]:
server_cnn.get_weights()

[array([[[[-0.09867188,  0.14038436,  0.04368007,  0.10912751,
           -0.13217494,  0.12761752,  0.11392225, -0.14094844,
            0.05158406, -0.02202985, -0.03674391,  0.00636083,
            0.08236691, -0.00184546, -0.06704932,  0.06318627,
           -0.09771751, -0.03474028,  0.05229424, -0.08579065,
           -0.03078154,  0.08751979, -0.03356294,  0.11475404,
            0.12217663,  0.13530181,  0.04078162,  0.02428082,
            0.04578532,  0.13023688, -0.12828481,  0.07579209]],
 
         [[ 0.12486617, -0.05201104,  0.06867196,  0.1381575 ,
           -0.07834218,  0.02061407, -0.06158811, -0.06629313,
            0.06348149, -0.1225654 , -0.0386292 , -0.06996137,
            0.09493688,  0.05410103, -0.13650358,  0.0615605 ,
           -0.10308299,  0.03784294, -0.09747477,  0.06917003,
            0.1345068 ,  0.0225428 ,  0.01310159, -0.11230831,
            0.03261822,  0.08617488, -0.03707909, -0.06101545,
           -0.07096687, -0.01604684, -0.13965279,  

In [83]:
for client_weights in updated_weights_dataset:
    cnn = get_compiled_and_built_simple_cnn()
    
    weights = [w for w in client_weights]
    
    cnn.set_weights(weights)
    break