In [17]:
import collections
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff
from scipy.stats import ks_2samp, chi2_contingency
import concurrent.futures

In [18]:
# Load the MNIST dataset
mnist = tf.keras.datasets.mnist
(X_train, Y_train), (X_test, Y_test) = mnist.load_data()

In [19]:
# Normalize the data
X_train, X_test = X_train / 255.0, X_test / 255.0

In [20]:
# Define the first model (A1)
def create_model_A1():
    model = tf.keras.Sequential([
        tf.keras.layers.Flatten(input_shape=(28, 28)),
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.Dropout(0.2),
        tf.keras.layers.Dense(10, activation='softmax')
    ])
    return model

In [21]:
# Define the second model (A2)
def create_model_A2():
    model = tf.keras.Sequential([
        tf.keras.layers.Flatten(input_shape=(28, 28)),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(10, activation='softmax')
    ])
    return model


In [6]:
# Convert Keras models to TFF models
def model_fn_A1():
    keras_model = create_model_A1()
    return tff.learning.from_keras_model(
        keras_model,
        input_spec=client_datasets[0].element_spec,
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]
    )

In [22]:
def model_fn_A2():
    keras_model = create_model_A2()
    return tff.learning.from_keras_model(
        keras_model,
        input_spec=client_datasets[0].element_spec,
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]
    )

In [8]:
# Create dummy client datasets
client_datasets = [tf.data.Dataset.from_tensor_slices((np.random.rand(10, 28, 28), np.random.randint(0, 10, 10))).batch(10) for _ in range(10)]

In [23]:
# Create federated learning algorithms
federated_algorithm_A1 = tff.learning.build_federated_averaging_process(
    model_fn_A1,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0)
)



In [24]:
federated_algorithm_A2 = tff.learning.build_federated_averaging_process(
    model_fn_A2,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0)
)


In [25]:
# Function to run the training in a separate event loop
def run_training():
    tff.backends.native.set_local_execution_context()

    state_A1 = federated_algorithm_A1.initialize()
    state_A2 = federated_algorithm_A2.initialize()

    # Train the models on the clients' data
    for round_num in range(1, 11):
        state_A1, metrics_A1 = federated_algorithm_A1.next(state_A1, client_datasets)
        state_A2, metrics_A2 = federated_algorithm_A2.next(state_A2, client_datasets)

        # Perform differential testing for each client
        for client_id, client_data in enumerate(client_datasets):
            # Get the model weights after training
            weights_A1 = state_A1.model.trainable
            weights_A2 = state_A2.model.trainable

            # Evaluate both models on the client's local test data
            predictions_A1, labels_A1 = evaluate_model_on_client(weights_A1, create_model_A1, client_data)
            predictions_A2, labels_A2 = evaluate_model_on_client(weights_A2, create_model_A2, client_data)

            # Criterion 1: Absolute differences between classes
            pred_class_A1 = np.argmax(predictions_A1, axis=1)
            pred_class_A2 = np.argmax(predictions_A2, axis=1)
            Δ_class = np.sum(pred_class_A1 != pred_class_A2)

            # Criterion 2: Absolute differences between scores
            Δ_score = np.sum(predictions_A1 != predictions_A2)

            # Criterion 3: Significance of difference between scores
            P_KS = ks_2samp(predictions_A1.flatten(), predictions_A2.flatten()).pvalue

            # Criterion 4: Significance of difference between classifications
            contingency = np.array([[np.sum((pred_class_A1 == i) & (pred_class_A2 == j)) for j in range(10)] for i in range(10)])
            contingency += 1  # Add-one smoothing
            P_X2 = chi2_contingency(contingency)[1]

            print(f"Client {client_id}:")
            print(f"Δ_class: {Δ_class}")
            print(f"Δ_score: {Δ_score}")
            print(f"P_KS: {P_KS}")
            print(f"P_X2: {P_X2}")
            print()

In [26]:
# Function to evaluate the model on client data
def evaluate_model_on_client(weights, create_model_fn, client_data):
    model = create_model_fn()
    model.set_weights(weights)
    predictions = []
    labels = []
    for batch in client_data:
        x, y = batch
        predictions.append(model.predict(x))
        labels.append(y)
    predictions = np.concatenate(predictions, axis=0)
    labels = np.concatenate(labels, axis=0)
    return predictions, labels

