In [1]:
#Simulation settings
num_rounds = 50 # Number of communication rounds
num_devices = 2 # Number of devices
bs = 64 # Batch size for local training at devices
ep = 2 # Number of local epochs before communication round

In [2]:
#Load MNIST dataset
import tensorflow as tf
import time

mnist = tf.keras.datasets.mnist

#Load dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
dataset = tf.data.Dataset.from_tensor_slices(
    (x_train.reshape(60000, 784).astype("float32")/255, y_train)
)
test_dataset = tf.data.Dataset.from_tensor_slices(
    (x_test.reshape(10000, 784).astype("float32")/255, y_test)
)
#Create batches
dataset = dataset.shuffle(buffer_size=1024).batch(bs, drop_remainder=True)
test_dataset = test_dataset.shuffle(buffer_size=1024).batch(bs, drop_remainder=True)

half_dataset1 = dataset.shard(2, 0)
half_dataset2 = dataset.shard(2, 1)

#Split dataset into shards
train_shards = []
for i in range(num_devices):
    train_shards.append(dataset.shard(num_devices, i))

print("Dataset loaded.")

Dataset loaded.


In [3]:
#Set up DNN models
model_template = tf.keras.models.Sequential([
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10)
])

model_list = []

for i in range(num_devices):
    model_list.append(tf.keras.models.clone_model(model_template))

global_model = tf.keras.models.clone_model(model_template)
    
#Define loss function and optimizer, required for "train_on_batch" f
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
opt = tf.keras.optimizers.Adam(learning_rate=1e-3)

#Define training function (x = batch of 'bs' training samples, y = corresponding labels, m = model)
#There is a wrapper here because of the following error: 
#'tf.function-decorated function tried to create variables on non-first call.'
#This error exists because I'm passing a model to the @tf.function annotated train_on_batch function.
#When this function runs for the first time, it will take the model passed to it and use that to 
#instantiate a computational graph. However, when new models are passed to the same function, it 
#crashes the program instead. By using this wrapper, it forces the tensorflow backend to create a new 
#computational graph for the new model.

def train_on_batch(x, y, m):
    with tf.GradientTape() as tape:
        #Forward pass
        logits = m(x)
        #Loss value for this batch
        loss_val = loss_fn(y, logits)
        gradients = tape.gradient(loss_val, m.trainable_weights)
    opt.apply_gradients(zip(gradients, m.trainable_weights))
    return loss_val, gradients

for model in model_list:
    model.compile(optimizer='adam', loss=loss_fn, metrics=['accuracy'])
    
global_model.compile(optimizer='adam', loss=loss_fn, metrics=['accuracy'])

In [4]:
#Train DNN
acc_history = []
for r in range(num_rounds):
    print("Communication round " + str(r+1) + "/" + str(num_rounds))
    start = time.time()
    #Local training
    gradients = []
    for d in range(num_devices): #TODO: Parallelize
        acc_list = []
        loss_val = 0
        cum_gradient = 0
        for epoch in range(ep):
            for step, (x, y) in enumerate(train_shards[d]):
                loss_val, gradient = train_on_batch(x, y, model_list[d])
                #Calculate cumulative gradient
                for l in range(len(gradient)):
                    if cum_gradient == 0:
                        cum_gradient = gradient
                        continue
                    cum_gradient[l] = cum_gradient[l]+gradient[l]
            #Divide cumulative gradient by number of batches
            for l in range(len(gradient)):
                cum_gradient[l] = cum_gradient[l]/len(list(train_shards[d]))
        for l in range(len(gradient)):
            cum_gradient[l] = cum_gradient[l]/ep
        print("Device: ", d, "Loss: ", float(loss_val))
            
    #Update global model
    global_weights = global_model.get_weights()
    for l in range(len(global_weights)):
        global_weights[l] = global_weights[l] - cum_gradient[l]
    
    #Set model of all devices to the global_model
    for model in model_list:
        model.set_weights(global_model.get_weights())
    acc_history.append(model_list[0].evaluate(test_dataset, verbose=0)[1])
    print("Accuracy on test dataset: ", acc_history[-1])
    print(str(int(time.time()-start)) + " seconds elapsed\n")

Communication round 1/50
Device:  0 Loss:  0.05585908889770508
Device:  1 Loss:  0.18742266297340393


ValueError: Weights for model sequential have not yet been created. Weights are created when the Model is first called on inputs or `build()` is called with an `input_shape`.

In [None]:
#Plot accuracy
import matplotlib.pyplot as plt
plt.plot(range(len(acc_history[1:])), acc_history[1:])