In [1]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import tensorflow as tf

from tensorflow import keras
from tensorflow.keras import layers

In [2]:
(X_train, y_train), (X_test, y_test) = keras.datasets.mnist.load_data()

X_train, X_test = X_train / 255.0, X_test / 255.0

In [3]:
n_train = len(X_train)

In [4]:
X_train_tensor = tf.constant(X_train, dtype=tf.float32)
del X_train

y_train_tensor = tf.constant(y_train, dtype=tf.int32)
del y_train

X_test_tensor = tf.constant(X_test, dtype=tf.float32)
del X_test

y_test_tensor = tf.constant(y_test, dtype=tf.int32)
del y_test

In [5]:
slices_test = (X_test_tensor, y_test_tensor)

def create_tf_dataset_for_testing(batch_size):
    return tf.data.Dataset.from_tensor_slices(slices_test).batch(batch_size)

In [6]:
def create_data_for_clients(num_clients):
    
    client_slices_train = {}

    for i in range(num_clients):
        # Compute the indices for this client's slice
        start_idx = int(i * n_train / num_clients)
        end_idx = int((i + 1) * n_train / num_clients)

        # Get the slice for this client
        X_client_train = X_train_tensor[start_idx:end_idx]
        y_client_train = y_train_tensor[start_idx:end_idx]
        
        # Combine the slices into a single dataset
        client_slices_train[f'client_{i}'] = (X_client_train, y_client_train)
    
    return client_slices_train

In [7]:
def create_tf_dataset_for_client(client_tensor_slices, batch_size, shuffle_buffer_size, num_steps_until_rtc_check, seed):
    
        return tf.data.Dataset.from_tensor_slices(client_tensor_slices) \
            .shuffle(buffer_size=shuffle_buffer_size, seed=seed).batch(batch_size) \
            .prefetch(tf.data.AUTOTUNE).take(num_steps_until_rtc_check)

In [8]:
def create_federated_data(client_slices_train, batch_size, shuffle_buffer_size, num_steps_until_rtc_check, seed=None):
    
    federated_dataset = [ 
        create_tf_dataset_for_client(client_tensor_slices, batch_size, shuffle_buffer_size, num_steps_until_rtc_check, seed)
        for client, client_tensor_slices in client_slices_train.items()
    ]
    
    return federated_dataset

In [9]:
def get_uncompiled_model():
    inputs = keras.Input(shape=(28, 28, 1), name="digits")
    x = layers.Conv2D(32, 3, activation='relu')(inputs)
    x = layers.MaxPooling2D(pool_size=(2, 2))(x)
    x = layers.Flatten()(x)
    x = layers.Dense(128, activation='relu')(x)
    outputs = layers.Dense(10, activation='softmax')(x)
    model = keras.Model(inputs=inputs, outputs=outputs)
    return model

In [10]:
def get_compiled_model():
    model = get_uncompiled_model()
    model.compile(
        optimizer=keras.optimizers.Adam(),
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=False), # we have softmax
        metrics=[keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')]
    )
    return model

Approach -> [Keras Doc](https://keras.io/guides/customizing_what_happens_in_fit/)

In [11]:
@tf.function
def client_train(model, dataset):
    
    print("Retrace `client_train`")
    
    @tf.function
    def _step(model, batch):
        
        x_batch, y_batch = batch

        with tf.GradientTape() as tape:
            # Forward pass: Compute predictions
            y_batch_pred = model(x_batch, training=True)  # Forward pass

            # Compute the loss value
            # (the loss function is configured in `compile()`)
            loss = model.compiled_loss(
                y_true=y_batch,
                y_pred=y_batch_pred,
                regularization_losses=model.losses
            )

        # Compute gradients
        gradients = tape.gradient(loss, model.trainable_variables)

        # Apply gradients to the model's trainable variables (update weights)
        model.optimizer.apply_gradients(zip(gradients, model.trainable_variables))

        # Update metrics (includes the metric that tracks the loss)
        model.compiled_metrics.update_state(y_batch, y_batch_pred)
    
    for batch in dataset:
        _step(model, batch)
    
    # Return a dict mapping metric names to current value
    #return {m.name: m.result() for m in model.metrics}


In [12]:
@tf.function
def train(model, r):
    for _ in tf.range(r):
        client_train(model, fed_data[0])

# Early tests

In [13]:
client_slices_train = create_data_for_clients(1)

In [14]:
fed_data = create_federated_data(
    client_slices_train=client_slices_train,
    batch_size=32,
    shuffle_buffer_size=n_train,
    num_steps_until_rtc_check=1,
    seed=None
)

In [15]:
model1 = get_compiled_model()

In [16]:
model2 = get_compiled_model()

In [17]:
train(model1, tf.constant(500, shape=(), dtype=tf.int32))

Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
Retrace `client_train`


In [18]:
{m.name: m.result() for m in model1.metrics}

{'loss': <tf.Tensor: shape=(), dtype=float32, numpy=0.28858897>,
 'test_accuracy': <tf.Tensor: shape=(), dtype=float32, numpy=0.9164375>}

In [19]:
train(model2, tf.constant(1000, shape=(), dtype=tf.int32))

Retrace `client_train`


In [20]:
{m.name: m.result() for m in model2.metrics}

{'loss': <tf.Tensor: shape=(), dtype=float32, numpy=0.19986692>,
 'test_accuracy': <tf.Tensor: shape=(), dtype=float32, numpy=0.94>}

In [None]:
# check difference variance etc.

In [21]:
#keras.utils.plot_model(model, "damn.png", show_shapes=True)