In [1]:
from flytekit.configuration import set_flyte_config_file, platform
set_flyte_config_file("notebook-staging.config")

print("Connected to {}".format(platform.URL.get()))

def print_console_url(exc):
    print("http://{}/console/projects/{}/domains/{}/executions/{}".format(platform.URL.get(), exc.id.project, exc.id.domain, exc.id.name))

Connected to localhost:30081


In [2]:
from flytekit.sdk.tasks import inputs
from flytekit.sdk.types import Types
from flytekit.sdk.workflow import workflow_class, Input, Output
from flytekit.common.tasks.sagemaker import training_job_task, hpo_job_task
from flytekit.models.sagemaker import training_job as training_job_models, hpo_job as hpo_job_models
from flytekit.sdk.sagemaker import types as _sdk_sagemaker_types
xgboost_hyperparameters = {
    "base_score": "0.5",
    "booster": "gbtree",
    "csv_weights": "0",
    "dsplit": "row",
    "grow_policy": "depthwise",
    "lambda_bias": "0.0",
    "max_bin": "256",
    "max_leaves": "0",
    "normalize_type": "tree",
    "objective": "reg:linear",
    "one_drop": "0",
    "prob_buffer_row": "1.0",
    "process_type": "default",
    "rate_drop": "0.0",
    "refresh_leaf": "1",
    "sample_type": "uniform",
    "scale_pos_weight": "1.0",
    "silent": "0",
    "sketch_eps": "0.03",
    "skip_drop": "0.0",
    "tree_method": "auto",
    "tweedie_variance_power": "1.5",
    "updater": "grow_colmaker,prune",
}
alg_spec = training_job_models.AlgorithmSpecification(
    input_mode=_sdk_sagemaker_types.InputMode.FILE,
    algorithm_name=_sdk_sagemaker_types.AlgorithmName.XGBOOST,
    algorithm_version="0.72",
    metric_definitions=[training_job_models.MetricDefinition(name="Minimize", regex="validation:error")]
)
simple_xgboost_trainingjob_task = training_job_task.SdkSimpleTrainingJobTask(
    training_job_config=training_job_models.TrainingJobConfig(
        instance_type="ml.m4.xlarge",
        instance_count=1,
        volume_size_in_gb=25,
    ),
    algorithm_specification=alg_spec,
    cache_version='1',
    cacheable=True,
)
simple_xgboost_hpo_job_task = hpo_job_task.SdkSimpleHPOJobTask(
    training_job=simple_xgboost_trainingjob_task,
    max_number_of_training_jobs=10,
    max_parallel_training_jobs=5,
    cache_version='1',
    retries=2,
    cacheable=True,
)

In [3]:
from flytekit.models.sagemaker.training_job import StoppingCondition
from flytekit.models.sagemaker.hpo_job import HPOJobConfig, HyperparameterTuningObjective
from flytekit.models.sagemaker.parameter_ranges import ParameterRanges, CategoricalParameterRange, ContinuousParameterRange, IntegerParameterRange
inputs={
    "train": "https://raw.githubusercontent.com/jbrownlee/Datasets/master/pima-indians-diabetes.data.csv",
    "validation": "https://raw.githubusercontent.com/jbrownlee/Datasets/master/pima-indians-diabetes.data.csv",
    "static_hyperparameters": xgboost_hyperparameters,
    "stopping_condition": StoppingCondition(
        max_runtime_in_seconds=43200,
        max_wait_time_in_seconds=43200,
    ).to_flyte_idl(),
    "hpo_job_config": HPOJobConfig(
        hyperparameter_ranges=ParameterRanges(
            parameter_range_map={
                "num_rounds": IntegerParameterRange(min_value=1, max_value=100, scaling_type=_sdk_sagemaker_types.HyperparameterScalingType.LOGARITHMIC),
                "max_leaves": IntegerParameterRange(min_value=0, max_value=5, scaling_type=_sdk_sagemaker_types.HyperparameterScalingType.LINEAR),
                "sketch_eps": ContinuousParameterRange(min_value=0.01, max_value=0.05, scaling_type=_sdk_sagemaker_types.HyperparameterScalingType.AUTO),
            }
        ),
        tuning_strategy=_sdk_sagemaker_types.HyperparameterTuningStrategy.BAYESIAN,
        tuning_objective=HyperparameterTuningObjective(
            objective_type=_sdk_sagemaker_types.HyperparameterTuningObjectiveType.MINIMIZE,
            metric_name="validation:error",
        ),
        training_job_early_stopping_type=_sdk_sagemaker_types.TrainingJobEarlyStoppingType.AUTO
    ).to_flyte_idl(),
}

exc = simple_xgboost_hpo_job_task.register_and_launch("flyteexamples", "development", inputs=inputs)
print_console_url(exc)

http://localhost:30081/console/projects/flyteexamples/domains/development/executions/fr6nyvb19q
