# Purpose
The purpose of this notebook is to showcase the usage of the Federated Model Aggregation (FMA) Service's ready-made client.
This is achieved by the usage of the fma_connect library's `webClient` class.

The webclient class is composed of methods to abstract and make interation with the service consistent across different implementations.

In this tutorial you will learn:
* The basic usage of the webclient for the FMA service
* The process of federated learning outlined by the FMA
* Pulling the latest model weights from a specific experiment
* Training a model on the client
* Send updates to the FMA service with correctly formatted API calls
* Pulling a model aggregate after it has been created for an experiment and repeating the training process

The FMA service's API can be used with any model and any client so long as they are registered and can access the service with an API call.
This particular tutorial uses Capital One's very own DataProfiler's DataLabeler model.

# Overview of FMA Client Components
The service assumes there are clients involved in your experiments that it does not initialize on its own.
The client needs certain components to run model training and send model updates to the service.
These components include:
* Client Initialization
  * Connect to Server
  * Client Validation Dataset Generation
* Model Initialization
  * Setup Architecture
  * Pull Model Weights
* Model Training
  * Training loop

# Imports

In [None]:
import os
import sys
import time
import sklearn
import dataprofiler as dp

sys.path.insert(0, os.path.abspath( '../../../../../clients/python_client'))
import fma_connect  # noqa: E402

sys.path.insert(0, os.path.abspath('../..'))
from utils import intialize_logger, numpify_array, jsonify_array, print_results, color_text

sys.path.insert(0, os.path.abspath('../'))
from dataprofiler_utils.generation_scripts import generate_sample

# Client Initialization
Before we can do any federated learning, we must first initialize our client communication by registering with the FMA service for an experiment.
This will allow us to push and pull model data to the FMA service itself.

## Connect to server
The three variables needed for registration are:
* `uuid_storage_path`: Unique ID for client
  * Note this can be a pre-existing unique ID or an ID generated by the service
* `federated_model_id`: The Unique ID of the model experiment in which the client will participate
* `url`: The base URL for the service

In [None]:
uuid_storage_path = "./uuid_temp_notebook.txt"
federated_model_id = 2
url = "http://localhost:8000/"

In [None]:
# Try registering with a pre-existing ID stored in a file
try:
    with open(uuid_storage_path) as f:
        uuid = f.read()
        print(f"Connecting with {uuid} as UUID")
        client = fma_connect.WebClient(
            federated_model_id=federated_model_id, url=url, uuid=uuid
        )
except FileNotFoundError:
    print("Connecting without UUID")
    client = fma_connect.WebClient(federated_model_id=federated_model_id, url=url)
finally:
    client.register()
    uuid = client.uuid
    with open(uuid_storage_path, "w+") as f:
        f.write(uuid)
    print(f"{uuid} stored as UUID")



## Client Validation Dataset Generation
Now that we have our ID. We need to prep our client device to train the model.
To do this we create a validation set that will be standard across all clients (we do this by passing a `validation_seed`).
This dataset mimics the centralized "golden" validation set that would be a static collection of data on the server side in a federated learning use case.

For this example we are going to use the following labels:
* COOKIE
* MAC_ADDRESS
* SSN
* DATETIME

In [None]:
val_entity_names = ["COOKIE", "MAC_ADDRESS", "SSN", "DATETIME"]
num_of_val_entities = 10
validation_seed = 0

In [None]:
# Create validation dataset
val_dataset_split = [
    generate_sample(entity, num_of_val_entities, validation_seed)
    for entity in val_entity_names
]

val_dataset = {
    k: [item for dic in val_dataset_split for item in dic[k]]
    for k in val_dataset_split[0]
}

# Get length of dataset for metrics generation
length_of_dataset = 0
for text in val_dataset["text"]:
    length_of_dataset += len(text)

def create_label_char_array_from_entity_array(entity_array, _entities_dict, data_raw):
    """
    Converts list with ranges of entities to a character level label list.

    :param entity_array: List that holds entity ranges
    :type entity_array: EntityArrayType
    :param _entities_dict: Dict of entities and their indices
    :type _entities_dict: EntitiesDictType
    :param data_raw: the raw data used for creation of entity_array
    :type data_raw: List[str]

    :return: An character level label array (label per character)
    :rtype: List[int]
    """
    char_label_array = []
    len_of_text = sum([len(text) for text in data_raw])
    for text_section_index in range(len(entity_array)):
        index = 0
        for entity in entity_array[text_section_index]:
            while index < entity[0]:
                char_label_array.append(1)
                index += 1
            while index < entity[1]:
                char_label_array.append(_entities_dict[str(entity[2])])
                index += 1
            while index < len(data_raw[text_section_index]):
                char_label_array.append(1)
                index += 1
    # Final length check
    while len(char_label_array) < len_of_text:
        char_label_array.append(1)
    return char_label_array

