# Training pipeline for TFRS  2tower model 

When completed you should have a pipeline that looks like this:

![](img/train-pipeline-sp-e2e.png)

#### Setps performed
1. Create custom components for training and parallel vocabulary adapts
2. Save master vocabulary and add a managed tensorboard to monitor training
3. Create pipeline with blend of custom and built-in components
4. Export/Import models to registry, deploy to endpoints
5. Create Matching Engine Endpoint as well as ANN, Brute Force indexes to test recall/latency tradeoff
6. Perform final tests on entire deployment

### pip

In [1]:
# ! pip3 install --upgrade --user -q google-cloud-aiplatform
# ! pip3 install --upgrade --user -q google-cloud-storage
# ! pip3 install --upgrade --user -q kfp
# ! pip3 install --upgrade --user -q google-cloud-pipeline-components

In [2]:
! python3 -c "import kfp; print('KFP SDK version: {}'.format(kfp.__version__))"
! python3 -c "import google_cloud_pipeline_components; print('google_cloud_pipeline_components version: {}'.format(google_cloud_pipeline_components.__version__))"
! python3 -c "import google.cloud.aiplatform; print('aiplatform SDK version: {}'.format(google.cloud.aiplatform.__version__))"

KFP SDK version: 1.8.20
google_cloud_pipeline_components version: 1.0.42
aiplatform SDK version: 1.26.1


## Load env config

In [3]:
# naming convention for all cloud resources
VERSION        = "v1"                  # TODO
PREFIX         = f'ndr-{VERSION}'      # TODO

print(f"PREFIX = {PREFIX}")

PREFIX = ndr-v1


In [4]:
# staging GCS
GCP_PROJECTS             = !gcloud config get-value project
PROJECT_ID               = GCP_PROJECTS[0]

# GCS bucket and paths
BUCKET_NAME              = f'{PREFIX}-{PROJECT_ID}-bucket'
BUCKET_URI               = f'gs://{BUCKET_NAME}'

config = !gsutil cat {BUCKET_URI}/config/notebook_env.py
print(config.n)
exec(config.n)


PROJECT_ID               = "hybrid-vertex"
PROJECT_NUM              = "934903580331"
LOCATION                 = "us-central1"

REGION                   = "us-central1"
BQ_LOCATION              = "US"
VPC_NETWORK_NAME         = "ucaip-haystack-vpc-network"

VERTEX_SA                = "934903580331-compute@developer.gserviceaccount.com"

PREFIX                   = "ndr-v1"
VERSION                  = "v1"

APP                      = "sp"
MODEL_TYPE               = "2tower"
FRAMEWORK                = "tfrs"
DATA_VERSION             = "v1"
TRACK_HISTORY            = "5"

BUCKET_NAME              = "ndr-v1-hybrid-vertex-bucket"
BUCKET_URI               = "gs://ndr-v1-hybrid-vertex-bucket"
SOURCE_BUCKET            = "spotify-million-playlist-dataset"

DATA_GCS_PREFIX          = "data"
DATA_PATH                = "gs://ndr-v1-hybrid-vertex-bucket/data"
VOCAB_SUBDIR             = "vocabs"
VOCAB_FILENAME           = "vocab_dict.pkl"

CANDIDATE_PREFIX         = "candidates"
TRAIN_DIR_PREFIX      

In [5]:
import os

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 

In [6]:
import json
from datetime import datetime
from time import time
import pandas as pd
import time
from pprint import pprint
import pickle as pkl

import logging
logging.disable(logging.WARNING)

from google.cloud import aiplatform as vertex_ai
from google.cloud import storage

# Pipelines
from typing import Any, Callable, Dict, NamedTuple, Optional, List
from google_cloud_pipeline_components import aiplatform as gcc_aip
from google_cloud_pipeline_components.types import artifact_types

# Kubeflow SDK
# TODO: fix these
from kfp.v2 import dsl
import kfp
import kfp.v2.dsl
from kfp.v2.google import client as pipelines_client
from kfp.v2.dsl import (Artifact, Dataset, Input, InputPath, Model, Output,
                        OutputPath, component)

# import modules
from util import feature_set_utils as feature_utils
from util import test_instances

storage_client = storage.Client(project=PROJECT_ID)

vertex_ai.init(project=PROJECT_ID,location=REGION)

In [7]:
PIPELINE_VERSION = 'pipe-v2'       # TODO

In [8]:
MODEL_ROOT_NAME = f'{APP}-{MODEL_TYPE}-{FRAMEWORK}-{VERSION}-{PIPELINE_VERSION}'
print(f"MODEL_ROOT_NAME: {MODEL_ROOT_NAME}")

MODEL_ROOT_NAME: sp-2tower-tfrs-v1-pipe-v2


# Pipeline Components

In [9]:
# os.getcwd()

In [10]:
REPO_SRC = 'src'
PIPELINES_SUB_DIR = 'train_pipes'

In [11]:
! rm -rf {REPO_SRC}/{PIPELINES_SUB_DIR}
! mkdir {REPO_SRC}/{PIPELINES_SUB_DIR}

## Create Tensorboard

In [12]:
%%writefile {REPO_SRC}/{PIPELINES_SUB_DIR}/create_tensorboard.py

import kfp
from typing import Any, Callable, Dict, NamedTuple, Optional, List
from kfp.v2.dsl import (
    Artifact, Dataset, Input, InputPath, 
    Model, Output, OutputPath, component, Metrics
)
@kfp.v2.dsl.component(
    base_image='python:3.9',
    packages_to_install=[
        'google-cloud-aiplatform==1.26.1',
        'numpy',
        'google-cloud-storage',
    ],
    # output_component_file="./pipelines/train_custom_model.yaml",
)
def create_tensorboard(
    project: str,
    location: str,
    model_version: str,
    pipeline_version: str,
    model_name: str, 
    experiment_name: str,
    experiment_run: str,
) -> NamedTuple('Outputs', [
    ('tensorboard_resource_name', str),
    ('tensorboard_display_name', str),
]):
    
    import logging
    from google.cloud import aiplatform as vertex_ai
    from google.cloud import storage
    
    vertex_ai.init(
        project=project,
        location=location,
        # experiment=experiment_name,
    )
    
    logging.info(f'experiment_name: {experiment_name}')
    
    # # create new TB instance
    TENSORBOARD_DISPLAY_NAME=f"{experiment_name}-v1"
    tensorboard = vertex_ai.Tensorboard.create(display_name=TENSORBOARD_DISPLAY_NAME, project=project, location=location)
    TB_RESOURCE_NAME = tensorboard.resource_name
    
    logging.info(f'TENSORBOARD_DISPLAY_NAME: {TENSORBOARD_DISPLAY_NAME}')
    logging.info(f'TB_RESOURCE_NAME: {TB_RESOURCE_NAME}')
    
    return (
        f'{TB_RESOURCE_NAME}',
        f'{TENSORBOARD_DISPLAY_NAME}',
    )

Writing src/train_pipes/create_tensorboard.py


## Custom train job

In [13]:
%%writefile {REPO_SRC}/{PIPELINES_SUB_DIR}/train_custom_model.py

import kfp
from typing import Any, Callable, Dict, NamedTuple, Optional, List
from kfp.v2.dsl import (
    Artifact, Dataset, Input, InputPath, 
    Model, Output, OutputPath, component, Metrics
)
@kfp.v2.dsl.component(
    base_image='python:3.9',
    packages_to_install=[
        'google-cloud-aiplatform==1.26.1',
        # 'tensorflow==2.9.2',
        # 'tensorflow-recommenders==0.7.0',
        'numpy',
        'google-cloud-storage',
    ],
    # output_component_file="./pipelines/train_custom_model.yaml",
)
def train_custom_model(
    project: str,
    location: str,
    model_version: str,
    pipeline_version: str,
    model_name: str, 
    worker_pool_specs: dict,
    # vocab_dict_uri: str, 
    train_output_gcs_bucket: str,                         # change to workdir?
    training_image_uri: str,
    tensorboard_resource_name: str,
    service_account: str,
    experiment_name: str,
    experiment_run: str,
    generate_new_vocab: bool,
) -> NamedTuple('Outputs', [
    ('job_dict_uri', str),
    ('query_tower_dir_uri', str),
    ('candidate_tower_dir_uri', str),
    ('experiment_run_dir', str),
]):
    
    import logging
    import numpy as np
    import pickle as pkl
    
    from google.cloud import aiplatform as vertex_ai
    # import google.cloud.aiplatform_v1beta1 as aip_beta
    from google.cloud import storage
    
    vertex_ai.init(
        project=project,
        location=location,
        experiment=experiment_name,
    )
    
    storage_client = storage.Client(project=project)
    
    JOB_NAME = f'train-{model_name}'
    logging.info(f'JOB_NAME: {JOB_NAME}')
    
    BASE_OUTPUT_DIR = f'gs://{train_output_gcs_bucket}/{experiment_name}/{experiment_run}'
    logging.info(f'BASE_OUTPUT_DIR: {BASE_OUTPUT_DIR}')
    
    # logging.info(f'vocab_dict_uri: {vocab_dict_uri}')
    
    logging.info(f'tensorboard_resource_name: {tensorboard_resource_name}')
    logging.info(f'service_account: {service_account}')
    logging.info(f'worker_pool_specs: {worker_pool_specs}')
    
    # ====================================================
    # Launch Vertex job
    # ====================================================
    
    worker_pool_specs[0]['container_spec']['args'].append(f'--tb_resource_name={tensorboard_resource_name}')
    
    if generate_new_vocab == 'True':
        worker_pool_specs[0]['container_spec']['args'].append(f'--new_vocab')
  
    job = vertex_ai.CustomJob(
        display_name=JOB_NAME,
        worker_pool_specs=worker_pool_specs,
        base_output_dir=BASE_OUTPUT_DIR,
        staging_bucket=f"{BASE_OUTPUT_DIR}/staging",
    )
    
    logging.info(f'Submitting train job to Vertex AI...')
    
    job.run(
        tensorboard=tensorboard_resource_name,
        service_account=f'{service_account}',
        restart_job_on_worker_restart=False,
        enable_web_access=True,
        sync=False,
    )
        
    # wait for job to complete
    job.wait()
    
    # ====================================================
    # Save job details
    # ====================================================
    
    train_job_dict = job.to_dict()
    logging.info(f'train_job_dict: {train_job_dict}')
    
    # pkl dict to GCS
    logging.info(f"Write pickled dict to GCS...")
    TRAIN_DICT_LOCAL = f'train_job_dict.pkl'
    TRAIN_DICT_GCS_OBJ = f'{experiment_name}/{experiment_run}/{TRAIN_DICT_LOCAL}' # destination folder prefix and blob name
    
    logging.info(f"TRAIN_DICT_LOCAL: {TRAIN_DICT_LOCAL}")
    logging.info(f"TRAIN_DICT_GCS_OBJ: {TRAIN_DICT_GCS_OBJ}")

    # pickle
    filehandler = open(f'{TRAIN_DICT_LOCAL}', 'wb')
    pkl.dump(train_job_dict, filehandler)
    filehandler.close()
    
    # upload to GCS
    bucket_client = storage_client.bucket(train_output_gcs_bucket)
    blob = bucket_client.blob(TRAIN_DICT_GCS_OBJ)
    blob.upload_from_filename(TRAIN_DICT_LOCAL)
    
    job_dict_uri = f'gs://{train_output_gcs_bucket}/{TRAIN_DICT_GCS_OBJ}'
    logging.info(f"{TRAIN_DICT_LOCAL} uploaded to {job_dict_uri}")
    
    # ====================================================
    # Model and index artifact uris
    # ====================================================
    EXPERIMENT_RUN_DIR = f"gs://{train_output_gcs_bucket}/{experiment_name}/{experiment_run}"
    query_tower_dir_uri = f"{EXPERIMENT_RUN_DIR}/model-dir/query_model" 
    candidate_tower_dir_uri = f"{EXPERIMENT_RUN_DIR}/model-dir/candidate_model"
    # candidate_index_dir_uri = f"gs://{output_dir_gcs_bucket_name}/{experiment_name}/{experiment_run}/candidate_model"
    
    logging.info(f'query_tower_dir_uri: {query_tower_dir_uri}')
    logging.info(f'candidate_tower_dir_uri: {candidate_tower_dir_uri}')
    # logging.info(f'candidate_index_dir_uri: {candidate_index_dir_uri}')
    
    return (
        f'{job_dict_uri}',
        f'{query_tower_dir_uri}',
        f'{candidate_tower_dir_uri}',
        f'{EXPERIMENT_RUN_DIR}',
    )

Writing src/train_pipes/train_custom_model.py


## Generate Candidates

In [14]:
%%writefile {REPO_SRC}/{PIPELINES_SUB_DIR}/generate_candidates.py

