In [1]:
# !pip install xgboost
# !pip install matplotlib scikit-learn tqdm torch torchmetrics torchvision torchsummary pandas
# !pip install flwr
# !pip install -U flwr["simulation"]

In [2]:
import xgboost as xgb
from xgboost import XGBClassifier, XGBRegressor
from sklearn.metrics import mean_squared_error, accuracy_score

import numpy as np
from tqdm import trange, tqdm
import flwr as fl
from flwr.common.typing import Parameters
from collections import OrderedDict
from typing import Any, Dict, List, Optional, Tuple, Union
from flwr.common import NDArray, NDArrays
from matplotlib import pyplot as plt 

import torch, torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchmetrics import Accuracy, MeanSquaredError
from tqdm import trange, tqdm
from torchsummary import summary
from torch.utils.data import DataLoader, Dataset, random_split

  from .autonotebook import tqdm as notebook_tqdm
2023-08-25 11:10:45,154	INFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


In [3]:
def construct_tree(
    # Inital 'dataset' was Dataset of pytorch
    dataset: Dataset, label: NDArray, n_estimators, tree_type: str, learning_rate=0.05, max_depth=30, booster='gbtree',random_state=0, subsample=0.9,
    colsample_bytree=0.9, alpha=5, gamma=5, min_child_weight=1, eval_metric='mlogloss', reg_lambda=1500, verbose=2020, max_leaves=30
) -> Union[XGBClassifier, XGBRegressor]:
    """Construct a xgboost tree from tabular dataset for multiclass classification."""
    if tree_type == "MULTICLASS":
        tree = XGBClassifier(
            objective="multi:softprob",
            # num_class=len(np.unique(label)),  # Number of unique classes in the label
            learning_rate=learning_rate,
            max_depth=max_depth,
            n_estimators=n_estimators,
            subsample=subsample,
            colsample_bylevel=1,
            colsample_bynode=1,
            colsample_bytree=colsample_bytree,
            alpha=alpha,
            gamma=gamma,
            num_parallel_tree=1,
            min_child_weight=min_child_weight,
            random_state=random_state,
            booster=booster,
            eval_metric=eval_metric,
            reg_lambda=reg_lambda,
            verbose=verbose,
            max_leaves=max_leaves
        )

    elif tree_type == "REG":
        tree = xgb.XGBRegressor(
            objective="reg:squarederror",
            learning_rate=0.1,
            max_depth=8,
            n_estimators=n_estimators,
            subsample=0.8,
            colsample_bylevel=1,
            colsample_bynode=1,
            colsample_bytree=1,
            alpha=5,
            gamma=5,
            num_parallel_tree=1,
            min_child_weight=1,
        )

    tree.fit(dataset, label)
    return tree

def construct_tree_from_loader(
    dataset_loader: DataLoader, n_estimators: int, tree_type: str
) -> Union[XGBClassifier, XGBRegressor]:
    """Construct a xgboost tree form tabular dataset loader."""
    for dataset in dataset_loader:
        data, label = dataset[0], dataset[1]
    print(f"Data: {data}")
    print(f"label: {label}")
    return construct_tree(data, label, n_estimators, tree_type)

def single_tree_prediction(
    tree: Union[XGBClassifier, XGBRegressor], n_tree: int, dataset: NDArray
) -> Optional[NDArray]:
    """Extract the prediction result of a single tree in the xgboost tree
    ensemble."""
    # How to access a single tree
    # https://github.com/bmreiniger/datascience.stackexchange/blob/master/57905.ipynb
    num_t = len(tree.get_booster().get_dump())
    if n_tree > num_t:
        print(
            "The tree index to be extracted is larger than the total number of trees."
        )
        return None
    prediction = tree.predict(  # type: ignore
        dataset, iteration_range=(n_tree, n_tree + 1), output_margin=True
    )
    if len(prediction) > 1:
        prediction = np.argmax(prediction, axis=1)
    return prediction



def tree_encoding(  # pylint: disable=R0914
    trainloader: DataLoader,
    client_trees: Union[
        Tuple[XGBClassifier, int],
        Tuple[XGBRegressor, int],
        List[Union[Tuple[XGBClassifier, int], Tuple[XGBRegressor, int]]],
    ],
    client_tree_num: int,
    client_num: int,
) -> Optional[Tuple[NDArray, NDArray]]:
    """Transform the tabular dataset into prediction results using the
    aggregated xgboost tree ensembles from all clients."""
    if trainloader is None:
        print("Trainloader is None")
        return None

    for local_dataset in trainloader:
        x_train, y_train = local_dataset[0], local_dataset[1]
        
    print(f"Train data: {x_train}")
    print(f"Train labels: {y_train}")
    x_train_enc = np.zeros((x_train.shape[0], client_num * client_tree_num))
    x_train_enc = np.array(x_train_enc, copy=True)

    temp_trees: Any = None
    if isinstance(client_trees, list) is False:
        print("isintance if condition - 1")
        temp_trees = [client_trees[0]] * client_num
    elif isinstance(client_trees, list) and len(client_trees) != client_num:
        print("isintance if condition - 2")
        temp_trees = [client_trees[0][0]] * client_num
    else:
        print("Else condition")
        cids = []
        temp_trees = []
        for i, _ in enumerate(client_trees):
            temp_trees.append(client_trees[i][0])  # type: ignore
            cids.append(client_trees[i][1])  # type: ignore
        sorted_index = np.argsort(np.asarray(cids))
        temp_trees = np.asarray(temp_trees)[sorted_index]
        print(cids)

    for i, _ in enumerate(temp_trees):
        for j in range(client_tree_num):
            x_train_enc[:, i * client_tree_num + j] = single_tree_prediction(
                temp_trees[i], j, x_train
            )

    x_train_enc32: Any = np.float32(x_train_enc)
    y_train32: Any = np.float32(y_train)

    x_train_enc32, y_train32 = torch.from_numpy(
        np.expand_dims(x_train_enc32, axis=1)  # type: ignore
    ), torch.from_numpy(
        np.expand_dims(y_train32, axis=-1)  # type: ignore
    )
    return x_train_enc32, y_train32

