In [3]:
#Code taken and modified from https://keras.io/examples/vision/mnist_convnet/ from user https://twitter.com/fchollet creator of Keras

import numpy as np
from tensorflow import keras
from tensorflow.keras import layers
from time import perf_counter 

In [8]:
# Model / data parameters
num_classes = 10
input_shape = (28, 28, 1)

# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# Scale images to the [0, 1] range
x_tune = x_train.astype("float32") / 255
x_test = x_test.astype("float32") / 255
# Make sure images have shape (28, 28, 1)
x_tune = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)

x_tune = x_tune[50000:]

print(x_test.shape[0], "test samples")
print(x_tune.shape[0], "tune samples")

y_tune = y_train[50000:]

# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, num_classes)
y_tune = keras.utils.to_categorical(y_tune, num_classes)

10000 test samples
6000 tune samples


In [5]:
model = keras.Sequential(
    [
        keras.Input(shape=input_shape),
        layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
        layers.MaxPooling2D(pool_size=(2, 2)),
        layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
        layers.MaxPooling2D(pool_size=(2, 2)),
        layers.Flatten(),
        layers.Dropout(0.5),
        layers.Dense(num_classes, activation="softmax"),
    ]
)

model.summary()

model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["categorical_accuracy"])

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d (Conv2D)              (None, 26, 26, 32)        320       
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 13, 13, 32)        0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 11, 11, 64)        18496     
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 5, 5, 64)          0         
_________________________________________________________________
flatten (Flatten)            (None, 1600)              0         
_________________________________________________________________
dropout (Dropout)            (None, 1600)              0         
_________________________________________________________________
dense (Dense)                (None, 10)                1

In [None]:
#server setup stuff
import socket
import time
import pickle

HEADERSIZE = 10

#Networking code modified from https://pythonprogramming.net/pickle-objects-sockets-tutorial-python-3/

In [None]:
def get_msg(connection,addrs):
  full_msg = b''
  new_msg = True
  in_progress = true;
  while in_progress:
      msg = connection.recv(1024)
      if new_msg:
          #print("new msg len:",msg[:HEADERSIZE])
          msglen = int(msg[:HEADERSIZE])
          new_msg = False

      #print(f"full message length: {msglen}")

      full_msg += msg

      #print(len(full_msg))

      if len(full_msg)-HEADERSIZE == msglen:
          print("full msg recvd from ", addrs)
          #print(full_msg[HEADERSIZE:])
          #print(pickle.loads(full_msg[HEADERSIZE:]))
          return pickle.loads(full_msg[HEADERSIZE:])


In [None]:
def calculate_weighted_avg(weights_m,loss_m):
  weighted_total = np.zeros_like(weights_m[2])
  
  for weight in weights_m:
    weighted_total = np.add(weighted_total, weight)

  return weighted_total/len(weights_m)


In [None]:
HEADERSIZE = 10

s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind((socket.gethostname(), 4321))
s.listen(2)

connections = []
addresses = []
conn = 0;

while conn < 2:
  c, addr = s.accept()     # Establish connection with client.
  connections.append(c)
  addresses.append(addr)
  conn = conn + 1

stop_msg = [False,False]

weights = [np.array(),np.array(),np.array(model.get_weights())]
errors = [0.0,0.0,0.0]

sync_num = 0

# Start the stopwatch / counter 
t1_start = perf_counter() 

while !stop_msg[0] and !stop_msg[1]:
  response = get_msg(connections[0],addresses[0])
  weights[0] = response[0];
  losses[0] = response[1];
  response = get_msg(connections[1],addresses[1])
  weights[1] = np.array(response[0]);
  losses[1] = response[1];

  if weight[0] == "DONE" && weight[1] == "DONE":
    stop_msg[0] = True
    stop_msg[1] = True
  elif weight[0] == "DONE" || weight[1] == "DONE"::
    print("mismatch ending terminating loop early")
    break
  else:
    score = model.evaluate(x_tune,y_tune)
    losses[2] = score[0]
    weights[2] = np.array(model.get_weights())
    new_weights = calculate_weighted_avg(weights,losses)
    model.set_weights(new_weights)
    print("Tune loss:", score[0])
    print("Tune accuracy:", score[1])
    p_w = pickle.dumps(np.array(model.get_weights()))
    msg = bytes(f"{len(msg):<{HEADERSIZE}}", 'utf-8')+p_w
    connections[0].send(msg)
    connections[1].send(msg)
    print("Sync_Iteration:", sync_num)


# Stop the stopwatch / counter 
t1_stop = perf_counter() 

s.close()

print("Elapsed time during the whole program in seconds:", 
                                        t1_stop-t1_start) 



In [None]:
# Start the stopwatch / counter 
t2_start = perf_counter() 

score = model.evaluate(x_test, y_test, verbose=0)
# Stop the stopwatch / counter 
t2_stop = perf_counter() 

print("Elapsed time during the whole program in seconds:", 
                                        t2_stop-t2_start) 

print("Test loss:", score[0])
print("Test accuracy:", score[1])