# Tensorflow Federated for faulty pill recognition

## Dependencies and parameters

In [1]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, regularizers
import threading
import tensorflow_federated as tff

from matplotlib import pyplot as plt

2024-09-06 17:24:26.237994: I tensorflow/core/util/port.cc:111] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-09-06 17:24:26.243028: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2024-09-06 17:24:26.355749: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-09-06 17:24:26.355807: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-09-06 17:24:26.355939: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to regi

In [2]:
# Besides the centralized server, how many independent clients are training their models
num_clients     = 3
image_shape     = (224, 224, 3)
num_categories  = 2
num_rounds      = 3
num_epochs      = 2
shuffle_buffer  = 30
batch_size      = 32
splitter        = 10

data_folder    = "data/PILL"

## Loading Data

In [3]:
def load_data(folder, num_clients):
    """Load a dataset from a directory and randomly divide it into clients."""
    # Read all data as a single data set
    full_dataset = tf.keras.preprocessing.image_dataset_from_directory(
        folder,
        labels='inferred',
        label_mode='binary',
        class_names=None,
        color_mode='rgb',
        batch_size=32,
        image_size=image_shape[:2],
        shuffle=True,
        seed=None,
    )
    
    # Setup epochs and normalize data
    def preprocess(ds):
      return ds.map(lambda x,y : (x / 255.0, y))

    # Partition the data set into a chunk for each client
    per_client_data = [full_dataset.shard(num_clients, i) for i in range(num_clients)]
    
    # Build the ClientData structure needed by TFF
    federated_dataset = tff.simulation.datasets.ClientData.from_clients_and_tf_fn(
        [i for i in range(num_clients)], lambda i: per_client_data[i]
    )

    return federated_dataset.preprocess(preprocess)

## Model definitions

### Architecture

In [4]:
class PillModel:
    """TensorFlow Keras model for Pill image recognition."""
    def __init__(self, input_shape, num_classes):
        self.model = tf.keras.Sequential([
            # First convolutional block
            layers.Conv2D(32, (3, 3), padding='same', input_shape=input_shape),
            layers.ReLU(),
            layers.MaxPooling2D(pool_size=(2, 2), strides=2),
            
            # Second convolutional block
            layers.Conv2D(64, (3, 3), padding='same'),
            layers.ReLU(),
            layers.MaxPooling2D(pool_size=(2, 2), strides=2),
            
            # Third convolutional block
            layers.Conv2D(128, (3, 3), padding='same'),
            layers.ReLU(),
            layers.MaxPooling2D(pool_size=(2, 2), strides=2),
            
            # Flatten the output and add fully connected layers
            layers.Flatten(),
            layers.Dense(256),
            layers.ReLU(),
            layers.Dropout(0.5),
            
            # Output layer for classification
            layers.Dense(num_classes, activation='softmax')
        ])

    # The element_spec for the element`s model
    # Tensorflow requires this for typing
    input_spec = (tf.TensorSpec(shape=(None, *image_shape), dtype=tf.float32, name=None), tf.TensorSpec(shape=(None, 1), dtype=tf.float32, name=None))

In [5]:
def model_fn():
    """Turn this model into something tff can run in a federated setting."""
    model = PillModel(image_shape, num_categories)
    
    return tff.learning.models.from_keras_model(
        keras_model = model.model,
        input_spec  = model.input_spec,
        loss        = tf.keras.losses.SparseCategoricalCrossentropy(),
        metrics     = [tf.keras.metrics.CategoricalAccuracy()]  # CategoricalAccuracy needed because of softmax activation
    )

In [6]:
def optimizer_fn():
    return tf.keras.optimizers.Adam(learning_rate=0.003, epsilon=1e-8)

### Steps

In [18]:
def federated_evaluate(test_dataset, training_process, training_state):
    """
    Evaluate a federated model in a testing dataset.
    Since we evaluate with all clients in a single round, this should be equivalent to a centralized evaluation.
    """
    evaluation_process = tff.learning.algorithms.build_fed_eval(model_fn)
    evaluation_state   = evaluation_process.initialize()

    # Copy weights from trained model to new testing model
    model_weights    = training_process.get_model_weights(training_state)
    evaluation_state = evaluation_process.set_model_weights(evaluation_state, model_weights)

    federated_test_data = [test_dataset.create_tf_dataset_for_client(x) for x in test_dataset.client_ids]

    # Evaluating amounts to a single forward step
    evaluation_state, metrics = evaluation_process.next(evaluation_state, federated_test_data)

    print(f'Evaluation Metrics: {metrics}')

    return evaluation_process, metrics