In [4]:
task_type = "MULTICLASS"

In [5]:
import pandas as pd
from preprocessing import preprocess

train_csv = '/media/Data-B/my_research/Geoscience_FL/data_well_log/train.csv'
test_csv = '/media/Data-B/my_research/Geoscience_FL/data_well_log/test_with_lables.csv'


train_data = pd.read_csv(train_csv, sep=';')
test_data = pd.read_csv(test_csv)

lithology_train = train_data['FORCE_2020_LITHOFACIES_LITHOLOGY']
lithology_test = test_data['FORCE_2020_LITHOFACIES_LITHOLOGY']

lithology_numbers = {30000: 0,
                        65030: 1,
                        65000: 2,
                        80000: 3,
                        74000: 4,
                        70000: 5,
                        70032: 6,
                        88000: 7,
                        86000: 8,
                        99000: 9,
                        90000: 10,
                        93000: 11}

lithology_train = lithology_train.map(lithology_numbers)
lithology_test = lithology_test.map(lithology_numbers)

# preprocess was changed
train_dataset = preprocess(train_data)
test_dataset = preprocess(test_data)

Shape of concatenated dataframe before dropping columns: (1170511, 29)
Shape of dataframe after dropping columns: (1170511, 23)
Shape of dataframe after label encoding columns: (1170511, 26)
Dataframe columns: Index(['DEPTH_MD', 'X_LOC', 'Y_LOC', 'Z_LOC', 'CALI', 'RSHA', 'RMED', 'RDEP',
       'RHOB', 'GR', 'NPHI', 'PEF', 'DTC', 'SP', 'BS', 'ROP', 'DCAL', 'DRHO',
       'MUDWEIGHT', 'RMIC', 'GROUP_encoded', 'FORMATION_encoded',
       'WELL_encoded'],
      dtype='object')
