In [5]:
import tensorflow as tf
import tensorflow_federated as tff
import tensorflow_privacy as tfp
import nest_asyncio
nest_asyncio.apply()

In [6]:
# Load the CIFAR-10 dataset
(cifar10_train_images, cifar10_train_labels), (cifar10_test_images, cifar10_test_labels) = tf.keras.datasets.cifar10.load_data()

In [7]:
# Preprocess the dataset
def preprocess(images, labels):
    images = tf.cast(images, tf.float32) / 255.0
    return (images, labels)

In [8]:
# Split the dataset into multiple "client" datasets
num_clients = 10
client_datasets = []
for i in range(num_clients):
    start = i * len(cifar10_train_images) // num_clients
    end = (i + 1) * len(cifar10_train_images) // num_clients
    client_images = cifar10_train_images[start:end]
    client_labels = cifar10_train_labels[start:end]
    client_dataset = tf.data.Dataset.from_tensor_slices((client_images, client_labels))
    client_dataset = client_dataset.map(preprocess).batch(20)
    client_datasets.append(client_dataset)

In [9]:
# Define a simple CNN model
def create_keras_model():
    model = tf.keras.models.Sequential([
        tf.keras.layers.Input(shape=(32, 32, 3)),
        tf.keras.layers.Conv2D(32, (3, 3), activation='relu'),
        tf.keras.layers.MaxPooling2D((2, 2)),
        tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
        tf.keras.layers.MaxPooling2D((2, 2)),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(10)
    ])
    return model

def model_fn_standard():
    keras_model = create_keras_model()
    return tff.learning.from_keras_model(
        keras_model,
        input_spec=client_datasets[0].element_spec,
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

def model_fn_with_dp():
    keras_model = create_keras_model()
    optimizer = tfp.DPAdamGaussianOptimizer(
        l2_norm_clip=1.0,
        noise_multiplier=0.5,
        num_microbatches=1,
        learning_rate=0.001
    )
    keras_model.compile(optimizer=optimizer,
                        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                        metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
    return tff.learning.from_keras_model(
        keras_model,
        input_spec=client_datasets[0].element_spec,
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])



In [10]:
def assign_weights_to_keras_model(keras_model, tff_state):
    tff_weights = tff_state.model.trainable
    for var, tff_var in zip(keras_model.trainable_variables, tff_weights):
        var.assign(tff_var)  # Removed .numpy()

In [11]:
def evaluate_model(state, model_fn, test_dataset):
    keras_model = create_keras_model()
    assign_weights_to_keras_model(keras_model, state)

    keras_model.compile(
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]
    )

    test_images, test_labels = zip(*list(test_dataset))
    test_images = tf.concat(test_images, axis=0)
    test_labels = tf.concat(test_labels, axis=0)

    loss, accuracy = keras_model.evaluate(test_images, test_labels, verbose=0)
    return loss, accuracy

In [12]:
def check_dataset(dataset):
    for batch in dataset.take(1):
        images, labels = batch
        print(f'Batch shape: {images.shape}, Labels: {labels.numpy()}')

In [13]:
for i, client_dataset in enumerate(client_datasets):
    print(f'Client {i} dataset:')
    check_dataset(client_dataset)

Client 0 dataset:
Batch shape: (20, 32, 32, 3), Labels: [[6]
 [9]
 [9]
 [4]
 [1]
 [1]
 [2]
 [7]
 [8]
 [3]
 [4]
 [7]
 [7]
 [2]
 [9]
 [9]
 [9]
 [3]
 [2]
 [6]]
Client 1 dataset:
Batch shape: (20, 32, 32, 3), Labels: [[6]
 [7]
 [9]
 [0]
 [5]
 [2]
 [3]
 [3]
 [3]
 [9]
 [0]
 [9]
 [2]
 [9]
 [1]
 [0]
 [2]
 [3]
 [9]
 [6]]
