In [1]:
from __future__ import absolute_import, print_function
import pandas as pd
import boto3
from botocore.client import Config
from flytekit.common import utils as _utils
from flytekit.sdk import test_utils as _test_utils
from flytekit.configuration import set_flyte_config_file
from flytekit.common.tasks.presto_task import SdkPrestoTask
from flytekit.sdk.tasks import inputs
from flytekit.sdk.types import Types
from flytekit.configuration import platform
from os import environ

environ["version"] = "25"
environ["spec_version"] = "25-1"
environ["FLYTE_INTERNAL_IMAGE"] = "docker.io/lyft/flytesnacks:datacouncil-{}".format(environ["version"])
set_flyte_config_file("staging.config")
environ["FLYTE_INTERNAL_CONFIGURATION_PATH"] = "/root/staging.config"

def print_console_url(exc):
    print("http://{}/console/projects/{}/domains/{}/executions/{}".format(platform.URL.get(), exc.id.project, exc.id.domain, exc.id.name)) 

s3 = boto3.resource('s3',
                    endpoint_url='http://localhost:30084',
                    aws_access_key_id='minio',
                    aws_secret_access_key='miniostorage',
                    config=Config(signature_version='s3v4'),
                    region_name='us-east-1')

def upload_file(f, ref):
    mod = ref.lstrip("s3://")
    bucket, path = mod.split("/", 1)
    s3.Bucket(bucket).upload_file('image.py',path)
    
from IPython.display import Image, display
def display_images(paths):
    for p in paths:
        display(Image(p))

def print_schema(schema):
    with _test_utils.LocalTestFileSystem() as sandbox:
        # load schema data
        schema.download()
        df = pd.read_parquet(schema.local_path)
    print(df) 

# Sample ML Model Pipeline

Let's start developing a more serious model. This involves pulling data using presto, transforming parquet to CSV and finally training an XGBoost Model on SageMaker.

### 1) Query the data
#### a. Get Training Data

In [2]:
get_train_data2 = SdkPrestoTask(
    task_inputs=inputs(),
    statement="""
    SELECT * 
    FROM hive.flyte.datacouncildemo_train
    """,
    output_schema=Types.Schema(),
    discoverable=True,
    discovery_version="3",
)


get_train_data2.register(project="flytesnacks", domain="development", name="get_train_data", version=environ["version"])

task_exec = get_train_data2.register_and_launch(project="flytekit", domain="development", inputs={"ds": '2020-07-05'})
print("Created execution.")
print_console_url(task_exec)
print("Waiting for execution to complete...")
task_exec.wait_for_completion()
print("Done!")

print_schema(task_exec.outputs['results'])

127.0.0.1 - - [13/Jul/2020 10:58:55] "GET /callback?code=gE0h7dQVrtW8ACEiQHW4&state=05PoGI4gZ51L9shCDflg75hG_iVtRjxABaH0n82-O8qOAMVtg-OD8g HTTP/1.1" 200 -


Created execution.
http://flyte-staging.lyft.net/console/projects/flytekit/domains/development/executions/mggnq77a7x
Waiting for execution to complete...
Done!
    col0 col1 col2 col3 col4  col5   col6 col7 col8
0      1   85   66   29    0  26.6  0.351   31    0
1      8  183   64    0    0  23.3  0.672   32    1
2      1   89   66   23   94  28.1  0.167   21    0
3      0  137   40   35  168  43.1  2.288   33    1
4      5  116   74    0    0  25.6  0.201   30    0
..   ...  ...  ...  ...  ...   ...    ...  ...  ...
762   10  101   76   48  180  32.9  0.171   63    0
763    2  122   70   27    0  36.8  0.340   27    0
764    5  121   72   23  112  26.2  0.245   30    0
765    1  126   60    0    0  30.1  0.349   47    1
766    1   93   70   31    0  30.4  0.315   23    0

[767 rows x 9 columns]


#### b. Get Validation Data

In [3]:
get_validation_data = SdkPrestoTask(
    task_inputs=inputs(),
    statement="""
    SELECT * 
    FROM hive.flyte.datacouncildemo_validation
    """,
    output_schema=Types.Schema(),
    discoverable=True,
    discovery_version="2",
)

# No need to run this. it's just a copy of the training data task.
# In a real scenario, we will probably query a huge data set then apply some common algorithm to split the datasets 
#  (e.g. 20-80)
get_validation_data.register(project="flytesnacks", domain="development", name="get_validation_data", version=environ["version"])

'tsk:flytesnacks:development:get_validation_data:24'

#### c. Transform Parquet to CSV (SageMaker requires that)
Somebody has already writen a common python task that transforms parquet to csv. Let's just import that task and use it.

In [4]:
from flytekit.common.tasks.task import SdkTask
transform_parquet_to_csv = SdkTask.fetch(project="flytesnacks", domain="development", name="transform_parquet_to_csv", version="24")

#### d. Let's write the Training Step!

