# Simple Federated Model

In [1]:
import numpy as np
import tensorflow      as tf
import tensorflow.keras as keras
from tensorflow.keras import backend as K
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Conv2D, Flatten, Dense

import fledge as fx
from fledge.federated import FederatedModel,FederatedDataSet
from fledge.native.native import run_experiment

import warnings
warnings.filterwarnings('ignore')

Instructions for updating:
non-resource variables are not supported in the long term


We start with a simple fully connected model that is trained to approximate the XOR function. 

In [2]:
#XOR function of length 7
X = np.array([[a,b,c,d,e,f,g] \
              for a in range(2) \
              for b in range(2) \
              for c in range(2) \
              for d in range(2) \
              for e in range(2) \
              for f in range(2) \
              for g in range(2) \
             ])
y = np.array(np.sum(X,axis=1)%2).reshape((len(X),1))

#Shuffle training and validation data
data_shuffle = np.random.choice(len(X),len(X),replace=False)
X = X[data_shuffle]
y = y[data_shuffle]

#Split the data into training and validation sets
X_train = X[:96]
y_train = y[:96]
X_val = X[96:]
y_val = y[96:]

#This is the initialization of the dataset. It will be split later
fl_data = FederatedDataSet(X_train,y_train,X_val,y_val)

feature_shape = X.shape[1]
classes       = y.shape[1]

def build_model(feature_shape,classes):
    #Define the XOR model
    model = Sequential()
    model.add(Dense(128, input_shape=feature_shape, activation='relu'))
    model.add(Dense(64, activation='relu'))
    model.add(Dense(1, activation='sigmoid'))
    model.compile(loss='binary_crossentropy', optimizer='sgd', metrics=['accuracy'])
    return model

Inferred 2 classes from the provided labels...


In [3]:

#Create a federated model using the build model function and dataset
fl_model = FederatedModel(build_model,data_loader=fl_data)


Instructions for updating:
If using Keras pass *_constraint arguments to layers.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


The `FederatedModel` object is a wrapper around your Keras, Tensorflow or PyTorch model that makes it compatible with fledge. It provides built in federated training and validation functions that we will see used below. Using it's `setup` function, th 

In [4]:
collaborator_models = fl_model.setup(2)
collaborators = {'one':collaborator_models[0],'two':collaborator_models[1]}

In [5]:
#Collaborator one's data
print(len(collaborator_models[0].data_loader.X_train))
print(len(collaborator_models[0].data_loader.X_valid))

#Collaborator two's data
print(len(collaborator_models[1].data_loader.X_train))
print(len(collaborator_models[1].data_loader.X_valid))

48
16
48
16


In [6]:
#Setup logging
from logging import basicConfig
from rich.console   import Console
from rich.logging   import RichHandler
console = Console(width = 160)
basicConfig(level = 'INFO', format = '%(message)s', datefmt = '[%X]', handlers = [RichHandler(console = console)])    


#Run experiment, return trained FederatedModel
final_fl_model = run_experiment(collaborators,{'rounds_to_train':10}, export_for_production=True)

In [8]:
final_fl_model.model.save('final_model')