In [None]:
import numpy as np
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras import backend as K
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Conv2D, Flatten, Dense, MaxPool2D
from tensorflow.keras.utils import to_categorical


import openfl.native as fx
from openfl.federated import FederatedModel,FederatedDataSet
tf.config.run_functions_eagerly(True)
tf.random.set_seed(0)
np.random.seed(0)

In [None]:
def test_intel_tensorflow():
    """
    Check if Intel version of TensorFlow is installed
    """
    import tensorflow as tf

    print("We are using Tensorflow version {}".format(tf.__version__))

    major_version = int(tf.__version__.split(".")[0])
    if major_version >= 2:
        from tensorflow.python.util import _pywrap_util_port
        print("Intel-optimizations (DNNL) enabled:",
              _pywrap_util_port.IsMklEnabled())
    else:
        print("Intel-optimizations (DNNL) enabled:")

test_intel_tensorflow()

In [None]:
from openfl.utilities.optimizers.keras import FedProxOptimizer

In [None]:
def build_model(input_shape,
                num_classes,
                **kwargs):
    """
    Define the model architecture.

    Args:
        input_shape (numpy.ndarray): The shape of the data
        num_classes (int): The number of classes of the dataset

    Returns:
        tensorflow.python.keras.engine.sequential.Sequential: The model defined in Keras

    """
    model = Sequential()
    
    model.add(tf.keras.Input(shape=input_shape))
    model.add(Dense(num_classes, activation='softmax'))

    model.compile(loss=keras.losses.categorical_crossentropy,
                  optimizer=FedProxOptimizer(mu=1),
                  metrics=['accuracy'])

    return model   

In [None]:
#Create a federated model using the build model function and dataset
fl_model = FederatedModel(build_model, data_loader=fl_data)

In [None]:
collaborator_models = fl_model.setup(num_collaborators=1000)
 
collaborators = {f'col{col}':collaborator_models[col] for col in range(len(collaborator_models))}#, 'three':collaborator_models[2]}

In [None]:
#Original MNIST dataset
print(f'Original training data size: {len(train_images)}')
print(f'Original validation data size: {len(valid_images)}\n')

#Collaborator one's data
print(f'Collaborator one\'s training data size: {len(collaborator_models[0].data_loader.X_train)}')
print(f'Collaborator one\'s validation data size: {len(collaborator_models[0].data_loader.X_valid)}\n')

#Collaborator two's data
print(f'Collaborator two\'s training data size: {len(collaborator_models[1].data_loader.X_train)}')
print(f'Collaborator two\'s validation data size: {len(collaborator_models[1].data_loader.X_valid)}\n')

#Collaborator three's data
#print(f'Collaborator three\'s training data size: {len(collaborator_models[2].data_loader.X_train)}')
#print(f'Collaborator three\'s validation data size: {len(collaborator_models[2].data_loader.X_valid)}')

We can see the current plan values by running the `fx.get_plan()` function

In [None]:
#Get the current values of the plan. Each of these can be overridden
import json
print(json.dumps(fx.get_plan(), indent=4, sort_keys=True))

Now we are ready to run our experiment. If we want to pass in custom plan settings, we can easily do that with the `override_config` parameter

In [None]:
#Run experiment, return trained FederatedModel
final_fl_model = fx.run_experiment(collaborators,override_config={'aggregator.settings.rounds_to_train':5, 'collaborator.settings.opt_treatment': 'CONTINUE_GLOBAL'})

In [None]:
#Save final model and load into keras
final_fl_model.save_native('final_model')
model = tf.keras.models.load_model('./final_model')

In [None]:
#Test the final model on our test set
model.evaluate(test_images,test_labels)

In [None]:
import matplotlib.pyplot as plt
import numpy as np


plt.plot([0.07627802075538784, 0.07518334008473902, 0.09541350667830556, 0.13141966053564103, 0.15887578643299638], label='FedAvg')
plt.plot([0.07627802075538784, 0.07518334008473902, 0.09541350667830556, 0.1314459763141349, 0.15887578643299638], linestyle='--', label='FedProx (mu=1e-2)')
plt.plot([0.07627802075538784, 0.0751056043850258, 0.09555227747093886, 0.131649036151357, 0.15966261748969554], linestyle='--', label='FedProx (mu=1e-1)')
plt.plot([0.07627802075538784, 0.07517912408802659, 0.09641592293512076, 0.13676991989742965, 0.1684917744528502], linestyle='--', label='FedProx (mu=1e1)')

plt.legend()
plt.xticks(range(5))
plt.show()