In [5]:
from flytekit.sdk.tasks import inputs
from flytekit.sdk.types import Types
from flytekit.sdk.workflow import workflow_class, Input, Output
from flytekit.common.tasks.sagemaker import training_job_task, hpo_job_task
from flytekit.models.sagemaker import training_job as training_job_models, hpo_job as hpo_job_models
from flytekit.sdk.sagemaker import types as _sdk_sagemaker_types
xgboost_hyperparameters = {
    "base_score": "0.5",
    "booster": "gbtree",
    "csv_weights": "0",
    "dsplit": "row",
    "grow_policy": "depthwise",
    "lambda_bias": "0.0",
    "max_bin": "256",
    "max_leaves": "0",
    "normalize_type": "tree",
    "objective": "reg:linear",
    "one_drop": "0",
    "prob_buffer_row": "1.0",
    "process_type": "default",
    "rate_drop": "0.0",
    "refresh_leaf": "1",
    "sample_type": "uniform",
    "scale_pos_weight": "1.0",
    "silent": "0",
    "skip_drop": "0.0",
    "tree_method": "auto",
    "tweedie_variance_power": "1.5",
    "updater": "grow_colmaker,prune",
}

alg_spec = training_job_models.AlgorithmSpecification(
    input_mode=_sdk_sagemaker_types.InputMode.FILE,
    algorithm_name=_sdk_sagemaker_types.AlgorithmName.XGBOOST,
    algorithm_version="0.72",
    metric_definitions=[training_job_models.MetricDefinition(name="Minimize", regex="validation:error")]
)

xgboost_train_task2 = training_job_task.SdkSimpleTrainingJobTask(
    training_job_config=training_job_models.TrainingJobConfig(
        instance_type="ml.m4.xlarge",
        instance_count=1,
        volume_size_in_gb=25,
    ),
    algorithm_specification=alg_spec,
    cache_version='2',
    cacheable=True,
)

xgboost_hpo_task2 = hpo_job_task.SdkSimpleHPOJobTask(
    training_job=xgboost_train_task2,
    max_number_of_training_jobs=10,
    max_parallel_training_jobs=5,
    cache_version='2',
    retries=2,
    cacheable=True,
)

xgboost_hpo_task2.register(project="flytesnacks", domain="development", name="xgboost_hpo_task2", version=environ["version"])

'tsk:flytesnacks:development:xgboost_hpo_task2:24'

In [11]:
from flytekit.sdk.workflow import workflow_class, Input, Output
from flytekit.models.sagemaker.training_job import StoppingCondition
from flytekit.models.sagemaker.hpo_job import HPOJobConfig, HyperparameterTuningObjective
from flytekit.models.sagemaker.parameter_ranges import ParameterRanges, CategoricalParameterRange, ContinuousParameterRange, IntegerParameterRange

@workflow_class()
class TrainingWorkflow(object):    
    # retrieve data
    train_data = get_train_data2()
    validation_data = get_validation_data()
    
    # transform data
    train_csv = transform_parquet_to_csv(input_parquet=train_data.outputs.results)
    validation_csv = transform_parquet_to_csv(input_parquet=validation_data.outputs.results)
    
    # train with HPO
    train = xgboost_hpo_task2(train=train_csv.outputs.output_csv,
                             validation=validation_csv.outputs.output_csv,
                             static_hyperparameters=xgboost_hyperparameters,
                             stopping_condition=StoppingCondition(
                                max_runtime_in_seconds=43200,
                             ).to_flyte_idl(),
                             hpo_job_config=HPOJobConfig(
                                hyperparameter_ranges=ParameterRanges(
                                    parameter_range_map={
                                        "num_round": IntegerParameterRange(min_value=1, max_value=100, scaling_type=_sdk_sagemaker_types.HyperparameterScalingType.LOGARITHMIC),
                                    }
                                ),
                                tuning_strategy=_sdk_sagemaker_types.HyperparameterTuningStrategy.BAYESIAN,
                                tuning_objective=HyperparameterTuningObjective(
                                    objective_type=_sdk_sagemaker_types.HyperparameterTuningObjectiveType.MINIMIZE,
                                    metric_name="validation:error",
                                ),
                                training_job_early_stopping_type=_sdk_sagemaker_types.TrainingJobEarlyStoppingType.AUTO
                            ).to_flyte_idl())
    
    model = Output(train.outputs.model, sdk_type=Types.Blob)
    
TrainingWorkflow.register(project="flytesnacks", domain="development", name="TrainingWorkflow", version=environ["spec_version"])
TrainingWorkflow_lp = TrainingWorkflow.create_launch_plan()
TrainingWorkflow_lp.register(project="flytesnacks", domain="development", name="TrainingWorkflow", version=environ["spec_version"])

'lp:flytesnacks:development:TrainingWorkflow:24-1'

In [12]:
exec = TrainingWorkflow_lp.launch(project="flytesnacks", domain="development", inputs={})
print_console_url(exec)

http://flyte-staging.lyft.net/console/projects/flytesnacks/domains/development/executions/f7807b50d9ced48a9a15


In [19]:
print("Waiting for execution to complete...")
exec.wait_for_completion()
print("Done!")
print()
print("Generated Model: {}".format(exec.outputs["model"].uri))

Waiting for execution to complete...
Done!

Generated Model: s3://lyft-modelbuilder/metadata/propeller/staging/flytesnacks-development-f593fba5e141d4fea829/train/data/0/hpo_outputs/f593fba5e141d4fea829-008-f2657b81/output/model.tar.gz