In [27]:
# Run the training
with concurrent.futures.ThreadPoolExecutor() as executor:
    future = executor.submit(run_training)
    future.result()

  def _map_element(e):
  def _map_element(e):


Client 0:
Δ_class: 8
Δ_score: 100
P_KS: 0.2819416298082479
P_X2: 1.0

Client 1:
Δ_class: 10
Δ_score: 100
P_KS: 0.469506448503778
P_X2: 1.0









Client 2:
Δ_class: 9
Δ_score: 100
P_KS: 0.7020569828664881
P_X2: 1.0









Client 3:
Δ_class: 9
Δ_score: 100
P_KS: 0.36818778606286096
P_X2: 1.0









Client 4:
Δ_class: 9
Δ_score: 100
P_KS: 0.7020569828664881
P_X2: 1.0









Client 5:
Δ_class: 9
Δ_score: 100
P_KS: 0.8154147124661313
P_X2: 1.0









Client 6:
Δ_class: 10
Δ_score: 100
P_KS: 0.2819416298082479
P_X2: 1.0









Client 7:
Δ_class: 10
Δ_score: 100
P_KS: 0.469506448503778
P_X2: 1.0









Client 8:
Δ_class: 10
Δ_score: 100
P_KS: 0.36818778606286096
P_X2: 1.0









Client 9:
Δ_class: 10
Δ_score: 100
P_KS: 0.469506448503778
P_X2: 1.0









Client 0:
Δ_class: 9
Δ_score: 100
P_KS: 0.2819416298082479
P_X2: 1.0









Client 1:
Δ_class: 10
Δ_score: 100
P_KS: 0.469506448503778
P_X2: 1.0









Client 2:
Δ_class: 9
Δ_score: 100
P_KS: 0.7020569828664881
P_X2: 1.0









Client 3:
Δ_class: 9
Δ_score: 100
P_KS: 0.469506448503778
P_X2: 1.0









Client 4:
Δ_class: 9
Δ_score: 100
P_KS: 0.7020569828664881
P_X2: 1.0









Client 5:
Δ_class: 9
Δ_score: 100
P_KS: 0.8154147124661313
P_X2: 1.0









Client 6:
Δ_class: 10
Δ_score: 100
P_KS: 0.2819416298082479
P_X2: 1.0









Client 7:
Δ_class: 10
Δ_score: 100
P_KS: 0.469506448503778
P_X2: 1.0









Client 8:
Δ_class: 10
Δ_score: 100
P_KS: 0.36818778606286096
P_X2: 1.0









Client 9:
Δ_class: 10
Δ_score: 100
P_KS: 0.469506448503778
P_X2: 1.0









Client 0:
Δ_class: 9
Δ_score: 100
P_KS: 0.2819416298082479
P_X2: 1.0









Client 1:
Δ_class: 10
Δ_score: 100
P_KS: 0.469506448503778
P_X2: 1.0









Client 2:
Δ_class: 9
Δ_score: 100
P_KS: 0.7020569828664881
P_X2: 1.0









Client 3:
Δ_class: 9
Δ_score: 100
P_KS: 0.469506448503778
P_X2: 1.0









Client 4:
Δ_class: 9
Δ_score: 100
P_KS: 0.5830090612540064
P_X2: 1.0









Client 5:
Δ_class: 9
Δ_score: 100
P_KS: 0.8154147124661313
P_X2: 1.0









Client 6:
Δ_class: 10
Δ_score: 100
P_KS: 0.36818778606286096
P_X2: 1.0









Client 7:
Δ_class: 10
Δ_score: 100
P_KS: 0.469506448503778
P_X2: 1.0









Client 8:
Δ_class: 10
Δ_score: 100
P_KS: 0.469506448503778
P_X2: 1.0









Client 9:
Δ_class: 10
Δ_score: 100
P_KS: 0.36818778606286096
P_X2: 1.0









Client 0:
Δ_class: 9
Δ_score: 100
P_KS: 0.36818778606286096
P_X2: 1.0









Client 1:
Δ_class: 10
Δ_score: 100
P_KS: 0.469506448503778
P_X2: 1.0