import kfp
from typing import Any, Callable, Dict, NamedTuple, Optional, List
from kfp.v2.dsl import (
    Artifact, Dataset, Input, InputPath, 
    Model, Output, OutputPath, component, Metrics
)
@kfp.v2.dsl.component(
    base_image="python:3.9",
    packages_to_install=[
        'google-cloud-aiplatform==1.26.1',
        'tensorflow==2.11.0',
        'tensorflow-recommenders==0.7.2',
        'numpy',
        # 'google-cloud-storage',
    ],
)
def generate_candidates(
    project: str,
    location: str,
    version: str, 
    # emb_index_gcs_uri: str,
    candidate_tower_dir_uri: str,
    candidate_file_dir_bucket: str,
    candidate_file_dir_prefix: str,
    train_output_gcs_bucket: str,
    experiment_name: str,
    experiment_run: str,
    experiment_run_dir: str,
) -> NamedTuple('Outputs', [
    ('emb_index_gcs_uri', str),
    # ('emb_index_artifact', Artifact),
]):
    import logging
    import json
    import pickle as pkl
    from pprint import pprint
    import time
    import numpy as np

    import os

    # os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 

    import tensorflow as tf
    import tensorflow_recommenders as tfrs

    from google.cloud import storage
    from google.cloud.storage.bucket import Bucket
    from google.cloud.storage.blob import Blob

    import google.cloud.aiplatform as vertex_ai
    
    # set clients
    vertex_ai.init(
        project=project,
        location=location,
    )
    storage_client = storage.Client(project=project)

    # tf.Data confg
    options = tf.data.Options()
    options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
    
    # ====================================================
    # Load trained candidate tower
    # ====================================================
    logging.info(f"candidate_tower_dir_uri: {candidate_tower_dir_uri}")
    
    loaded_candidate_model = tf.saved_model.load(candidate_tower_dir_uri)
    logging.info(f"loaded_candidate_model.signatures: {loaded_candidate_model.signatures}")
    
    candidate_predictor = loaded_candidate_model.signatures["serving_default"]
    logging.info(f"structured_outputs: {candidate_predictor.structured_outputs}")
    
    # ===================================================
    # set feature vars
    # ===================================================
    FEATURES_PREFIX = f'{experiment_name}/{experiment_run}/features'
    logging.info(f"FEATURES_PREFIX: {FEATURES_PREFIX}")
    
    def download_blob(bucket_name, source_gcs_obj, local_filename):
        """Uploads a file to the bucket."""
        # storage_client = storage.Client(project=project_number)
        bucket = storage_client.bucket(bucket_name)
        blob = bucket.blob(source_gcs_obj)
        blob.download_to_filename(local_filename)
        
        filehandler = open(f'{local_filename}', 'rb')
        loaded_dict = pkl.load(filehandler)
        filehandler.close()
        
        logging.info(f"File {local_filename} downloaded from gs://{bucket_name}/{source_gcs_obj}")
        
        return loaded_dict
    
    # ===================================================
    # load pickled Candidate features
    # ===================================================
    
    # candidate features
    CAND_FEAT_FILENAME = 'candidate_feats_dict.pkl'
    CAND_FEAT_GCS_OBJ = f'{FEATURES_PREFIX}/{CAND_FEAT_FILENAME}'
    LOADED_CANDIDATE_DICT = f'loaded_{CAND_FEAT_FILENAME}'
    
    loaded_candidate_features_dict = download_blob(
        train_output_gcs_bucket,
        CAND_FEAT_GCS_OBJ,
        LOADED_CANDIDATE_DICT
    )
    
    # ====================================================
    # Features and Helper Functions
    # ====================================================
    
    def parse_candidate_tfrecord_fn(example):
        """
        Reads candidate serialized examples from gcs and converts to tfrecord
        """
        # example = tf.io.parse_single_example(
        example = tf.io.parse_example(
            example, 
            features=loaded_candidate_features_dict
        )
        return example

    def full_parse(data):
        # used for interleave - takes tensors and returns a tf.dataset
        data = tf.data.TFRecordDataset(data)
        return data
    
    # ====================================================
    # Create Candidate Dataset
    # ====================================================

    candidate_files = []
    for blob in storage_client.list_blobs(f"{candidate_file_dir_bucket}", prefix=f'{candidate_file_dir_prefix}/'):
        if '.tfrecords' in blob.name:
            candidate_files.append(blob.public_url.replace("https://storage.googleapis.com/", "gs://"))

    candidate_dataset = tf.data.Dataset.from_tensor_slices(candidate_files)

    parsed_candidate_dataset = candidate_dataset.interleave(
        # lambda x: tf.data.TFRecordDataset(x),
        full_parse,
        cycle_length=tf.data.AUTOTUNE, 
        num_parallel_calls=tf.data.AUTOTUNE,
        deterministic=False
    ).map(parse_candidate_tfrecord_fn, num_parallel_calls=tf.data.AUTOTUNE).with_options(options)

    parsed_candidate_dataset = parsed_candidate_dataset.cache() #400 MB on machine mem
    
    # ====================================================
    # Generate embedding vectors for each candidate
    # ====================================================
    logging.info("Starting candidate dataset mapping...")
    
    start_time = time.time()
    
    embs_iter = parsed_candidate_dataset.batch(10000).map(
        lambda data: (
            data["track_uri_can"],
            loaded_candidate_model(data)
        )
    )
    
    embs = []
    for emb in embs_iter:
        embs.append(emb)

    end_time = time.time()
    elapsed_time = int((end_time - start_time) / 60)
    logging.info(f"elapsed_time   : {elapsed_time}")
    logging.info(f"Length of embs : {len(embs)}")
    logging.info(f"embeddings[0]  : {embs[0]}")
    
    # ====================================================
    # prep Track IDs and Vectors for JSON
    # ====================================================
    logging.info("Cleaning embeddings and track IDs...")
    start_time = time.time()
    
    # cleaned_embs = [x['output_1'].numpy()[0] for x in embs] #clean up the output
    
    cleaned_embs = []
    track_uris = []
    
    for ids , embedding in embs:
        cleaned_embs.extend(embedding.numpy())
        track_uris.extend(ids.numpy())
    
    end_time = time.time()
    elapsed_time = int((end_time - start_time) / 60)
    logging.info(f"elapsed_time           : {elapsed_time}")
    logging.info(f"Length of cleaned_embs : {len(cleaned_embs)}")
    logging.info(f"Length of track_uris: {len(track_uris)}")
    
    track_uris_decoded = [z.decode("utf-8") for z in track_uris]
    logging.info(f"Length of track_uris decoded: {len(track_uris_decoded)}")
    logging.info(f"track_uris_decoded[0]       : {track_uris_decoded[0]}")
    
    # check for bad records
    bad_records = []

    for i, emb in enumerate(cleaned_embs):
        bool_emb = np.isnan(emb)
        for val in bool_emb:
            if val:
                bad_records.append(i)

    bad_record_filter = np.unique(bad_records)

    logging.info(f"bad_records: {len(bad_records)}")
    logging.info(f"bad_record_filter: {len(bad_record_filter)}")
    
    # ZIP together
    logging.info("Zipping IDs and vectors ...")
    
    track_uris_valid = []
    emb_valid = []

    for i, pair in enumerate(zip(track_uris_decoded, cleaned_embs)):
        if i in bad_record_filter:
            pass
        else:
            t_uri, embed = pair
            track_uris_valid.append(t_uri)
            emb_valid.append(embed)
            
    logging.info(f"track_uris_valid[0]: {track_uris_valid[0]}")
    logging.info(f"bad_records: {len(bad_records)}")
            
    # ====================================================
    # writting JSON file to GCS
    # ====================================================
    TIMESTAMP = time.strftime("%Y%m%d-%H%M%S")
    embeddings_index_filename = f'candidate_embs.json'

    with open(f'{embeddings_index_filename}', 'w') as f:
        for prod, emb in zip(track_uris_valid, emb_valid):
            f.write('{"id":"' + str(prod) + '",')
            f.write('"embedding":[' + ",".join(str(x) for x in list(emb)) + "]}")
            f.write("\n")
            
    # write to GCS
    INDEX_GCS_URI = f'{experiment_run_dir}/candidate-embeddings-{TIMESTAMP}'
    logging.info(f"INDEX_GCS_URI: {INDEX_GCS_URI}")

    DESTINATION_BLOB_NAME = embeddings_index_filename
    SOURCE_FILE_NAME = embeddings_index_filename

    logging.info(f"DESTINATION_BLOB_NAME: {DESTINATION_BLOB_NAME}")
    logging.info(f"SOURCE_FILE_NAME: {SOURCE_FILE_NAME}")
    
    blob = Blob.from_string(os.path.join(INDEX_GCS_URI, DESTINATION_BLOB_NAME))
    blob.bucket._client = storage_client
    blob.upload_from_filename(SOURCE_FILE_NAME)
    
    return (
        f'{INDEX_GCS_URI}',
        # f'{INDEX_GCS_URI}',
    )

Writing src/train_pipes/generate_candidates.py


## Create ANN Index

In [15]:
%%writefile {REPO_SRC}/{PIPELINES_SUB_DIR}/create_ann_index.py

import kfp
from typing import Any, Callable, Dict, NamedTuple, Optional, List
from kfp.v2.dsl import (
    Artifact, Dataset, Input, InputPath, 
    Model, Output, OutputPath, component, Metrics
)
@kfp.v2.dsl.component(
    base_image="python:3.9",
    packages_to_install=[
        'google-cloud-aiplatform==1.26.1',
        'google-api-core==2.11.0'
        # 'google-cloud-storage',
    ],
)
def create_ann_index(
    project: str,
    location: str,
    version: str, 
    vpc_network_name: str,
    emb_index_gcs_uri: str,
    dimensions: int,
    ann_index_display_name: str,
    approximate_neighbors_count: int,
    distance_measure_type: str,
    leaf_node_embedding_count: int,
    leaf_nodes_to_search_percent: int, 
    ann_index_description: str,
    # ann_index_labels: Dict, 
) -> NamedTuple('Outputs', [
    ('ann_index_resource_uri', str),
    ('ann_index', Artifact),
]):
    import logging
    from google.cloud import aiplatform as vertex_ai
    from datetime import datetime
    import time

    vertex_ai.init(
        project=project,
        location=location,
    )
    
    VERSION = version.replace('_', '-')
    
    TIMESTAMP = datetime.now().strftime("%Y%m%d%H%M%S")
    
    ENDPOINT = "{}-aiplatform.googleapis.com".format(location)
    NETWORK_NAME = vpc_network_name
    INDEX_DIR_GCS = emb_index_gcs_uri
    PARENT = "projects/{}/locations/{}".format(project, location)

    logging.info(f"ENDPOINT: {ENDPOINT}")
    logging.info(f"project: {project}")
    logging.info(f"location: {location}")
    logging.info(f"INDEX_DIR_GCS: {INDEX_DIR_GCS}")
    
    display_name = f'{ann_index_display_name}-{VERSION}'
    
    logging.info(f"display_name: {display_name}")
    
    # ==============================================================================
    # Create Index 
    # ==============================================================================

    start = time.time()
        
    tree_ah_index = vertex_ai.MatchingEngineIndex.create_tree_ah_index(
        display_name=display_name,
        contents_delta_uri=f'{emb_index_gcs_uri}', # emb_index_gcs_uri,
        dimensions=dimensions,
        approximate_neighbors_count=approximate_neighbors_count,
        distance_measure_type=distance_measure_type,
        leaf_node_embedding_count=leaf_node_embedding_count,
        leaf_nodes_to_search_percent=leaf_nodes_to_search_percent,
        description=ann_index_description,
        # labels=ann_index_labels,
        sync=True,
    )

    end = time.time()
    elapsed_time = round((end - start), 2)
    logging.info(f'Elapsed time creating index: {elapsed_time} seconds\n')
    
    ann_index_resource_uri = tree_ah_index.resource_name
    logging.info("ann_index_resource_uri:", ann_index_resource_uri) 

    return (
      f'{ann_index_resource_uri}',
      tree_ah_index,
    )

Writing src/train_pipes/create_ann_index.py


## Create brute force index

In [16]:
%%writefile {REPO_SRC}/{PIPELINES_SUB_DIR}/create_brute_force_index.py

