# Tensorflow Federated for faulty pill recognition

## Dependencies and parameters

In [1]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, regularizers
import threading
import tensorflow_federated as tff

import matplotlib.pyplot as plt
tf.get_logger().setLevel('ERROR')

2024-10-28 19:36:00.587020: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-10-28 19:36:00.587058: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-10-28 19:36:00.587107: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [2]:
# Besides the centralized server, how many independent clients are training their models
num_clients     = 10
image_shape     = (256, 256, 3)
num_categories  = 2
num_rounds      = 20
num_epochs      = 10
shuffle_buffer  = 30
batch_size      = 32
splitter        = 10

data_folder    = "data/PILL"

## Loading Data

In [3]:
def load_data(folder, num_clients):
    """Load a dataset from a directory and randomly divide it into clients."""
    # Read all data as a single data set
    full_dataset = tf.keras.preprocessing.image_dataset_from_directory(
        folder,
        labels='inferred',
        label_mode='categorical',
        class_names=None,
        color_mode='rgb',
        batch_size=batch_size,
        image_size=image_shape[:2],
        shuffle=True,
        seed=None,
    )
    
    # Setup epochs and normalize data
    def preprocess(ds):
      return ds.map(lambda x,y : (x / 255.0, y))

    # Partition the data set into a chunk for each client
    per_client_data = [full_dataset.shard(num_clients, i) for i in range(num_clients)]
    
    # Build the ClientData structure needed by TFF
    federated_dataset = tff.simulation.datasets.ClientData.from_clients_and_tf_fn(
        [i for i in range(num_clients)], lambda i: per_client_data[i]
    )

    return federated_dataset.preprocess(preprocess)

## Model definitions

### Architecture