Client 2:
Δ_class: 9
Δ_score: 100
P_KS: 0.8154147124661313
P_X2: 1.0









Client 3:
Δ_class: 9
Δ_score: 100
P_KS: 0.5830090612540064
P_X2: 1.0









Client 4:
Δ_class: 9
Δ_score: 100
P_KS: 0.5830090612540064
P_X2: 1.0









Client 5:
Δ_class: 9
Δ_score: 100
P_KS: 0.8154147124661313
P_X2: 1.0









Client 6:
Δ_class: 10
Δ_score: 100
P_KS: 0.36818778606286096
P_X2: 1.0









Client 7:
Δ_class: 10
Δ_score: 100
P_KS: 0.469506448503778
P_X2: 1.0









Client 8:
Δ_class: 10
Δ_score: 100
P_KS: 0.469506448503778
P_X2: 1.0









Client 9:
Δ_class: 10
Δ_score: 100
P_KS: 0.2819416298082479
P_X2: 1.0









Client 0:
Δ_class: 9
Δ_score: 100
P_KS: 0.36818778606286096
P_X2: 1.0









Client 1:
Δ_class: 10
Δ_score: 100
P_KS: 0.469506448503778
P_X2: 1.0









Client 2:
Δ_class: 9
Δ_score: 100
P_KS: 0.9084105017744525
P_X2: 1.0









Client 3:
Δ_class: 9
Δ_score: 100
P_KS: 0.7020569828664881
P_X2: 1.0









Client 4:
Δ_class: 9
Δ_score: 100
P_KS: 0.5830090612540064
P_X2: 1.0









Client 5:
Δ_class: 9
Δ_score: 100
P_KS: 0.9084105017744525
P_X2: 1.0









Client 6:
Δ_class: 10
Δ_score: 100
P_KS: 0.36818778606286096
P_X2: 1.0









Client 7:
Δ_class: 10
Δ_score: 100
P_KS: 0.469506448503778
P_X2: 1.0









Client 8:
Δ_class: 10
Δ_score: 100
P_KS: 0.5830090612540064
P_X2: 1.0









Client 9:
Δ_class: 10
Δ_score: 100
P_KS: 0.469506448503778
P_X2: 1.0









Client 0:
Δ_class: 9
Δ_score: 100
P_KS: 0.36818778606286096
P_X2: 1.0









Client 1:
Δ_class: 10
Δ_score: 100
P_KS: 0.469506448503778
P_X2: 1.0









Client 2:
Δ_class: 9
Δ_score: 100
P_KS: 0.8154147124661313
P_X2: 1.0









Client 3:
Δ_class: 9
Δ_score: 100
P_KS: 0.7020569828664881
P_X2: 1.0









Client 4:
Δ_class: 9
Δ_score: 100
P_KS: 0.5830090612540064
P_X2: 1.0









Client 5:
Δ_class: 9
Δ_score: 100
P_KS: 0.9084105017744525
P_X2: 1.0









Client 6:
Δ_class: 10
Δ_score: 100
P_KS: 0.36818778606286096
P_X2: 1.0









Client 7:
Δ_class: 10
Δ_score: 100
P_KS: 0.469506448503778
P_X2: 1.0









Client 8:
Δ_class: 10
Δ_score: 100
P_KS: 0.7020569828664881
P_X2: 1.0









Client 9:
Δ_class: 10
Δ_score: 100
P_KS: 0.469506448503778
P_X2: 1.0









Client 0:
Δ_class: 9
Δ_score: 100
P_KS: 0.2819416298082479
P_X2: 1.0









Client 1:
Δ_class: 9
Δ_score: 100
P_KS: 0.469506448503778
P_X2: 1.0









Client 2:
Δ_class: 9
Δ_score: 100
P_KS: 0.8154147124661313
P_X2: 1.0









Client 3:
Δ_class: 9
Δ_score: 100
P_KS: 0.8154147124661313
P_X2: 1.0









Client 4:
Δ_class: 10
Δ_score: 100
P_KS: 0.5830090612540064
P_X2: 1.0









Client 5:
Δ_class: 9
Δ_score: 100
P_KS: 0.9684099261397212
P_X2: 1.0