import kfp
from typing import Any, Callable, Dict, NamedTuple, Optional, List
from kfp.v2.dsl import (
    Artifact, Dataset, Input, InputPath, 
    Model, Output, OutputPath, component, Metrics
)
@kfp.v2.dsl.component(
    base_image="python:3.9",
    packages_to_install=[
        'google-cloud-aiplatform==1.26.1',
        'google-api-core==2.11.0',
        # 'google-cloud-storage',
    ],
)
def create_brute_force_index(
    project: str,
    location: str,
    version: str,
    vpc_network_name: str,
    emb_index_gcs_uri: str,
    dimensions: int,
    brute_force_index_display_name: str,
    approximate_neighbors_count: int,
    distance_measure_type: str,
    brute_force_index_description: str,
    # brute_force_index_labels: Dict,
) -> NamedTuple('Outputs', [
    ('brute_force_index_resource_uri', str),
    ('brute_force_index', Artifact),
]):

    import logging
    from google.cloud import aiplatform as vertex_ai
    from datetime import datetime
    import time

    vertex_ai.init(
        project=project,
        location=location,
    )
    
    TIMESTAMP = datetime.now().strftime("%Y%m%d%H%M%S")
    VERSION = version.replace('_', '-')
    
    ENDPOINT = "{}-aiplatform.googleapis.com".format(location)
    NETWORK_NAME = vpc_network_name
    INDEX_DIR_GCS = emb_index_gcs_uri
    PARENT = "projects/{}/locations/{}".format(project, location)

    logging.info("ENDPOINT: {}".format(ENDPOINT))
    logging.info("PROJECT_ID: {}".format(project))
    logging.info("REGION: {}".format(location))
    
    display_name = f'{brute_force_index_display_name}_{VERSION}'
    
    logging.info(f"display_name: {display_name}")
    
    # ==============================================================================
    # Create Index 
    # ==============================================================================

    start = time.time()
    
    brute_force_index = vertex_ai.MatchingEngineIndex.create_brute_force_index(
        display_name=display_name,
        contents_delta_uri=f'{emb_index_gcs_uri}', # emb_index_gcs_uri,
        dimensions=dimensions,
        # approximate_neighbors_count=approximate_neighbors_count,
        distance_measure_type=distance_measure_type,
        description=brute_force_index_description,
        # labels=brute_force_index_labels,
        sync=True,
    )
    brute_force_index_resource_uri = brute_force_index.resource_name
    print("brute_force_index_resource_uri:",brute_force_index_resource_uri) 

    return (
      f'{brute_force_index_resource_uri}',
      brute_force_index,
    )

Writing src/train_pipes/create_brute_force_index.py


## Create ANN index endpoint

In [17]:
%%writefile {REPO_SRC}/{PIPELINES_SUB_DIR}/create_ann_index_endpoint_vpc.py

import kfp
from typing import Any, Callable, Dict, NamedTuple, Optional, List
from kfp.v2.dsl import (
    Artifact, Dataset, Input, InputPath, 
    Model, Output, OutputPath, component, Metrics
)
@kfp.v2.dsl.component(
    base_image="python:3.9",
    packages_to_install=[
        'google-cloud-aiplatform==1.26.1',
        'google-api-core==2.11.0',
    ],
)
def create_ann_index_endpoint_vpc(
    ann_index_artifact: Input[Artifact],
    project: str,
    project_number: str,
    location: str,
    version: str,
    vpc_network_name: str,
    ann_index_endpoint_display_name: str,
    ann_index_endpoint_description: str,
    ann_index_resource_uri: str,
) -> NamedTuple('Outputs', [
    ('vpc_network_resource_uri', str),
    ('ann_index_endpoint_resource_uri', str),
    ('ann_index_endpoint', Artifact),
    ('ann_index_endpoint_display_name', str),
    ('ann_index_resource_uri', str),
]):

    import logging
    from google.cloud import aiplatform as vertex_ai
    from datetime import datetime
    import time

    vertex_ai.init(
        project=project,
        location=location,
    )
    
    TIMESTAMP = datetime.now().strftime("%Y%m%d%H%M%S")

    vpc_network_resource_uri = f'projects/{project_number}/global/networks/{vpc_network_name}'
    logging.info(f"vpc_network_resource_uri: {vpc_network_resource_uri}")

    ann_index_endpoint = vertex_ai.MatchingEngineIndexEndpoint.create(
        display_name=f'{ann_index_endpoint_display_name}',
        description=ann_index_endpoint_description,
        network=vpc_network_resource_uri,
    )
    ann_index_endpoint_resource_uri = ann_index_endpoint.resource_name
    logging.info(f"ann_index_endpoint_resource_uri: {ann_index_endpoint_resource_uri}")

    return (
        f'{vpc_network_resource_uri}',
        f'{ann_index_endpoint_resource_uri}',
        ann_index_endpoint,
        f'{ann_index_endpoint_display_name}',
        f'{ann_index_resource_uri}',
    )

Writing src/train_pipes/create_ann_index_endpoint_vpc.py


## Create brute force index endpoint

In [18]:
%%writefile {REPO_SRC}/{PIPELINES_SUB_DIR}/create_brute_index_endpoint_vpc.py

import kfp
from typing import Any, Callable, Dict, NamedTuple, Optional, List
from kfp.v2.dsl import (
    Artifact, Dataset, Input, InputPath, 
    Model, Output, OutputPath, component, Metrics
)
@kfp.v2.dsl.component(
    base_image="python:3.9",
    packages_to_install=[
        'google-cloud-aiplatform==1.26.1',
        'google-api-core==2.11.0',
    ],
)
def create_brute_index_endpoint_vpc(
    bf_index_artifact: Input[Artifact],
    project: str,
    project_number: str,
    location: str,
    version: str,
    vpc_network_name: str,
    brute_index_endpoint_display_name: str,
    brute_index_endpoint_description: str,
    brute_force_index_resource_uri: str,
) -> NamedTuple('Outputs', [
    ('vpc_network_resource_uri', str),
    ('brute_index_endpoint_resource_uri', str),
    ('brute_index_endpoint', Artifact),
    ('brute_index_endpoint_display_name', str),
    ('brute_force_index_resource_uri', str),
]):

    import logging
    from google.cloud import aiplatform as vertex_ai
    from datetime import datetime
    import time

    vertex_ai.init(
        project=project,
        location=location,
    )
    
    TIMESTAMP = datetime.now().strftime("%Y%m%d%H%M%S")

    vpc_network_resource_uri = f'projects/{project_number}/global/networks/{vpc_network_name}'
    logging.info(f"vpc_network_resource_uri: {vpc_network_resource_uri}")

    brute_index_endpoint = vertex_ai.MatchingEngineIndexEndpoint.create(
        display_name=f'{brute_index_endpoint_display_name}',
        description=brute_index_endpoint_description,
        network=vpc_network_resource_uri,
    )
    brute_index_endpoint_resource_uri = brute_index_endpoint.resource_name
    logging.info(f"brute_index_endpoint_resource_uri: {brute_index_endpoint_resource_uri}")

    return (
      f'{vpc_network_resource_uri}',
      f'{brute_index_endpoint_resource_uri}',
      brute_index_endpoint,
      f'{brute_index_endpoint_display_name}',
      f'{brute_force_index_resource_uri}',
    )

Writing src/train_pipes/create_brute_index_endpoint_vpc.py


## Deploy ANN Index

In [19]:
%%writefile {REPO_SRC}/{PIPELINES_SUB_DIR}/deploy_ann_index.py

import kfp
from typing import Any, Callable, Dict, NamedTuple, Optional, List
from kfp.v2.dsl import (
    Artifact, Dataset, Input, InputPath, 
    Model, Output, OutputPath, component, Metrics
)
@kfp.v2.dsl.component(
    base_image="python:3.9",
    packages_to_install=[
        'google-cloud-aiplatform==1.26.1',
        'google-api-core==2.11.0',
    ]
)
def deploy_ann_index(
    project: str,
    location: str,
    version: str,
    deployed_ann_index_name: str,
    ann_index_resource_uri: str,
    index_endpoint_resource_uri: str,
) -> NamedTuple('Outputs', [
    ('index_endpoint_resource_uri', str),
    ('ann_index_resource_uri', str),
    ('deployed_ann_index_name', str),
    ('deployed_ann_index', Artifact),
]):
  
    import logging
    from google.cloud import aiplatform as vertex_ai
    from datetime import datetime
    import time

    vertex_ai.init(
        project=project,
        location=location,
    )
    
    # define vars
    TIMESTAMP = datetime.now().strftime("%Y%m%d%H%M%S")
    # deployed_ann_index_name = deployed_ann_index_name.replace('-', '_')
    # logging.info(f"deployed_ann_index_name: {deployed_ann_index_name}")
    
    DEPLOYED_INDEX_NAME = f'{deployed_ann_index_name}-{TIMESTAMP}'
    logging.info(f"DEPLOYED_INDEX_NAME: {DEPLOYED_INDEX_NAME}")
    
    # init index
    ann_index = vertex_ai.MatchingEngineIndex(
      index_name=ann_index_resource_uri
    )
    ann_index_resource_uri = ann_index.resource_name
    logging.info(f"ann_index_resource_uri: {ann_index_resource_uri}")

    # init index endpoint
    index_endpoint = vertex_ai.MatchingEngineIndexEndpoint(
      index_endpoint_resource_uri
    )
    logging.info(f"index_endpoint: {index_endpoint}")

    # deploy index to endpoint
    index_endpoint = index_endpoint.deploy_index(
      index=ann_index, 
      deployed_index_id=DEPLOYED_INDEX_NAME
    )

    logging.info(f"index_endpoint.deployed_indexes: {index_endpoint.deployed_indexes}")
    INDEX_ID = index_endpoint.deployed_indexes[0].id
    logging.info(f"INDEX_ID: {INDEX_ID}")

    return (
      f'{index_endpoint_resource_uri}',
      f'{ann_index_resource_uri}',
      f'{deployed_ann_index_name}',
      ann_index,
    )

Writing src/train_pipes/deploy_ann_index.py


## Deploy brute force Index

In [20]:
%%writefile {REPO_SRC}/{PIPELINES_SUB_DIR}/deploy_brute_index.py

import kfp
from typing import Any, Callable, Dict, NamedTuple, Optional, List
from kfp.v2.dsl import (
    Artifact, Dataset, Input, InputPath, 
    Model, Output, OutputPath, component, Metrics
)
@kfp.v2.dsl.component(
    base_image="python:3.9",
    packages_to_install=[
        'google-cloud-aiplatform==1.26.1',
        'google-api-core==2.11.0',
    ],
)
def deploy_brute_index(
    project: str,
    location: str,
    version: str,
    deployed_brute_force_index_name: str,
    brute_force_index_resource_uri: str,
    index_endpoint_resource_uri: str,
) -> NamedTuple('Outputs', [
    ('index_endpoint_resource_uri', str),
    ('brute_force_index_resource_uri', str),
    ('deployed_brute_force_index_name', str),
    ('deployed_brute_force_index', Artifact),
]):
  
    import logging
    from google.cloud import aiplatform as vertex_ai
    from datetime import datetime
    import time

    vertex_ai.init(
        project=project,
        location=location,
    )
    # define vars
    TIMESTAMP = datetime.now().strftime("%Y%m%d%H%M%S")
    # deployed_brute_force_index_name = deployed_brute_force_index_name.replace('-', '_')
    # logging.info(f"deployed_brute_force_index_name: {deployed_brute_force_index_name}")
    
    DEPLOYED_INDEX_NAME = f'{deployed_brute_force_index_name}-{TIMESTAMP}'
    logging.info(f"DEPLOYED_INDEX_NAME: {DEPLOYED_INDEX_NAME}")

    # init index
    brute_index = vertex_ai.MatchingEngineIndex(
        index_name=brute_force_index_resource_uri
    )
    brute_force_index_resource_uri = brute_index.resource_name
    logging.info(f"brute_force_index_resource_uri: {brute_force_index_resource_uri}")

    # init index endpoint
    index_endpoint = vertex_ai.MatchingEngineIndexEndpoint(index_endpoint_resource_uri)
    logging.info(f"index_endpoint: {index_endpoint}")

    # deploy index to endpoint
    index_endpoint = index_endpoint.deploy_index(
        index=brute_index, 
        deployed_index_id=DEPLOYED_INDEX_NAME
    )

    logging.info(f"index_endpoint.deployed_indexes: {index_endpoint.deployed_indexes}")
    INDEX_ID = index_endpoint.deployed_indexes[0].id
    logging.info(f"INDEX_ID: {INDEX_ID}")

    return (
      f'{index_endpoint_resource_uri}',
      f'{brute_force_index_resource_uri}',
      f'{deployed_brute_force_index_name}', #-{TIMESTAMP}',
      brute_index,
    )

Writing src/train_pipes/deploy_brute_index.py


## Model monitoring job

In [21]:
%%writefile {REPO_SRC}/{PIPELINES_SUB_DIR}/model_monitoring_config.py