Shape of the dataset BEFORE augmentation: (1170511, 23)
Shape of the dataset AFTER augmentation: (1170511, 92)
Shape of concatenated dataframe before dropping columns: (136786, 28)
Shape of dataframe after dropping columns: (136786, 23)
Shape of dataframe after label encoding columns: (136786, 26)
Dataframe columns: Index(['DEPTH_MD', 'X_LOC', 'Y_LOC', 'Z_LOC', 'CALI', 'RSHA', 'RMED', 'RDEP',
       'RHOB', 'GR', 'NPHI', 'PEF', 'DTC', 'SP', 'BS', 'ROP', 'DCAL', 'DRHO',
       'MUDWEIGHT', 'RMIC', 'GROUP_encoded', 'FO

In [6]:
print(type(train_dataset))
print(type(lithology_train))

<class 'numpy.ndarray'>
<class 'pandas.core.series.Series'>


In [7]:
train_labels = lithology_train.values
test_labels = lithology_test.values
print(type(train_labels))
print(type(test_labels))

<class 'numpy.ndarray'>
<class 'numpy.ndarray'>


In [8]:
print("Feature dimension of the dataset:", train_dataset.shape[1])
print("Size of the trainset:", train_dataset.shape[0])
print("Size of the testset:", test_dataset.shape[0])

Feature dimension of the dataset: 92
Size of the trainset: 1170511
Size of the testset: 136786


In [9]:
class_num_train = len(set(train_labels))
class_num_test = len(set(test_labels))
print(class_num_train)
print(class_num_test)

12
10


In [10]:
class TreeDataset(Dataset):
    def __init__(self, data: NDArray, labels: NDArray) -> None:
        self.labels = labels
        self.data = data

    def __len__(self) -> int:
        return len(self.labels)

    def __getitem__(self, idx: int) -> Dict[int, NDArray]:
        label = self.labels[idx]
        data = self.data[idx, :]
        sample = {0: data, 1: label}
        return sample

In [11]:
trainset = TreeDataset(np.array(train_dataset, copy=True), np.array(train_labels, copy=True))
testset = TreeDataset(np.array(test_dataset, copy=True), np.array(test_labels, copy=True))

In [12]:
# global_tree = construct_tree(train_dataset, train_labels, n_estimators=100, tree_type='MULTICLASS')
# preds_train = global_tree.predict(train_dataset)
# preds_test = global_tree.predict(test_dataset)

# result_train = accuracy_score(train_labels, preds_train)
# result_test = accuracy_score(test_labels, preds_test)
# print("Global XGBoost Training Accuracy: %f" % (result_train))
# print("Global XGBoost Testing Accuracy: %f" % (result_test))


# print(global_tree)

In [13]:
def get_dataloader(
    dataset: Dataset, partition: str, batch_size: Union[int, str]
) -> DataLoader:
    if batch_size == "whole":
        batch_size = len(dataset)
    return DataLoader(
        dataset, batch_size=batch_size, pin_memory=True, shuffle=(partition == "train")
    )


# https://github.com/adap/flower
def do_fl_partitioning(
    trainset: Dataset,
    testset: Dataset,
    pool_size: int,
    batch_size: Union[int, str],
    val_ratio: float = 0.0,
) -> Tuple[DataLoader, DataLoader, DataLoader]:
    # Split training set into `num_clients` partitions to simulate different local datasets
    partition_size = len(trainset) // pool_size
    lengths = [partition_size] * pool_size
    if sum(lengths) != len(trainset):
        lengths[-1] = len(trainset) - sum(lengths[0:-1])
    datasets = random_split(trainset, lengths, torch.Generator().manual_seed(0))

    # Split each partition into train/val and create DataLoader
    trainloaders = []
    valloaders = []
    for ds in datasets:
        len_val = int(len(ds) * val_ratio)
        len_train = len(ds) - len_val
        lengths = [len_train, len_val]
        ds_train, ds_val = random_split(ds, lengths, torch.Generator().manual_seed(0))
        trainloaders.append(get_dataloader(ds_train, "train", batch_size))
        if len_val != 0:
            valloaders.append(get_dataloader(ds_val, "val", batch_size))
        else:
            valloaders = None
    testloader = get_dataloader(testset, "test", batch_size)
    if trainloaders:
        for local_dataset in trainloaders[0]:
             print(f"Trainloaders data shape: {local_dataset[0].shape}, Label shape: {local_dataset[1].shape}")
    if testloader:
        for local_dataset in testloader:
             print(f"Testloaders data shape: {local_dataset[0].shape}, Label shape: {local_dataset[1].shape}")
    return trainloaders, valloaders, testloader

In [14]:
# The number of clients participated in the federated learning
client_num = 3

# The number of XGBoost trees in the tree ensemble that will be built for each client
# client_tree_num = 300 // client_num
client_tree_num = 600 // client_num

client_trees_comparison = []
trainloader, _, testloader = do_fl_partitioning(
    trainset, testset, pool_size=client_num, batch_size="whole", val_ratio=0.0
)

# No problem with do_fl_partitioning. Check multitasking

Trainloaders data shape: torch.Size([390170, 92]), Label shape: torch.Size([390170])
Testloaders data shape: torch.Size([136786, 92]), Label shape: torch.Size([136786])


In [15]:
# client_trees_comparison = []
# trainloaders, _, testloader = do_fl_partitioning(
#     trainset, testset, pool_size=client_num, batch_size="whole", val_ratio=0.0
# )

# for i, trainloader in tqdm(enumerate(trainloaders), desc="Client Progress", total=len(trainloaders)):
#     for local_dataset in trainloader:
#         local_X_train, local_y_train = local_dataset[0], local_dataset[1]
#         tree = construct_tree(train_dataset, train_labels, n_estimators=client_tree_num, tree_type='MULTICLASS')
#         client_trees_comparison.append(tree)

#         preds_train = client_trees_comparison[-1].predict(local_X_train)
#         preds_test = client_trees_comparison[-1].predict(test_dataset)
#         result_train = accuracy_score(local_y_train, preds_train)
#         result_test = accuracy_score(test_labels, preds_test)
#         print("Local Client %d XGBoost Training Accuracy: %f" % (i, result_train))
#         print("Local Client %d XGBoost Testing Accuracy: %f" % (i, result_test))     

In [16]:
import torch.nn.functional as F
# num_classes - None
class CNN(nn.Module):
    def __init__(self, n_channel: int = 64, num_classes: int = 12, task_type: str = "MULTICLASS") -> None:
        super(CNN, self).__init__()
        n_out = num_classes
        self.task_type = task_type
        self.conv1d = nn.Conv1d(
            1, n_channel, kernel_size=client_tree_num, stride=client_tree_num, padding=0
        )
        self.layer_direct = nn.Linear(n_channel * client_num, n_out)
        self.ReLU = nn.ReLU()
        self.Softmax = nn.Softmax(dim=1)
        self.Identity = nn.Identity()

        # Add weight initialization
        for layer in self.modules():
            if isinstance(layer, nn.Linear):
                nn.init.kaiming_uniform_(
                    layer.weight, mode="fan_in", nonlinearity="relu"
                )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.ReLU(self.conv1d(x))
        x = x.flatten(start_dim=1)
        x = self.ReLU(x)
        x = self.layer_direct(x)
        if self.task_type == "MULTICLASS":
            x = self.Softmax(x)
        elif self.task_type == "REG":
            x = self.Identity(x)
        return x
    def get_weights(self) -> fl.common.NDArrays:
        """Get model weights as a list of NumPy ndarrays."""
        return [
            np.array(val.cpu().numpy(), copy=True)
            for _, val in self.state_dict().items()
        ]

    def set_weights(self, weights: fl.common.NDArrays) -> None:
        """Set model weights from a list of NumPy ndarrays."""
        layer_dict = {}
        for k, v in zip(self.state_dict().keys(), weights):
            if v.ndim != 0:
                layer_dict[k] = torch.Tensor(np.array(v, copy=True))
        state_dict = OrderedDict(layer_dict)
        self.load_state_dict(state_dict, strict=True)


def train(
    task_type: str,
    net: CNN,
    trainloader: DataLoader,
    device: torch.device,
    num_iterations: int,
    log_progress: bool = True,
) -> Tuple[float, float, int]:
    # Define loss and optimizer
    if task_type == "MULTICLASS":
        criterion = nn.CrossEntropyLoss()  # Cross-entropy loss for multiclass classification
    elif task_type == "REG":
        criterion = nn.MSELoss()

    optimizer = torch.optim.Adam(net.parameters(), lr=0.0001, betas=(0.9, 0.999))
    
    def cycle(iterable):
        """Repeats the contents of the train loader, in case it gets exhausted in 'num_iterations'."""
        while True:
            for x in iterable:
                yield x
                
                
    # Train the network
    net.train()
    total_loss, total_result, n_samples = 0.0, 0.0, 0
    pbar = (
        tqdm(iter(cycle(trainloader)), total=num_iterations, desc=f"TRAIN")
        if log_progress
        else iter(cycle(trainloader))
    )

    for i, data in zip(range(num_iterations), pbar):
        tree_outputs, labels = data[0].to(device), data[1].to(device)
        labels = labels.squeeze()
        labels = labels.to(torch.long)
        outputs = net(tree_outputs)
        log_probs = F.log_softmax(outputs, dim=1)
        optimizer.zero_grad()

        loss = criterion(log_probs, labels)
        loss.backward()
        optimizer.step()

        # Collected training loss and accuracy statistics
        total_loss += loss.item()
        n_samples += labels.size(0)

        if task_type == "MULTICLASS":
            class_predictions = torch.argmax(log_probs, dim=1)
            acc = Accuracy(task="multiclass",num_classes=12)(
                class_predictions.cpu().int(), labels.type(torch.int).cpu()
                )
            total_result += acc * labels.size(0)
        
        if log_progress:
            if task_type == "MULTICLASS":
                pbar.set_postfix(
                    {
                        "train_loss": total_loss / n_samples,
                        "train_acc": total_result / n_samples,
                    }
                )
    
    if log_progress:
        print("\n")

    return total_loss / n_samples, total_result / n_samples, n_samples

def test(
    task_type: str,
    net: CNN,
    testloader: DataLoader,
    device: torch.device,
    log_progress: bool = True,
) -> Tuple[float, float, int]:
    """Evaluates the network on test data."""
    # Define loss function for multiclass classification
    if task_type == "MULTICLASS":
        criterion = nn.CrossEntropyLoss()
    
    elif task_type == "REG":
        criterion = nn.MSELoss()

    total_loss, total_result, n_samples = 0.0, 0.0, 0
    net.eval()
    with torch.no_grad():
        print(f"Testloader: {testloader}")
        pbar = tqdm(testloader, desc="TEST") if log_progress else testloader
        for data in pbar:
            tree_outputs, labels = data[0].to(device), data[1].to(device)
            labels = labels.squeeze()
            labels = labels.to(torch.long)
            outputs = net(tree_outputs)
            log_probs = F.log_softmax(outputs, dim=1)

            # Collected testing loss and accuracy statistics
            # Error with test
            total_loss += criterion(log_probs, labels).item()
            n_samples += labels.size(0)

            if task_type == "MULTICLASS":
                class_predictions = torch.argmax(log_probs, dim=1)
                acc = Accuracy(task="multiclass", num_classes=12)(
                    class_predictions.cpu().int(), labels.type(torch.int).cpu()
                )
                total_result += acc * labels.size(0)
                
    if log_progress:
        print("\n")

    return total_loss / n_samples, total_result / n_samples, n_samples

Trainloaders data shape: torch.Size([390170, 92]), Label shape: torch.Size([390170])

Testloaders data shape: torch.Size([136786, 92]), Label shape: torch.Size([136786])

In [17]:
# Flower client
from flwr.common import (
    EvaluateIns,
    EvaluateRes,
    FitIns,
    FitRes,
    GetPropertiesIns,
    GetPropertiesRes,
    GetParametersIns,
    GetParametersRes,
    Status,
    Code,
    parameters_to_ndarrays,
    ndarrays_to_parameters,
)

In [18]:
def tree_encoding_loader(
    dataloader: DataLoader,
    batch_size: int,
    client_trees: Union[
        Tuple[XGBClassifier, int],
        Tuple[XGBRegressor, int],
        List[Union[Tuple[XGBClassifier, int], Tuple[XGBRegressor, int]]],
    ],
    client_tree_num: int,
    client_num: int,
) -> DataLoader:
    encoding = tree_encoding(dataloader, client_trees, client_tree_num, client_num)
    if encoding is None:
        print("Encoding is None")
        return None
    data, labels = encoding
    tree_dataset = TreeDataset(data, labels)
    print(f"client trees: {client_trees}")
    # Error
    print(f"tree dataset: {tree_dataset}")
    return get_dataloader(tree_dataset, "tree", batch_size)

class FL_Client(fl.client.Client):
    def __init__(
        self,
        task_type: str,
        trainloader: DataLoader,
        valloader: DataLoader,
        client_tree_num: int,
        client_num: int,
        num_classes: int,
        cid: str,
        log_progress: bool = False,
    ):
        """
        Creates a client for training `network.Net` on tabular dataset.
        """
        self.task_type = task_type
        self.num_classes = num_classes  # Store the number of classes
        self.cid = cid
        self.tree = construct_tree_from_loader(trainloader, client_tree_num, task_type)
        self.trainloader_original = trainloader
        self.valloader_original = valloader
        self.trainloader = None
        self.valloader = None
        self.client_tree_num = client_tree_num
        self.client_num = client_num
        self.properties = {"tensor_type": "numpy.ndarray"}
        self.log_progress = log_progress

        # Instantiate model with num_classes
        self.net = CNN()

        # Determine device
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    def get_properties(self, ins: GetPropertiesIns) -> GetPropertiesRes:
        return GetPropertiesRes(properties=self.properties)

    def get_parameters(
        self, ins: GetParametersIns
    ) -> Tuple[
        GetParametersRes, Union[Tuple[XGBClassifier, int], Tuple[XGBRegressor, int]]
    ]:
        return [
            GetParametersRes(
                status=Status(Code.OK, ""),
                parameters=ndarrays_to_parameters(self.net.get_weights()),
            ),
            (self.tree, int(self.cid)),
        ]

    def set_parameters(
        self,
        parameters: Tuple[
            Parameters,
            Union[
                Tuple[XGBClassifier, int],
                Tuple[XGBRegressor, int],
                List[Union[Tuple[XGBClassifier, int], Tuple[XGBRegressor, int]]],
            ],
        ],
    ) -> Union[
        Tuple[XGBClassifier, int],
        Tuple[XGBRegressor, int],
        List[Union[Tuple[XGBClassifier, int], Tuple[XGBRegressor, int]]],
    ]:
        self.net.set_weights(parameters_to_ndarrays(parameters[0]))
        return parameters[1]

    def fit(self, fit_params: FitIns) -> FitRes:
        # Process incoming request to train
        num_iterations = fit_params.config["num_iterations"]
        batch_size = fit_params.config["batch_size"]
        aggregated_trees = self.set_parameters(fit_params.parameters)

        if type(aggregated_trees) is list:
            print("Client " + self.cid + ": recieved", len(aggregated_trees), "trees")
        else:
            print("Client " + self.cid + ": only had its own tree")
        self.trainloader = tree_encoding_loader(
            self.trainloader_original,
            batch_size,
            aggregated_trees,
            self.client_tree_num,
            self.client_num,
        )
        self.valloader = tree_encoding_loader(
            self.valloader_original,
            batch_size,
            aggregated_trees,
            self.client_tree_num,
            self.client_num,
        )

        # num_iterations = None special behaviour: train(...) runs for a single epoch, however many updates it may be
        # num_iterations = num_iterations or len(self.trainloader)
        num_iterations = len(self.trainloader)

        # Train the model
        print(f"Client {self.cid}: training for {num_iterations} iterations/updates")
        self.net.to(self.device)
        train_loss, train_result, num_examples = train(
            self.task_type,
            self.net,
            self.trainloader,
            device=self.device,
            num_iterations=num_iterations,
            log_progress=self.log_progress,
        )
        print(
            f"Client {self.cid}: training round complete, {num_examples} examples processed"
        )
        if self.task_type == "MULTICLASS":
            return FitRes(
                status=Status(Code.OK, ""),
                parameters=self.get_parameters(fit_params.config),
                num_examples=num_examples,
                metrics={"loss": train_loss, "accuracy": train_result},
            )
        elif self.task_type == "REG":
            return FitRes(
                status=Status(Code.OK, ""),
                parameters=self.get_parameters(fit_params.config),
                num_examples=num_examples,
                metrics={"loss": train_loss, "mse": train_result},
            )

    def evaluate(self, eval_params: EvaluateIns) -> EvaluateRes:
        # Process incoming request to evaluate
        self.set_parameters(eval_params.parameters)

        # Evaluate the model
        self.net.to(self.device)
        loss, result, num_examples = test(
            self.task_type,
            self.net,
            self.valloader,
            device=self.device,
            log_progress=self.log_progress,
        )
        
        # Return evaluation information
        if self.task_type == "MULTICLASS":
            print(
                f"Client {self.cid}: evaluation on {num_examples} examples: loss={loss:.4f}, accuracy={result:.4f}"
            )
            return EvaluateRes(
                status=Status(Code.OK, ""),
                loss=loss,
                num_examples=num_examples,
                metrics={"accuracy": result},
            )
        elif self.task_type == "REG":
            print(
                f"Client {self.cid}: evaluation on {num_examples} examples: loss={loss:.4f}, mse={result:.4f}"
            )
            return EvaluateRes(
                status=Status(Code.OK, ""),
                loss=loss,
                num_examples=num_examples,
                metrics={"mse": result},
            )

# Server

In [19]:
# Flower server
import functools
from flwr.server.strategy import FedXgbNnAvg
from flwr.server.app import ServerConfig

import timeit
from logging import DEBUG, INFO
from typing import Dict, List, Optional, Tuple, Union

from flwr.common import DisconnectRes, Parameters, ReconnectIns, Scalar
from flwr.common.logger import log
from flwr.common.typing import GetParametersIns
from flwr.server.client_manager import ClientManager, SimpleClientManager
from flwr.server.client_proxy import ClientProxy
from flwr.server.history import History
from flwr.server.strategy import Strategy
from flwr.server.server import (
    reconnect_clients,
    reconnect_client,
    fit_clients,
    fit_client,
    _handle_finished_future_after_fit,
    evaluate_clients,
    evaluate_client,
    _handle_finished_future_after_evaluate,
)

FitResultsAndFailures = Tuple[
    List[Tuple[ClientProxy, FitRes]],
    List[Union[Tuple[ClientProxy, FitRes], BaseException]],
]
EvaluateResultsAndFailures = Tuple[
    List[Tuple[ClientProxy, EvaluateRes]],
    List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]],
]