In [32]:
def federated_train(data_folder, num_rounds):
    """Train model using dataset on folder."""
    # Input datasets
    dataset_train = load_data(data_folder + "/Training", num_clients)
    dataset_test  = load_data(data_folder + "/Testing", num_clients)
    
    # Preprocesing: 
    #   Setup epochs and shuffle data
    dataset_train.preprocess(lambda ds: ds.shuffle(shuffle_buffer).repeat(num_epochs))
    
    # Build the federated averaging process
    process = tff.learning.algorithms.build_weighted_fed_avg(
        model_fn = model_fn,
        client_optimizer_fn = optimizer_fn,
        server_optimizer_fn = optimizer_fn,
    )
    # Initialize and run the federated learning process
    state = process.initialize()

    metrics_history = {
        "train": [],
        "test":  [],
    }

    num_clients_per_round = num_clients # Select all clients
    
    for round in range(0, num_rounds):
        # Clients participating on this round
        selected_clients     = np.random.choice(dataset_train.client_ids, size=num_clients_per_round, replace=False)
        federated_train_data = [dataset_train.create_tf_dataset_for_client(x) for x in selected_clients]
        
        state, metrics = process.next(state, federated_train_data)
        
        print(f'Round {round+1:2d}, Metrics: {metrics}')
        metrics_history["train"].append(metrics)

        # Evaluate data after each round
        evaluation_process, evaluation_metrics = federated_evaluate(dataset_test, process, state)
        metrics_history["test"].append(evaluation_metrics)

    # Outputs
    accuracies = [metrics['client_work']['train']['categorical_accuracy'] for metrics in metrics_history["train"]]
    final_accuracy = np.mean(accuracies)
    print(f"Final averaged accuracy over {len(metrics_history)} rounds is: {final_accuracy}")
        
    return process, state, metrics_history

## Train and evaluate

In [33]:
federated_train(data_folder, num_rounds)

Found 348 files belonging to 2 classes.
Found 86 files belonging to 2 classes.


2024-09-06 17:36:25.430969: I tensorflow/core/grappler/devices.cc:66] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 0
2024-09-06 17:36:25.431114: I tensorflow/core/grappler/clusters/single_machine.cc:361] Starting new session
2024-09-06 17:36:25.466021: I tensorflow/core/grappler/devices.cc:66] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 0
2024-09-06 17:36:25.466322: I tensorflow/core/grappler/clusters/single_machine.cc:361] Starting new session
2024-09-06 17:36:26.625422: I tensorflow/core/grappler/devices.cc:66] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 0
2024-09-06 17:36:26.625692: I tensorflow/core/grappler/clusters/single_machine.cc:361] Starting new session
2024-09-06 17:36:26.926261: I tensorflow/core/grappler/devices.cc:66] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 0
2024-09-06 17:36:26.926393: I tensorflow/core/grappler/clusters/single_machine.cc:361] Starting new session


Round  1, Metrics: OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('categorical_accuracy', 0.15229885), ('loss', 3.9961717), ('num_examples', 348), ('num_batches', 11)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])


2024-09-06 17:36:54.411547: I tensorflow/core/grappler/devices.cc:66] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 0
2024-09-06 17:36:54.411722: I tensorflow/core/grappler/clusters/single_machine.cc:361] Starting new session
2024-09-06 17:36:54.462168: I tensorflow/core/grappler/devices.cc:66] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 0
2024-09-06 17:36:54.464628: I tensorflow/core/grappler/clusters/single_machine.cc:361] Starting new session
2024-09-06 17:37:00.835028: I tensorflow/core/grappler/devices.cc:66] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 0
2024-09-06 17:37:00.835223: I tensorflow/core/grappler/clusters/single_machine.cc:361] Starting new session
2024-09-06 17:37:00.962688: I tensorflow/core/grappler/devices.cc:66] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 0
2024-09-06 17:37:00.962832: I tensorflow/core/grappler/clusters/single_machine.cc:361] Starting new session


