In [5]:
#Code taken and modified from https://keras.io/examples/vision/mnist_convnet/ from user https://twitter.com/fchollet creator of Keras

import numpy as np
from tensorflow import keras
from tensorflow.keras import layers
from time import perf_counter 

In [6]:
# Model / data parameters
num_classes = 10
input_shape = (28, 28, 1)

# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# Scale images to the [0, 1] range
x_train = x_train.astype("float32") / 255
x_test = x_test.astype("float32") / 255
# Make sure images have shape (28, 28, 1)
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)

x_tune = x_train[50000:]
x_train = x_train[:25000]


print("x_train shape:", x_train.shape)
print(x_train.shape[0], "train samples")
print(x_test.shape[0], "test samples")

y_tune = y_train[50000:]
y_train = y_train[:25000]


# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)
y_tune = keras.utils.to_categorical(y_tune, num_classes)

x_train shape: (25000, 28, 28, 1)
25000 train samples
10000 test samples


In [7]:
model = keras.Sequential(
    [
        keras.Input(shape=input_shape),
        layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
        layers.MaxPooling2D(pool_size=(2, 2)),
        layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
        layers.MaxPooling2D(pool_size=(2, 2)),
        layers.Flatten(),
        layers.Dropout(0.5),
        layers.Dense(num_classes, activation="softmax"),
    ]
)

model.summary()


model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["categorical_accuracy"])


Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_2 (Conv2D)            (None, 26, 26, 32)        320       
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 13, 13, 32)        0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 11, 11, 64)        18496     
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 5, 5, 64)          0         
_________________________________________________________________
flatten_1 (Flatten)          (None, 1600)              0         
_________________________________________________________________
dropout_1 (Dropout)          (None, 1600)              0         
_________________________________________________________________
dense_1 (Dense)              (None, 10)               

In [None]:
import socket
import pickle

HEADERSIZE = 10

TOTAL_EPOCHs = 15

s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.connect((socket.gethostname(), 4321))

In [8]:
def get_msg(connection):
  full_msg = b''
  new_msg = True
  in_progress = true;
  while in_progress:
      msg = connection.recv(1024)
      if new_msg:
          #print("new msg len:",msg[:HEADERSIZE])
          msglen = int(msg[:HEADERSIZE])
          new_msg = False

      #print(f"full message length: {msglen}")

      full_msg += msg

      #print(len(full_msg))

      if len(full_msg)-HEADERSIZE == msglen:
          print("full msg recvd")
          #print(full_msg[HEADERSIZE:])
          #print(pickle.loads(full_msg[HEADERSIZE:]))
          return pickle.loads(full_msg[HEADERSIZE:])

In [9]:
def run_epoch_get_wnl():
  batch_size = 128
  epochs = 1
  model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs)
  score = model.evaluate(x_tune,y_tune)
  return [np.array(model.get_weights()),score[0]]


In [None]:
for x in range(TOTAL_EPOCHs):
  p_wnl = pickle.dumps(run_epoch_get_wnl())
  msg = bytes(f"{len(msg):<{HEADERSIZE}}", 'utf-8')+p_wnl
  s.send(msg)
  new_weights = np.array(get_msg(s))
  model.set_weights(new_weights)

msg = bytes(f"{len(msg):<{HEADERSIZE}}", 'utf-8')+pickle.dumps("DONE")
s.send(msg)

s.close()