In [20]:
class FL_Server(fl.server.Server):
    """Flower server."""

    def __init__(
        self, *, client_manager: ClientManager, strategy: Optional[Strategy] = None
    ) -> None:
        self._client_manager: ClientManager = client_manager
        self.parameters: Parameters = Parameters(
            tensors=[], tensor_type="numpy.ndarray"
        )
        self.strategy: Strategy = strategy
        self.max_workers: Optional[int] = None

    # pylint: disable=too-many-locals
    def fit(self, num_rounds: int, timeout: Optional[float]) -> History:
        """Run federated averaging for a number of rounds."""
        history = History()

        # Initialize parameters
        log(INFO, "Initializing global parameters")
        self.parameters = self._get_initial_parameters(timeout=timeout)

        log(INFO, "Evaluating initial parameters")
        res = self.strategy.evaluate(0, parameters=self.parameters)
        if res is not None:
            log(
                INFO,
                "initial parameters (loss, other metrics): %s, %s",
                res[0],
                res[1],
            )
            history.add_loss_centralized(server_round=0, loss=res[0])
            history.add_metrics_centralized(server_round=0, metrics=res[1])

        # Run federated learning for num_rounds
        log(INFO, "FL starting")
        start_time = timeit.default_timer()

        for current_round in range(1, num_rounds + 1):
            # Train model and replace previous global model
            res_fit = self.fit_round(server_round=current_round, timeout=timeout)
            if res_fit:
                parameters_prime, _, _ = res_fit  # fit_metrics_aggregated
                if parameters_prime:
                    self.parameters = parameters_prime

            # Evaluate model using strategy implementation
            res_cen = self.strategy.evaluate(current_round, parameters=self.parameters)
            if res_cen is not None:
                loss_cen, metrics_cen = res_cen
                log(
                    INFO,
                    "fit progress: (%s, %s, %s, %s)",
                    current_round,
                    loss_cen,
                    metrics_cen,
                    timeit.default_timer() - start_time,
                )
                history.add_loss_centralized(server_round=current_round, loss=loss_cen)
                history.add_metrics_centralized(
                    server_round=current_round, metrics=metrics_cen
                )

            # Evaluate model on a sample of available clients
            res_fed = self.evaluate_round(server_round=current_round, timeout=timeout)
            if res_fed:
                loss_fed, evaluate_metrics_fed, _ = res_fed
                if loss_fed:
                    history.add_loss_distributed(
                        server_round=current_round, loss=loss_fed
                    )
                    history.add_metrics_distributed(
                        server_round=current_round, metrics=evaluate_metrics_fed
                    )

        # Bookkeeping
        end_time = timeit.default_timer()
        elapsed = end_time - start_time
        log(INFO, "FL finished in %s", elapsed)
        return history

    def evaluate_round(
        self,
        server_round: int,
        timeout: Optional[float],
    ) -> Optional[
        Tuple[Optional[float], Dict[str, Scalar], EvaluateResultsAndFailures]
    ]:
        """Validate current global model on a number of clients."""

        # Get clients and their respective instructions from strategy
        client_instructions = self.strategy.configure_evaluate(
            server_round=server_round,
            parameters=self.parameters,
            client_manager=self._client_manager,
        )
        if not client_instructions:
            log(INFO, "evaluate_round %s: no clients selected, cancel", server_round)
            return None
        log(
            DEBUG,
            "evaluate_round %s: strategy sampled %s clients (out of %s)",
            server_round,
            len(client_instructions),
            self._client_manager.num_available(),
        )

        # Collect `evaluate` results from all clients participating in this round
        results, failures = evaluate_clients(
            client_instructions,
            max_workers=self.max_workers,
            timeout=timeout,
        )
        log(
            DEBUG,
            "evaluate_round %s received %s results and %s failures",
            server_round,
            len(results),
            len(failures),
        )

        # Aggregate the evaluation results
        aggregated_result: Tuple[
            Optional[float],
            Dict[str, Scalar],
        ] = self.strategy.aggregate_evaluate(server_round, results, failures)

        loss_aggregated, metrics_aggregated = aggregated_result
        return loss_aggregated, metrics_aggregated, (results, failures)

    def fit_round(
        self,
        server_round: int,
        timeout: Optional[float],
    ) -> Optional[
        Tuple[
            Optional[
                Tuple[
                    Parameters,
                    Union[
                        Tuple[XGBClassifier, int],
                        Tuple[XGBRegressor, int],
                        List[
                            Union[Tuple[XGBClassifier, int], Tuple[XGBRegressor, int]]
                        ],
                    ],
                ]
            ],
            Dict[str, Scalar],
            FitResultsAndFailures,
        ]
    ]:
        """Perform a single round of federated averaging."""

        # Get clients and their respective instructions from strategy
        client_instructions = self.strategy.configure_fit(
            server_round=server_round,
            parameters=self.parameters,
            client_manager=self._client_manager,
        )

        if not client_instructions:
            log(INFO, "fit_round %s: no clients selected, cancel", server_round)
            return None
        log(
            DEBUG,
            "fit_round %s: strategy sampled %s clients (out of %s)",
            server_round,
            len(client_instructions),
            self._client_manager.num_available(),
        )

        # Collect `fit` results from all clients participating in this round
        results, failures = fit_clients(
            client_instructions=client_instructions,
            max_workers=self.max_workers,
            timeout=timeout,
        )

        log(
            DEBUG,
            "fit_round %s received %s results and %s failures",
            server_round,
            len(results),
            len(failures),
        )

        # Aggregate training results
        NN_aggregated: Parameters
        trees_aggregated: Union[
            Tuple[XGBClassifier, int],
            Tuple[XGBRegressor, int],
            List[Union[Tuple[XGBClassifier, int], Tuple[XGBRegressor, int]]],
        ]
        metrics_aggregated: Dict[str, Scalar]
        aggregated, metrics_aggregated = self.strategy.aggregate_fit(
            server_round, results, failures
        )
        NN_aggregated, trees_aggregated = aggregated[0], aggregated[1]

        if type(trees_aggregated) is list:
            print("Server side aggregated", len(trees_aggregated), "trees.")
        else:
            print("Server side did not aggregate trees.")

        return (
            [NN_aggregated, trees_aggregated],
            metrics_aggregated,
            (results, failures),
        )

    def _get_initial_parameters(
        self, timeout: Optional[float]
    ) -> Tuple[Parameters, Union[Tuple[XGBClassifier, int], Tuple[XGBRegressor, int]]]:
        """Get initial parameters from one of the available clients."""

        # Server-side parameter initialization
        parameters: Optional[Parameters] = self.strategy.initialize_parameters(
            client_manager=self._client_manager
        )
        if parameters is not None:
            log(INFO, "Using initial parameters provided by strategy")
            return parameters

        # Get initial parameters from one of the clients
        log(INFO, "Requesting initial parameters from one random client")
        random_client = self._client_manager.sample(1)[0]
        ins = GetParametersIns(config={})
        get_parameters_res_tree = random_client.get_parameters(ins=ins, timeout=timeout)
        parameters = [get_parameters_res_tree[0].parameters, get_parameters_res_tree[1]]
        log(INFO, "Received initial parameters from one random client")

        return parameters

