In [21]:
import flwr as fl
import tensorflow as tf
from tensorflow import keras
from keras import layers
import utils as ut
import numpy as np
from typing import Dict, Optional, Tuple

In [22]:
def get_model():
    model = tf.keras.models.Sequential(
        [
            layers.Input(shape=(8,)),  
            layers.Dense(64, activation='relu'),
            layers.Dense(64, activation='relu'),  
            layers.Dense(1)  
        ]
    )
    model.compile(optimizer='adam', loss='mean_squared_error', metrics=['mae', 'mse'])
    return model

In [23]:
class CifarClient(fl.client.NumPyClient):
    def __init__(self, model, x_train, y_train, x_test, y_test):
        self.model = model
        self.x_train, self.y_train = x_train, y_train
        self.x_test, self.y_test = x_test, y_test

    def fit(self, parameters, config):
        """Train parameters on the locally held training set."""

        # Update local model parameters
        self.model.set_weights(parameters)

        # Get hyperparameters for this round
        batch_size: int = config["batch_size"]
        epochs: int = config["local_epochs"]

        # Train the model using hyperparameters from config
        history = self.model.fit(
            self.x_train,
            self.y_train,
            batch_size=batch_size,
            epochs=epochs,
            validation_split=0.2,
        )

        # Return updated model parameters, number of examples trained, and results
        parameters_prime = self.model.get_weights()
        num_examples_train = len(self.x_train)
        results = {
            "mse": history.history["mse"][-1],  # Use the last epoch's MSE
            "mae": history.history["mae"][-1],  # Use the last epoch's MAE
            "val_mse": history.history["val_mse"][-1],  # Use the last epoch's validation MSE
            "val_mae": history.history["val_mae"][-1],  # Use the last epoch's validation MAE
        }

        return parameters_prime, num_examples_train, results

    def get_parameters(self):
        """Get the current parameters of the local model."""
        return self.model.get_weights()

    def evaluate(self, parameters, config):
        """Evaluate parameters on the locally held test set."""
        self.model.set_weights(parameters)
        loss = tf.keras.losses.mean_squared_error(self.y_test, self.model.predict(self.x_test)).numpy().mean().item()
        mae = tf.keras.metrics.mean_absolute_error(self.y_test, self.model.predict(self.x_test)).numpy().mean().item()
        print("*************LOSS******************",loss)
        print("************MAE********************",mae)
        return loss, len(self.x_train), {"mae": mae}

    def get_weights(self):
        """Get the current weights of the local model."""
        return self.model.get_weights()

    def set_weights(self, weights):
        """Set the weights of the local model."""
        self.model.set_weights(weights)

In [24]:
(x_train, y_train), (x_test, y_test) = ut.partition_dataset(2,4,0)

client = CifarClient(get_model(), x_train, y_train, x_test, y_test)

history = fl.client.start_numpy_client(
server_address="127.0.0.1:8080",
client=client,
)

INFO flwr 2023-11-23 01:51:07,977 | grpc.py:49 | Opened insecure gRPC connection (no certificates were passed)
DEBUG flwr 2023-11-23 01:51:07,996 | connection.py:42 | ChannelConnectivity.IDLE
DEBUG flwr 2023-11-23 01:51:07,998 | connection.py:42 | ChannelConnectivity.READY


Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Epoch 24/50
Epoch 25/50
Epoch 26/50
Epoch 27/50
Epoch 28/50
Epoch 29/50
Epoch 30/50
Epoch 31/50
Epoch 32/50
Epoch 33/50
Epoch 34/50
Epoch 35/50
Epoch 36/50
Epoch 37/50
Epoch 38/50
Epoch 39/50
Epoch 40/50
Epoch 41/50
Epoch 42/50
Epoch 43/50
Epoch 44/50
Epoch 45/50
Epoch 46/50
Epoch 47/50
Epoch 48/50
Epoch 49/50
Epoch 50/50
*************LOSS****************** 2.1078832149505615
************MAE******************** 1.1092127561569214
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/

DEBUG flwr 2023-11-23 01:53:46,384 | connection.py:139 | gRPC channel closed
INFO flwr 2023-11-23 01:53:46,386 | app.py:215 | Disconnect and shut down


In [25]:
(x_train, y_train), (x_test, y_test) = ut.partition_dataset(0,4,0)

client = CifarClient(get_model(), x_train, y_train, x_test, y_test)

history = fl.client.start_numpy_client(
server_address="127.0.0.1:8081",
client=client,
)

INFO flwr 2023-11-23 01:54:53,705 | grpc.py:49 | Opened insecure gRPC connection (no certificates were passed)
DEBUG flwr 2023-11-23 01:54:53,724 | connection.py:42 | ChannelConnectivity.IDLE
DEBUG flwr 2023-11-23 01:54:53,727 | connection.py:42 | ChannelConnectivity.READY


Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Epoch 24/50
Epoch 25/50
Epoch 26/50
Epoch 27/50
Epoch 28/50
Epoch 29/50
Epoch 30/50
Epoch 31/50
Epoch 32/50
Epoch 33/50
Epoch 34/50
Epoch 35/50
Epoch 36/50
Epoch 37/50
Epoch 38/50
Epoch 39/50
Epoch 40/50
Epoch 41/50
Epoch 42/50
Epoch 43/50
Epoch 44/50
Epoch 45/50
Epoch 46/50
Epoch 47/50
Epoch 48/50
Epoch 49/50
Epoch 50/50
*************LOSS****************** 2.2793779373168945
************MAE******************** 1.171873927116394
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/1

DEBUG flwr 2023-11-23 01:57:16,659 | connection.py:139 | gRPC channel closed
INFO flwr 2023-11-23 01:57:16,661 | app.py:215 | Disconnect and shut down


In [26]:
(x_train, y_train), (x_test, y_test) = ut.partition_dataset(-2,4,0)

client = CifarClient(get_model(), x_train, y_train, x_test, y_test)

history = fl.client.start_numpy_client(
server_address="127.0.0.1:8082",
client=client,
)

INFO flwr 2023-11-23 01:59:51,626 | grpc.py:49 | Opened insecure gRPC connection (no certificates were passed)
DEBUG flwr 2023-11-23 01:59:51,646 | connection.py:42 | ChannelConnectivity.IDLE
DEBUG flwr 2023-11-23 01:59:51,648 | connection.py:42 | ChannelConnectivity.READY


Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Epoch 24/50
Epoch 25/50
Epoch 26/50
Epoch 27/50
Epoch 28/50
Epoch 29/50
Epoch 30/50
Epoch 31/50
Epoch 32/50
Epoch 33/50
Epoch 34/50
Epoch 35/50
Epoch 36/50
Epoch 37/50
Epoch 38/50
Epoch 39/50
Epoch 40/50
Epoch 41/50
Epoch 42/50
Epoch 43/50
Epoch 44/50
Epoch 45/50
Epoch 46/50
Epoch 47/50
Epoch 48/50
Epoch 49/50
Epoch 50/50
*************LOSS****************** 2.3105416297912598
************MAE******************** 1.1772518157958984
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/

DEBUG flwr 2023-11-23 02:02:12,512 | connection.py:139 | gRPC channel closed


 2.4310128688812256
************MAE******************** 1.2073485851287842


INFO flwr 2023-11-23 02:02:12,513 | app.py:215 | Disconnect and shut down
