In [1]:
# conda activate py37
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers

from funcx.sdk.client import FuncXClient
from funcx.sdk.executor import FuncXExecutor

def hello_world():
    print("Hello world!")

def get_data():
    from tensorflow import keras
    import numpy as np

    num_samples = 10

    (x_train, y_train), _ = keras.datasets.mnist.load_data()
    
    # take a random set of images
    idx = np.random.choice(np.arange(len(x_train)), num_samples, replace=True)
    x_train = x_train[idx]
    y_train = y_train[idx]

    return (x_train, y_train)

def process_data(x_train, y_train):
    from tensorflow import keras
    import numpy as np

    num_classes = 10

    # Scale images to the [0, 1] range
    x_train = x_train.astype("float32") / 255

    # Make sure images have shape (28, 28, 1)
    x_train = np.expand_dims(x_train, -1)
    print("x_train shape:", x_train.shape)
    print(x_train.shape[0], "train samples")

    # convert class vectors to binary class matrices
    y_train = keras.utils.to_categorical(y_train, num_classes)

    return (x_train, y_train)

def train_model(global_model_weights,
                x_train,
                y_train,
                batch_size=128,
                epochs=10,
                loss="categorical_crossentropy",
                optimizer="adam", 
                metrics=["accuracy"]):

    # import dependencies
    from tensorflow import keras

    # create the model
    #model = keras.models.model_from_json(json_model_config)
    model = keras.Sequential(
    [
        keras.Input(shape=input_shape),
        layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
        layers.MaxPooling2D(pool_size=(2, 2)),
        layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
        layers.MaxPooling2D(pool_size=(2, 2)),
        layers.Flatten(),
        layers.Dropout(0.5),
        layers.Dense(num_classes, activation="softmax"),
    ]
    )

    # compile the model and set weights to the global model
    model.compile(loss=loss, optimizer=optimizer, metrics=metrics)
    model.set_weights(global_model_weights)

    # train the model on the local data and extract the weights
    model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1)
    model_weights = model.get_weights()

    return model_weights


def create_training_function(get_data = get_data, process_data=process_data, train_model=train_model):
    
    def training_function(global_model_weights):

        # import all the dependencies Irequired for funcX functions)
        import numpy as np

        # get data
        (x_train, y_train) = get_data()

        # process data
        (x_train, y_train) = process_data(x_train, y_train)

        model_weights=train_model(global_model_weights, x_train, y_train)
        np_model_weights = np.asarray(model_weights, dtype=object)

        return {"model_weights":np_model_weights, "samples_count": x_train.shape[0]}
    
    return training_function

def get_edge_weights(sample_counts):
    '''
    Returns weights for each model to find the weighted average 
    '''
    total = sum(sample_counts)
    fractions = sample_counts/total
    return fractions

def federated_average(global_model, endpoint_ids, get_data = get_data, process_data=process_data, train_model=train_model, weighted=False):
    fx = FuncXExecutor(FuncXClient())

    #json_config = global_model.to_json()
    
    gm_weights = global_model.get_weights()
    gm_weights_np = np.asarray(gm_weights, dtype=object)

    # compile the training function
    training_function = create_training_function()
    
    # train the MNIST model on each of the endpoints and return the result, sending the global weights to each edge
    tasks = []
    for e in endpoint_ids:
        tasks.append(fx.submit(training_function, 
                                global_model_weights=gm_weights_np, 
                                endpoint_id=e))
    
    # extract weights from each edge model
    model_weights = [t.result()["model_weights"] for t in tasks]
    
    if weighted:
        # get the weights
        sample_counts = np.array([t.result()["samples_count"] for t in tasks])
        edge_weights = get_edge_weights(sample_counts)
        
        print(f"Model Weights: {edge_weights}")
        # find weighted average
        average_weights = np.average(model_weights, weights=edge_weights, axis=0)
        
    else:
        # simple average of the weights
        average_weights = np.mean(model_weights, axis=0)
    
    # assign the weights to the global_model
    global_model.set_weights(average_weights)

    print('Trained Federated Model')

    return global_model

In [3]:
(x_train, y_train) = get_data()