In [21]:
def print_model_layers(model: nn.Module) -> None:
    print(model)
    for param_tensor in model.state_dict():
        print(param_tensor, "\t", model.state_dict()[param_tensor].size())


def serverside_eval(
    server_round: int,
    parameters: Tuple[
        Parameters,
        Union[
            Tuple[XGBClassifier, int],
            Tuple[XGBRegressor, int],
            List[Union[Tuple[XGBClassifier, int], Tuple[XGBRegressor, int]]],
        ],
    ],
    config: Dict[str, Scalar],
    task_type: str,
    testloader: DataLoader,
    batch_size: int,
    client_tree_num: int,
    client_num: int,
) -> Tuple[float, Dict[str, float]]:
    """An evaluation function for centralized/serverside evaluation over the entire test set."""
    # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    device = "cpu"
    model = CNN()
    # print_model_layers(model)

    model.set_weights(parameters_to_ndarrays(parameters[0]))
    model.to(device)
    trees_aggregated = parameters[1]
    # Chekc this testloader
    testloader = tree_encoding_loader(
        testloader, batch_size, trees_aggregated, client_tree_num, client_num
    )
    loss, result, _ = test(
        task_type, model, testloader, device=device, log_progress=False
    )

    if task_type == "MULTICLASS":
        print(
            f"Evaluation on the server: test_loss={loss:.4f}, test_accuracy={result:.4f}"
        )
        return loss, {"accuracy": result}
    elif task_type == "REG":
        print(f"Evaluation on the server: test_loss={loss:.4f}, test_mse={result:.4f}")
        return loss, {"mse": result}