# Model Initialization

## Setup Model Architecture
Now that we have the validation set created, we need to set up our client's model architecture.
* **Pull the model**: Pull the DataLabeler model from the DataProfiler's library
* **Set labels**: Set the labels we want to use for the model
* **Set params**: Set the necessary params to ready the architecture for training

In [None]:
model_arch = dp.DataLabeler(labeler_type="unstructured", trainable=True)
model_arch.set_labels(
    {
        "PAD": 0,
        "UNKNOWN": 1,
        "DATETIME": 2,
        "COOKIE": 3,
        "MAC_ADDRESS": 4,
        "SSN": 5,
    }
)
model_arch.model._reconstruct_model()

# Setting post process params for human-readable format
model_arch.set_params(
    {"postprocessor": {"output_format": "ner", "use_word_level_argmax": True}}
)

# We create a version of the val data that will allow us to present some better quality,
# human-readable results by highlighting the classifications the model makes.
val_dataset_array = create_label_char_array_from_entity_array(val_dataset["entities"], model_arch.label_mapping, val_dataset["text"])

## Pull Model Weights
With the client's setup out of the way, we are finally ready to start our training process.

* **Pull the latest model weights**:  We will pull model weights for our experiment using the `check_for_latest_model` function. This function pulls the latest weights (either latest model aggregate or artifact depending on which is the most recent).
* **Set weights**: Set newly obtained weights to our model architecture

In [None]:
model_obj = client.check_for_latest_model()
numpify_array(model_obj["values"])
new_weights, base_aggregate = model_obj["values"], model_obj["aggregate"]
model_arch.model._model.set_weights(new_weights)

# Model Training

## Training Loop
Finally, we are able to run our training loop.
1. **Set training parameters**:
    * Choose a singular entity to train on (for this example we are using `SSN`)
    * Decide how many epochs we want to run before pushing the model to the service (1)
    * Decide the number of federated training iterations we wish to participate in (10)
2. **Gather training data**: Generate train data for each iteration with new randomly generated text injected with 200 SSN examples
3. **Run training of Model**:
    * Run training with the model weights provided by the service
    * Send model updates to the service
    * Pull model aggregates generated from service (aggregate of all involved clients)
    * Evaluate newly aggregated model on the validation dataset
    * Repeat 2 and 3

Because the service is taking all the updates provided by the clients and aggregating their weights, the resultant weights are more generalized to all the data involved in the experiment rather than the data on which the client is currently trained.

In [None]:
train_entity_names = "SSN"
num_epochs = 1
train_iterations = 10

In [None]:
for _ in range(train_iterations): 
    # Client data collection
    train_dataset = generate_sample(train_entity_names, 200)

    # Client Train
    model_arch.fit(x=train_dataset["text"], y=train_dataset["entities"], epochs=num_epochs)

    # Inference for results
    predictions = model_arch.predict(val_dataset["text"])
    color_text(val_dataset["text"], predictions, model_arch.label_mapping)

    # Send weights to service for aggregation
    weights = model_arch.model._model.get_weights()
    jsonify_array(weights)

    client.send_update(client, weights, base_aggregate)

    # Get new weights from service after aggregation
    model_obj = client.check_for_latest_model()
    while not model_obj:
        print("Received None response...")
        time.sleep(5)
        model_obj = client.check_for_latest_model()
    print("Received valid response...")

    # Load aggregated weights to model
    numpify_array(model_obj["values"])
    weights_updates = model_obj["values"]
    base_aggregate = model_obj["aggregate"]
    model_arch.model._model.set_weights(weights_updates)


    # Run validation of new weights on client collected data
    predictions = model_arch.predict(val_dataset["text"])
    pred_array = create_label_char_array_from_entity_array(
        predictions["pred"], model_arch.label_mapping, data_raw
    )
    val_results = {
        "f1_score": sklearn.metrics.f1_score(
            val_dataset_array, pred_array, average=None
        ),
        "description": sklearn.metrics.classification_report(
            val_dataset_array, pred_array
        ),
    }
    color_text(val_dataset["text"], predictions, model_arch.label_mapping)

    # Send validation results to service 
    resp = client.send_val_results(val_results, base_aggregate)
    print(f"Validation results sent! {resp}")


# Conclusion
This was a very basic example of federated learning with the FMA service, but the concept of making federated learning easily accessible to any training pipeline is the base idea for the FMA service as a whole.The idea is to allow for all the pieces required for a federated learning process to be very simple additions to a pre-existing pipeline without subject matter expertise in federated learning.