<a href="https://colab.research.google.com/github/likitha888/python/blob/main/Federated_Learning_for_Privacy_Preserving_Image_Classification_.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
pip install tensorflow tensorflow_federated matplotlib numpy



In [2]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_federated as tff

# Load and preprocess EMNIST dataset
def preprocess(dataset):
    def batch_format_fn(element):
        """
        Reshape the 'pixels' element to have the correct shape
        without the extra dimension.
        """
        return (tf.reshape(element['pixels'], [28, 28, 1]),
                tf.reshape(element['label'], [-1, 1]))

    return dataset.map(batch_format_fn).batch(20)

emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

# Pick a few clients for simulation
client_ids = emnist_train.client_ids[0:10]
federated_train_data = [preprocess(emnist_train.create_tf_dataset_for_client(x)) for x in client_ids]


# Define the model
def create_keras_model():
    return tf.keras.models.Sequential([
        tf.keras.layers.Conv2D(32, (3,3), activation='relu', input_shape=(28,28,1)),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.Dense(10, activation='softmax')
    ])

# Wrap Keras model in TFF model
def model_fn():
    keras_model = create_keras_model()
    return tff.learning.models.from_keras_model(
        keras_model,
        input_spec=federated_train_data[0].element_spec,
        loss=tf.keras.losses.SparseCategoricalCrossentropy(),
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]
    )

# Federated Averaging Process
iterative_process = tff.learning.algorithms.build_weighted_fed_avg(
    model_fn=model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0)
)

# Initialize the process
state = iterative_process.initialize()

# Run federated training
NUM_ROUNDS = 10
for round_num in range(1, NUM_ROUNDS + 1):
    result = iterative_process.next(state, federated_train_data)
    state = result.state
    train_metrics = result.metrics
    print(f'Round {round_num}, Train Accuracy: {train_metrics["client_work"]["train"]["sparse_categorical_accuracy"]:.4f}')

# Evaluate the trained model centrally for demonstration (not part of actual FL)
def get_centralized_test_data():
    test_data = emnist_test.create_tf_dataset_from_all_clients()
    test_data = test_data.map(lambda x: (tf.reshape(x['pixels'], [28,28,1]), x['label']))
    return test_data.batch(32)

central_test_data = get_centralized_test_data()
final_model = create_keras_model()
final_model.compile(
    loss='sparse_categorical_crossentropy',
    optimizer='sgd',
    metrics=['accuracy']
)
# Load trained weights
final_model = create_keras_model()
final_model.compile(
    loss='sparse_categorical_crossentropy',
    optimizer='sgd',
    metrics=['accuracy']
)

# Get the model weights from the iterative process
model_weights = iterative_process.get_model_weights(state)

# Assign the weights to the Keras model
final_model.set_weights(model_weights.trainable)
# Assign only the trainable weights, which aligns with the expected structure

loss, accuracy = final_model.evaluate(central_test_data)
print(f'\nFinal Centralized Evaluation Accuracy: {accuracy:.4f}')

ERROR:jax._src.xla_bridge:Jax plugin configuration error: Plugin module %s could not be loaded
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/xla_bridge.py", line 428, in discover_pjrt_plugins
    plugin_module = importlib.import_module(plugin_module_name)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/importlib/__init__.py", line 126, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<frozen importlib._bootstrap>", line 1204, in _gcd_import
  File "<frozen importlib._bootstrap>", line 1176, in _find_and_load
  File "<frozen importlib._bootstrap>", line 1147, in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 690, in _load_unlocked
  File "<frozen importlib._bootstrap_external>", line 940, in exec_module
  File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_r

Round 1, Train Accuracy: 0.0576
Round 2, Train Accuracy: 0.1111
Round 3, Train Accuracy: 0.1173
Round 4, Train Accuracy: 0.1204
Round 5, Train Accuracy: 0.1204
Round 6, Train Accuracy: 0.1235
Round 7, Train Accuracy: 0.1224
Round 8, Train Accuracy: 0.1245
Round 9, Train Accuracy: 0.1235
Round 10, Train Accuracy: 0.1235

Final Centralized Evaluation Accuracy: 0.1152
