In [1]:
import os
import sys
import logging
import uuid
import kfp

from datetime import datetime
from google.cloud import aiplatform
from google_cloud_pipeline_components import aiplatform as gcc_aip
from kfp.v2 import compiler
from kfp.v2.dsl import component
from kfp.v2.google import experimental
from kfp.v2.google.client import AIPlatformClient

sys.path.append('pipelines')
from pipelines.pipeline import taxi_tip_predictor_pipeline

In [2]:
kfp.__version__

'1.6.3'

In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
PROJECT = 'jk-mlops-dev'
STAGING_BUCKET = 'gs://jk-vertex-workshop-bucket'
REGION = 'us-central1'
PIPELINES_SA = 'pipelines-sa@jk-mlops-dev.iam.gserviceaccount.com'


TIMESTAMP = datetime.now().strftime("%Y%m%d%H%M%S")
TIMESTAMP

In [9]:
@component(base_image='gcr.io/ml-pipeline/google-cloud-pipeline-components:0.1.1')
def training_op(
    project: str,
    region: str,
    staging_bucket: str,
    training_table: str,
    validation_table: str,
    epochs: int,
    per_replica_batch_size: int =128,
    machine_type: str = 'n1-standard-4',
    accelerator_type: str = 'NVIDIA_TESLA_T4',
    accelerator_count: int = 1,
    container_image_uri: str = 'gcr.io/jk-mlops-dev/taxi_classifier_trainer'
):
    
    import time
    from google.cloud import aiplatform as vertex_ai
    
    vertex_ai.init(
        project=project,
        location=region
    )
    
    worker_pool_specs =  [
        {
            "machine_spec": {
                "machine_type": machine_type,
                "accelerator_type": accelerator_type,
                "accelerator_count": accelerator_count,
            },
            "replica_count": 1,
            "container_spec": {
                "image_uri": container_image_uri,
                "command": ["python", "train.py"],
                "args": [
                    '--epochs=' + str(epochs), 
                    '--per_replica_batch_size=' + str(per_replica_batch_size),
                    '--training_table=' + training_table,
                    '--validation_table=' + validation_table,
                ],
            },
        }
    ]
                                                      
    job_name = "JOB_{}".format(time.strftime("%Y%m%d_%H%M%S"))
    
    job = vertex_ai.CustomJob(
        display_name=job_name,
        worker_pool_specs=worker_pool_specs,
        staging_bucket=f'{staging_bucket}/{job_name}'
    )

    job.run(sync=False)





In [14]:
VERTEX_TRAINING_JOB_NAME = 'taxi-tip-predictor-training-job'
PIPELINE_NAME = 'taxi-tip-predictor-continuous-training'

@kfp.dsl.pipeline(name=PIPELINE_NAME)
def taxi_tip_predictor_pipeline(
    project: str,
    region: str,
    staging_bucket: str,
    epochs: int,
    per_replica_batch_size: int,
    training_table: str,
    validation_table: str
):
    
    train = training_op(
        project=project,
        region=region,
        staging_bucket=staging_bucket,
        training_table=training_table,
        validation_table=validation_table,
        epochs=epochs,
        per_replica_batch_size=per_replica_batch_size
    )

    
    

### Compile the pipeline

In [15]:
package_path = 'taxi_tip_predictor_pipeline.json'
compiler.Compiler().compile(
    pipeline_func=taxi_tip_predictor_pipeline,
    package_path=package_path
)

### Submit a pipeline run

In [16]:
api_client = AIPlatformClient(
    project_id=PROJECT,
    region=REGION,
)

In [17]:
pipeline_root = f'{STAGING_BUCKET}/pipelines'
model_display_name = 'Taxi tip predictor'
training_container_image = 'gcr.io/jk-mlops-dev/taxi_classifier_trainer'
epochs = 3
per_replica_batch_size = 128
training_table = 'jk-mlops-dev.chicago_taxi_training.training_split'
validation_table = 'jk-mlops-dev.chicago_taxi_training.validation_split'


parameter_values = {
    'project': PROJECT,
    'region': REGION,
    'staging_bucket': STAGING_BUCKET,
    'epochs': epochs,
    'per_replica_batch_size': per_replica_batch_size,
    'training_table': training_table,
    'validation_table': validation_table,
}

response = api_client.create_run_from_job_spec(
    package_path,
    pipeline_root=pipeline_root,
    parameter_values=parameter_values,
    service_account=PIPELINES_SA
)