In [1]:
import pandas as pd
import time
import socket
import threading
import numpy as np
import matplotlib.pyplot as plt
from keras.models import Sequential
from tensorflow import keras as K
from keras.layers import Dense
from tensorflow.keras.optimizers import Adam
from sklearn.metrics import mean_squared_error

In [2]:
server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server.bind(('localhost', 12345))
server.listen(5)
num_of_clients = 3
models = []

In [3]:
data = pd.read_csv('dataset.csv', index_col='district_name')
X, y = data.iloc[:,:-5].to_numpy(), data.iloc[:,-5:].to_numpy()

In [4]:
model = Sequential()
model.add(Dense(2, input_dim=X.shape[1], activation='relu'))
model.add(Dense(40, activation='relu'))
model.add(Dense(20, activation='relu'))
model.add(Dense(y.shape[1], activation='linear'))
model.compile(loss='mean_squared_error', optimizer=Adam(learning_rate=0.01))

In [5]:
# Randomly populate the initial model
weights = model.get_weights()
for i in range(len(weights)):
    weights[i] = np.random.random(weights[i].shape)
model.set_weights(weights)

# Save the initial model to a file
model.save('models/model.h5', save_format='h5')

In [6]:
def averaging_thread():
    while True:
        if len(models) == num_of_clients:
            global_weights = model.get_weights()
            new_weights = []
            for i in range(len(global_weights)):
                weight_sum = np.zeros(global_weights[i].shape)
                for j in range(len(models)):
                    weight_sum += models[j].get_weights()[i]
                new_weights.append(weight_sum / len(models))
            model.set_weights(new_weights)
            model.save('models/model.h5', save_format='h5')
            print("global model updated!")
            models.clear()

In [7]:
def send_model(client):
    with open('models/model.h5', 'rb') as f:
        data = f.read()
    client.sendall(data)
    client.sendall(b'END')

In [8]:
def check_client_start(client):
    message = client.recv(4096).decode()
    return True if message == "model_initialized" else False

In [9]:
def receive_and_save(client, addr):
    data = b''
    while True:
        packet = client.recv(4096)
        data += packet
        if b'END' in data:
                with open(f'models/client_model_{addr}.h5', 'wb') as f:
                    f.write(data[:data.index(b'END')])
                client_model = K.models.load_model(f'models/client_model_{addr}.h5')
                models.append(client_model)
                break

In [10]:
#Create a new thread for each client
def client_thread(client, addr):
    print("Client connected from: ",addr)
    send_model(client)
    start = check_client_start(client)
    while start:
        message = client.recv(4096).decode()
        if message == "request_global_model":
            send_model(client)
        time.sleep(1)
        client.sendall("send_updated_model".encode())
        receive_and_save(client, addr)
        time.sleep(5)
    client.shutdown(socket.SHUT_WR)
    client.close()

In [11]:
# Start the server
server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server.bind(('localhost', 12345))
server.listen(5)
print("Server listening...")

averaging_thread = threading.Thread(target=averaging_thread)
averaging_thread.start()

thread = []
for i in range(num_of_clients):
    client, addr = server.accept()
    thread.append(threading.Thread(target=client_thread, args=(client, addr)))
    thread[i].start()

#Close the server socket
server.close()

Server listening...
Client connected from:  ('127.0.0.1', 60936)
Client connected from:  ('127.0.0.1', 60957)
Client connected from:  ('127.0.0.1', 60959)
global model updated!
global model updated!
global model updated!
global model updated!
global model updated!
global model updated!
global model updated!
global model updated!
global model updated!
global model updated!
global model updated!
global model updated!
global model updated!
global model updated!
global model updated!
global model updated!
global model updated!
global model updated!
global model updated!
global model updated!
global model updated!
global model updated!
global model updated!
global model updated!
global model updated!
global model updated!
global model updated!
global model updated!
global model updated!
global model updated!
global model updated!
global model updated!
global model updated!
global model updated!
global model updated!
global model updated!
global model updated!
global model updated!
global mo