In [11]:
import flwr as fl
import tensorflow as tf
from tensorflow import keras
from keras import layers
import utils as ut
import numpy as np
from typing import Dict, Optional, Tuple

In [12]:
def get_model():
    model = tf.keras.models.Sequential(
        [
            layers.Input(shape=(8,)),  
            layers.Dense(64, activation='relu'),
            layers.Dense(64, activation='relu'),  
            layers.Dense(1)  
        ]
    )
    model.compile(optimizer='adam', loss='mean_squared_error', metrics=['mae', 'mse'])
    return model

In [13]:
class CifarClient(fl.client.NumPyClient):
    def __init__(self, model, x_train, y_train, x_test, y_test):
        self.model = model
        self.x_train, self.y_train = x_train, y_train
        self.x_test, self.y_test = x_test, y_test

    def fit(self, parameters, config):
        """Train parameters on the locally held training set."""

        # Update local model parameters
        self.model.set_weights(parameters)

        # Get hyperparameters for this round
        batch_size: int = config["batch_size"]
        epochs: int = config["local_epochs"]

        # Train the model using hyperparameters from config
        history = self.model.fit(
            self.x_train,
            self.y_train,
            batch_size=batch_size,
            epochs=epochs,
            validation_split=0.2,
        )

        # Return updated model parameters, number of examples trained, and results
        parameters_prime = self.model.get_weights()
        num_examples_train = len(self.x_train)
        results = {
            "mse": history.history["mse"][-1],  # Use the last epoch's MSE
            "mae": history.history["mae"][-1],  # Use the last epoch's MAE
            "val_mse": history.history["val_mse"][-1],  # Use the last epoch's validation MSE
            "val_mae": history.history["val_mae"][-1],  # Use the last epoch's validation MAE
        }

        return parameters_prime, num_examples_train, results

    def get_parameters(self):
        """Get the current parameters of the local model."""
        return self.model.get_weights()

    def evaluate(self, parameters, config):
        """Evaluate parameters on the locally held test set."""
        self.model.set_weights(parameters)
        loss = tf.keras.losses.mean_squared_error(self.y_test, self.model.predict(self.x_test)).numpy().mean().item()
        mae = tf.keras.metrics.mean_absolute_error(self.y_test, self.model.predict(self.x_test)).numpy().mean().item()
        print("*************LOSS******************",loss)
        print("************MAE********************",mae)
        return loss, len(self.x_train), {"mae": mae}

    def get_weights(self):
        """Get the current weights of the local model."""
        return self.model.get_weights()

    def set_weights(self, weights):
        """Set the weights of the local model."""
        self.model.set_weights(weights)

In [14]:
(x_train, y_train), (x_test, y_test) = ut.partition_dataset(0,4,2)

client = CifarClient(get_model(), x_train, y_train, x_test, y_test)

history = fl.client.start_numpy_client(
server_address="127.0.0.1:8080",
client=client,
)

INFO flwr 2023-11-23 01:42:53,621 | grpc.py:49 | Opened insecure gRPC connection (no certificates were passed)
DEBUG flwr 2023-11-23 01:42:53,642 | connection.py:42 | ChannelConnectivity.IDLE
DEBUG flwr 2023-11-23 01:42:53,644 | connection.py:42 | ChannelConnectivity.READY


*************LOSS****************** 2.959141492843628
************MAE******************** 1.311035394668579
*************LOSS****************** 1.3971108198165894
************MAE******************** 0.9196226596832275
Epoch 1/2
Epoch 2/2
*************LOSS****************** 1.5016642808914185
************MAE******************** 0.964270293712616
Epoch 1/2
Epoch 2/2
*************LOSS****************** 1.821413516998291
************MAE******************** 1.047902226448059
Epoch 1/2
Epoch 2/2
*************LOSS****************** 2.0762386322021484
************MAE******************** 1.1131839752197266
Epoch 1/2
Epoch 2/2
*************LOSS****************** 2.0832161903381348
************MAE******************** 1.1128684282302856
*************LOSS****************** 2.1087265014648438
************MAE******************** 1.1279268264770508
Epoch 1/2
Epoch 2/2
*************LOSS****************** 2.0910847187042236
************MAE******************** 1.1183719635009766
Epoch 1/2
Epoch 2/2
*****

DEBUG flwr 2023-11-23 01:43:14,808 | connection.py:139 | gRPC channel closed
INFO flwr 2023-11-23 01:43:14,810 | app.py:215 | Disconnect and shut down


In [16]:
(x_train, y_train), (x_test, y_test) = ut.partition_dataset(0,4,2)

client = CifarClient(get_model(), x_train, y_train, x_test, y_test)

history = fl.client.start_numpy_client(
server_address="127.0.0.1:8082",
client=client,
)

INFO flwr 2023-11-23 01:46:09,824 | grpc.py:49 | Opened insecure gRPC connection (no certificates were passed)
DEBUG flwr 2023-11-23 01:46:09,845 | connection.py:42 | ChannelConnectivity.IDLE
DEBUG flwr 2023-11-23 01:46:09,849 | connection.py:42 | ChannelConnectivity.CONNECTING
DEBUG flwr 2023-11-23 01:46:09,851 | connection.py:42 | ChannelConnectivity.READY


*************LOSS****************** 2.0978851318359375
************MAE******************** 1.115322470664978
*************LOSS****************** 2.1095409393310547
************MAE******************** 1.1189020872116089
Epoch 1/2
Epoch 2/2
*************LOSS****************** 2.103264808654785
************MAE******************** 1.1168915033340454
Epoch 1/2
Epoch 2/2
*************LOSS****************** 2.119640588760376
************MAE******************** 1.1234972476959229
Epoch 1/2
Epoch 2/2
*************LOSS****************** 2.1284210681915283
************MAE******************** 1.124629259109497
Epoch 1/2
Epoch 2/2
*************LOSS****************** 2.1420881748199463
************MAE******************** 1.1313738822937012
Epoch 1/2
Epoch 2/2
*************LOSS****************** 2.1244852542877197
************MAE******************** 1.1255003213882446
*************LOSS****************** 2.13070011138916
************MAE******************** 1.1263831853866577
Epoch 1/2
Epoch 2/2
******

DEBUG flwr 2023-11-23 01:46:33,099 | connection.py:139 | gRPC channel closed
INFO flwr 2023-11-23 01:46:33,100 | app.py:215 | Disconnect and shut down


In [None]:
(x_train, y_train), (x_test, y_test) = ut.partition_dataset(-2,4,2)

client = CifarClient(get_model(), x_train, y_train, x_test, y_test)

history = fl.client.start_numpy_client(
server_address="127.0.0.1:8082",
client=client,
)

In [None]:
print(history)

NameError: name 'histories' is not defined