(x_train, y_train) = process_data(x_train, y_train)

x_train shape: (10, 28, 28, 1)
10 train samples


In [4]:
endpoint_ids = ['00929e1a-ccc5-40be-8b04-c171f132f7b2', '11983ca1-2d45-40d1-b5a2-8736b3544dea']
batch_size = 128
epochs = 5
input_shape = (28, 28, 1)
num_classes = 10

global_model = keras.Sequential(
    [
        keras.Input(shape=input_shape),
        layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
        layers.MaxPooling2D(pool_size=(2, 2)),
        layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
        layers.MaxPooling2D(pool_size=(2, 2)),
        layers.Flatten(),
        layers.Dropout(0.5),
        layers.Dense(num_classes, activation="softmax"),
    ]
    )

global_model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
global_model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<tensorflow.python.keras.callbacks.History at 0x1e5545a4308>

In [None]:
federated_average(global_model=global_model, 
                  endpoint_ids=endpoint_ids,
                  weighted=False)

Caught unexpected while setting results
Traceback (most recent call last):
  File "C:\ProgramData\Anaconda3\lib\site-packages\funcx\sdk\asynchronous\ws_polling_task.py", line 125, in handle_incoming
    data["exception"]
  File "C:\ProgramData\Anaconda3\lib\site-packages\funcx\serialize\facade.py", line 156, in deserialize
    result = self.methods_for_data[header].deserialize(payload)
  File "C:\ProgramData\Anaconda3\lib\site-packages\funcx\serialize\concretes.py", line 27, in deserialize
    data = pickle.loads(codecs.decode(chomped.encode(), "base64"))
  File "C:\Users\Nikita\AppData\Roaming\Python\Python37\site-packages\parsl\__init__.py", line 22, in <module>
    from parsl.app.app import bash_app, join_app, python_app
  File "C:\Users\Nikita\AppData\Roaming\Python\Python37\site-packages\parsl\app\app.py", line 12, in <module>
    from parsl.dataflow.dflow import DataFlowKernel
  File "C:\Users\Nikita\AppData\Roaming\Python\Python37\site-packages\parsl\dataflow\dflow.py", line 23,

In [6]:
json_config = global_model.to_json()
gm_weights = global_model.get_weights()
gm_weights_np = np.asarray(gm_weights, dtype=object)

new_model_weights = train_model(json_config, gm_weights_np, x_train, y_train)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


In [12]:
len(new_model_weights)

6

In [9]:
fxc = FuncXClient()
fx = FuncXExecutor(FuncXClient())
training_function = create_training_function()

In [None]:

# train the MNIST model on each of the endpoints and return the result, sending the global weights to each edge
tasks = []
for e in endpoint_ids:
    tasks.append(fx.submit(training_function, 
                            json_model_config=json_config, 
                            global_model_weights=gm_weights_np, 
                            endpoint_id=e))
    
# extract weights from each edge model
model_weights = [t.result()["model_weights"] for t in tasks]

Caught unexpected while setting results
Traceback (most recent call last):
  File "C:\ProgramData\Anaconda3\lib\site-packages\funcx\sdk\asynchronous\ws_polling_task.py", line 125, in handle_incoming
    data["exception"]
  File "C:\ProgramData\Anaconda3\lib\site-packages\funcx\serialize\facade.py", line 156, in deserialize
    result = self.methods_for_data[header].deserialize(payload)
  File "C:\ProgramData\Anaconda3\lib\site-packages\funcx\serialize\concretes.py", line 27, in deserialize
    data = pickle.loads(codecs.decode(chomped.encode(), "base64"))
  File "C:\Users\Nikita\AppData\Roaming\Python\Python37\site-packages\parsl\__init__.py", line 22, in <module>
    from parsl.app.app import bash_app, join_app, python_app
  File "C:\Users\Nikita\AppData\Roaming\Python\Python37\site-packages\parsl\app\app.py", line 12, in <module>
    from parsl.dataflow.dflow import DataFlowKernel
  File "C:\Users\Nikita\AppData\Roaming\Python\Python37\site-packages\parsl\dataflow\dflow.py", line 27,