In [13]:
import flwr as fl
import tensorflow as tf
from tensorflow import keras
from keras import layers
import utils as ut
import numpy as np
from typing import Dict, Optional, Tuple

In [14]:
def get_model():
    model = tf.keras.models.Sequential(
        [
            layers.Input(shape=(8,)),  
            layers.Dense(64, activation='relu'),
            layers.Dense(64, activation='relu'),  
            layers.Dense(1)  
        ]
    )
    model.compile(optimizer='adam', loss='mean_squared_error', metrics=['mae', 'mse'])
    return model

In [15]:
class CifarClient(fl.client.NumPyClient):
    def __init__(self, model, x_train, y_train, x_test, y_test):
        self.model = model
        self.x_train, self.y_train = x_train, y_train
        self.x_test, self.y_test = x_test, y_test

    def fit(self, parameters, config):
        """Train parameters on the locally held training set."""

        # Update local model parameters
        self.model.set_weights(parameters)

        # Get hyperparameters for this round
        batch_size: int = config["batch_size"]
        epochs: int = config["local_epochs"]

        # Train the model using hyperparameters from config
        history = self.model.fit(
            self.x_train,
            self.y_train,
            batch_size=batch_size,
            epochs=epochs,
            validation_split=0.2,
        )

        # Return updated model parameters, number of examples trained, and results
        parameters_prime = self.model.get_weights()
        num_examples_train = len(self.x_train)
        results = {
            "mse": history.history["mse"][-1],  # Use the last epoch's MSE
            "mae": history.history["mae"][-1],  # Use the last epoch's MAE
            "val_mse": history.history["val_mse"][-1],  # Use the last epoch's validation MSE
            "val_mae": history.history["val_mae"][-1],  # Use the last epoch's validation MAE
        }

        return parameters_prime, num_examples_train, results

    def get_parameters(self):
        """Get the current parameters of the local model."""
        return self.model.get_weights()

    def evaluate(self, parameters, config):
        """Evaluate parameters on the locally held test set."""
        self.model.set_weights(parameters)
        loss = tf.keras.losses.mean_squared_error(self.y_test, self.model.predict(self.x_test)).numpy().mean().item()
        mae = tf.keras.metrics.mean_absolute_error(self.y_test, self.model.predict(self.x_test)).numpy().mean().item()
        print("*************LOSS******************",loss)
        print("************MAE********************",mae)
        return loss, len(self.x_train), {"mae": mae}

    def get_weights(self):
        """Get the current weights of the local model."""
        return self.model.get_weights()

    def set_weights(self, weights):
        """Set the weights of the local model."""
        self.model.set_weights(weights)

In [16]:
(x_train, y_train), (x_test, y_test) = ut.partition_dataset(2,4,0)

client = CifarClient(get_model(), x_train, y_train, x_test, y_test)

history = fl.client.start_numpy_client(
server_address="127.0.0.1:8080",
client=client,
)

INFO flwr 2023-11-23 01:42:36,479 | grpc.py:49 | Opened insecure gRPC connection (no certificates were passed)
DEBUG flwr 2023-11-23 01:42:36,498 | connection.py:42 | ChannelConnectivity.IDLE
DEBUG flwr 2023-11-23 01:42:36,502 | connection.py:42 | ChannelConnectivity.READY


*************LOSS****************** 2.9103686809539795
************MAE******************** 1.2954483032226562
Epoch 1/2
Epoch 2/2
*************LOSS****************** 1.3876734972000122
************MAE******************** 0.9221519827842712
*************LOSS****************** 1.494423508644104
************MAE******************** 0.9639512300491333
*************LOSS****************** 1.8109973669052124
************MAE******************** 1.0409014225006104
Epoch 1/2
Epoch 2/2
*************LOSS****************** 2.0660860538482666
************MAE******************** 1.1044645309448242
*************LOSS****************** 2.0793609619140625
************MAE******************** 1.1045624017715454
Epoch 1/2
Epoch 2/2
*************LOSS****************** 2.0867295265197754
************MAE******************** 1.1128555536270142
*************LOSS****************** 2.0804426670074463
************MAE******************** 1.1065306663513184
Epoch 1/2
Epoch 2/2
*************LOSS****************** 2.106

DEBUG flwr 2023-11-23 01:43:14,809 | connection.py:139 | gRPC channel closed
INFO flwr 2023-11-23 01:43:14,811 | app.py:215 | Disconnect and shut down


In [20]:
(x_train, y_train), (x_test, y_test) = ut.partition_dataset(0,4,0)

client = CifarClient(get_model(), x_train, y_train, x_test, y_test)

history = fl.client.start_numpy_client(
server_address="127.0.0.1:8082",
client=client,
)

INFO flwr 2023-11-23 01:45:52,471 | grpc.py:49 | Opened insecure gRPC connection (no certificates were passed)
DEBUG flwr 2023-11-23 01:45:52,490 | connection.py:42 | ChannelConnectivity.IDLE
DEBUG flwr 2023-11-23 01:45:52,493 | connection.py:42 | ChannelConnectivity.READY


*************LOSS****************** 2.0978851318359375
************MAE******************** 1.115322470664978
Epoch 1/2
Epoch 2/2
*************LOSS****************** 2.1095409393310547
************MAE******************** 1.1189020872116089
*************LOSS****************** 2.103264808654785
************MAE******************** 1.1168915033340454
Epoch 1/2
Epoch 2/2
*************LOSS****************** 2.119640588760376
************MAE******************** 1.1234972476959229
Epoch 1/2
Epoch 2/2
*************LOSS****************** 2.1284210681915283
************MAE******************** 1.124629259109497
Epoch 1/2
Epoch 2/2
*************LOSS****************** 2.1420881748199463
************MAE******************** 1.1313738822937012
Epoch 1/2
Epoch 2/2
*************LOSS****************** 2.1244852542877197
************MAE******************** 1.1255003213882446
Epoch 1/2
Epoch 2/2
*************LOSS****************** 2.13070011138916
************MAE******************** 1.1263831853866577
******

DEBUG flwr 2023-11-23 01:46:33,097 | connection.py:139 | gRPC channel closed
INFO flwr 2023-11-23 01:46:33,099 | app.py:215 | Disconnect and shut down


In [None]:
(x_train, y_train), (x_test, y_test) = ut.partition_dataset(-2,4,0)

client = CifarClient(get_model(), x_train, y_train, x_test, y_test)

history = fl.client.start_numpy_client(
server_address="127.0.0.1:8082",
client=client,
)

None