def start_experiment(
    task_type: str,
    trainset: Dataset,
    testset: Dataset,
    num_rounds: int = 5,
    client_tree_num: int = 50,
    client_pool_size: int = 5,
    num_iterations: int = 100,
    fraction_fit: float = 1.0,
    min_fit_clients: int = 2,
    batch_size: int = 32,
    val_ratio: float = 0.1,
) -> History:
    client_resources = {"num_cpus": 0.5}  # 2 clients per CPU

    # Partition the dataset into subsets reserved for each client.
    # - 'val_ratio' controls the proportion of the (local) client reserved as a local test set
    # (good for testing how the final model performs on the client's local unseen data)
    
    trainloaders, valloaders, testloader = do_fl_partitioning(
        trainset,
        testset,
        batch_size="whole",
        pool_size=client_pool_size,
        val_ratio=val_ratio,
    )
    print(
        f"Data partitioned across {client_pool_size} clients"
        f" and {val_ratio} of local dataset reserved for validation."
    )

    # Configure the strategy
    def fit_config(server_round: int) -> Dict[str, Scalar]:
        print(f"Configuring round {server_round}")
        return {
            "num_iterations": num_iterations,
            "batch_size": batch_size,
        }

    # FedXgbNnAvg
    strategy = FedXgbNnAvg(
        fraction_fit=fraction_fit,
        fraction_evaluate=fraction_fit if val_ratio > 0.0 else 0.0,
        min_fit_clients=min_fit_clients,
        min_evaluate_clients=min_fit_clients,
        min_available_clients=client_pool_size,  # all clients should be available
        on_fit_config_fn=fit_config,
        on_evaluate_config_fn=(lambda r: {"batch_size": batch_size}),
        evaluate_fn=functools.partial(
            serverside_eval,
            task_type=task_type,
            testloader=testloader,
            batch_size=batch_size,
            client_tree_num=client_tree_num,
            client_num=client_num,
        ),
        accept_failures=False,
    )

    print(
        f"FL experiment configured for {num_rounds} rounds with {client_pool_size} client in the pool."
    )
    print(
        f"FL round will proceed with {fraction_fit * 100}% of clients sampled, at least {min_fit_clients}."
    )

    def client_fn(cid) -> fl.client.Client:
        """Creates a federated learning client"""
        if val_ratio is not None and 0.0 < val_ratio <= 1.0:
            print("Fl with validation")
            # print(val_ratio)
            return FL_Client(
                task_type,
                trainloaders[int(cid)],
                valloaders[int(cid)],
                client_tree_num,
                client_pool_size,
                num_classes=12,
                cid=cid,
                log_progress=False,
            )
        else:
            print("Fl with validation")
            return FL_Client(
                task_type,
                trainloaders[int(cid)],
                None,
                client_tree_num,
                client_pool_size,
                num_classes=12,
                cid=cid,
                log_progress=False,
            )

    # Start the simulation
    history = fl.simulation.start_simulation(
        client_fn=client_fn,
        server=FL_Server(client_manager=SimpleClientManager(), strategy=strategy),
        num_clients=client_pool_size,
        client_resources=client_resources,
        config=ServerConfig(num_rounds=num_rounds),
        strategy=strategy,
    )

    print(history)

    return history