import kfp
from typing import Any, Callable, Dict, NamedTuple, Optional, List
from kfp.v2.dsl import (
    Artifact, Dataset, Input, InputPath, 
    Model, Output, OutputPath, component, Metrics
)
@kfp.v2.dsl.component(
    base_image="python:3.9",
    packages_to_install=[
        'google-cloud-aiplatform==1.26.1',
        'google-cloud-pipeline-components',
        'google-cloud-storage',
        'tensorflow==2.11.0',
        'numpy'
    ],
)
def model_monitoring_config(
    project: str,
    location: str,
    version: str,
    prefix: str,
    emails: str,
    train_output_gcs_bucket: str,
    # feature_dict: dict, # TODO
    bq_dataset: str,
    bq_train_table: str,
    experiment_name: str,
    experiment_run: str,
    endpoint: str,
):
    # TODO - imports
    
    import logging
    from datetime import datetime
    import time
    import numpy as np
    import pickle as pkl
    
    logging.getLogger().setLevel(logging.INFO)
    # google cloud SDKs
    from google.cloud import storage
    from google.cloud import aiplatform as vertex_ai
    from google.cloud.aiplatform import model_monitoring
    
    from google.cloud import storage
    from google.cloud.storage.bucket import Bucket
    from google.cloud.storage.blob import Blob
    
    from google.protobuf import json_format
    from google.protobuf.json_format import Parse
    from google.protobuf.struct_pb2 import Value
    from google_cloud_pipeline_components.proto.gcp_resources_pb2 import GcpResources
    
    vertex_ai.init(
        project=project,
        location=location,
    )
    storage_client = storage.Client(project=project)
    
    # ====================================================
    # helper functions
    # ====================================================
    
    def download_blob(bucket_name, source_gcs_obj, local_filename):
        """Uploads a file to the bucket."""
        # storage_client = storage.Client(project=project_number)
        bucket = storage_client.bucket(bucket_name)
        blob = bucket.blob(source_gcs_obj)
        blob.download_to_filename(local_filename)
        
        filehandler = open(f'{local_filename}', 'rb')
        loaded_dict = pkl.load(filehandler)
        filehandler.close()
        
        logging.info(f"File {local_filename} downloaded from gs://{bucket_name}/{source_gcs_obj}")
        
        return loaded_dict
    
    # ====================================================
    # get deployed model endpoint
    # ====================================================
    logging.info(f"Endpoint = {endpoint}")
    gcp_resources = Parse(endpoint, GcpResources())
    logging.info(f"gcp_resources = {gcp_resources}")
    
    _endpoint_resource = gcp_resources.resources[0].resource_uri
    logging.info(f"_endpoint_resource = {_endpoint_resource}")
    
    _endpoint_uri = "/".join(_endpoint_resource.split("/")[-8:-2])
    logging.info(f"_endpoint_uri = {_endpoint_uri}")
    
    # define endpoint resource in component
    _endpoint = vertex_ai.Endpoint(_endpoint_uri)
    logging.info(f"_endpoint defined")
    
    
    USER_EMAILS = [emails]
    alert_config = model_monitoring.EmailAlertConfig(USER_EMAILS, enable_logging=True)
    
    MONITOR_INTERVAL = 1
    schedule_config = model_monitoring.ScheduleConfig(monitor_interval=MONITOR_INTERVAL)
    
    SAMPLE_RATE = 0.8

    logging_sampling_strategy = model_monitoring.RandomSampleConfig(sample_rate=SAMPLE_RATE)
    
    # ===================================================
    # feature dict
    # ===================================================
    QUERY_FILENAME = 'query_feats_dict.pkl'
    # FEATURES_PREFIX = f'{experiment_name}/{experiment_run}/features'
    GCS_PATH_TO_BLOB = f'{experiment_name}/{experiment_run}/features/{QUERY_FILENAME}'
    
    loaded_feat_dict = download_blob(
        bucket_name=train_output_gcs_bucket,
        source_gcs_obj=GCS_PATH_TO_BLOB,
        local_filename=QUERY_FILENAME
    )
    logging.info(f'loaded_feat_dict: {loaded_feat_dict}')
    
    filehandler = open(QUERY_FILENAME, 'rb')
    FEAT_DICT = pkl.load(filehandler)
    filehandler.close()

    
    feature_names = list(FEAT_DICT.keys())

    # =========================== #
    ##   Feature value drift     ##
    # =========================== #
    DRIFT_THRESHOLD_VALUE = 0.05
    ATTRIBUTION_DRIFT_THRESHOLD_VALUE = 0.05
    
    drift_thresholds = dict()

    for feature in feature_names:
        if feature in drift_thresholds:
            print("feature name already in dict")
        else:
            drift_thresholds[feature] = DRIFT_THRESHOLD_VALUE

    logging.info(f"drift_thresholds      : {drift_thresholds}\n")
    
    drift_config = model_monitoring.DriftDetectionConfig(
        drift_thresholds=drift_thresholds,
        # attribute_drift_thresholds=attr_drift_thresholds,
    )

    # =========================== #
    ##   Feature value skew      ##
    # =========================== #
    TRAIN_DATA_SOURCE_URI = f"bq://{project}.{bq_dataset}.{bq_train_table}"
    logging.info(f"TRAIN_DATA_SOURCE_URI = {TRAIN_DATA_SOURCE_URI}")
    
    SKEW_THRESHOLD_VALUE = 0.05
    ATTRIBUTION_SKEW_THRESHOLD_VALUE = 0.05
    
    skew_thresholds = dict()

    for feature in feature_names:
        if feature in skew_thresholds:
            logging.info("feature name already in dict")
        else:
            skew_thresholds[feature] = SKEW_THRESHOLD_VALUE        
    logging.info(f"skew_thresholds      : {skew_thresholds}\n")
    
    # skew config
    skew_config = model_monitoring.SkewDetectionConfig(
        data_source=TRAIN_DATA_SOURCE_URI,
        # data_format = TRAIN_DATA_FORMAT, # only used if source in GCS
        skew_thresholds=skew_thresholds,
        # attribute_skew_thresholds=attribute_skew_thresholds,
        # target_field=TARGET, # no target; embedding model
    )
    
    # ====================================================
    # objective_config
    # ====================================================
    objective_config = model_monitoring.ObjectiveConfig(
        skew_detection_config=skew_config,
        drift_detection_config=drift_config,
        explanation_config=None,
    )
    
    # ====================================================
    # launch monitoring_job
    # ====================================================
    
    JOB_DISPLAY_NAME = f"mm_pipe_{experiment_run}_{prefix}"
    logging.info(f"JOB_DISPLAY_NAME: {JOB_DISPLAY_NAME}")

    monitoring_job = vertex_ai.ModelDeploymentMonitoringJob.create(
        display_name=JOB_DISPLAY_NAME,
        project=project,
        location=location,
        endpoint=_endpoint,
        logging_sampling_strategy=logging_sampling_strategy,
        schedule_config=schedule_config,
        alert_config=alert_config,
        objective_configs=objective_config,
    )
    
    logging.info(f"monitoring_job: {monitoring_job.resource_name}")

Writing src/train_pipes/model_monitoring_config.py


## Test query model endpoint

In [22]:
%%writefile {REPO_SRC}/{PIPELINES_SUB_DIR}/test_model_endpoint.py

import kfp
from typing import Any, Callable, Dict, NamedTuple, Optional, List
from kfp.v2.dsl import (
    Artifact, Dataset, Input, InputPath, 
    Model, Output, OutputPath, component, Metrics
)
@kfp.v2.dsl.component(
    base_image="python:3.9",
    packages_to_install=[
        'google-cloud-aiplatform==1.26.1',
        'google-cloud-pipeline-components',
        'google-cloud-storage',
        'tensorflow==2.11.0',
        'numpy'
    ],
)
def test_model_endpoint(
    project: str,
    location: str,
    version: str,
    train_output_gcs_bucket: str,
    many_test_instances_gcs_filename: str,
    experiment_name: str,
    experiment_run: str,
    endpoint: str, # Input[Artifact],
    # feature_dict: dict,
    # metrics: Output[Metrics],
):
    
    import logging
    from datetime import datetime
    import time
    import numpy as np
    import pickle as pkl
    
    from google.cloud import aiplatform as vertex_ai
    
    from google.cloud import storage
    from google.cloud.storage.bucket import Bucket
    from google.cloud.storage.blob import Blob
    
    from google.protobuf import json_format
    from google.protobuf.json_format import Parse
    from google.protobuf.struct_pb2 import Value
    from google_cloud_pipeline_components.proto.gcp_resources_pb2 import GcpResources

    import tensorflow as tf

    logging.getLogger().setLevel(logging.INFO)

    vertex_ai.init(
        project=project,
        location=location,
    )
    storage_client = storage.Client(project=project)
    
    # ====================================================
    # helper functions
    # ====================================================
    
    def download_blob(bucket_name, source_gcs_obj, local_filename):
        """Uploads a file to the bucket."""
        # storage_client = storage.Client(project=project_number)
        bucket = storage_client.bucket(bucket_name)
        blob = bucket.blob(source_gcs_obj)
        blob.download_to_filename(local_filename)
        
        filehandler = open(f'{local_filename}', 'rb')
        loaded_dict = pkl.load(filehandler)
        filehandler.close()
        
        logging.info(f"File {local_filename} downloaded from gs://{bucket_name}/{source_gcs_obj}")
        
        return loaded_dict
    
    # ===================================================
    # load test instance
    # ===================================================
    LOCAL_INSTANCE_FILE = 'test_instance_list.pkl'
    GCS_PATH_TO_BLOB = f'{experiment_name}/{experiment_run}/{many_test_instances_gcs_filename}'
    LOADED_TEST_LIST = f'loaded_{LOCAL_INSTANCE_FILE}'
    
    loaded_test_instance = download_blob(
        bucket_name=train_output_gcs_bucket,
        source_gcs_obj=GCS_PATH_TO_BLOB,
        local_filename=LOADED_TEST_LIST
    )
    logging.info(f'loaded_test_instance: {loaded_test_instance}')
    
    filehandler = open(LOADED_TEST_LIST, 'rb')
    LIST_OF_DICTS = pkl.load(filehandler)
    filehandler.close()
    
    logging.info(f'len(LIST_OF_DICTS): {len(LIST_OF_DICTS)}')
    
    # LIST_OF_DICTS[200]
    
    # ====================================================
    # get deployed model endpoint
    # ====================================================
    logging.info(f"Endpoint = {endpoint}")
    gcp_resources = Parse(endpoint, GcpResources())
    logging.info(f"gcp_resources = {gcp_resources}")
    
    _endpoint_resource = gcp_resources.resources[0].resource_uri
    logging.info(f"_endpoint_resource = {_endpoint_resource}")
    
    _endpoint_uri = "/".join(_endpoint_resource.split("/")[-8:-2])
    logging.info(f"_endpoint_uri = {_endpoint_uri}")
    
    # define endpoint resource in component
    _endpoint = vertex_ai.Endpoint(_endpoint_uri)
    logging.info(f"_endpoint defined")
    
    
    # ====================================================
    # Send predictions
    # ====================================================
    # TOTAL_ROUNDS = 4
    SLEEP_SECONDS = 2 
    START=1
    END=4

    logging.info(f"testing online endpoint for {END} rounds")
    
    for i in range(START, END+1):
        
        count = 0

        for test in LIST_OF_DICTS:
            response = _endpoint.predict(instances=[test])

            if count > 0 and count % 250 == 0:
                logging.info(f"{count} prediciton requests..")

            count += 1
            
        logging.info(f"finsihed round {i} of {END}")
        time.sleep(SLEEP_SECONDS)
        
    logging.info(f"endpoint test complete - {count} predictions sent")

Writing src/train_pipes/test_model_endpoint.py


## Send skewed traffic

In [23]:
%%writefile {REPO_SRC}/{PIPELINES_SUB_DIR}/send_skewed_traffic.py

