In [67]:
import os
import sys
import logging
import uuid
import kfp

import kfp.v2.dsl as dsl

from datetime import datetime
from google.cloud import aiplatform
from google_cloud_pipeline_components import aiplatform as gcc_aip
from kfp.v2 import compiler
from kfp.v2.google.client import AIPlatformClient

from typing import NamedTuple


In [2]:
kfp.__version__

'1.6.3'

In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
PROJECT = 'jk-mlops-dev'
STAGING_BUCKET = 'gs://jk-vertex-workshop-bucket'
REGION = 'us-central1'
PIPELINES_SA = 'pipelines-sa@jk-mlops-dev.iam.gserviceaccount.com'

## Define custom components

In [79]:
from kfp.v2.dsl import (Artifact, Dataset, Input, InputPath, Model, Output,
                        OutputPath)


### Data preparation component

In [105]:
@dsl.component(base_image='gcr.io/ml-pipeline/google-cloud-pipeline-components:0.1.1')
def prepare_data_splits_op(
    project: str,
    bq_location: str,
    sample_size: int,
    year: int,
    name_prefix: str,
    dataset: Output[Dataset]
):
    """Prepares training, validation, and testing data splits."""
    
    import logging
    from google.cloud import bigquery
    from google.cloud import exceptions
    
    sql_script_template = '''
    CREATE TEMP TABLE features 
    AS (
        WITH
        taxitrips AS (
        SELECT
            FORMAT_DATETIME('%Y-%d-%m', trip_start_timestamp) AS date,
            trip_start_timestamp,
            trip_seconds,
            trip_miles,
            payment_type,
            pickup_longitude,
            pickup_latitude,
            dropoff_longitude,
            dropoff_latitude,
            tips,
            fare
        FROM
            `bigquery-public-data.chicago_taxi_trips.taxi_trips`
        WHERE 1=1 
        AND pickup_longitude IS NOT NULL
        AND pickup_latitude IS NOT NULL
        AND dropoff_longitude IS NOT NULL
        AND dropoff_latitude IS NOT NULL
        AND trip_miles > 0
        AND trip_seconds > 0
        AND fare > 0
        AND EXTRACT(YEAR FROM trip_start_timestamp) = @YEAR
        )

        SELECT
        trip_start_timestamp,
        EXTRACT(MONTH from trip_start_timestamp) as trip_month,
        EXTRACT(DAY from trip_start_timestamp) as trip_day,
        EXTRACT(DAYOFWEEK from trip_start_timestamp) as trip_day_of_week,
        EXTRACT(HOUR from trip_start_timestamp) as trip_hour,
        trip_seconds,
        trip_miles,
        payment_type,
        ST_AsText(
            ST_SnapToGrid(ST_GeogPoint(pickup_longitude, pickup_latitude), 0.1)
        ) AS pickup_grid,
        ST_AsText(
            ST_SnapToGrid(ST_GeogPoint(dropoff_longitude, dropoff_latitude), 0.1)
        ) AS dropoff_grid,
        ST_Distance(
            ST_GeogPoint(pickup_longitude, pickup_latitude), 
            ST_GeogPoint(dropoff_longitude, dropoff_latitude)
        ) AS euclidean,
        IF((tips/fare >= 0.2), 1, 0) AS tip_bin,
        CASE (ABS(MOD(FARM_FINGERPRINT(date),10))) 
            WHEN 9 THEN 'TEST'
            WHEN 8 THEN 'VALIDATE'
            ELSE 'TRAIN' END AS data_split
        FROM
        taxitrips
        LIMIT @LIMIT
    );

    CREATE OR REPLACE TABLE `@PROJECT.@DATASET.@TRAIN_SPLIT`
    AS
    SELECT * EXCEPT (trip_start_timestamp, data_split)
    FROM features
    WHERE data_split='TRAIN';

    CREATE OR REPLACE TABLE `@PROJECT.@DATASET.@VALIDATE_SPLIT`
    AS
    SELECT * EXCEPT (trip_start_timestamp, data_split)
    FROM features
    WHERE data_split='VALIDATE';

    CREATE OR REPLACE TABLE `@PROJECT.@DATASET.@TEST_SPLIT`
    AS
    SELECT * EXCEPT (trip_start_timestamp, data_split)
    FROM features
    WHERE data_split='TEST';

    DROP TABLE features;
    '''
    
    client = bigquery.Client(project=project)
    dataset_name = f'{name_prefix}_dataset'
    ds = bigquery.Dataset(f'{project}.{dataset_name}')
    ds.location = bq_location
    try:
        ds = client.create_dataset(ds, timeout=30)
        logging.info(f'Created dataset: {project}.{dataset_name}')
    except exceptions.Conflict:
        logging.info(f'Dataset {project}.{dataset_name} already exists')
        
    train_split = f'{name_prefix}_train'
    valid_split = f'{name_prefix}_valid'
    test_split = f'{name_prefix}_test'
    sql_script = sql_script_template.replace(
        '@PROJECT', project).replace(
        '@DATASET', dataset_name).replace(
        '@TRAIN_SPLIT', train_split).replace(
        '@VALIDATE_SPLIT', valid_split).replace(
        '@TEST_SPLIT', test_split).replace(
        '@YEAR', str(year)).replace(
        '@LIMIT', str(sample_size))

    job = client.query(sql_script)
    job.result()
    
    dataset.metadata['training_split'] = f'{project}.{dataset_name}.{train_split}'
    dataset.metadata['validation_split'] = f'{project}.{dataset_name}.{valid_split}'
    dataset.metadata['testing_split'] = f'{project}.{dataset_name}.{test_split}'