Client 6:
Δ_class: 10
Δ_score: 100
P_KS: 0.36818778606286096
P_X2: 1.0









Client 7:
Δ_class: 10
Δ_score: 100
P_KS: 0.469506448503778
P_X2: 1.0









Client 8:
Δ_class: 10
Δ_score: 100
P_KS: 0.5830090612540064
P_X2: 1.0









Client 9:
Δ_class: 10
Δ_score: 100
P_KS: 0.469506448503778
P_X2: 1.0









Client 0:
Δ_class: 9
Δ_score: 100
P_KS: 0.2819416298082479
P_X2: 1.0









Client 1:
Δ_class: 9
Δ_score: 100
P_KS: 0.469506448503778
P_X2: 1.0









Client 2:
Δ_class: 9
Δ_score: 100
P_KS: 0.9084105017744525
P_X2: 1.0









Client 3:
Δ_class: 9
Δ_score: 100
P_KS: 0.8154147124661313
P_X2: 1.0









Client 4:
Δ_class: 10
Δ_score: 100
P_KS: 0.469506448503778
P_X2: 1.0









Client 5:
Δ_class: 9
Δ_score: 100
P_KS: 0.9684099261397212
P_X2: 1.0









Client 6:
Δ_class: 10
Δ_score: 100
P_KS: 0.21117008625127576
P_X2: 1.0









Client 7:
Δ_class: 10
Δ_score: 100
P_KS: 0.469506448503778
P_X2: 1.0









Client 8:
Δ_class: 10
Δ_score: 100
P_KS: 0.7020569828664881
P_X2: 1.0









Client 9:
Δ_class: 10
Δ_score: 100
P_KS: 0.469506448503778
P_X2: 1.0









Client 0:
Δ_class: 9
Δ_score: 100
P_KS: 0.36818778606286096
P_X2: 1.0









Client 1:
Δ_class: 9
Δ_score: 100
P_KS: 0.5830090612540064
P_X2: 1.0









Client 2:
Δ_class: 9
Δ_score: 100
P_KS: 0.9084105017744525
P_X2: 1.0









Client 3:
Δ_class: 9
Δ_score: 100
P_KS: 0.8154147124661313
P_X2: 1.0









Client 4:
Δ_class: 10
Δ_score: 100
P_KS: 0.36818778606286096
P_X2: 1.0









Client 5:
Δ_class: 9
Δ_score: 100
P_KS: 0.9684099261397212
P_X2: 1.0









Client 6:
Δ_class: 10
Δ_score: 100
P_KS: 0.2819416298082479
P_X2: 1.0









Client 7:
Δ_class: 10
Δ_score: 100
P_KS: 0.5830090612540064
P_X2: 1.0









Client 8:
Δ_class: 10
Δ_score: 100
P_KS: 0.7020569828664881
P_X2: 1.0









Client 9:
Δ_class: 10
Δ_score: 100
P_KS: 0.469506448503778
P_X2: 1.0









Client 0:
Δ_class: 9
Δ_score: 100
P_KS: 0.2819416298082479
P_X2: 1.0









Client 1:
Δ_class: 9
Δ_score: 100
P_KS: 0.469506448503778
P_X2: 1.0









Client 2:
Δ_class: 9
Δ_score: 100
P_KS: 0.8154147124661313
P_X2: 1.0









Client 3:
Δ_class: 9
Δ_score: 100
P_KS: 0.9084105017744525
P_X2: 1.0









Client 4:
Δ_class: 10
Δ_score: 100
P_KS: 0.36818778606286096
P_X2: 1.0









Client 5:
Δ_class: 9
Δ_score: 100
P_KS: 0.9684099261397212
P_X2: 1.0









Client 6:
Δ_class: 10
Δ_score: 100
P_KS: 0.21117008625127576
P_X2: 1.0









Client 7:
Δ_class: 10
Δ_score: 100
P_KS: 0.469506448503778
P_X2: 1.0









Client 8:
Δ_class: 10
Δ_score: 100
P_KS: 0.8154147124661313
P_X2: 1.0









Client 9:
Δ_class: 10
Δ_score: 100
P_KS: 0.5830090612540064
P_X2: 1.0