import kfp
from typing import Any, Callable, Dict, NamedTuple, Optional, List
from kfp.v2.dsl import (
    Artifact, Dataset, Input, InputPath, 
    Model, Output, OutputPath, component, Metrics
)
@kfp.v2.dsl.component(
    base_image="python:3.9",
    packages_to_install=[
        'google-cloud-aiplatform==1.26.1',
        'google-cloud-pipeline-components',
        'google-cloud-storage',
        'tensorflow==2.11.0',
        'numpy'
    ],
)
def send_skewed_traffic(
    project: str,
    location: str,
    version: str,
    train_output_gcs_bucket: str,
    experiment_name: str,
    experiment_run: str,
    endpoint: str, # Input[Artifact],
    # feature_dict: dict,
    # metrics: Output[Metrics],
):
    
    import logging
    from datetime import datetime
    import time
    import numpy as np
    import pickle as pkl
    
    from google.cloud import aiplatform as vertex_ai
    
    from google.cloud import storage
    from google.cloud.storage.bucket import Bucket
    from google.cloud.storage.blob import Blob
    
    from google.protobuf import json_format
    from google.protobuf.json_format import Parse
    from google.protobuf.struct_pb2 import Value
    from google_cloud_pipeline_components.proto.gcp_resources_pb2 import GcpResources

    import tensorflow as tf

    logging.getLogger().setLevel(logging.INFO)

    vertex_ai.init(
        project=project,
        location=location,
    )
    storage_client = storage.Client(project=project)
    
    # ====================================================
    # helper functions
    # ====================================================
    def download_blob(bucket_name, source_gcs_obj, local_filename):
        """Uploads a file to the bucket."""
        # storage_client = storage.Client(project=project_number)
        bucket = storage_client.bucket(bucket_name)
        blob = bucket.blob(source_gcs_obj)
        blob.download_to_filename(local_filename)
        
        filehandler = open(f'{local_filename}', 'rb')
        loaded_dict = pkl.load(filehandler)
        filehandler.close()
        
        logging.info(f"File {local_filename} downloaded from gs://{bucket_name}/{source_gcs_obj}")
        
        return loaded_dict
    
    # ====================================================
    # get deployed model endpoint
    # ====================================================
    logging.info(f"Endpoint = {endpoint}")
    gcp_resources = Parse(endpoint, GcpResources())
    logging.info(f"gcp_resources = {gcp_resources}")
    
    _endpoint_resource = gcp_resources.resources[0].resource_uri
    logging.info(f"_endpoint_resource = {_endpoint_resource}")
    
    _endpoint_uri = "/".join(_endpoint_resource.split("/")[-8:-2])
    logging.info(f"_endpoint_uri = {_endpoint_uri}")
    
    # define endpoint resource in component
    _endpoint = vertex_ai.Endpoint(_endpoint_uri)
    logging.info(f"_endpoint defined")
    
    # ===================================================
    # load test instance
    # ===================================================
    LOCAL_INSTANCE_FILE = 'test_instance_list.pkl'
    GCS_PATH_TO_BLOB = f'{experiment_name}/{experiment_run}/{many_test_instances_gcs_filename}'
    LOADED_TEST_LIST = f'loaded_{LOCAL_INSTANCE_FILE}'
    
    loaded_test_instance = download_blob(
        bucket_name=train_output_gcs_bucket,
        source_gcs_obj=GCS_PATH_TO_BLOB,
        local_filename=LOADED_TEST_LIST
    )
    logging.info(f'loaded_test_instance: {loaded_test_instance}')
    
    filehandler = open(LOADED_TEST_LIST, 'rb')
    LIST_OF_DICTS = pkl.load(filehandler)
    filehandler.close()
    
    logging.info(f'len(LIST_OF_DICTS): {len(LIST_OF_DICTS)}')
    
    # ====================================================
    # load skew features stats
    # ====================================================
    SKEW_FEATURES_STATS_FILE = 'skew_feat_stats.pkl'
    GCS_PATH_TO_BLOB = f'{experiment_name}/{experiment_run}/{SKEW_FEATURES_STATS_FILE}'
    LOADED_SKEW_FEATURES_STATS_FILE = f"loaded_{SKEW_FEATURES_STATS_FILE}"
    logging.info(f'loading: {LOADED_SKEW_FEATURES_STATS_FILE}')
    
    loaded_skew_test_instance = download_blob(
        bucket_name=train_output_gcs_bucket,
        source_gcs_obj=GCS_PATH_TO_BLOB,
        local_filename=LOADED_SKEW_FEATURES_STATS_FILE
    )
    logging.info(f'loaded_skew_test_instance: {loaded_skew_test_instance}')
    
    filehandler_v2 = open(LOADED_SKEW_FEATURES_STATS_FILE, 'rb')
    SKEW_FEATURES = pkl.load(filehandler_v2)
    filehandler_v2.close()
    
    mean_durations, std_durations = SKEW_FEATURES['pl_duration_ms_new']
    mean_num_songs, std_num_songs = SKEW_FEATURES['num_pl_songs_new']
    mean_num_artists, std_num_artists = SKEW_FEATURES['num_pl_artists_new']
    mean_num_albums, std_num_albums = SKEW_FEATURES['num_pl_albums_new']
    
    logging.info(f"std_durations   : {round(std_durations, 0)}")
    logging.info(f"std_num_songs   : {round(std_num_songs, 0)}")
    logging.info(f"std_num_artists : {round(std_num_artists, 0)}")
    logging.info(f"std_num_albums  : {round(std_num_albums, 0)}\n")
    
    def monitoring_test(endpoint, instances, skew_feat_stat, start=2, end=4):

        mean_durations, std_durations = skew_feat_stat['pl_duration_ms_new']
        mean_num_songs, std_num_songs = skew_feat_stat['num_pl_songs_new']
        mean_num_artists, std_num_artists = skew_feat_stat['num_pl_artists_new']
        mean_num_albums, std_num_albums = skew_feat_stat['num_pl_albums_new']
        
        logging.info(f"std_durations   : {round(std_durations, 0)}")
        logging.info(f"std_num_songs   : {round(std_num_songs, 0)}")
        logging.info(f"std_num_artists : {round(std_num_artists, 0)}")
        logging.info(f"std_num_albums  : {round(std_num_albums, 0)}\n")

        total_preds = 0

        for multiplier in range(start, end+1):

            print(f"multiplier: {multiplier}")

            pred_count = 0

            for example in instances:
                list_dict = {}

                example['pl_duration_ms_new'] = round(std_durations * multiplier, 0)
                example['num_pl_songs_new'] = round(std_num_songs * multiplier, 0)
                example['num_pl_artists_new'] = round(std_num_artists * multiplier, 0)
                example['num_pl_albums_new'] = round(std_num_albums * multiplier, 0)
                # list_of_skewed_instances.append(example)

                response = endpoint.predict(instances=[example])

                if pred_count > 0 and pred_count % 250 == 0:
                    print(f"pred_count: {pred_count}")

                pred_count += 1
                total_preds += 1

            logging.info(f"sent {pred_count} pred requests with {multiplier}X multiplier")

        logging.info(f"sent {total_preds} total pred requests")
        
    # send skewed traffic
    monitoring_test(
        endpoint=_endpoint, 
        instances=LIST_OF_DICTS,
        skew_feat_stat=SKEW_FEATURES,
        start=2, 
        end=8
    )

Writing src/train_pipes/send_skewed_traffic.py


## Test index recall

In [24]:
%%writefile {REPO_SRC}/{PIPELINES_SUB_DIR}/test_model_index_endpoint.py

import kfp
from typing import Any, Callable, Dict, NamedTuple, Optional, List
from kfp.v2.dsl import (
    Artifact, Dataset, Input, InputPath, 
    Model, Output, OutputPath, component, Metrics
)
@kfp.v2.dsl.component(
    base_image="python:3.9",
    packages_to_install=[
        'google-cloud-aiplatform==1.26.1',
        'google-cloud-pipeline-components',
        'google-cloud-storage',
        'tensorflow==2.11.0',
        'numpy'
    ],
)
def test_model_index_endpoint(
    project: str,
    location: str,
    version: str,
    train_output_gcs_bucket: str,
    test_instances_gcs_filename: str,
    experiment_name: str,
    experiment_run: str,
    # train_dir: str,
    # train_dir_prefix: str,
    # ann_index_resource_uri: str,
    ann_index_endpoint_resource_uri: str,
    brute_index_endpoint_resource_uri: str,
    gcs_train_script_path: str,
    endpoint: str, # Input[Artifact],
    metrics: Output[Metrics],
):
    
    import logging
    from datetime import datetime
    import time
    import numpy as np
    import pickle as pkl
    
    import base64

    from typing import Dict, List, Union

    from google.cloud import aiplatform as vertex_ai
    
    from google.protobuf import json_format
    from google.protobuf.json_format import Parse
    from google.protobuf.struct_pb2 import Value

    from google_cloud_pipeline_components.proto.gcp_resources_pb2 import GcpResources
    
    from google.cloud import storage
    from google.cloud.storage.bucket import Bucket
    from google.cloud.storage.blob import Blob

    import tensorflow as tf

    logging.getLogger().setLevel(logging.INFO)

    vertex_ai.init(
        project=project,
        location=location,
    )
    storage_client = storage.Client(project=project)
    
    # ====================================================
    # helper functions
    # ====================================================
    
    def download_blob(bucket_name, source_gcs_obj, local_filename):
        """Uploads a file to the bucket."""
        # storage_client = storage.Client(project=project_number)
        bucket = storage_client.bucket(bucket_name)
        blob = bucket.blob(source_gcs_obj)
        blob.download_to_filename(local_filename)
        
        filehandler = open(f'{local_filename}', 'rb')
        loaded_dict = pkl.load(filehandler)
        filehandler.close()
        
        logging.info(f"File {local_filename} downloaded from gs://{bucket_name}/{source_gcs_obj}")
        
        return loaded_dict
    
    # ====================================================
    # get deployed model endpoint
    # ====================================================
    logging.info(f"Endpoint = {endpoint}")
    gcp_resources = Parse(endpoint, GcpResources())
    logging.info(f"gcp_resources = {gcp_resources}")
    
    _endpoint_resource = gcp_resources.resources[0].resource_uri
    logging.info(f"_endpoint_resource = {_endpoint_resource}")
    
    _endpoint_uri = "/".join(_endpoint_resource.split("/")[-8:-2])
    logging.info(f"_endpoint_uri = {_endpoint_uri}")
    
    # define endpoint resource in component
    _endpoint = vertex_ai.Endpoint(_endpoint_uri)
    logging.info(f"_endpoint defined")
    
    # ==============================================================
    # helper function for returning endpoint predictions via json
    # ==============================================================
    
    def predict_custom_trained_model_sample(
        project: str,
        endpoint_id: str,
        instances: Dict,
        location: str = "us-central1",
        api_endpoint: str = "us-central1-aiplatform.googleapis.com",
    ):
        """
        either single instance of type dict or a list of instances.
        This client only needs to be created once, and can be reused for multiple requests.
        """

        # The AI Platform services require regional API endpoints.
        client_options = {"api_endpoint": api_endpoint}
        
        # Initialize client that will be used to create and send requests.
        client = vertex_ai.gapic.PredictionServiceClient(client_options=client_options)
        
        # The format of each instance should conform to the deployed model's prediction input schema.
        instances = instances if type(instances) == list else [instances]
        instances = [
            json_format.ParseDict(instance_dict, Value()) for instance_dict in instances
        ]
        
        parameters_dict = {}
        parameters = json_format.ParseDict(parameters_dict, Value())
        
        endpoint = client.endpoint_path(
            project=project, location=location, endpoint=endpoint_id
        )
        
        response = client.predict(
            endpoint=endpoint, instances=instances, parameters=parameters
        )
        logging.info(f'Response: {response}')
        logging.info(f'Deployed Model ID(s): {response.deployed_model_id}')
        # The predictions are a google.protobuf.Value representation of the model's predictions.
        _predictions = response.predictions
        logging.info(f'Response Predictions: {_predictions}')
        
        return _predictions
    
    # ===================================================
    # load test instance
    # ===================================================
    LOCAL_TEST_INSTANCE = 'test_instances_dict.pkl'
    GCS_PATH_TO_BLOB = f'{experiment_name}/{experiment_run}/{test_instances_gcs_filename}'
    LOADED_CANDIDATE_DICT = f'loaded_{LOCAL_TEST_INSTANCE}'
    
    loaded_test_instance = download_blob(
        bucket_name=train_output_gcs_bucket,
        source_gcs_obj=GCS_PATH_TO_BLOB,
        local_filename=LOADED_CANDIDATE_DICT
    )
    logging.info(f'loaded_test_instance: {loaded_test_instance}')
    
    # make prediction request
    _endpoint_id = _endpoint_uri.split('/')[-1]
    logging.info(f"_endpoint_id created = {_endpoint_id}")
    prediction_test = predict_custom_trained_model_sample(
        project=project,                     
        endpoint_id=_endpoint_id,
        location="us-central1",
        instances=loaded_test_instance
    )
    
    # ===================================================
    # Matching Engine
    # ===================================================
    logging.info(f"ann_index_endpoint_resource_uri: {ann_index_endpoint_resource_uri}")
    logging.info(f"brute_index_endpoint_resource_uri: {brute_index_endpoint_resource_uri}")

    deployed_ann_index = vertex_ai.MatchingEngineIndexEndpoint(ann_index_endpoint_resource_uri)
    deployed_bf_index = vertex_ai.MatchingEngineIndexEndpoint(brute_index_endpoint_resource_uri)

    DEPLOYED_ANN_ID = deployed_ann_index.deployed_indexes[0].id
    DEPLOYED_BF_ID = deployed_bf_index.deployed_indexes[0].id
    logging.info(f"DEPLOYED_ANN_ID: {DEPLOYED_ANN_ID}")
    logging.info(f"DEPLOYED_BF_ID: {DEPLOYED_BF_ID}")
    
    logging.info('Retreiving neighbors from ANN index...')
    
    start = time.time()
    ANN_response = deployed_ann_index.match(
        deployed_index_id=DEPLOYED_ANN_ID,
        queries=prediction_test,
        num_neighbors=10
    )
    elapsed_ann_time = time.time() - start
    elapsed_ann_time = round(elapsed_ann_time, 4)
    logging.info(f'ANN latency: {elapsed_ann_time} seconds')
    
    logging.info('Retreiving neighbors from BF index...')
    
    start = time.time()
    BF_response = deployed_bf_index.match(
        deployed_index_id=DEPLOYED_BF_ID,
        queries=prediction_test,
        num_neighbors=10
    )
    
    elapsed_bf_time = time.time() - start
    elapsed_bf_time = round(elapsed_bf_time, 4)
    logging.info(f'Bruteforce latency: {elapsed_bf_time} seconds')
    
    # =========================================================
    # Calculate recall by determining how many neighbors 
    # correctly retrieved as compared to the brute-force option
    # =========================================================
    recalled_neighbors = 0
    for tree_ah_neighbors, brute_force_neighbors in zip(
        ANN_response, BF_response
    ):
        tree_ah_neighbor_ids = [neighbor.id for neighbor in tree_ah_neighbors]
        brute_force_neighbor_ids = [neighbor.id for neighbor in brute_force_neighbors]

        recalled_neighbors += len(
            set(tree_ah_neighbor_ids).intersection(brute_force_neighbor_ids)
        )

    recall = recalled_neighbors / len(
        [neighbor for neighbors in BF_response for neighbor in neighbors]
    )
    
    # =========================================================
    # Metrics
    # =========================================================
    reduction = (elapsed_bf_time - elapsed_ann_time) / elapsed_bf_time*100.00
    increase  = (elapsed_bf_time - elapsed_ann_time)/elapsed_ann_time*100.00
    faster    = elapsed_bf_time / elapsed_ann_time

    logging.info(f"reduction in time         : {round(reduction, 3)}%")
    logging.info(f"% increase in performance : {round(increase, 3)}%")
    logging.info(f"how many times faster     : {round(faster, 3)}x faster")

    logging.info("Recall: {}".format(recall * 100.0))
    
    metrics.log_metric("Recall", (recall * 100.0))
    # metrics.log_metric("elapsed_query_time", elapsed_query_time)
    metrics.log_metric("elapsed_ann_time", elapsed_ann_time)
    metrics.log_metric("elapsed_bf_time", elapsed_bf_time)
    metrics.log_metric("latency_reduction", reduction)
    metrics.log_metric("perf_increase", increase)
    metrics.log_metric("x_faster", faster)

