In [8]:
# https://www.kaggle.com/code/kaanboke/beginner-friendly-end-to-end-ml-project-enjoy

import warnings

import flwr as fl

import logging

logger = logging.getLogger('flwr')

NUM_CLIENT = 1
CLIENT_INDEX = 5

In [9]:
from datasets import load_dataset
from flwr_datasets.partitioner import IidPartitioner
from sklearn.compose import ColumnTransformer

from sklearn.impute import SimpleImputer
from sklearn.preprocessing import OneHotEncoder, PowerTransformer
import pandas as pd
from sklearn.model_selection import train_test_split

def load_data ():
    data_files = ["./stroke-prediction-dataset/healthcare-dataset-stroke-data.csv"]
    dataset = load_dataset("csv", data_files=data_files)

    partitioner = IidPartitioner(num_partitions=NUM_CLIENT)
    partitioner.dataset = dataset['train']

    partition = partitioner.load_partition(partition_id=CLIENT_INDEX)
    partition = partition.to_pandas()
    
    partition = partition.drop('id', axis=1)
    
    categorical = [ 'hypertension', 'heart_disease', 'ever_married','work_type', 'Residence_type', 'smoking_status']
    numerical = ['avg_glucose_level', 'bmi','age']
    
    y= partition['stroke']
    X = partition.drop('stroke', axis=1)
    
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)

    transformer = ColumnTransformer(
        transformers=[
            ('imp', SimpleImputer(strategy='median'), numerical),
            ('ohe', OneHotEncoder(), categorical)
        ]
    )


    # Apply the ColumnTransformer to the training data
    X_train_transformed = transformer.fit_transform(X_train)
    X_test_transformed = transformer.transform(X_test)

    # Apply the PowerTransformer to the numerical features in the training data
    power_transformer = PowerTransformer(method='yeo-johnson', standardize=True)
    X_train_transformed[:, :len(numerical)] = power_transformer.fit_transform(X_train_transformed[:, :len(numerical)])
    X_test_transformed[:, :len(numerical)] = power_transformer.transform(X_test_transformed[:, :len(numerical)])

    return X_train_transformed, y_train, X_test_transformed, y_test

In [10]:
import numpy as np
from sklearn.linear_model import LogisticRegression

from flwr.common import NDArrays


def get_model_parameters(model: LogisticRegression) -> NDArrays:
    """Returns the parameters of a sklearn LogisticRegression model."""
    if model.fit_intercept:
        params = [
            model.coef_,
            model.intercept_,
        ]
    else:
        params = [
            model.coef_,
        ]
    return params


def set_model_params(model: LogisticRegression, params: NDArrays) -> LogisticRegression:
    """Sets the parameters of a sklean LogisticRegression model."""
    model.coef_ = params[0]
    if model.fit_intercept:
        model.intercept_ = params[1]
    return model


def set_initial_params(model: LogisticRegression):
    """Sets initial parameters as zeros Required since model params are uninitialized
    until model.fit is called.

    But server asks for initial parameters from clients at launch. Refer to
    sklearn.linear_model.LogisticRegression documentation for more information.
    """
    n_classes = 2  # Number of classes in dataset
    n_features = 20  # Number of features in dataset
    model.classes_ = np.array([i for i in range(n_classes)])

    model.coef_ = np.zeros((n_classes, n_features))
    if model.fit_intercept:
        model.intercept_ = np.zeros((n_classes,))

In [11]:
model = LogisticRegression(
    # penalty="l2",
    solver="liblinear"
    # max_iter=1,  # local epoch
    # warm_start=True,  # prevent refreshing weights when fitting
)

set_initial_params(model)
X_train, y_train, X_test, y_test = load_data()

In [11]:
from sklearn.metrics import log_loss

class SklearnClient(fl.client.NumPyClient):
    def get_parameters(self, config):  # type: ignore
        return get_model_parameters(model)

    def fit(self, parameters, config):  # type: ignore
        set_model_params(model, parameters)
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            model.fit(X_train, y_train)
        return get_model_parameters(model), len(X_train), {}

    def evaluate(self, parameters, config):  # type: ignore
        set_model_params(model, parameters)
        loss = log_loss(y_test, model.predict_proba(X_test))
        accuracy = model.score(X_test, y_test)
        return loss, len(X_test), {"accuracy": accuracy}

In [12]:
# fl.client.start_client(server_address="20.198.223.216:8000", client=SklearnClient().to_client())
fl.client.start_client(server_address="20.198.223.216:8000", client=SklearnClient().to_client())