Client 2 dataset:
Batch shape: (20, 32, 32, 3), Labels: [[1]
 [6]
 [6]
 [8]
 [8]
 [3]
 [4]
 [6]
 [0]
 [6]
 [0]
 [3]
 [6]
 [6]
 [5]
 [4]
 [8]
 [3]
 [2]
 [6]]
Client 3 dataset:
Batch shape: (20, 32, 32, 3), Labels: [[0]
 [6]
 [7]
 [0]
 [4]
 [9]
 [5]
 [8]
 [0]
 [4]
 [3]
 [8]
 [4]
 [7]
 [1]
 [8]
 [3]
 [5]
 [4]
 [5]]
Client 4 dataset:
Batch shape: (20, 32, 32, 3), Labels: [[8]
 [5]
 [0]
 [6]
 [9]
 [2]
 [8]
 [3]
 [6]
 [2]
 [7]
 [4]
 [6]
 [9]
 [0]
 [0]
 [7]
 [3]
 [7]
 [2]]
Client 5 dataset:
Batch shape: (20, 32, 32, 3), Labels: [[6]
 [9]
 [8]
 [4]
 [0]
 [6]
 [3]
 [1]
 [3]
 [9]
 [9]
 [8]
 [5]
 [8]
 [4]
 [5]
 [0]
 [4]
 [2]
 [3]]
Client 6 dataset:
Batch shape: (20, 32, 32, 3), Labels: [[

In [14]:
# Create federated averaging processes
iterative_process_standard = tff.learning.build_federated_averaging_process(
    model_fn_standard,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0)
)

iterative_process_with_dp = tff.learning.build_federated_averaging_process(
    model_fn_standard,  # Use standard model_fn as DP optimizer will be applied in the client update
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0)
)

In [15]:
# Train the standard model
NUM_ROUNDS = 3
state_standard = iterative_process_standard.initialize()
for round_num in range(NUM_ROUNDS):
    state_standard, metrics_standard = iterative_process_standard.next(state_standard, client_datasets)
    print(f'Standard Model - Round {round_num}, metrics={metrics_standard["train"]}')

Standard Model - Round 0, metrics=OrderedDict([('sparse_categorical_accuracy', 0.17964), ('loss', 2.2078638)])
Standard Model - Round 1, metrics=OrderedDict([('sparse_categorical_accuracy', 0.2609), ('loss', 2.0217845)])
Standard Model - Round 2, metrics=OrderedDict([('sparse_categorical_accuracy', 0.30872), ('loss', 1.9185474)])


In [16]:
# Train the differentially private model
state_with_dp = iterative_process_with_dp.initialize()
for round_num in range(NUM_ROUNDS):
    state_with_dp, metrics_with_dp = iterative_process_with_dp.next(state_with_dp, client_datasets)
    print(f'DP Model - Round {round_num}, metrics={metrics_with_dp["train"]}')


DP Model - Round 0, metrics=OrderedDict([('sparse_categorical_accuracy', 0.16242), ('loss', 2.236325)])
DP Model - Round 1, metrics=OrderedDict([('sparse_categorical_accuracy', 0.24062), ('loss', 2.0594568)])
DP Model - Round 2, metrics=OrderedDict([('sparse_categorical_accuracy', 0.29254), ('loss', 1.9599725)])


In [17]:
# Preprocess the test dataset
test_dataset = tf.data.Dataset.from_tensor_slices((cifar10_test_images, cifar10_test_labels))
test_dataset = test_dataset.map(preprocess).batch(20)

# Evaluate the standard model
loss_standard, accuracy_standard = evaluate_model(state_standard, model_fn_standard, test_dataset)
print(f'Standard Model - Test loss: {loss_standard}, Test accuracy: {accuracy_standard}')

# Evaluate the differentially private model
loss_with_dp, accuracy_with_dp = evaluate_model(state_with_dp, model_fn_with_dp, test_dataset)
print(f'DP Model - Test loss: {loss_with_dp}, Test accuracy: {accuracy_with_dp}')

Standard Model - Test loss: 1.7991496324539185, Test accuracy: 0.3668000102043152
DP Model - Test loss: 1.8537687063217163, Test accuracy: 0.3522999882698059