Writing src/train_pipes/test_model_index_endpoint.py


## Compute config for pipeline steps

In [25]:
%%writefile {REPO_SRC}/{PIPELINES_SUB_DIR}/pipeline_config.py

CPU_LIMIT='96'
MEMORY_LIMIT='624G'

Writing src/train_pipes/pipeline_config.py


# Prepare Job Specs

## Accelerators and Device Strategy

In [26]:
from src.two_tower_jt import train_utils

gpu_dict = train_utils.get_accelerator_config(
    key='t4', 
    worker_machine_type = 'n1-highmem-16',
    reduction_n=0
)

WORKER_MACHINE_TYPE            = gpu_dict['WORKER_MACHINE_TYPE']
REPLICA_COUNT                  = gpu_dict['REPLICA_COUNT']
ACCELERATOR_TYPE               = gpu_dict['ACCELERATOR_TYPE']
PER_MACHINE_ACCELERATOR_COUNT  = gpu_dict['PER_MACHINE_ACCELERATOR_COUNT']
DISTRIBUTE_STRATEGY            = gpu_dict['DISTRIBUTE_STRATEGY']
REDUCTION_SERVER_COUNT         = gpu_dict['REDUCTION_SERVER_COUNT']
REDUCTION_SERVER_MACHINE_TYPE  = gpu_dict['REDUCTION_SERVER_MACHINE_TYPE'] 

WORKER_MACHINE_TYPE            : n1-highmem-16
REPLICA_COUNT                  : 1
ACCELERATOR_TYPE               : NVIDIA_TESLA_T4
PER_MACHINE_ACCELERATOR_COUNT  : 1
DISTRIBUTE_STRATEGY            : single
REDUCTION_SERVER_COUNT         : 0
REDUCTION_SERVER_MACHINE_TYPE  : n1-highcpu-16


## Vertex AI Experiments

In [27]:
EXPERIMENT_PREFIX = 'tfrs-pipe'                     # custom identifier for organizing experiments
EXPERIMENT_NAME=f'{EXPERIMENT_PREFIX}-{VERSION}'
# RUN_NAME = f'run-{time.strftime("%Y%m%d-%H%M%S")}'

RUN_NAME = "run-20230926-120445"

print(f"EXPERIMENT_NAME: {EXPERIMENT_NAME}")
print(f"RUN_NAME: {RUN_NAME}")

EXPERIMENT_NAME: tfrs-pipe-v1
RUN_NAME: run-20230926-120445


## Training Config

