In [1]:
import warnings
import flwr as fl
import numpy as np
import utils
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import log_loss

if __name__ == "__main__":
    # Load MNIST dataset from https://www.openml.org/d/554
    (X_train, y_train), (X_test, y_test) = utils.load_mnist()

    # Split train set into 10 partitions and randomly use one for training.
    partition_id = np.random.choice(10)
    (X_train, y_train) = utils.partition(X_train, y_train, 10)[partition_id]

    # Create LogisticRegression Model
    model = LogisticRegression(
        penalty="l2",
        max_iter=1,  # local epoch
        warm_start=True,  # prevent refreshing weights when fitting
    )

    # Setting initial parameters, akin to model.compile for keras models
    utils.set_initial_params(model)

    #要記住數值的
    arr = 'c2 '

    # Define Flower client
    class MnistClient(fl.client.NumPyClient):

        def get_parameters(self):  # type: ignore
            return utils.get_model_parameters(model)

        def fit(self, parameters, config):  # type: ignore
            """
            #server端的正確率
            server_ac=model.score(X_test, y_test)
            server_arr.append(server_ac)
            """

            utils.set_model_params(model, parameters)
            # Ignore convergence failure due to low local epochs
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                model.fit(X_train, y_train)
            print(f"Training finished for round {config['rnd']}")

            global arr
            # 存住 正確率
            ac = model.score(X_test, y_test)
            arr = arr + str(ac) + ' '

            return utils.get_model_parameters(model), len(X_train), {}

        def evaluate(self, parameters, config):  # type: ignore
            utils.set_model_params(model, parameters)
            loss = log_loss(y_test, model.predict_proba(X_test))
            accuracy = model.score(X_test, y_test)
            return loss, len(X_test), {"accuracy": accuracy}

    # Start Flower client
    fl.client.start_numpy_client("0.0.0.0:8080", client=MnistClient())

    f = open("output.txt", "a")
    f.writelines(arr)
    f.writelines("\n")

    # 關閉檔案
    f.close()

INFO flower 2022-04-14 10:59:32,291 | connection.py:102 | Opened insecure gRPC connection (no certificates were passed)
DEBUG flower 2022-04-14 10:59:32,295 | connection.py:39 | ChannelConnectivity.IDLE
DEBUG flower 2022-04-14 10:59:32,298 | connection.py:39 | ChannelConnectivity.CONNECTING
DEBUG flower 2022-04-14 10:59:32,304 | connection.py:39 | ChannelConnectivity.READY


[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]
[7. 2. 1. ... 4. 5. 6.]
Training finished for round 1
Training finished for round 2
Training finished for round 3
Training finished for round 4
Training finished for round 5


DEBUG flower 2022-04-14 10:59:50,070 | connection.py:121 | gRPC channel closed
INFO flower 2022-04-14 10:59:50,073 | app.py:101 | Disconnect and shut down