In [4]:
class PillModel:
    """TensorFlow Keras model for Pill image recognition."""
    def __init__(self, input_shape, num_classes):
        base_model = tf.keras.applications.VGG16(weights='imagenet', include_top=False, input_shape=input_shape)
        for layer in base_model.layers[:23]:  # Freeze first 23 layers
            layer.trainable = False
            
        self.model = tf.keras.Sequential([
            base_model,
            # NOT NEEDED BECAUSE TENSORFLOW VGG16 INCLUDES ONE EXTRA MAXPOOL2D layers.MaxPooling2D(pool_size=2, strides=2),
            layers.AveragePooling2D(pool_size=(224 // 2 ** 5, 224 // 2 ** 5)),
            layers.Flatten(),
            
            # Output layer for classification
            layers.Dense(num_classes, activation='softmax')
        ])
    
    # The element_spec for the element`s model
    # Tensorflow requires this for typing
    input_spec = (tf.TensorSpec(shape=(None, *image_shape), dtype=tf.float32, name=None), 
                  tf.TensorSpec(shape=(None, num_categories), dtype=tf.float32, name=None))

In [5]:
def model_fn():
    """Turn this model into something tff can run in a federated setting."""
    model = PillModel(image_shape, num_categories)
    
    return tff.learning.models.from_keras_model(
        keras_model = model.model,
        input_spec  = model.input_spec,
        loss        = tf.keras.losses.CategoricalCrossentropy(),
        metrics     = [tf.keras.metrics.CategoricalAccuracy()]  # CategoricalAccuracy needed because of softmax activation
    )

In [6]:
def optimizer_fn():
    return tf.keras.optimizers.Adam(learning_rate=0.003, epsilon=1e-8)

### Steps

In [7]:
def federated_evaluate(test_dataset, training_process, training_state):
    """
    Evaluate a federated model in a testing dataset.
    Since we evaluate with all clients in a single round, this should be equivalent to a centralized evaluation.
    """
    evaluation_process = tff.learning.algorithms.build_fed_eval(model_fn)
    evaluation_state   = evaluation_process.initialize()

    # Copy weights from trained model to new testing model
    model_weights    = training_process.get_model_weights(training_state)
    evaluation_state = evaluation_process.set_model_weights(evaluation_state, model_weights)

    federated_test_data = [test_dataset.create_tf_dataset_for_client(x) for x in test_dataset.client_ids]

    # Evaluating amounts to a single forward step
    evaluation_state, metrics = evaluation_process.next(evaluation_state, federated_test_data)

    print(f'Evaluation Metrics: {metrics}')

    return evaluation_process, metrics

In [8]:
def federated_train(data_folder, num_rounds):
    """Train model using dataset on folder."""
    # Input datasets
    dataset_train = load_data(data_folder + "/Training", num_clients)
    dataset_test  = load_data(data_folder + "/Testing", num_clients)
    
    # Preprocesing: 
    #   Setup epochs and shuffle data
    dataset_train.preprocess(lambda ds: ds.shuffle(shuffle_buffer).repeat(num_epochs))
    
    # Build the federated averaging process
    process = tff.learning.algorithms.build_weighted_fed_avg(
        model_fn = model_fn,
        client_optimizer_fn = optimizer_fn,
        server_optimizer_fn = optimizer_fn,
    )
    # Initialize and run the federated learning process
    state = process.initialize()

    metrics_history = {
        "train": [],
        "test":  [],
    }

    num_clients_per_round = num_clients # Select all clients
    
    for round in range(0, num_rounds):
        # Clients participating on this round
        selected_clients     = np.random.choice(dataset_train.client_ids, size=num_clients_per_round, replace=False)
        federated_train_data = [dataset_train.create_tf_dataset_for_client(x) for x in selected_clients]
        
        state, metrics = process.next(state, federated_train_data)
        
        print(f'Round {round+1:2d}, Metrics: {metrics}')
        metrics_history["train"].append(metrics)

        # Evaluate data after each round
        evaluation_process, evaluation_metrics = federated_evaluate(dataset_test, process, state)
        metrics_history["test"].append(evaluation_metrics)

    # Outputs
    accuracies = [metrics['client_work']['train']['categorical_accuracy'] for metrics in metrics_history["train"]]
    final_accuracy = np.mean(accuracies)
    print(f"Final averaged accuracy over {num_rounds} rounds is: {final_accuracy}")
        
    return process, state, metrics_history

## Train and evaluate

In [9]:
federated_train(data_folder, num_rounds)

Found 349 files belonging to 2 classes.
Found 86 files belonging to 2 classes.
Round  1, Metrics: OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('categorical_accuracy', 0.6733524), ('loss', 0.69248664), ('num_examples', 349), ('num_batches', 11)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])
Evaluation Metrics: OrderedDict([('distributor', ()), ('client_work', OrderedDict([('eval', OrderedDict([('current_round_metrics', OrderedDict([('categorical_accuracy', 0.6744186), ('loss', 0.6912025), ('num_examples', 86), ('num_batches', 3)])), ('total_rounds_metrics', OrderedDict([('categorical_accuracy', 0.6744186), ('loss', 0.6912025), ('num_examples', 86), ('num_batches', 3)]))]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])
Round  2, Metrics: OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', 

(<tensorflow_federated.python.learning.templates.learning_process.LearningProcess at 0x7fc17c4b97b0>,
 LearningAlgorithmState(global_model_weights=ModelWeights(trainable=[array([[ 0.02062182,  0.10203122],
        [ 0.03286504, -0.0059608 ],
        [-0.04988245, -0.01699325],
        ...,
        [ 0.02982369,  0.01699643],
        [-0.03605681,  0.13545357],
        [ 0.04336181,  0.10659578]], dtype=float32), array([-0.05999757,  0.05999757], dtype=float32)], non_trainable=[array([[[[-0.02113511, -0.05094716,  0.0984482 , ...,  0.08248278,
            0.038821  , -0.09808293],
          [-0.04166424, -0.02198938, -0.01333573, ..., -0.06589898,
            0.02098585,  0.00496716],
          [ 0.04085758,  0.08558862,  0.06042169, ..., -0.02287044,
            0.09403396,  0.00566111]],
 
         [[-0.02407274, -0.07164977, -0.0934803 , ..., -0.0285603 ,
           -0.02830876, -0.09843919],
          [ 0.02528422,  0.02064074,  0.07415628, ...,  0.05273193,
           -0.07521991, 