# Experiment

In [22]:
start_experiment(
    task_type='MULTICLASS',
    trainset=trainset,
    testset=testset,
    num_rounds=10,
    client_tree_num=client_tree_num,
    client_pool_size=client_num,
    num_iterations=150,
    batch_size=64,
    fraction_fit=1.0,
    min_fit_clients=1,
)

Trainloaders data shape: torch.Size([351153, 92]), Label shape: torch.Size([351153])
Testloaders data shape: torch.Size([136786, 92]), Label shape: torch.Size([136786])
Data partitioned across 3 clients and 0.1 of local dataset reserved for validation.
FL experiment configured for 10 rounds with 3 client in the pool.
FL round will proceed with 100.0% of clients sampled, at least 1.


2023-08-25 11:11:29,192	INFO worker.py:1621 -- Started a local Ray instance.


[2m[36m(launch_and_get_parameters pid=3998775)[0m Fl with validation


[2m[36m(launch_and_get_parameters pid=3998775)[0m   return collate([torch.as_tensor(b) for b in batch], collate_fn_map=collate_fn_map)


[2m[36m(launch_and_get_parameters pid=3998775)[0m Data: tensor([[1.0441e+03, 5.7263e+05, 6.7386e+06,  ..., 0.0000e+00, 0.0000e+00,
[2m[36m(launch_and_get_parameters pid=3998775)[0m          0.0000e+00],
[2m[36m(launch_and_get_parameters pid=3998775)[0m         [2.3925e+03, 4.7589e+05, 6.6216e+06,  ..., 0.0000e+00, 0.0000e+00,
[2m[36m(launch_and_get_parameters pid=3998775)[0m          0.0000e+00],
[2m[36m(launch_and_get_parameters pid=3998775)[0m         [2.3008e+03, 4.7590e+05, 6.5931e+06,  ..., 0.0000e+00, 0.0000e+00,
[2m[36m(launch_and_get_parameters pid=3998775)[0m          0.0000e+00],
[2m[36m(launch_and_get_parameters pid=3998775)[0m         ...,
[2m[36m(launch_and_get_parameters pid=3998775)[0m         [2.9869e+03, 4.3778e+05, 6.7916e+06,  ..., 0.0000e+00, 0.0000e+00,
[2m[36m(launch_and_get_parameters pid=3998775)[0m          0.0000e+00],
[2m[36m(launch_and_get_parameters pid=3998775)[0m         [3.6342e+03, 4.5786e+05, 6.7772e+06,  ..., 0.0000e+00, 