In [1]:
#Load MNIST dataset
import tensorflow as tf

mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train/255, x_test/255

#Split dataset into shards
x_train1 = x_train[0:30000]
x_train2 = x_train[30000:60000]
y_train1 = y_train[0:30000]
y_train2 = y_train[30000:60000]
x_test1 = x_test[0:30000]
x_test2 = x_test[30000:60000]
y_test1 = y_test[0:30000]
y_test2 = y_test[30000:60000]

print("Size of x_train1 = " + str(len(x_train1)))
print("Size of x_train2 = " + str(len(x_train2)))
print("Should be a label 0-9: " + str(y_train1[5]))

Size of x_train1 = 30000
Size of x_train2 = 30000
Should be a label 0-9: 2


In [2]:
#Set up DNN models
model1 = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28,28)),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10)
])
model2 = tf.keras.models.clone_model(model1)
global_model = tf.keras.models.clone_model(model1)

loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

model1.compile(optimizer='adam', loss=loss_fn, metrics=['accuracy'])
model2.compile(optimizer='adam', loss=loss_fn, metrics=['accuracy'])
#Distribute global model to all devices
global_weights = global_model.get_weights()
model1.set_weights(global_weights)
model2.set_weights(global_weights)

In [3]:
#Train DNN
num_rounds = 5 # Number of communication rounds
for i in range(num_rounds):
    print("Communication round " + str(i) + ":")
    #Train using local dataset
    model1.fit(x_train1, y_train1, epochs=1)
    model2.fit(x_train2, y_train2, epochs=1)
    #Calculate weight update
    weights1_new = model1.get_weights()
    weights2_new = model2.get_weights()
    weights1_update = weights1_new #Just to initiate shape
    weights2_update = weights2_new #Just to initiate shape
    for l in range(len(weights1_new)):
        weights1_update[l] = weights1_new[l]-global_weights[l]
        weights2_update[l] = weights2_new[l]-global_weights[l]
    average_update = weights1_update #Just to initiate shape
    for j in range(len(weights1_update)):
        average_update[j] = (weights1_update[j]+weights2_update[j])/2
    #Update global model
    new_global = global_weights #Just to initiate shape
    for l in range(len(global_weights)):
        new_global[l] = global_weights[l]+average_update[l]
    model1.set_weights(new_global)
    model2.set_weights(new_global)
    global_weights = new_global

Communication round 0:
Train on 30000 samples
Train on 30000 samples
Communication round 1:
Train on 30000 samples
Train on 30000 samples
Communication round 2:
Train on 30000 samples
Train on 30000 samples
Communication round 3:
Train on 30000 samples
Train on 30000 samples
Communication round 4:
Train on 30000 samples
Train on 30000 samples


In [4]:
#Evaluate performance
model1.evaluate(x_test, y_test, verbose=2)
model2.evaluate(x_test, y_test, verbose=2)
print("Performance should be identical")

10000/10000 - 0s - loss: 0.0871 - accuracy: 0.9742
10000/10000 - 0s - loss: 0.0871 - accuracy: 0.9742
Performance should be identical
