In [1]:
# Util import 

from caller_2 import *
from tensorflow.keras import backend as K

In [2]:

# Dict. pair creation for the Federated modeling  with input as 256 vector and output as 365 vector

obj = Reader('Nizamabad')
clients = obj.create_clients()

clients_train_max, clients_test_max= dict(), dict()
clients_train_min, clients_test_min= dict(), dict()

n_steps_in,n_steps_out = 256,365
for (client_name, data) in clients.items():
    clients_train_min[client_name],clients_test_min[client_name] = obj.data_loader(data, n_steps_in, n_steps_out, feaat='Min Temp (°C)')
    clients_train_max[client_name],clients_test_max[client_name] = obj.data_loader(data, n_steps_in, n_steps_out, feaat='Max Temp (°C)')

In [3]:

# Weight scaling for global averaging
def scale_model_weights(weight, scalar):
    scalar = 1/scalar
    '''function for scaling a models weights'''
    weight_final = []
    steps = len(weight)
    for i in range(steps):
        weight_final.append(scalar * weight[i])
    return weight_final

# weight addition for appending to global model
def sum_scaled_weights(scaled_weight_list):
    '''Return the sum of the listed scaled weights. The is equivalent to scaled avg of the weights'''
    avg_grad = list()
    #get the average grad accross all client gradients
    for grad_list_tuple in zip(*scaled_weight_list):
        layer_mean = tf.math.reduce_sum(grad_list_tuple, axis=0)
        avg_grad.append(layer_mean)
    return avg_grad


# Global Test Function
def test_model(X_test_max, y_test_max, X_test_min, y_test_min,  model, comm_round):
    cce = tf.keras.losses.MeanSquaredError()
    logits = model.predict([X_test_max[0].reshape(1,256),X_test_min[0].reshape(1,256)])
    mse_1 = cce(y_test_min[0], logits[0])
    mse_2 = cce(y_test_max[0], logits[1])
    print('comm_round: {} | global_loss_Temp_min: {} | global_loss_Temp_max: {}'.format(comm_round, mse_1, mse_2))
    return mse_1, mse_2

In [4]:
from tensorflow.keras.callbacks import EarlyStopping
import random
from tensorflow.keras.models import load_model


# Parameter Setting
nEpochs = 15  # number of local rounds
comms_round =  20  # number of global round 
k=5  # Number of mandals choosen
earlystop = EarlyStopping(patience=10)

# Global Model Initialization
smlp_global = SmplMLP()
global_model = smlp_global.Conv1D_model(n_steps_in, n_steps_out)
opt=[]
for i in range(5):
    opt.append(list(clients_test_max.keys())[i])
history_local_arr = []     
global_min_arr, global_max_arr = [],[]
# global_model = load_model('models/adila/')

#randomize client data - using keys

client_names= list(clients_train_max.keys())

for comm_round in range(comms_round):
            
    global_weights = global_model.get_weights()
    
    #initial list to collect local model weights after scalling
    scaled_local_weight_list = list() 
    #loop through each client and create new local model
    for client in opt:
        smlp_local = SmplMLP()
        local_model = smlp_local.Conv1D_model(n_steps_in, n_steps_out)
        local_model.compile(optimizer='adam', loss=['mse','mse'])
        
        #set local model weight to the weight of the global model
        local_model.set_weights(global_weights)

        # Data Seperation for min and max input
        (X_train_max,y_train_max),(X_test_max,y_test_max) = clients_train_max[client],clients_test_max[client]
        (X_train_min,y_train_min),(X_test_min,y_test_min) = clients_train_min[client],clients_test_min[client]


        history_local = local_model.fit([X_train_max,X_train_min], [y_train_max,y_train_min], 
                                                                validation_data = ([X_test_max,X_test_min],[y_test_max,y_test_min]),
                                                                callbacks=[earlystop], epochs=nEpochs)
        print('Epoch done for {} '.format(client))
        history_local_arr.append(history_local)
        
        #scale the model weights and add to list
        scaled_weights = scale_model_weights(local_model.get_weights(), scalar=k)
        scaled_local_weight_list.append(scaled_weights)

        #clear session to free memory after each communication round
        K.clear_session()
    #to get the average over all the local model, we simply take the sum of the scaled weights
    average_weights = sum_scaled_weights(scaled_local_weight_list)

    print("\n")
    print("Round {} Done for all the cities".format(comm_round))  

    #update global model 
    global_model.set_weights(average_weights)

    # test global model and print out metrics after each communications round
    glob_min_temp,glob_max_temp = test_model(X_test_max, y_test_max, X_test_min, y_test_min, global_model, comm_round)
    global_min_arr.append(glob_min_temp)
    global_max_arr.append(glob_max_temp)
    print("\n")
    print("\n")

Epoch 1/15
Epoch 2/15
Epoch 3/15
Epoch 4/15
Epoch 5/15
Epoch 6/15
Epoch 7/15
Epoch 8/15
Epoch 9/15
Epoch 10/15
Epoch 11/15
Epoch 12/15
Epoch 13/15
Epoch 14/15
Epoch 15/15
Epoch done for Ranjal 
Epoch 1/15
Epoch 2/15
Epoch 3/15
Epoch 4/15
Epoch 5/15
Epoch 6/15
Epoch 7/15
Epoch 8/15
Epoch 9/15
Epoch 10/15
Epoch 11/15
Epoch 12/15
Epoch 13/15
Epoch 14/15
Epoch 15/15
Epoch done for Dharpalle 
Epoch 1/15
Epoch 2/15
Epoch 3/15
Epoch 4/15
Epoch 5/15
Epoch 6/15
Epoch 7/15
Epoch 8/15
Epoch 9/15
Epoch 10/15
Epoch 11/15
Epoch 12/15
Epoch 13/15
Epoch 14/15
Epoch 15/15
Epoch done for Dichpalle 
Epoch 1/15
Epoch 2/15
Epoch 3/15
Epoch 4/15
Epoch 5/15
Epoch 6/15
Epoch 7/15
Epoch 8/15
Epoch 9/15
Epoch 10/15
Epoch 11/15
Epoch 12/15
Epoch 13/15
Epoch 14/15
Epoch 15/15
Epoch done for Vailpur 
Epoch 1/15
Epoch 2/15
Epoch 3/15
Epoch 4/15
Epoch 5/15
Epoch 6/15
Epoch 7/15
Epoch 8/15
Epoch 9/15
Epoch 10/15
Epoch 11/15
Epoch 12/15
Epoch 13/15
Epoch 14/15
Epoch 15/15
Epoch done for Sirkonda 


Round 0 Done for al

In [5]:
global_model.save('models/nizam/global')
local_model.save('models/nizam/local')





INFO:tensorflow:Assets written to: models/nizam/global\assets


INFO:tensorflow:Assets written to: models/nizam/global\assets


INFO:tensorflow:Assets written to: models/nizam/local\assets


INFO:tensorflow:Assets written to: models/nizam/local\assets