Evaluation Metrics: OrderedDict([('distributor', ()), ('client_work', OrderedDict([('eval', OrderedDict([('current_round_metrics', OrderedDict([('categorical_accuracy', 0.0), ('loss', 5.2477517), ('num_examples', 86), ('num_batches', 3)])), ('total_rounds_metrics', OrderedDict([('categorical_accuracy', 0.0), ('loss', 5.2477517), ('num_examples', 86), ('num_batches', 3)]))]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])
Round  2, Metrics: OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('categorical_accuracy', 0.0), ('loss', 5.2337494), ('num_examples', 348), ('num_batches', 11)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])


2024-09-06 17:37:21.355237: I tensorflow/core/grappler/devices.cc:66] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 0
2024-09-06 17:37:21.355469: I tensorflow/core/grappler/clusters/single_machine.cc:361] Starting new session
2024-09-06 17:37:21.398345: I tensorflow/core/grappler/devices.cc:66] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 0
2024-09-06 17:37:21.398543: I tensorflow/core/grappler/clusters/single_machine.cc:361] Starting new session
2024-09-06 17:37:28.577289: I tensorflow/core/grappler/devices.cc:66] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 0
2024-09-06 17:37:28.577459: I tensorflow/core/grappler/clusters/single_machine.cc:361] Starting new session
2024-09-06 17:37:28.708043: I tensorflow/core/grappler/devices.cc:66] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 0
2024-09-06 17:37:28.708183: I tensorflow/core/grappler/clusters/single_machine.cc:361] Starting new session


Evaluation Metrics: OrderedDict([('distributor', ()), ('client_work', OrderedDict([('eval', OrderedDict([('current_round_metrics', OrderedDict([('categorical_accuracy', 0.0), ('loss', 5.2477517), ('num_examples', 86), ('num_batches', 3)])), ('total_rounds_metrics', OrderedDict([('categorical_accuracy', 0.0), ('loss', 5.2477517), ('num_examples', 86), ('num_batches', 3)]))]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])
Round  3, Metrics: OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('categorical_accuracy', 0.0), ('loss', 5.2337494), ('num_examples', 348), ('num_batches', 11)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])


2024-09-06 17:37:58.068405: I tensorflow/core/grappler/devices.cc:66] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 0
2024-09-06 17:37:58.068579: I tensorflow/core/grappler/clusters/single_machine.cc:361] Starting new session
2024-09-06 17:37:58.093099: I tensorflow/core/grappler/devices.cc:66] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 0
2024-09-06 17:37:58.093233: I tensorflow/core/grappler/clusters/single_machine.cc:361] Starting new session
2024-09-06 17:38:03.960525: I tensorflow/core/grappler/devices.cc:66] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 0
2024-09-06 17:38:03.960699: I tensorflow/core/grappler/clusters/single_machine.cc:361] Starting new session
2024-09-06 17:38:04.069455: I tensorflow/core/grappler/devices.cc:66] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 0
2024-09-06 17:38:04.069605: I tensorflow/core/grappler/clusters/single_machine.cc:361] Starting new session


Evaluation Metrics: OrderedDict([('distributor', ()), ('client_work', OrderedDict([('eval', OrderedDict([('current_round_metrics', OrderedDict([('categorical_accuracy', 0.0), ('loss', 5.2477517), ('num_examples', 86), ('num_batches', 3)])), ('total_rounds_metrics', OrderedDict([('categorical_accuracy', 0.0), ('loss', 5.2477517), ('num_examples', 86), ('num_batches', 3)]))]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])
Final averaged accuracy over 2 rounds is: 0.05076628550887108


(<tensorflow_federated.python.learning.templates.learning_process.LearningProcess at 0x7f8f90f8dcf0>,
 LearningAlgorithmState(global_model_weights=ModelWeights(trainable=[array([[[[-1.17406212e-02, -3.86885703e-02,  6.79798424e-03,
            9.88665000e-02, -1.01429500e-01, -4.19395491e-02,
           -6.79604113e-02,  5.82015552e-02,  1.20196424e-01,
           -3.92708834e-03, -2.74675526e-02,  1.15980068e-02,
           -1.25729050e-02,  1.14516675e-01, -1.11606263e-01,
            6.06143326e-02,  6.84000999e-02, -4.14127260e-02,
           -2.77523771e-02,  2.19572335e-02, -9.82932299e-02,
            1.30695730e-01, -1.01742074e-01, -7.69657874e-03,
            3.78560871e-02,  7.20634833e-02,  9.44345817e-02,
            1.13030523e-01,  4.99533713e-02, -5.88241825e-03,
           -7.43643865e-02,  1.28434852e-01],
          [ 8.69062766e-02,  4.80129123e-02, -5.97994700e-02,
            1.30983189e-01,  8.58158320e-02, -8.80510062e-02,
            9.81462300e-02, -5.44715039e