### Generate statistics

In [140]:
@dsl.component(base_image='tensorflow/tfx:latest')
def generate_stats_op(
    project: str,
    dataset: Input[Dataset],
    stats: Output[Artifact],
   
):
    """Generates statistics from the data splits."""
    
    import tensorflow_data_validation as tfdv
    from google.cloud import bigquery
    
    training_split_name = dataset.metadata['training_split']
    
    sql_script = f'''
    SELECT * 
    FROM {training_split_name} 
    '''
    
    client = bigquery.Client(project=project)
    df = client.query(sql_script).result().to_dataframe()
    
    stats = tfdv.generate_statistics_from_dataframe(
        dataframe=df,
        stats_options=tfdv.StatsOptions(
            weight_feature=None,
            sample_rate=1,
            num_top_values=50
        )
    )
    
    file_path = os.path.join(stats.path, 'train')
    tfdv.write_stats_text(stats, file_path)
    
    

### Validate statistics

In [142]:
@dsl.component(base_image='tensorflow/tfx:latest')
def validate_stats_op(
    project: str,
    stats: Input[Artifact],
    schema: Input[Artifact],
    anomalies: Output[Artifact],
   
):
    """Generates statistics from the data splits."""
    
    import tensorflow_data_validation as tfdv
    
    from google.cloud import bigquery
    
    print(stats.path)

### Trainer component

In [143]:
@dsl.component(base_image='gcr.io/ml-pipeline/google-cloud-pipeline-components:0.1.1')
def train_op(
    project: str,
    region: str,
    epochs: int,
    per_replica_batch_size: int,
    machine_type: str,
    accelerator_type: str,
    accelerator_count: int,
    dataset: Input[Dataset],
    model: Output[Model],
   
) -> NamedTuple(
  'TrainOutputs',
  [
    ('artifacts_uri', str)
  ]):
    """Prepares and submits Vertex AI Training custom container job."""
    
    
    CONTAINER_IMAGE_URI = 'gcr.io/jk-mlops-dev/taxi_classifier_trainer'
    
    import logging
    import time
    
    from collections import namedtuple
    from google.cloud import aiplatform as vertex_ai
    
    output = namedtuple('TrainOutputs', ['artifacts_uri'])
    
    return output('gs://jk-vertex-workshop-bucket/pipelines/895222332033/taxi-tip-predictor-continuous-training-20210607212717/train-op_2367310107252883456/model')
    

    # Set base_output_dir
    if model.path[0:4] != '/gcs':
        raise RuntimeError('Model dir must be a GCS location.')   
 
    base_output_dir = 'gs://' + model.path[5:].rsplit('/', 1)[0]

    # Prepare worker pool specification
    worker_pool_specs =  [
        {
            "machine_spec": {
                "machine_type": machine_type,
                "accelerator_type": accelerator_type,
                "accelerator_count": accelerator_count,
            },
            "replica_count": 1,
            "container_spec": {
                "image_uri": CONTAINER_IMAGE_URI,
                "command": ["python", "train.py"],
                "args": [
                    '--epochs=' + str(epochs), 
                    '--per_replica_batch_size=' + str(per_replica_batch_size),
                    '--training_table=' + dataset.metadata['training_split'],
                    '--validation_table=' + dataset.metadata['validation_split'],
                ],
            },
        }
    ]
    
    # Submit the job
    vertex_ai.init(
        project=project,
        location=region
    )
                                             
    job_name = "JOB_{}".format(time.strftime("%Y%m%d_%H%M%S"))
    job = vertex_ai.CustomJob(
        display_name=job_name,
        worker_pool_specs=worker_pool_specs,
        staging_bucket=base_output_dir
    )

    response = job.run(sync=True)
    
    return (f'{base_output_dir}/model')