* see [src code](https://github.com/googleapis/python-aiplatform/blob/e7bf0d83d8bb0849a9bce886c958d13f5cbe5fab/google/cloud/aiplatform/utils/worker_spec_utils.py#L153) for worker_pool_spec

In [28]:
# =================================================
# trainconfig: gcs locations
# =================================================

# Stores pipeline executions for each run
PIPELINE_ROOT_PATH = f'gs://{BUCKET_NAME}/{EXPERIMENT_NAME}/{RUN_NAME}/pipeline_root'
print('PIPELINE_ROOT_PATH: {}'.format(PIPELINE_ROOT_PATH))

PIPELINE_ROOT_PATH: gs://ndr-v1-hybrid-vertex-bucket/tfrs-pipe-v1/run-20230926-120445/pipeline_root


### Feature lists

In [29]:
# os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
    
import tensorflow as tf

FEATURES_PREFIX = f'{EXPERIMENT_NAME}/{RUN_NAME}/features'
print(f"FEATURES_PREFIX: {FEATURES_PREFIX}")

FEATURES_PREFIX: tfrs-pipe-v1/run-20230926-120445/features


#### candidate features

In [30]:
CANDIDATE_FEATURES_DICT = feature_utils.get_candidate_features()
CANDIDATE_FEATURES_DICT

{'track_uri_can': FixedLenFeature(shape=(), dtype=tf.string, default_value=None),
 'track_name_can': FixedLenFeature(shape=(), dtype=tf.string, default_value=None),
 'artist_uri_can': FixedLenFeature(shape=(), dtype=tf.string, default_value=None),
 'artist_name_can': FixedLenFeature(shape=(), dtype=tf.string, default_value=None),
 'album_uri_can': FixedLenFeature(shape=(), dtype=tf.string, default_value=None),
 'album_name_can': FixedLenFeature(shape=(), dtype=tf.string, default_value=None),
 'duration_ms_can': FixedLenFeature(shape=(), dtype=tf.float32, default_value=None),
 'track_pop_can': FixedLenFeature(shape=(), dtype=tf.float32, default_value=None),
 'artist_pop_can': FixedLenFeature(shape=(), dtype=tf.float32, default_value=None),
 'artist_genres_can': FixedLenFeature(shape=(), dtype=tf.string, default_value=None),
 'artist_followers_can': FixedLenFeature(shape=(), dtype=tf.float32, default_value=None),
 'track_danceability_can': FixedLenFeature(shape=(), dtype=tf.float32, defa

In [31]:
# candidate features
CANDIDATE_FILENAME = 'candidate_feats_dict.pkl'
CANDIDATE_FEATURES_GCS_OBJ = f'{FEATURES_PREFIX}/{CANDIDATE_FILENAME}'

# pickle
filehandler = open(f'{CANDIDATE_FILENAME}', 'wb')
pkl.dump(CANDIDATE_FEATURES_DICT, filehandler)
filehandler.close()

# upload to GCS
bucket_client = storage_client.bucket(BUCKET_NAME)
blob = bucket_client.blob(CANDIDATE_FEATURES_GCS_OBJ)
blob.upload_from_filename(CANDIDATE_FILENAME)

#### query features

In [32]:
# MAX_PLAYLIST_LENGTH=5 # TODO- make consistent with previous notebooks e.g., 5

QUERY_FEATURES_DICT = feature_utils.get_all_features(TRACK_HISTORY, ranker=False)
# QUERY_FEATURES_DICT

In [33]:
# query features
QUERY_FILENAME = 'query_feats_dict.pkl'
QUERY_FEATURES_GCS_OBJ = f'{FEATURES_PREFIX}/{QUERY_FILENAME}'

# pickle
filehandler = open(f'{QUERY_FILENAME}', 'wb')
pkl.dump(QUERY_FEATURES_DICT, filehandler)
filehandler.close()

# upload to GCS
bucket_client = storage_client.bucket(BUCKET_NAME)
blob = bucket_client.blob(QUERY_FEATURES_GCS_OBJ)
blob.upload_from_filename(QUERY_FILENAME)

### test instances

* create test instances pkl
* will copy to pipeline root later

In [34]:
instances = test_instances.TEST_INSTANCE_5
print(f"length of instances: {len(instances)}")

LOCAL_INSTANCES_PKL = "test_instances_5.pkl"

filehandler = open(f'{LOCAL_INSTANCES_PKL}', 'wb')
pkl.dump(instances, filehandler)
filehandler.close()

length of instances: 52


In [35]:
# FEATURE_DICT = feature_utils.get_all_features(TRACK_HISTORY, ranker=False)
# FEATURE_DICT

### train image

In [36]:
# =================================================
# train image
# =================================================
# Existing image URI or name for image to create
# IMAGE_URI = f'gcr.io/hybrid-vertex/sp-2tower-tfrs-trainerv6-tr'
# DOCKERNAME = 'tfrs'

print(f"REMOTE_IMAGE_NAME : {REMOTE_IMAGE_NAME}")
print(f"DOCKERNAME        : {DOCKERNAME}")

REMOTE_IMAGE_NAME : us-central1-docker.pkg.dev/hybrid-vertex/ndr-v1-spotify/train-v1
DOCKERNAME        : tfrs


### train params

In [37]:
SEED = 1234

# =================================================
# trainconfig: GPU related
# =================================================
TF_GPU_THREAD_COUNT  = '8'      # '1' | '4' | '8'

# =================================================
# trainconfig: data input pipeline
# =================================================
BLOCK_LENGTH         = 64            # 1, 8, 16, 32, 64
NUM_DATA_SHARDS      = 4          # 2, 4, 8, 16, 32, 64
# TRAIN_PREFETCH=3

# =================================================
# trainconfig: training hparams
# =================================================
NUM_EPOCHS           = 10
LEARNING_RATE        = 0.01
BATCH_SIZE           = 4096           # 8192, 4096, 2048, 1024, 512 

# dropout
DROPOUT_RATE         = 0.33

# model size
EMBEDDING_DIM        = 128
PROJECTION_DIM       = int(EMBEDDING_DIM / 4) # 50  
LAYER_SIZES          = '[512,256,128]'
MAX_TOKENS           = 20000     # vocab

# =================================================
# trainconfig: train & valid steps
# =================================================
train_sample_cnt     = 8_205_265 # 8_205_265
valid_samples_cnt    = 82_959

# validation & evaluation
VALID_FREQUENCY      = NUM_EPOCHS + 1 #// 3 # 20
VALID_STEPS          = valid_samples_cnt // BATCH_SIZE # 100
EPOCH_STEPS          = train_sample_cnt // BATCH_SIZE

# =================================================
# trainconfig: tensorboard
# =================================================
EMBED_FREQUENCY      = 1
HIST_FREQUENCY       = 0
CHECKPOINT_FREQ      = EPOCH_STEPS // 4 # 'epoch'
UPDATE_FREQ          = EPOCH_STEPS // 4 # 'epoch'

print(f"VALID_FREQUENCY : {VALID_FREQUENCY}")
print(f"VALID_STEPS     : {VALID_STEPS}")
print(f"EPOCH_STEPS     : {EPOCH_STEPS}")
print(f"EMBED_FREQUENCY : {EMBED_FREQUENCY}")
print(f"HIST_FREQUENCY  : {HIST_FREQUENCY}")
print(f"CHECKPOINT_FREQ : {CHECKPOINT_FREQ}")
print(f"UPDATE_FREQ     : {UPDATE_FREQ}")

VALID_FREQUENCY : 11
VALID_STEPS     : 20
EPOCH_STEPS     : 2003
EMBED_FREQUENCY : 1
HIST_FREQUENCY  : 0
CHECKPOINT_FREQ : 500
UPDATE_FREQ     : 500


### data source

**TODO:** update these variables to point to the GCS location where the processed training data is stored 

In [38]:
# =================================================
# trainconfig: Data sources
# =================================================
TRAIN_DIR_PREFIX = f'data/{DATA_VERSION}/valid' # train | valid
VALID_DIR_PREFIX = f'data/{DATA_VERSION}/valid' 
CANDIDATE_PREFIX = f'data/{DATA_VERSION}/candidates' 

# print(f"BUCKET_DATA_DIR: {BUCKET_DATA_DIR}")
print(f"CANDIDATE_PREFIX: {CANDIDATE_PREFIX}")
print(f"TRAIN_DIR_PREFIX: {TRAIN_DIR_PREFIX}")
print(f"VALID_DIR_PREFIX: {VALID_DIR_PREFIX}")

CANDIDATE_PREFIX: data/v1/candidates
TRAIN_DIR_PREFIX: data/v1/valid
VALID_DIR_PREFIX: data/v1/valid


## Gather train args

In [39]:
from util import workerpool_specs

WORKER_CMD = ["python", "-m", "src.two_tower_jt.task"]
# WORKER_CMD = ["python", "-m", "task"]

WORKER_ARGS = [
    f'--project={PROJECT_ID}',
    f'--train_output_gcs_bucket={BUCKET_NAME}',
    f'--train_dir={BUCKET_NAME}',
    f'--train_dir_prefix={TRAIN_DIR_PREFIX}',
    f'--valid_dir={BUCKET_NAME}',
    f'--valid_dir_prefix={VALID_DIR_PREFIX}',
    f'--candidate_file_dir={BUCKET_NAME}',
    f'--candidate_files_prefix={CANDIDATE_PREFIX}',
    f'--experiment_name={EXPERIMENT_NAME}',
    f'--experiment_run={RUN_NAME}',
    f'--num_epochs={NUM_EPOCHS}',
    f'--batch_size={BATCH_SIZE}',
    f'--embedding_dim={EMBEDDING_DIM}',
    f'--projection_dim={PROJECTION_DIM}',
    f'--layer_sizes={LAYER_SIZES}',
    f'--learning_rate={LEARNING_RATE}',
    f'--valid_frequency={VALID_FREQUENCY}',
    f'--valid_steps={VALID_STEPS}',
    f'--epoch_steps={EPOCH_STEPS}',
    f'--distribute={DISTRIBUTE_STRATEGY}',
    f'--model_version={VERSION}',
    f'--pipeline_version={PIPELINE_VERSION}',
    f'--seed={SEED}',
    f'--max_tokens={MAX_TOKENS}',
    # f'--tb_resource_name={TB_RESOURCE_NAME}',
    f'--embed_frequency={EMBED_FREQUENCY}',
    f'--update_frequency={UPDATE_FREQ}',      # TODO - turn on
    f'--hist_frequency={HIST_FREQUENCY}',
    f'--tf_gpu_thread_count={TF_GPU_THREAD_COUNT}',
    f'--block_length={BLOCK_LENGTH}',
    f'--num_data_shards={NUM_DATA_SHARDS}',
    f'--chkpt_freq={CHECKPOINT_FREQ}',
    f'--dropout_rate={DROPOUT_RATE}',
    # uncomment these to pass value of True (bool)
    # f'--cache_train',                              # caches train_dataset
    # f'--evaluate_model',                             # runs model.eval()
    # f'--write_embeddings',                         # writes embeddings index in train job
    f'--profiler',                                   # runs TB profiler
    # f'--set_jit',                                  # enables XLA
    f'--compute_batch_metrics',
    f'--use_cross_layer',
    f'--use_dropout',
]

WORKER_POOL_SPECS = workerpool_specs.prepare_worker_pool_specs(
    image_uri=f"{REMOTE_IMAGE_NAME}:latest",
    args=WORKER_ARGS,
    cmd=WORKER_CMD,
    replica_count=REPLICA_COUNT,
    machine_type=WORKER_MACHINE_TYPE,
    accelerator_count=PER_MACHINE_ACCELERATOR_COUNT,
    accelerator_type=ACCELERATOR_TYPE,
    reduction_server_count=REDUCTION_SERVER_COUNT,
    reduction_server_machine_type=REDUCTION_SERVER_MACHINE_TYPE,
)

from pprint import pprint
pprint(WORKER_POOL_SPECS)

[{'container_spec': {'args': ['--project=hybrid-vertex',
                              '--train_output_gcs_bucket=ndr-v1-hybrid-vertex-bucket',
                              '--train_dir=ndr-v1-hybrid-vertex-bucket',
                              '--train_dir_prefix=data/v1/valid',
                              '--valid_dir=ndr-v1-hybrid-vertex-bucket',
                              '--valid_dir_prefix=data/v1/valid',
                              '--candidate_file_dir=ndr-v1-hybrid-vertex-bucket',
                              '--candidate_files_prefix=data/v1/candidates',
                              '--experiment_name=tfrs-pipe-v1',
                              '--experiment_run=run-20230926-120445',
                              '--num_epochs=10',
                              '--batch_size=4096',
                              '--embedding_dim=128',
                              '--projection_dim=32',
                              '--layer_sizes=[512,256,128]',
                  

In [40]:
!export PWD=pwd
!export PIPELINE_ROOT_PATH=PIPELINE_ROOT_PATH
!export REPO_SRC=REPO_SRC

! echo $PWD
! echo $PIPELINE_ROOT_PATH
! echo $REPO_SRC

/home/jupyter/jw-repo2/spotify_mpd_two_tower
gs://ndr-v1-hybrid-vertex-bucket/tfrs-pipe-v1/run-20230926-120445/pipeline_root
src


### Copy train and deployment files to GCS

In [41]:
# from notebook: 06-deploy-query-tower-monitoring
MANY_TESTS_FILE          = 'test_instance_list.pkl' 
SKEW_FEATURES_STATS_FILE = 'skew_feat_stats.pkl'

BASE_OUTPUT_DIR = f'gs://{BUCKET_NAME}/{EXPERIMENT_NAME}/{RUN_NAME}'
BASE_OUTPUT_DIR

'gs://ndr-v1-hybrid-vertex-bucket/tfrs-pipe-v1/run-20230926-120445'

#### copy training files for tracking

> copy train files to pipeline root

* Cloud Build `yaml`
* Dockerfile
* trainer code

In [42]:
# !gsutil -q cp $REPO_SRC/cloudbuild.yaml $PIPELINE_ROOT_PATH/cloudbuild.yaml
! gsutil -q cp $REPO_SRC/Dockerfile_tfrs $PIPELINE_ROOT_PATH/Dockerfile_tfrs
! gsutil -q -m cp -r $REPO_SRC/two_tower_jt/* $PIPELINE_ROOT_PATH/trainer

# print(f"Copied files to {PIPELINE_ROOT_PATH}")
! gsutil ls $PIPELINE_ROOT_PATH

gs://ndr-v1-hybrid-vertex-bucket/tfrs-pipe-v1/run-20230926-120445/pipeline_root/Dockerfile_tfrs
gs://ndr-v1-hybrid-vertex-bucket/tfrs-pipe-v1/run-20230926-120445/pipeline_root/pipeline_spec.json
gs://ndr-v1-hybrid-vertex-bucket/tfrs-pipe-v1/run-20230926-120445/pipeline_root/934903580331/
gs://ndr-v1-hybrid-vertex-bucket/tfrs-pipe-v1/run-20230926-120445/pipeline_root/trainer/


#### copy deployment files for pipeline use

> copy deployment files to pipeline run base output directory

* vocabulary (vocab)
* test instances
* trainer code

In [43]:
!gsutil -q cp vocab_dict.pkl $BASE_OUTPUT_DIR/vocab_dict.pkl
!gsutil -q cp $LOCAL_INSTANCES_PKL $BASE_OUTPUT_DIR/$LOCAL_INSTANCES_PKL
!gsutil -q cp $BUCKET_URI/endpoint-tests/$MANY_TESTS_FILE $BASE_OUTPUT_DIR/$MANY_TESTS_FILE
!gsutil -q cp $BUCKET_URI/endpoint-tests/$SKEW_FEATURES_STATS_FILE $BASE_OUTPUT_DIR/$SKEW_FEATURES_STATS_FILE

# print(f"Copied files to {BASE_OUTPUT_DIR}")
! gsutil ls $BASE_OUTPUT_DIR

gs://ndr-v1-hybrid-vertex-bucket/tfrs-pipe-v1/run-20230926-120445/skew_feat_stats.pkl
gs://ndr-v1-hybrid-vertex-bucket/tfrs-pipe-v1/run-20230926-120445/test_instance_list.pkl
gs://ndr-v1-hybrid-vertex-bucket/tfrs-pipe-v1/run-20230926-120445/test_instances_5.pkl
gs://ndr-v1-hybrid-vertex-bucket/tfrs-pipe-v1/run-20230926-120445/train_job_dict.pkl
gs://ndr-v1-hybrid-vertex-bucket/tfrs-pipe-v1/run-20230926-120445/vocab_dict.pkl
gs://ndr-v1-hybrid-vertex-bucket/tfrs-pipe-v1/run-20230926-120445/candidate-embeddings-20230926-123713/
gs://ndr-v1-hybrid-vertex-bucket/tfrs-pipe-v1/run-20230926-120445/candidate-embeddings-20230927-113357/
gs://ndr-v1-hybrid-vertex-bucket/tfrs-pipe-v1/run-20230926-120445/features/
gs://ndr-v1-hybrid-vertex-bucket/tfrs-pipe-v1/run-20230926-120445/logs/
gs://ndr-v1-hybrid-vertex-bucket/tfrs-pipe-v1/run-20230926-120445/model-dir/
gs://ndr-v1-hybrid-vertex-bucket/tfrs-pipe-v1/run-20230926-120445/pipeline_root/


# Build & Submit Pipeline

In [44]:
PIPELINE_TAG = f'2tower-{PIPELINE_VERSION}'
print("PIPELINE_TAG:", PIPELINE_TAG)

PIPELINE_NAME = f'tfrs-{VERSION}-{PIPELINE_TAG}'.replace('_', '-')
print("PIPELINE_NAME:", PIPELINE_NAME)

PIPELINE_TAG: 2tower-pipe-v2
PIPELINE_NAME: tfrs-v1-2tower-pipe-v2


## Create pipeline

In [45]:
from src.train_pipes import train_custom_model, create_tensorboard, generate_candidates, \
                            create_ann_index, create_brute_force_index, create_ann_index_endpoint_vpc, \
                            create_brute_index_endpoint_vpc, deploy_ann_index, deploy_brute_index, \
                            test_model_index_endpoint, model_monitoring_config, test_model_endpoint, \
                            send_skewed_traffic

from src.train_pipes import pipeline_config as cfg

@kfp.v2.dsl.pipeline(
    name=f'{PIPELINE_NAME}'.replace('_', '-')
)
def pipeline(
    project: str,
    project_number: str,
    location: str,
    service_account: str,
    model_version: str,
    pipeline_version: str,
    train_image_uri: str,
    train_output_gcs_bucket: str,
    gcs_train_script_path: str,
    model_display_name: str,
    train_dockerfile_name: str,
    train_dir: str,
    train_dir_prefix: str,
    valid_dir: str,
    valid_dir_prefix: str,
    candidate_file_dir: str,
    candidate_files_prefix: str,
    test_instances_gcs_filename: str,
    many_test_instances_gcs_filename: str,
    # tensorboard_resource_name: str,
    experiment_name: str,
    experiment_run: str,
    register_model_flag: str,
    vpc_network_name: str,
    generate_new_vocab: bool,
    max_playlist_length: int,
    max_tokens: int,
    ngrams: int,
    # new
    # feature_dict: dict,
    prefix: str,
    emails: str,
    bq_dataset: str,
    bq_train_table: str,
):
    
    from kfp.v2.components import importer_node
    from google_cloud_pipeline_components.types import artifact_types
            
    # ========================================================================
    # Managed TB
    # ========================================================================
    
    create_managed_tensorboard_op = (
        create_tensorboard.create_tensorboard(
            # here
            project=project,
            location=location,
            model_version=model_version,
            pipeline_version=pipeline_version,
            model_name=model_display_name, 
            experiment_name=experiment_name,
            experiment_run=experiment_run,
        )
        .set_display_name("Managed TB")
        .set_caching_options(True)
    )


    run_train_task_op = (
        train_custom_model.train_custom_model(
            project=project,
            location=location,
            model_version=model_version,
            pipeline_version=pipeline_version,
            model_name=model_display_name,
            worker_pool_specs=WORKER_POOL_SPECS, 
            train_output_gcs_bucket=train_output_gcs_bucket,
            experiment_name=experiment_name,
            experiment_run=experiment_run,
            training_image_uri=train_image_uri,
            tensorboard_resource_name=create_managed_tensorboard_op.outputs['tensorboard_resource_name'], #tensorboard_resource_name, 
            service_account=service_account,
            generate_new_vocab=generate_new_vocab,
        )
        .set_display_name("2Tower Training")
        .set_caching_options(True)
        # .after(build_custom_train_image_op)
    )
    
    # ========================================================================
    # Import trained Query and Candidate Towers to this DAG (metadata)
    # ========================================================================
    
    import_unmanaged_query_model_task = (
        importer_node.importer(
            artifact_uri=run_train_task_op.outputs['query_tower_dir_uri'],
            artifact_class=artifact_types.UnmanagedContainerModel,
            metadata={
                'containerSpec': {
                    'imageUri': 'us-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-11:latest',
                },
            },
        )
        .set_display_name("Import Query Tower")
        .after(run_train_task_op)
        .set_caching_options(True)
    )
    
    import_unmanaged_candidate_model_task = (
        importer_node.importer(
            artifact_uri=run_train_task_op.outputs['candidate_tower_dir_uri'],
            artifact_class=artifact_types.UnmanagedContainerModel,
            metadata={
                'containerSpec': {
                    'imageUri': 'us-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-11:latest',
                },
            },
        )
        .set_display_name("Import Candidate Tower")
        .after(run_train_task_op)
        .set_caching_options(True)
    )
    
    # ========================================================================
    # Conditional: Upload models to Vertex model registry
    # ========================================================================
    # with kfp.v2.dsl.Condition(register_model_flag == "True", name="Register towers"):
        
    # here

    query_model_upload_op = (
        gcc_aip.ModelUploadOp(
            project=project,
            location=location,
            display_name=f'query-tower-{model_display_name}',
            unmanaged_container_model=import_unmanaged_query_model_task.outputs["artifact"],
            labels={"tower": "query"},
        )
        .set_display_name("Upload Query Tower")
        .set_caching_options(True)
    )

    candidate_model_upload_op = (
        gcc_aip.ModelUploadOp(
            project=project,
            location=location,
            display_name=f'candidate-tower-{model_display_name}',
            unmanaged_container_model=import_unmanaged_candidate_model_task.outputs["artifact"],
            labels={"tower": "candidate"},
        )
        .set_display_name("Upload Query Tower to Vertex")
        .set_caching_options(True)
    )

    # ========================================================================
    # Deploy Query Tower to Endpoint
    # ========================================================================
    endpoint_create_op = (
        gcc_aip.EndpointCreateOp(
            project=project,
            display_name=f'query-tower-endpoint-{pipeline_version}'
        )
        .after(query_model_upload_op)
        .set_display_name("Create Query Endpoint")
        .set_caching_options(True)
    )

    model_deploy_op = (
        gcc_aip.ModelDeployOp(
            endpoint=endpoint_create_op.outputs['endpoint'],
            model=query_model_upload_op.outputs['model'],
            deployed_model_display_name=f'deployed-qtower-{pipeline_version}',
            # dedicated_resources_accelerator_type="NVIDIA_TESLA_T4",
            # dedicated_resources_accelerator_count=1,
            # dedicated_resources_max_replica_count=1,
            # dedicated_resources_min_replica_count=1,
            dedicated_resources_machine_type="n1-standard-16",
            dedicated_resources_min_replica_count=1,
            dedicated_resources_max_replica_count=1,
            service_account=service_account,
        )
        .set_display_name("Deploy Query Tower")
        .set_caching_options(True)
    )

    generate_candidates_op = (
        generate_candidates.generate_candidates(
            project=project,
            location=location,
            version=model_version,
            candidate_tower_dir_uri=run_train_task_op.outputs['candidate_tower_dir_uri'],
            train_output_gcs_bucket=train_output_gcs_bucket,
            experiment_name=experiment_name,
            experiment_run=experiment_run,
            candidate_file_dir_bucket=candidate_file_dir,
            candidate_file_dir_prefix=candidate_files_prefix,
            experiment_run_dir=run_train_task_op.outputs['experiment_run_dir']
        )
        .set_display_name("Generate Candidate emb vectors")
        .set_caching_options(True)
        .set_cpu_limit(cfg.CPU_LIMIT)
        .set_memory_limit(cfg.MEMORY_LIMIT)
    )
    
    model_monitoring_config_op = (
        model_monitoring_config.model_monitoring_config(
            project=project,
            location=location,
            version=model_version,
            prefix=prefix,
            emails=emails,
            train_output_gcs_bucket=train_output_gcs_bucket,
            # feature_dict=feature_dict, # TODO
            bq_dataset= bq_dataset,
            bq_train_table=bq_train_table,
            experiment_name=experiment_name,
            experiment_run=experiment_run,
            endpoint=model_deploy_op.outputs['gcp_resources']
        )
        .set_display_name("set Model Monitoring")
        # .after(XXXX)
        .set_caching_options(True)
    )
    
    test_model_endpoint_op = (
        test_model_endpoint.test_model_endpoint(
            project=project,
            location=location,
            version=model_version,
            train_output_gcs_bucket=train_output_gcs_bucket,
            many_test_instances_gcs_filename=many_test_instances_gcs_filename,
            experiment_name=experiment_name,
            experiment_run=experiment_run,
            endpoint=model_deploy_op.outputs['gcp_resources']
        )
        .set_display_name("Test endpoint deployment")
        .after(model_monitoring_config_op)
        .set_caching_options(True)
    )
    
    send_skewed_traffic_op = (
        send_skewed_traffic.send_skewed_traffic(
            project=project,
            location=location,
            version=model_version,
            train_output_gcs_bucket=train_output_gcs_bucket,
            experiment_name=experiment_name,
            experiment_run=experiment_run,
            endpoint=model_deploy_op.outputs['gcp_resources']
        )
        .set_display_name("Send skewed traffic")
        .after(model_monitoring_config_op)
        .set_caching_options(True)
    )

    # ========================================================================
    # Create ME indexes
    # ========================================================================

    create_ann_index_op = (
        create_ann_index.create_ann_index(
            project=project,
            location=location,
            version=model_version,
            vpc_network_name=vpc_network_name,
            emb_index_gcs_uri=generate_candidates_op.outputs['emb_index_gcs_uri'],
            dimensions=128, #TODO: parameterize
            ann_index_display_name=f'ann_index_{pipeline_version}'.replace('-', '_'),
            approximate_neighbors_count=50,
            distance_measure_type="DOT_PRODUCT_DISTANCE",
            leaf_node_embedding_count=500,
            leaf_nodes_to_search_percent=7, 
            ann_index_description="testing ann index for TFRS deployment",
            # ann_index_labels=ann_index_labels,
        )
        .set_display_name("Create ANN Index")
        # .after(XXXX)
        .set_caching_options(True)
    )

    create_brute_force_index_op = (
        create_brute_force_index.create_brute_force_index(
            project=project,
            location=location,
            version=model_version,
            vpc_network_name=vpc_network_name,
            emb_index_gcs_uri=generate_candidates_op.outputs['emb_index_gcs_uri'],
            dimensions=128, #TODO: parameterize
            brute_force_index_display_name=f'bf_index_{pipeline_version}'.replace('-', '_'),
            approximate_neighbors_count=50,
            distance_measure_type="DOT_PRODUCT_DISTANCE",
            brute_force_index_description="testing bf index for TFRS deployment",
            # brute_force_index_labels=brute_force_index_labels,
        )
        .set_display_name("Create BF Index")
        # .after(XXX)
        .set_caching_options(True)
    )

    # ========================================================================
    # Create ME index endpoints
    # ========================================================================

    create_ann_index_endpoint_vpc_op = (
        create_ann_index_endpoint_vpc.create_ann_index_endpoint_vpc(
            ann_index_artifact=create_ann_index_op.outputs['ann_index'],
            project=project,
            project_number=project_number,
            version=model_version,
            location=location,
            vpc_network_name=vpc_network_name,
            ann_index_endpoint_display_name=f'ann-index-endpoint_{pipeline_version}'.replace('-', '_'),
            ann_index_endpoint_description='endpoint for ann index',
            ann_index_resource_uri=create_ann_index_op.outputs['ann_index_resource_uri'],
        )
        .set_display_name("Create ANN Index Endpoint")
        # .after(XXX)
    )

    create_brute_index_endpoint_vpc_op = (
        create_brute_index_endpoint_vpc.create_brute_index_endpoint_vpc(
            bf_index_artifact=create_brute_force_index_op.outputs['brute_force_index'],
            project=project,
            project_number=project_number,
            version=model_version,
            location=location,
            vpc_network_name=vpc_network_name,
            brute_index_endpoint_display_name=f'bf-index-endpoint_{pipeline_version}'.replace('-', '_'),
            brute_index_endpoint_description='endpoint for brute force index',
            brute_force_index_resource_uri=create_brute_force_index_op.outputs['brute_force_index_resource_uri'],
        )
        .set_display_name("Create BF Index Endpoint")
        # .after(XXX)
    )

    # ========================================================================
    # Deploy Indexes
    # ========================================================================

    deploy_ann_index_op = (
        deploy_ann_index.deploy_ann_index(
            project=project,
            location=location,
            version=model_version,
            deployed_ann_index_name=f'deployedann_{model_version}'.replace('-', '_'), #todo update to letters, numbers, and underscores only
            ann_index_resource_uri=create_ann_index_endpoint_vpc_op.outputs['ann_index_resource_uri'],
            index_endpoint_resource_uri=create_ann_index_endpoint_vpc_op.outputs['ann_index_endpoint_resource_uri'],
        )
        .set_display_name("Deploy ANN Index")
        .set_caching_options(True)
    )

    deploy_brute_index_op = (
        deploy_brute_index.deploy_brute_index(
            project=project,
            location=location,
            version=model_version,
            deployed_brute_force_index_name=f'deployedbf_{model_version}'.replace('-', '_'),
            brute_force_index_resource_uri=create_brute_index_endpoint_vpc_op.outputs['brute_force_index_resource_uri'],
            index_endpoint_resource_uri=create_brute_index_endpoint_vpc_op.outputs['brute_index_endpoint_resource_uri'],
        )
        .set_display_name("Deploy BF Index")
        .set_caching_options(True)
    )

    test_model_index_endpoint_op = (
        test_model_index_endpoint.test_model_index_endpoint(
            project=project,
            location=location,
            version=model_version,
            test_instances_gcs_filename=test_instances_gcs_filename,
            gcs_train_script_path=gcs_train_script_path,
            train_output_gcs_bucket=train_output_gcs_bucket, 
            experiment_name=experiment_name,
            experiment_run=experiment_run,
            ann_index_endpoint_resource_uri=deploy_ann_index_op.outputs['index_endpoint_resource_uri'],
            brute_index_endpoint_resource_uri=deploy_brute_index_op.outputs['index_endpoint_resource_uri'],
            endpoint=model_deploy_op.outputs['gcp_resources']
        )
    )

In [46]:
# ! rm -f custom_container_pipeline_spec.json

PIPELINE_JSON_SPEC_LOCAL = "custom_pipeline_spec.json"

! rm -f $PIPELINE_JSON_SPEC_LOCAL

kfp.v2.compiler.Compiler().compile(
    pipeline_func=pipeline, package_path=PIPELINE_JSON_SPEC_LOCAL,
)



### save pipeline spec json

In [47]:
# !gsutil cp custom_container_pipeline_spec.json $PIPELINE_ROOT_PATH/pipeline_spec.json

PIPELINES_FILEPATH = f'{PIPELINE_ROOT_PATH}/pipeline_spec.json'
print("PIPELINES_FILEPATH:", PIPELINES_FILEPATH)

!gsutil -q cp $PIPELINE_JSON_SPEC_LOCAL $PIPELINES_FILEPATH

PIPELINES_FILEPATH: gs://ndr-v1-hybrid-vertex-bucket/tfrs-pipe-v1/run-20230926-120445/pipeline_root/pipeline_spec.json


In [48]:
!gsutil ls $PIPELINE_ROOT_PATH

gs://ndr-v1-hybrid-vertex-bucket/tfrs-pipe-v1/run-20230926-120445/pipeline_root/Dockerfile_tfrs
gs://ndr-v1-hybrid-vertex-bucket/tfrs-pipe-v1/run-20230926-120445/pipeline_root/pipeline_spec.json
gs://ndr-v1-hybrid-vertex-bucket/tfrs-pipe-v1/run-20230926-120445/pipeline_root/934903580331/
gs://ndr-v1-hybrid-vertex-bucket/tfrs-pipe-v1/run-20230926-120445/pipeline_root/trainer/


## Submit pipeline to Vertex

In [99]:
PROJECT_NUMBER=PROJECT_NUM
vpc_network_name = VPC_NETWORK_NAME
# VERTEX_SA = 'notebooksa@hybrid-vertex.iam.gserviceaccount.com'

TRAIN_APP_CODE_PATH = f'{PIPELINE_ROOT_PATH}/trainer'

job = vertex_ai.PipelineJob(
    display_name=PIPELINE_NAME,
    template_path=PIPELINES_FILEPATH,
    pipeline_root=f'{PIPELINE_ROOT_PATH}',
    failure_policy='fast', # slow | fast
    # enable_caching=False,
    parameter_values={
        'project': PROJECT_ID,
        'project_number': PROJECT_NUM,
        'location': REGION,
        'model_version': VERSION,
        'pipeline_version': PIPELINE_VERSION,
        'model_display_name': MODEL_ROOT_NAME,
        'vpc_network_name':vpc_network_name,
        # 'pipeline_tag': PIPELINE_TAG,
        'gcs_train_script_path': TRAIN_APP_CODE_PATH,
        'train_image_uri': f"{REMOTE_IMAGE_NAME}",
        'train_output_gcs_bucket': BUCKET_NAME,
        'train_dir': BUCKET_NAME,
        'train_dir_prefix': TRAIN_DIR_PREFIX,
        'valid_dir': BUCKET_NAME,
        'valid_dir_prefix': VALID_DIR_PREFIX,
        'candidate_file_dir': BUCKET_NAME,
        'candidate_files_prefix': CANDIDATE_PREFIX,
        'test_instances_gcs_filename': LOCAL_INSTANCES_PKL,
        'many_test_instances_gcs_filename': MANY_TESTS_FILE,
        # 'tensorboard_resource_name': TB_RESOURCE_NAME,
        'train_dockerfile_name': DOCKERNAME,
        'experiment_name': EXPERIMENT_NAME,
        'experiment_run': RUN_NAME,
        'service_account': VERTEX_SA,
        'register_model_flag': 'True',
        'generate_new_vocab': False,
        'max_playlist_length': TRACK_HISTORY,
        'max_tokens': 20000,
        'ngrams': 2,
        # 'feature_dict': FEATURE_DICT,
        'prefix': PREFIX,
        'emails': "jordantotten@google.com",
        'bq_dataset': BQ_DATASET,
        'bq_train_table': BQ_TABLE_TRAIN,
    },
)
    
job.run(
    sync=False,
    service_account=VERTEX_SA,
    network=f'projects/{PROJECT_NUM}/global/networks/{VPC_NETWORK_NAME}'
)

#### clean up

In [248]:
# ! rm -rf custom_pipeline_spec.json

**Finished**