In [3]:
import os
import sys
import time
import logging
import random
from pprint import pprint

import pandas as pd
import dataprofiler as dp

sys.path.insert(0, os.path.abspath( '../../../../../clients/python_client'))
import fma_connect  # noqa: E402

sys.path.insert(0, os.path.abspath('../..'))
from utils import intialize_logger, numpify_array, jsonify_array, print_results

sys.path.insert(0, os.path.abspath('../'))
from dataprofiler_utils.generation_scripts import generate_sample

In [None]:
# Global settings
pd.set_option('display.width', 100)
pd.set_option('display.max_rows', None)

try: logger
except NameError:
    logger = logging.getLogger(__name__)
    intialize_logger(logger, "logfile.log")


In [None]:
# Service contact variables
federated_model_id = 4
url = "http://127.0.0.1:8000"
uuid_storage_path = "./uuid_temp.txt"

# Model training params
num_epochs = 1


Setup for generation of a dataset

In [None]:
entity_names = ["COOKIE", "MAC_ADDRESS", "SSN", "DATETIME"]

entity_name = random.choice(entity_names)
num_of_train_entities = 1
num_of_val_entities = 1
validation_seed = 0

Initialize the model you wish to use with the service (Dataprofiler in this case)

In [None]:
model_arch = dp.DataLabeler(labeler_type='unstructured',
                            trainable=True)
model_arch.set_labels({
    "PAD": 0,
    "UNKNOWN": 1,
    "DATETIME": 2,
    "COOKIE": 3,
    "MAC_ADDRESS": 4,
    "SSN": 5,
})
model_arch.model._reconstruct_model()
# Setting post process params for human-readable format
model_arch.set_params({ 'postprocessor': { 'output_format': 'ner', 'use_word_level_argmax': True } })

Connect to service

In [None]:
try:
    with open(uuid_storage_path) as f:
        uuid = f.read()
        logger.info(f"Connecting with {uuid} as UUID")
        client = fma_connect.WebClient(federated_model_id=federated_model_id, url=url, uuid=uuid)
except FileNotFoundError:
    logger.info(f"Connecting without UUID")
    client = fma_connect.WebClient(federated_model_id=federated_model_id, url=url)
finally:
    client.register()
    uuid = client.uuid
    with open(uuid_storage_path, "w+") as f:
        f.write(uuid)
    logger.info(f"{uuid} stored as UUID")

Pull weights from service

In [None]:
model_agg = client.check_for_new_model_aggregate(update_after=0)
if not model_agg:
    init_weights = client.get_current_artifact()['values']
else:
    init_weights = model_agg['result']

numpify_array(init_weights)
model_arch.model._model.set_weights(init_weights)

Generate validation data

In [None]:
val_dataset_split = [
    generate_sample(entity, num_of_val_entities, validation_seed)
    for entity in entity_names
]
val_dataset = {
    k: [
        item for dic in val_dataset_split
        for item in dic[k]]
    for k in val_dataset_split[0]
}
pprint(val_dataset)

Model training loop

In [None]:
# Train
training_idx = 0
while True:
    logger.info(f"Training loop: {training_idx}")

    logger.info("Generating training dataset")
    train_dataset = generate_sample(entity_name, num_of_train_entities)

    logger.info("Training initiated")
    model_arch.fit(x=train_dataset['text'], y=train_dataset['entities'], epochs=num_epochs)
    logger.info("Training complete")


    # Run validation with trained weights
    logger.info("Eval after local training started...")
    predictions = model_arch.predict(val_dataset["text"])
    logger.info("Eval after local training complete!")
    print_results(val_dataset['text'], predictions, logger)

    # Data preparation for sending to service API
    weights = model_arch.model._model.get_weights()
    jsonify_array(weights)

    # Send data to service
    client.send_update(weights)
    logger.info("Weights updates sent")

    # Check for model weights updates from service
    logger.info("Checking for service model update")
    model_agg = client.check_for_new_model_aggregate()
    while not model_agg:
        logger.info("Received None response...")
        time.sleep(5)
        model_agg = client.check_for_new_model_aggregate()
    logger.info("Response received.")

    # Convert service response to loadable weights
    numpify_array(model_agg['result'])
    model_arch.model._model.set_weights(model_agg['result'])
    logger.info("New model weights set")

    # Run validation with new weights
    logger.info("Aggregated model weights eval started...")
    predictions = model_arch.predict(val_dataset["text"])
    logger.info("Aggregated model weights eval complete!")
    print_results(val_dataset['text'], predictions, logger)

    training_idx += 1