In [144]:
VERTEX_TRAINING_JOB_NAME = 'taxi-tip-predictor-training-job'
PIPELINE_NAME = 'taxi-tip-predictor-continuous-training'

@dsl.pipeline(name=PIPELINE_NAME)
def taxi_tip_predictor_pipeline(
    project: str,
    region: str,
    model_display_name: str,
    epochs: int,
    per_replica_batch_size: int,
    schema: str,
    machine_type: str = 'n1-standard-4',
    accelerator_type: str = 'NVIDIA_TESLA_T4',
    accelerator_count: int = 1,
    bq_location: str = 'US',
    year: int = 2020,
    sample_size: int = 1000000,
    name_prefix: str = 'chicago_taxi_tips',
    serving_container_image: str = "us-docker.pkg.dev/cloud-aiplatform/prediction/tf2-cpu.2-4:latest"
):
    
    import_schema = kfp.dsl.importer(
        artifact_uri=schema,
        artifact_class=Artifact,
        reimport=False,
    )
    
    prepare_data = prepare_data_splits_op(
        project=project,
        bq_location=bq_location,
        sample_size=sample_size,
        year=year,
        name_prefix=name_prefix,
    )
    
    generate_stats = generate_stats_op(
        project=project,
        dataset=prepare_data.outputs['dataset'],
    )
    
    validate_stats = validate_stats_op(
        project=project,
        schema=import_schema.output,
        stats=generate_stats.outputs['stats'],
    )
    
    train = train_op(
        project=project,
        region=region,
        dataset=prepare_data.outputs['dataset'],
        epochs=epochs,
        per_replica_batch_size=per_replica_batch_size,
        machine_type=machine_type,
        accelerator_type=accelerator_type,
        accelerator_count=accelerator_count
    )
    
    upload_model = gcc_aip.ModelUploadOp(
        project=project,
        display_name=model_display_name,
        artifact_uri=train.outputs['artifacts_uri'],
        serving_container_image_uri=serving_container_image
    )
    

### Compile the pipeline

In [145]:
package_path = 'taxi_tip_predictor_pipeline.json'
compiler.Compiler().compile(
    pipeline_func=taxi_tip_predictor_pipeline,
    package_path=package_path
)

### Submit a pipeline run

In [146]:
api_client = AIPlatformClient(
    project_id=PROJECT,
    region=REGION,
)

In [147]:
pipeline_root = f'{STAGING_BUCKET}/pipelines'
model_display_name = 'Taxi tip predictor'
training_container_image = 'gcr.io/jk-mlops-dev/taxi_classifier_trainer'
epochs = 3
per_replica_batch_size = 128
training_table = 'jk-mlops-dev.chicago_taxi_training.training_split'
validation_table = 'jk-mlops-dev.chicago_taxi_training.validation_split'
schema = 'gs://jk-vertex-workshop-bucket/schema/schema.pbtxt'


parameter_values = {
    'project': PROJECT,
    'region': REGION,
    'model_display_name': model_display_name,
    'epochs': epochs,
    'schema': schema,
    'per_replica_batch_size': per_replica_batch_size,
}

response = api_client.create_run_from_job_spec(
    package_path,
    pipeline_root=pipeline_root,
    parameter_values=parameter_values,
    service_account=PIPELINES_SA
)