# An end-to-end Vertex HPO Pipeline
Some highlights for this work VS the vanilla HPO solutions

1.https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/cad623ef84882f410fcc0dc39527be25a5e5f584/notebooks/community/ml_ops/stage3/get_started_with_hpt_pipeline_components.ipynb

2.https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/hyperparameter_tuning/distributed-hyperparameter-tuning.ipynb

1. Supports pipeline input data. We added a worker pool spec generator to receive the output of preprocess and consume as an argument of the HPO job.
2. Supports multiple sub groups at the same time i.e. multiple warehouses.
3. Supports HPO params logged in firestore. 
4. Supports training kickoff after HPO.

Validated via the sentiment pipeline.
Please be noted that some params flowing in the pipeline are not practically consumed in each step, instead, we aim to validate the chain of individual components.
One example is 'warehouse', it's not used to slice any data as it's supposed to do as we are using sentiment analysis data as the example. However, we do validate (print/log) if that param is correctly passed to the list of arguments of relevant steps.

In [1]:
#!pip install kfp==1.8.11

Collecting kfp==1.8.11
  Downloading kfp-1.8.11.tar.gz (298 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m298.6/298.6 kB[0m [31m20.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25ldone
Collecting PyYAML<6,>=5.3
  Downloading PyYAML-5.4.1-cp37-cp37m-manylinux1_x86_64.whl (636 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m636.6/636.6 kB[0m [31m47.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting google-cloud-storage<2,>=1.20.0
  Downloading google_cloud_storage-1.44.0-py2.py3-none-any.whl (106 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m106.8/106.8 kB[0m [31m18.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting kubernetes<19,>=8.0.0
  Downloading kubernetes-18.20.0-py2.py3-none-any.whl (1.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m65.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting google-api-python-client<2,>=1.7.8
  Downloading google_api_python

In [2]:
#!python3 -c "import kfp; print('KFP SDK version: {}'.format(kfp.__version__))"

KFP SDK version: 1.8.11


## Prerequisite

In [31]:
import os
import json
from functools import partial

import kfp
import pprint
import yaml
from jinja2 import Template
from kfp.v2 import dsl
from kfp.v2.compiler import compiler
from kfp.v2.dsl import Dataset
from kfp.v2.google.client import AIPlatformClient

from google.cloud import aiplatform, firestore
from datetime import datetime

In [2]:
project_id='petcircle-science-playground'
project_number='734227425472'

In [3]:
af_registry_location='australia-southeast1'
af_registry_name='mlops-vertex-kit'

In [4]:
components_dir='../components/'

In [5]:
def _load_custom_component(project_id: str,
                           af_registry_location: str,
                           af_registry_name: str,
                           components_dir: str,
                           component_name: str):
    component_path = os.path.join(components_dir,
                                component_name,
                                'component.yaml.jinja')
    with open(component_path, 'r') as f:
        component_text = Template(f.read()).render(
          project_id=project_id,
          af_registry_location=af_registry_location,
          af_registry_name=af_registry_name)

    return kfp.components.load_component_from_text(component_text)

load_custom_component = partial(_load_custom_component,
                                project_id=project_id,
                                af_registry_location=af_registry_location,
                                af_registry_name=af_registry_name,
                                components_dir=components_dir)

In [6]:
preprocess_op = load_custom_component(component_name='data_preprocess')
train_op = load_custom_component(component_name='train_model')
check_metrics_op = load_custom_component(component_name='check_model_metrics')
create_endpoint_op = load_custom_component(component_name='create_endpoint')
test_endpoint_op = load_custom_component(component_name='test_endpoint')
deploy_model_op = load_custom_component(component_name='deploy_model')
monitor_model_op = load_custom_component(component_name='monitor_model')

Then define the pipeline using the following function:

In [7]:
pipeline_region='australia-southeast1'
pipeline_root='gs://vertex_pipeline_demo_root_hy_syd/pipeline_root'

In [8]:
data_region='australia-southeast1'
#input_dataset_uri='bq://petcircle-science-playground.vertex_pipeline_demo.banknote_authentication'
input_dataset_uri='bq://petcircle-science-playground.datalake.review_product_2013_2022'
gcs_data_output_folder='gs://vertex_pipeline_demo_root_hy_syd/datasets/training'
training_data_schema='reviewtext:string;Class:int'

data_pipeline_root='gs://vertex_pipeline_demo_root_hy_syd/compute_root'

In [9]:
training_container_image_uri=f'{af_registry_location}-docker.pkg.dev/{project_id}/{af_registry_name}/training:latest'
serving_container_image_uri=f'{af_registry_location}-docker.pkg.dev/{project_id}/{af_registry_name}/batch_prediction:latest'
hpo_container_image_uri=f'{af_registry_location}-docker.pkg.dev/{project_id}/{af_registry_name}/hpo:latest'
custom_job_service_account=f'{project_number}-compute@developer.gserviceaccount.com'

In [15]:
training_container_image_uri,serving_container_image_uri,custom_job_service_account, hpo_container_image_uri

('australia-southeast1-docker.pkg.dev/petcircle-science-playground/mlops-vertex-kit/training:latest',
 'australia-southeast1-docker.pkg.dev/petcircle-science-playground/mlops-vertex-kit/batch_prediction:latest',
 '734227425472-compute@developer.gserviceaccount.com',
 'australia-southeast1-docker.pkg.dev/petcircle-science-playground/mlops-vertex-kit/hpo:latest')

In [16]:
from google_cloud_pipeline_components.experimental import hyperparameter_tuning_job
from google_cloud_pipeline_components.v1.hyperparameter_tuning_job import HyperparameterTuningJobRunOp
from google_cloud_pipeline_components.v1.model import ModelUploadOp
from google_cloud_pipeline_components.types import artifact_types
from kfp.v2.components import importer_node

## HPO pipeline

In [134]:
from kfp.v2.dsl import component
from kfp.v2.dsl import Dataset, Input, Metrics, Model, Output

@component
def worker_pool_specs(project_id: str,
    data_region: str,
    data_pipeline_root: str,
    hpo_container_image_uri: str,
    custom_job_service_account: str,
    warehouse: str,
    input_dataset: Input[Dataset]
                     ) -> list:
    """
    Pass the preprocessed data uri to HPO as a worker pool argument. The vanilla HPO API 
    doesn't support 'input data' so it's done this way.
    
    data_preprocess -> dataset.uri -> CMDARGS -> worker_pool_specs -> HPO
    """
    
    task_type = 'training'
    display_name = 'hpo-pipeline-template'
    CMDARGS = [
    "--training_data_uri="+str(input_dataset.uri),
    "--warehouse="+warehouse,
    #"--training_data_uri=gs://vertex_pipeline_demo_root_hy_syd/datasets/training/processed_data-20230118012508.csv"
    ]

    # The spec of the worker pools including machine type and Docker image
    worker_pool_specs = [
    {
        "machine_spec": {
            "machine_type": "n1-standard-4",
        },
        "replica_count": 1,
        "container_spec": {"image_uri": hpo_container_image_uri, "args": CMDARGS},
    }
    ]
    
    return worker_pool_specs

@component(packages_to_install=['google-cloud-firestore==2.3'])
def best_hpo_to_args(hpo_best: str,
                    project_id: str,
                    solution_name: str,
                    as_at_date: str,
                    warehouse: str) -> str:
    """
    Write the best HPO params to firestore. 
    We keep the output to chain this component to the hpo_completion step.
    """
    
    import json
    from google.cloud import firestore
    hpo_best = json.loads(hpo_best.replace("'", '"'))

    hpo_best_dict = {}
    
    for i in hpo_best['parameters']:
        hpo_best_dict.update({i['parameterId']: i['value']})
    
    for i in hpo_best['finalMeasurement']['metrics']:
        hpo_best_dict.update({i['metricId']: i['value']})
    
    db = firestore.Client(project=project_id)
    db.collection("models").document(solution_name).collection("HPO").document(
        as_at_date).collection(warehouse).document("params").set(hpo_best_dict,merge=True)
    
    hpo_best_dict.update({'warehouse': warehouse})
    hpo_best_dict=str(hpo_best_dict).replace("'", '"')
    
    return hpo_best_dict

@component
def hpo_completion(hpo_flags: list) -> str:
    """
    This function doesn nothing but wait to merge all the async HPO jobs so as 
    to gurantee that the following training module gets the latest params from
    firestore for all warehouses.
    """
    return "true"

def hpo_warehouse(project_id,
                 data_region,
                 data_pipeline_root,
                 preprocess_task,
                 display_name,
                 metric_spec,
                 parameter_spec,
                 warehouse
                 ):
    """
    This is not a component function. It's a normal function that generates the sub graph 
    for every warehouse. We return the pipeline operation to chain this component to the 
    hpo_completion step. 
    """
    
    worker_pool_specs_op = worker_pool_specs(project_id=project_id,
    data_region=data_region,
    data_pipeline_root=data_pipeline_root,
    hpo_container_image_uri=hpo_container_image_uri,
    custom_job_service_account=custom_job_service_account,
    warehouse=warehouse,                               
    input_dataset=preprocess_task.outputs['output_dataset']
    )

    tuning_op = HyperparameterTuningJobRunOp(
    display_name=display_name+'-'+warehouse,
    project=project_id,
    location=data_region,
    worker_pool_specs=worker_pool_specs_op.output,
    study_spec_metrics=metric_spec,
    study_spec_parameters=parameter_spec,
    max_trial_count=4,
    parallel_trial_count=2,
    base_output_directory=data_pipeline_root,
    study_spec_algorithm='GRID_SEARCH'
    )
 
    trials_op = hyperparameter_tuning_job.GetTrialsOp(
        gcp_resources=tuning_op.outputs["gcp_resources"]
    )

    best_trial_op = hyperparameter_tuning_job.GetBestTrialOp(
        trials=trials_op.output, study_spec_metrics=metric_spec
    )
    
    best_hpo_to_args_op = best_hpo_to_args(best_trial_op.output,
                                          project_id=project_id,               
                                        as_at_date=datetime.now().strftime('%Y-%m-%d'),
                                        warehouse=warehouse,
                                          solution_name=display_name)
    return best_hpo_to_args_op
    

@dsl.pipeline(name='hpo-pipeline-template')
def pipeline(project_id: str,
             data_region: str,
             gcs_data_output_folder: str,
             input_dataset_uri: str,
             training_data_schema: str,
             data_pipeline_root: str,
             
             training_container_image_uri: str,
             serving_container_image_uri: str,
             custom_job_service_account: str,
             hptune_region: str,
             hp_config_suggestions_per_request: int,
             hp_config_max_trials: int,
             
             metrics_name: str,
             metrics_threshold: float,
             
             endpoint_machine_type: str,
             endpoint_min_replica_count: int,
             endpoint_max_replica_count: int,
             endpoint_test_instances: str,
             
             output_model_file_name: str = 'model.h5',
             machine_type: str = "n1-standard-8",
             accelerator_count: int = 0,
             accelerator_type: str = 'ACCELERATOR_TYPE_UNSPECIFIED',
             vpc_network: str = "",
             enable_model_monitoring: str = 'False',
            task_type: str = 'training'):
    
    task_type = 'training'
    display_name = 'hpo-pipeline-template'
    metric_spec = hyperparameter_tuning_job.serialize_metrics({"val_balanced_acc": "maximize"})
    parameter_spec = hyperparameter_tuning_job.serialize_parameters(
    {
        "batch_size": aiplatform.hyperparameter_tuning.DiscreteParameterSpec(
            values=[32, 64], scale=None
        ),
        "lr": aiplatform.hyperparameter_tuning.DiscreteParameterSpec(
            values=[0.0001, 0.0002], scale=None
        ),
    }
    )

    preprocess_task = preprocess_op(
      project_id=project_id,
      data_region=data_region,
      gcs_output_folder=gcs_data_output_folder,
      gcs_output_format="CSV",
      task_type=task_type)

    hpo_op_ec = hpo_warehouse(project_id,
             data_region,
             data_pipeline_root,
             preprocess_task,
             display_name,
             metric_spec,
             parameter_spec,
            "EC")
    
    hpo_op_mel = hpo_warehouse(project_id,
                 data_region,
                 data_pipeline_root,
                 preprocess_task,
                 display_name,
                 metric_spec,
                 parameter_spec,
                "MEL")
    
    hpo_completion_op = hpo_completion([str(hpo_op_ec.output), 
                                        str(hpo_op_mel.output)])
    
    with dsl.Condition(
         hpo_completion_op.output=="true",
        name="train_model"
    ):
    """
    We use the condition module to check if all HPO jobs for different warehouse are finished so as to
    kick off the training step at the right time.
    """
        train_task = train_op(
          project_id=project_id,
          data_region=data_region,
          data_pipeline_root=data_pipeline_root,
          input_data_schema=training_data_schema,
          training_container_image_uri=training_container_image_uri,
          serving_container_image_uri=serving_container_image_uri,
          custom_job_service_account=custom_job_service_account,
          input_dataset=preprocess_task.outputs['output_dataset'],
          output_model_file_name=output_model_file_name,
          machine_type=machine_type,
          accelerator_count=accelerator_count,
          accelerator_type=accelerator_type,
          hptune_region=hptune_region,
          hp_config_max_trials=hp_config_max_trials,
          hp_config_suggestions_per_request=hp_config_suggestions_per_request,
          vpc_network=vpc_network
        )

### Compile and run the end-to-end HPO pipeline
With our full pipeline defined, it's time to compile it:

In [135]:
compiler.Compiler().compile(
    pipeline_func=pipeline, 
    package_path="training_pipeline_job.json"
)

api_client = AIPlatformClient(
    project_id=project_id,
    region=pipeline_region)


test_instances = json.dumps([
		{'reviewtext': 'pet circle is not recommended',"Class":"0"},
		{'reviewtext': 'pet circle is highly recommended',"Class":"1"},
		{'reviewtext': 'think twice before you buy',"Class":"0"},
		{'reviewtext': 'great product. will buy again.',"Class":"1"}
		])

pipeline_params = {
    'project_id': project_id,
    'data_region': data_region,
    'gcs_data_output_folder': gcs_data_output_folder,
    'output_model_file_name': 'model.h5',
    'input_dataset_uri': input_dataset_uri,
    'training_data_schema': training_data_schema,
    'data_pipeline_root': data_pipeline_root,
    
    'training_container_image_uri': training_container_image_uri,
    'serving_container_image_uri': serving_container_image_uri,
    'custom_job_service_account': custom_job_service_account,
    'hptune_region':"asia-east1",
    'hp_config_suggestions_per_request': 5,
    'hp_config_max_trials': 30,
    
    'metrics_name': 'au_prc',
    'metrics_threshold': 0.3,
    
    'endpoint_machine_type': 'n1-standard-4',
    'endpoint_min_replica_count': 1,
    'endpoint_max_replica_count': 1,
    'endpoint_test_instances': test_instances
}

response = api_client.create_run_from_job_spec(
    job_spec_path="training_pipeline_job.json", 
    pipeline_root=pipeline_root,
    parameter_values=pipeline_params,
    enable_caching=False)

## Trouble shooting

In [132]:
from kfp.v2.dsl import component
from kfp.v2.dsl import Dataset, Input, Metrics, Model, Output

@component
def worker_pool_specs(project_id: str,
    data_region: str,
    data_pipeline_root: str,
    hpo_container_image_uri: str,
    custom_job_service_account: str,
    warehouse: str,
   # input_dataset: Input[Dataset]
                     ) -> list:

    task_type = 'training'
    display_name = 'hpo-pipeline-template'
    CMDARGS = [
    #"--training_data_uri="+str(input_dataset.uri),
    "--warehouse="+warehouse,
    "--training_data_uri=gs://vertex_pipeline_demo_root_hy_syd/datasets/training/processed_data-20230118012508.csv"
    ]

    # The spec of the worker pools including machine type and Docker image
    worker_pool_specs = [
    {
        "machine_spec": {
            "machine_type": "n1-standard-4",
        },
        "replica_count": 1,
        "container_spec": {"image_uri": hpo_container_image_uri, "args": CMDARGS},
    }
    ]
    
    return worker_pool_specs

@component(packages_to_install=['google-cloud-firestore==2.3'])
def best_hpo_to_args(hpo_best: str,
                    project_id: str,
                    solution_name: str,
                    as_at_date: str,
                    warehouse: str) -> str:
    import json
    from google.cloud import firestore
    hpo_best = json.loads(hpo_best.replace("'", '"'))

    hpo_best_dict = {}
    
    for i in hpo_best['parameters']:
        hpo_best_dict.update({i['parameterId']: i['value']})
    
    for i in hpo_best['finalMeasurement']['metrics']:
        hpo_best_dict.update({i['metricId']: i['value']})
    
    db = firestore.Client(project=project_id)
    db.collection("models").document(solution_name).collection("HPO").document(
        as_at_date).collection(warehouse).document("params").set(hpo_best_dict,merge=True)
    
    hpo_best_dict.update({'warehouse': warehouse})
    hpo_best_dict=str(hpo_best_dict).replace("'", '"')
    
    return hpo_best_dict

#@component
# def hpo_completion(hpo_flags_1: str, hpo_flags_2: str) -> str:
#     if hpo_flags_1 and hpo_flags_2:
#         return "true"

@component
def hpo_completion(hpo_flags: list) -> str:
    return "true"

def hpo_warehouse(project_id,
                 data_region,
                 data_pipeline_root,
                 #preprocess_task,
                 display_name,
                 metric_spec,
                 parameter_spec,
                 warehouse,
                 gcp_resources
                 ):
#     worker_pool_specs_op = worker_pool_specs(project_id=project_id,
#     data_region=data_region,
#     data_pipeline_root=data_pipeline_root,
#     hpo_container_image_uri=hpo_container_image_uri,
#     custom_job_service_account=custom_job_service_account,
#     warehouse=warehouse,
# #    input_dataset="gs://vertex_pipeline_demo_root_hy_syd/datasets/training/processed_data-20230118012508.csv"                                 
#     #input_dataset=preprocess_task.outputs['output_dataset'])
#     )

#     tuning_op = HyperparameterTuningJobRunOp(
#     display_name=display_name+'-'+warehouse,
#     project=project_id,
#     location=data_region,
#     worker_pool_specs=worker_pool_specs_op.output,
#     study_spec_metrics=metric_spec,
#     study_spec_parameters=parameter_spec,
#     max_trial_count=4,
#     parallel_trial_count=2,
#     base_output_directory=data_pipeline_root,
#     study_spec_algorithm='GRID_SEARCH'
#     )
 
    trials_op = hyperparameter_tuning_job.GetTrialsOp(
#        gcp_resources=tuning_op.outputs["gcp_resources"]
        gcp_resources=gcp_resources
#        gcp_resources='{"resources":[{"resourceType":"HyperparameterTuningJob","resourceUri":"https://australia-southeast1-aiplatform.googleapis.com/v1/projects/734227425472/locations/australia-southeast1/hyperparameterTuningJobs/695071668661387264"}]}'

    )

    best_trial_op = hyperparameter_tuning_job.GetBestTrialOp(
        trials=trials_op.output, study_spec_metrics=metric_spec
    )
    
    best_hpo_to_args_op = best_hpo_to_args(best_trial_op.output,
                                          project_id=project_id,               
                                        as_at_date=datetime.now().strftime('%Y-%m-%d'),
                                        warehouse=warehouse,
                                          solution_name=display_name)
    return best_hpo_to_args_op

    
@dsl.pipeline(name='hpo-pipeline-template')
def pipeline(project_id: str,
             data_region: str,
             gcs_data_output_folder: str,
             input_dataset_uri: str,
             training_data_schema: str,
             data_pipeline_root: str,
             
             training_container_image_uri: str,
             serving_container_image_uri: str,
             custom_job_service_account: str,
             hptune_region: str,
             hp_config_suggestions_per_request: int,
             hp_config_max_trials: int,
             
             metrics_name: str,
             metrics_threshold: float,
             
             endpoint_machine_type: str,
             endpoint_min_replica_count: int,
             endpoint_max_replica_count: int,
             endpoint_test_instances: str,
             
             output_model_file_name: str = 'model.h5',
             machine_type: str = "n1-standard-8",
             accelerator_count: int = 0,
             accelerator_type: str = 'ACCELERATOR_TYPE_UNSPECIFIED',
             vpc_network: str = "",
             enable_model_monitoring: str = 'False',
            task_type: str = 'training'):
    
    task_type = 'training'
    display_name = 'hpo-pipeline-template'
    metric_spec = hyperparameter_tuning_job.serialize_metrics({"val_balanced_acc": "maximize"})
    parameter_spec = hyperparameter_tuning_job.serialize_parameters(
    {
        "batch_size": aiplatform.hyperparameter_tuning.DiscreteParameterSpec(
            values=[32, 64], scale=None
        ),
        "lr": aiplatform.hyperparameter_tuning.DiscreteParameterSpec(
            values=[0.0001, 0.0002], scale=None
        ),
    }
    )

    preprocess_task = preprocess_op(
      project_id=project_id,
      data_region=data_region,
      gcs_output_folder=gcs_data_output_folder,
      gcs_output_format="CSV",
      task_type=task_type)

#     worker_pool_specs_op = worker_pool_specs(project_id=project_id,
#     data_region=data_region,
#     data_pipeline_root=data_pipeline_root,
#     hpo_container_image_uri=hpo_container_image_uri,
#     custom_job_service_account=custom_job_service_account,
#     input_dataset=preprocess_task.outputs['output_dataset'])

#     tuning_op = HyperparameterTuningJobRunOp(
#     display_name=display_name,
#     project=project_id,
#     location=data_region,
#     worker_pool_specs=worker_pool_specs_op.output,
#     study_spec_metrics=metric_spec,
#     study_spec_parameters=parameter_spec,
#     max_trial_count=4,
#     parallel_trial_count=2,
#     base_output_directory=data_pipeline_root,
#     study_spec_algorithm='GRID_SEARCH'
#     )
    
#     trials_op = hyperparameter_tuning_job.GetTrialsOp(
#         gcp_resources=tuning_op.outputs["gcp_resources"]
# #        gcp_resources='{"resources":[{"resourceType":"HyperparameterTuningJob","resourceUri":"https://australia-southeast1-aiplatform.googleapis.com/v1/projects/734227425472/locations/australia-southeast1/hyperparameterTuningJobs/695071668661387264"}]}'

#     )

#     best_trial_op = hyperparameter_tuning_job.GetBestTrialOp(
#         trials=trials_op.output, study_spec_metrics=metric_spec
#     )

    hpo_op_ec = hpo_warehouse(project_id,
             data_region,
             data_pipeline_root,
             #preprocess_task,
             display_name,
             metric_spec,
             parameter_spec,
            "EC",
            '{"resources":[{"resourceType":"HyperparameterTuningJob","resourceUri":"https://australia-southeast1-aiplatform.googleapis.com/v1/projects/734227425472/locations/australia-southeast1/hyperparameterTuningJobs/4881167522302263296"}]}'
            )
    
    hpo_op_mel = hpo_warehouse(project_id,
                 data_region,
                 data_pipeline_root,
                 #preprocess_task,
                 display_name,
                 metric_spec,
                 parameter_spec,
                "MEL",
                '{"resources":[{"resourceType":"HyperparameterTuningJob","resourceUri":"https://australia-southeast1-aiplatform.googleapis.com/v1/projects/734227425472/locations/australia-southeast1/hyperparameterTuningJobs/2070921354823073792"}]}'     
                              )
    
    hpo_completion_op = hpo_completion([str(hpo_op_ec.output), 
                                        str(hpo_op_mel.output)])
    
    with dsl.Condition(
         hpo_completion_op.output=="true",
        name="train_model"
    ):
    
        train_task = train_op(
          project_id=project_id,
          data_region=data_region,
          data_pipeline_root=data_pipeline_root,
          input_data_schema=training_data_schema,
          training_container_image_uri=training_container_image_uri,
          serving_container_image_uri=serving_container_image_uri,
          custom_job_service_account=custom_job_service_account,
          input_dataset=preprocess_task.outputs['output_dataset'],
          output_model_file_name=output_model_file_name,
          machine_type=machine_type,
          accelerator_count=accelerator_count,
          accelerator_type=accelerator_type,
          hptune_region=hptune_region,
          hp_config_max_trials=hp_config_max_trials,
          hp_config_suggestions_per_request=hp_config_suggestions_per_request,
          vpc_network=vpc_network
        )