[92mINFO [0m:      
[92mINFO [0m:      Received: get_parameters message 3bd8852a-c910-4327-bce8-3140949f9a4e
[92mINFO [0m:      Sent reply
[92mINFO [0m:      
[92mINFO [0m:      Received: train message 59f1b00f-dd6f-4f38-b7fa-3493ceb8b057
[92mINFO [0m:      Sent reply
[92mINFO [0m:      
[92mINFO [0m:      Received: evaluate message 3f5666f4-bd6d-40a7-9297-de4808ded61f
[92mINFO [0m:      Sent reply
[92mINFO [0m:      
[92mINFO [0m:      Received: train message f043ac79-65c0-4968-9782-ee7d8c469a55
[92mINFO [0m:      Sent reply
[92mINFO [0m:      
[92mINFO [0m:      Received: evaluate message fb8dd54b-ad04-41e0-803a-a59bd944b071
[92mINFO [0m:      Sent reply


Loss: 0.1765543232216287, Accuracy: 0.9452054794520548
Loss: 0.1765543232216287, Accuracy: 0.9452054794520548


[92mINFO [0m:      
[92mINFO [0m:      Received: train message f195a493-0efa-4a36-afac-9af351d7c6e4
[92mINFO [0m:      Sent reply
[92mINFO [0m:      
[92mINFO [0m:      Received: evaluate message c0c705ce-8e3c-458e-a2f5-a4515691cc70
[92mINFO [0m:      Sent reply
[92mINFO [0m:      
[92mINFO [0m:      Received: train message b699350a-cdeb-41cb-8b23-80171f0bd419
[92mINFO [0m:      Sent reply
[92mINFO [0m:      
[92mINFO [0m:      Received: evaluate message 25e20799-7735-4647-bb89-cc0f046e7016
[92mINFO [0m:      Sent reply


Loss: 0.1765543232216287, Accuracy: 0.9452054794520548
Loss: 0.1765543232216287, Accuracy: 0.9452054794520548


[92mINFO [0m:      
[92mINFO [0m:      Received: train message 82cdaec3-ab3c-4cdb-8226-71054790750e
[92mINFO [0m:      Sent reply
[92mINFO [0m:      
[92mINFO [0m:      Received: evaluate message 790e6e3c-ec7d-48b8-abff-46be60970194
[92mINFO [0m:      Sent reply
[92mINFO [0m:      
[92mINFO [0m:      Received: train message cdd0fe58-2848-4188-bb6a-a6aeba92efd5
[92mINFO [0m:      Sent reply
[92mINFO [0m:      
[92mINFO [0m:      Received: evaluate message fe22d854-bb54-4080-9b47-92bcc6eb53d3
[92mINFO [0m:      Sent reply


Loss: 0.1765543232216287, Accuracy: 0.9452054794520548
Loss: 0.1765543232216287, Accuracy: 0.9452054794520548


[92mINFO [0m:      
[92mINFO [0m:      Received: train message 185ccec9-adc5-4f85-b4a3-a7dca1008d5c
[92mINFO [0m:      Sent reply
[92mINFO [0m:      
[92mINFO [0m:      Received: evaluate message 7b070855-6569-4d9c-b424-a737a4325b2c
[92mINFO [0m:      Sent reply
[92mINFO [0m:      
[92mINFO [0m:      Received: train message fcd573a8-144f-4069-80e9-f6318adb5537
[92mINFO [0m:      Sent reply


Loss: 0.1765543232216287, Accuracy: 0.9452054794520548
Loss: 0.1765543232216287, Accuracy: 0.9452054794520548


[92mINFO [0m:      
[92mINFO [0m:      Received: evaluate message 8572c3a5-a3c9-4723-aaf8-0917d771fea4
[92mINFO [0m:      Sent reply
[92mINFO [0m:      
[92mINFO [0m:      Received: train message 9ec1d03d-9093-4d28-b688-488f77ba755f
[92mINFO [0m:      Sent reply
[92mINFO [0m:      
[92mINFO [0m:      Received: evaluate message ff51505c-c0da-46ae-8d2e-b9c2bbecded5
[92mINFO [0m:      Sent reply
[92mINFO [0m:      
[92mINFO [0m:      Received: train message 461cc422-a0de-4838-b77d-9829cdf3730a
[92mINFO [0m:      Sent reply
[92mINFO [0m:      
[92mINFO [0m:      Received: evaluate message 34dda1e2-c9dc-4ef6-bf93-0088a6fc9c08
[92mINFO [0m:      Sent reply


Loss: 0.1765543232216287, Accuracy: 0.9452054794520548
Loss: 0.1765543232216287, Accuracy: 0.9452054794520548


[92mINFO [0m:      
[92mINFO [0m:      Received: reconnect message 2b66cff9-6523-48bb-98c3-3a9ad28bfdce
[92mINFO [0m:      Disconnect and shut down
