# Federated ID3 using FLEX library. 


In this notebook we show how to use the *Federated ID3* model, from the [paper](https://arxiv.org/pdf/2007.10987.pdf).

First we do all the imports needed.

In [1]:
import numpy as np

from flex.data import FedDataDistribution, FedDatasetConfig
from flex.pool import FlexPool

from flextrees.datasets.tabular_datasets import nursery
from flextrees.pool.primitives_fedid3 import (
    init_server_model_id3, 
    deploy_server_config_id3,
    deploy_server_model_id3,
    build_id3,
    set_aggregated_id3,
    evaluate_id3_model,
    evaluate_global_model_clients,
)

## Loading the data using FLEX.

In this tutorial we are going to use the **nursery** database. We can use it by importing the dataset using the flextrees library. In this model the server needs to know the unique values from all the features of the dataset that is been used, so after loading it, we have to get them.

In [2]:
train_data, test_data, features_names = nursery(ret_feature_names=True, categorical=True)
unique_values = []
for i, val in enumerate(features_names[:-1]):
    unique_values_feature = list(set(train_data.X_data.to_numpy()[:,i]))
    unique_values.append(unique_values_feature)
n_clients = 2

## Federating the data using FLEX

Once the data is loaded, we have to federate it. To do so we use the FLEX library. We show to ways of federating the data, using a iid distribution or a non-idd distribution. For the IID distribution we can just use the the `ìid_distribution` function from FedDataDistribution. If we are using a non-iid distribution, we have to use a custom configuration and, in this case, we just set the seed, the number of clients, and we can set manually the weights by creating them randomly or whatever the user wants. For more information, go to the FLEX library notebooks, and take a look at the notebook *Federating data with FLEXible*.

In [3]:
dist = 'iid'

if dist == 'iid':
    federated_data = FedDataDistribution.iid_distribution(centralized_data=train_data,
                                                        n_nodes=n_clients)
else:
    weights = np.random.dirichlet(np.repeat(1, n_clients), 1)[0] # To generate random weights (Full Non-IID)
    config_nidd = FedDatasetConfig(seed=0, n_nodes=n_clients, 
                                replacement=False, weights=weights)

    federated_data = FedDataDistribution.from_config(centralized_data=train_data,
                                                        config=config_nidd)

## Creating the federated architecture

When creating the federated architecture, we use `FlexPool`. As we're running a client-server architecture, we use the function `client_server_architecture`. We need to give to this function the dimension of the dataset for creating the LSH functions in order of creating the planes to hash all the data from the clients.

In [4]:
pool = FlexPool.client_server_pool(federated_data, init_server_model_id3,
                                        dataset_features = features_names)

clients = pool.clients
aggregator = pool.aggregators
server = pool.servers

Lastly, we set the configuration for all the clients for training the model.

In [5]:
# Deploy clients config
pool.servers.map(func=deploy_server_config_id3, dst_pool=pool.clients)
root_ = None
value_features = {
    feature: unique_values[i]
    for i, feature in enumerate(features_names[:-1])
}

## Training the model

As the model is built recursively, we've built a primitive function, `build_id3` that builds the tree. This function only needs to recieve the initialized root as *None*, the maximum depth of the problem, that is defined to *n_features/2* in the paper, and the rest of parameters needed to build the tree. Note that we also need to give the *pool* to the function, so we can build the tree in a federated way.

In [6]:
# Build the ID3 tree
root_ = build_id3(
    node=root_,
    depth=1,
    available_features=features_names[:-1],
    pool=pool,
    max_depth=len(features_names) // 2,
    values_features=value_features,
)

### Deploying the model

Deploy the model across the clients so they can use it within its local data.

In [7]:
pool.aggregators._models['server']['aggregated_weights'] = root_
pool.aggregators.map(set_aggregated_id3, pool.servers)
pool.servers.map(deploy_server_model_id3, pool.clients)

Aggregated weights: <flextrees.utils.utils_trees.Node object at 0x7f5a5440cdf0>


## Evaluating the model

Evaluate the model at client's side.

In [8]:
import os
from datetime import datetime

path_to_results_folder = os.path.abspath('../')
filename = f"/resultados_{nursery.__name__}_clients_{n_clients}_{dist}_exec_{exec}_{datetime.now().strftime('%Y_%m_%d-%I_%M_%S_%p')}.csv"
filename = f"{path_to_results_folder}/results_fedid3/{filename}"
pool.clients.map(evaluate_global_model_clients, filename=filename)

Results on test data at client level.
Accuracy: 0.8890817901234568
Macro F1: 0.5399280046144798
Classificarion report: 
               precision    recall  f1-score   support

           0       1.00      1.00      1.00      1708
           1       0.85      0.80      0.83      1703
           2       0.00      0.00      0.00         1
           3       0.82      0.94      0.87      1642
           4       0.00      0.00      0.00       130

    accuracy                           0.89      5184
   macro avg       0.53      0.55      0.54      5184
weighted avg       0.87      0.89      0.88      5184

Results on test data at client level.
Accuracy: 0.8939043209876543
Macro F1: 0.5427441555312512
Classificarion report: 
               precision    recall  f1-score   support

           0       1.00      1.00      1.00      1740
           1       0.86      0.81      0.83      1694
           2       0.00      0.00      0.00         1
           3       0.83      0.94      0.88      161

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Evaluate the model with a global test data

In [None]:
pool.servers.map(evaluate_id3_model, test_data=test_data, filename=filename, etime=0)

              precision    recall  f1-score   support

           0       1.00      1.00      1.00       872
           1       0.86      0.82      0.84       869
           3       0.82      0.93      0.87       789
           4       0.00      0.00      0.00        62

    accuracy                           0.89      2592
   macro avg       0.67      0.69      0.68      2592
weighted avg       0.88      0.89      0.88      2592

Accuracy: 0.8946759259259259
F1-Macro: 0.6783916731079667


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


# End of Notebook