# Training pipeline for TFRS  2tower model 

```
tensorflow==2.10.1
tensorflow-cloud==0.1.16
tensorflow-datasets==4.6.0
tensorflow-estimator==2.10.0
tensorflow-hub==0.12.0
tensorflow-io==0.27.0
tensorflow-io-gcs-filesystem==0.27.0
tensorflow-metadata==1.8.0
tensorflow-probability==0.18.0
tensorflow-recommenders==0.7.2
tensorflow-serving-api==2.8.3
tensorflow-transform==1.8.0
```

In [None]:
# !pip install kfp==1.8.18 --user
# !pip install google-cloud-pipeline-components==1.0.32

In [None]:
# ! python3 -c "import kfp; print('KFP SDK version: {}'.format(kfp.__version__))"
# ! python3 -c "import google_cloud_pipeline_components; print('google_cloud_pipeline_components version: {}'.format(google_cloud_pipeline_components.__version__))"
# ! python3 -c "import google.cloud.aiplatform; print('aiplatform SDK version: {}'.format(google.cloud.aiplatform.__version__))"

In [1]:
GCP_PROJECTS = !gcloud config get-value project
PROJECT_ID = GCP_PROJECTS[0]
PROJECT_NUM = !gcloud projects list --filter="$PROJECT_ID" --format="value(PROJECT_NUMBER)"
PROJECT_NUM = PROJECT_NUM[0]
REGION = 'us-central1'

print(f"PROJECT_ID: {PROJECT_ID}")
print(f"PROJECT_NUM: {PROJECT_NUM}")
print(f"REGION: {REGION}")

VERTEX_SA = '934903580331-compute@developer.gserviceaccount.com'

PROJECT_ID: hybrid-vertex
PROJECT_NUM: 934903580331
REGION: us-central1


In [2]:
import os
import json
from datetime import datetime
from time import time
import pandas as pd
# disable INFO and DEBUG logging everywhere
import logging
import time
from pprint import pprint
import pickle as pkl

logging.disable(logging.WARNING)

from google.cloud import aiplatform as vertex_ai
from google.cloud import storage

# Pipelines
from typing import Any, Callable, Dict, NamedTuple, Optional, List
from google_cloud_pipeline_components import aiplatform as gcc_aip
from google_cloud_pipeline_components.types import artifact_types

# Kubeflow SDK
# TODO: fix these
from kfp.v2 import dsl
import kfp
import kfp.v2.dsl
from kfp.v2.google import client as pipelines_client
from kfp.v2.dsl import (Artifact, Dataset, Input, InputPath, Model, Output,
                        OutputPath, component)

storage_client = storage.Client(project=PROJECT_ID)

vertex_ai.init(project=PROJECT_ID,location=REGION)

In [3]:
VERSION= "jtv5"

# TODO: update
# BUCKET_DATA_DIR = 'spotify-data-regimes'
# TRAIN_DIR_PREFIX = f'{VERSION}/train'                 # subset: valid_v9 | train_v9
# VALID_DIR_PREFIX = f'{VERSION}/valid'                 # valid_v9 | train_v9
# CANDIDATE_PREFIX = f'{VERSION}/candidates' 

In [4]:
# PREFIX = 'spotify-2tower'
APP='sp'
MODEL_TYPE='2tower'
FRAMEWORK = 'tfrs'
PIPELINE_VERSION = 'pipev3'
MODEL_ROOT_NAME = f'{APP}-{MODEL_TYPE}-{FRAMEWORK}-{VERSION}-{PIPELINE_VERSION}'

print(f"MODEL_ROOT_NAME: {MODEL_ROOT_NAME}")

MODEL_ROOT_NAME: sp-2tower-tfrs-jtv5-pipev3


## Write Train files

In [5]:
REPO_DOCKER_PATH_PREFIX = 'src'

In [6]:
# Docker definitions for training
IMAGE_NAME = f'{MODEL_ROOT_NAME}-training'
IMAGE_URI = f'gcr.io/{PROJECT_ID}/{IMAGE_NAME}'

DOCKERNAME = 'tfrs'
REPO_DOCKER_PATH_PREFIX = 'src'
MACHINE_TYPE ='e2-highcpu-32'
FILE_LOCATION = './src'

print(f"IMAGE_URI: {IMAGE_URI}")

IMAGE_URI: gcr.io/hybrid-vertex/sp-2tower-tfrs-jtv5-pipev3-training


In [7]:
%%writefile {REPO_DOCKER_PATH_PREFIX}/cloudbuild.yaml

steps:
- name: 'gcr.io/cloud-builders/docker'
  args: ['build', '-t', '$_IMAGE_URI', '$_FILE_LOCATION', '-f', '$_FILE_LOCATION/Dockerfile.$_DOCKERNAME']
images:
- '$_IMAGE_URI'

Overwriting src/cloudbuild.yaml


In [8]:
%%writefile {REPO_DOCKER_PATH_PREFIX}/two_tower_jt/train_config.py

# PROJECT_ID = 'hybrid-vertex'
# NEW_ADAPTS = 'True'
# USE_CROSS_LAYER = True
# USE_DROPOUT = 'True'
# SEED = 1234
MAX_PLAYLIST_LENGTH = 15
# EMBEDDING_DIM = 128   
# PROJECTION_DIM = 25  
# SEED = 1234
# DROPOUT_RATE = 0.33
# MAX_TOKENS = 20000

Overwriting src/two_tower_jt/train_config.py


In [9]:
%%writefile {REPO_DOCKER_PATH_PREFIX}/two_tower_jt/requirements.txt
google-cloud-aiplatform>=1.21.0
tensorflow-recommenders==0.7.2
tensorboard==2.10.1
tensorboard-data-server==0.6.1
tensorboard-plugin-profile==2.11.1
tensorflow-io==0.27.0
google-cloud-aiplatform[cloud_profiler]>=1.20.0

Overwriting src/two_tower_jt/requirements.txt


In [10]:
%%writefile {REPO_DOCKER_PATH_PREFIX}/Dockerfile.{DOCKERNAME}

# FROM tensorflow/tensorflow:2.10.1-gpu
FROM gcr.io/deeplearning-platform-release/tf-gpu.2-10

WORKDIR /src

# Copies the trainer code to the docker image.
COPY two_tower_jt/* two_tower_jt/ 

RUN pip install -r two_tower_jt/requirements.txt

RUN apt update && apt -y install nvtop

# # Sets up the entry point to invoke the trainer.
# # ENTRYPOINT ["python", "-m", "two_tower_jt.task"]

Overwriting src/Dockerfile.tfrs


## Build Custom Train Image

In [11]:
print(f"DOCKERNAME: {DOCKERNAME}")
print(f"IMAGE_URI: {IMAGE_URI}")
print(f"FILE_LOCATION: {FILE_LOCATION}")
print(f"MACHINE_TYPE: {MACHINE_TYPE}")

DOCKERNAME: tfrs
IMAGE_URI: gcr.io/hybrid-vertex/sp-2tower-tfrs-jtv5-pipev3-training
FILE_LOCATION: ./src
MACHINE_TYPE: e2-highcpu-32


In [12]:
!pwd

/home/jupyter/jw-repo/spotify_mpd_two_tower


In [13]:
!tree /home/jupyter/jw-repo/spotify_mpd_two_tower/src

[01;34m/home/jupyter/jw-repo/spotify_mpd_two_tower/src[00m
├── Dockerfile.tfrs
├── cloudbuild.yaml
├── [01;34mtrain_pipes[00m
│   ├── [01;34m__pycache__[00m
│   │   ├── adapt_fixed_text_layer_vocab.cpython-37.pyc
│   │   ├── adapt_ragged_text_layer_vocab.cpython-37.pyc
│   │   ├── build_custom_image.cpython-37.pyc
│   │   ├── create_ann_index.cpython-37.pyc
│   │   ├── create_ann_index_endpoint_vpc.cpython-37.pyc
│   │   ├── create_brute_force_index.cpython-37.pyc
│   │   ├── create_brute_index_endpoint_vpc.cpython-37.pyc
│   │   ├── create_master_vocab.cpython-37.pyc
│   │   ├── create_tensorboard.cpython-37.pyc
│   │   ├── deploy_ann_index.cpython-37.pyc
│   │   ├── deploy_brute_index.cpython-37.pyc
│   │   ├── generate_candidates.cpython-37.pyc
│   │   ├── pipeline_config.cpython-37.pyc
│   │   ├── test_index_recall.cpython-37.pyc
│   │   └── train_custom_model.cpython-37.pyc
│   ├── adapt_fixed_text_layer_vocab.py
│   ├── adapt_ragged_text_layer_vocab.py
│   ├── build_custom_

### Optionally include a `.gcloudignore` file 

* limits the files submitted to Cloud Build
* see [gcloudignore](https://cloud.google.com/sdk/gcloud/reference/topic/gcloudignore) for details

In [14]:
! gcloud config set gcloudignore/enabled true

Updated property [gcloudignore/enabled].


In [15]:
%%writefile .gcloudignore
.gcloudignore
/local_files/
/img/
*.pkl
*.png
.git
.github
.ipynb_checkpoints/*
*__pycache__
*cpython-37.pyc
spotipy_secret_creds.py
candidate_embs_local_20230130-180710.json
vocab_dict.pkl

Overwriting .gcloudignore


In [16]:
# !gcloud meta list-files-for-upload
# !ls

In [21]:
! gcloud builds submit --config src/cloudbuild.yaml \
    --substitutions _DOCKERNAME=$DOCKERNAME,_IMAGE_URI=$IMAGE_URI,_FILE_LOCATION=$FILE_LOCATION \
    --timeout=2h \
    --machine-type=$MACHINE_TYPE

Creating temporary tarball archive of 77 file(s) totalling 1.8 MiB before compression.
Uploading tarball of [.] to [gs://hybrid-vertex_cloudbuild/source/1676353668.064571-9739bcfbd4ff4ace83214bef350f82a7.tgz]
Created [https://cloudbuild.googleapis.com/v1/projects/hybrid-vertex/locations/global/builds/6d6b4738-4edf-4fcf-aff8-5309a9860b0e].
Logs are available at [ https://console.cloud.google.com/cloud-build/builds/6d6b4738-4edf-4fcf-aff8-5309a9860b0e?project=934903580331 ].
----------------------------- REMOTE BUILD OUTPUT ------------------------------
starting build "6d6b4738-4edf-4fcf-aff8-5309a9860b0e"

FETCHSOURCE
Fetching storage object: gs://hybrid-vertex_cloudbuild/source/1676353668.064571-9739bcfbd4ff4ace83214bef350f82a7.tgz#1676353668665894
Copying gs://hybrid-vertex_cloudbuild/source/1676353668.064571-9739bcfbd4ff4ace83214bef350f82a7.tgz#1676353668665894...
/ [1 files][295.2 KiB/295.2 KiB]                                                
Operation completed over 1 objects/295.

# Pipeline Components

In [17]:
os.getcwd()

'/home/jupyter/jw-repo/spotify_mpd_two_tower'

In [18]:
REPO_DOCKER_PATH_PREFIX = 'src'
PIPELINES_SUB_DIR = 'train_pipes'

In [19]:
! rm -rf {REPO_DOCKER_PATH_PREFIX}/{PIPELINES_SUB_DIR}
! mkdir {REPO_DOCKER_PATH_PREFIX}/{PIPELINES_SUB_DIR}

## Build Custom Image

In [20]:
%%writefile {REPO_DOCKER_PATH_PREFIX}/{PIPELINES_SUB_DIR}/build_custom_image.py

import kfp
from typing import Any, Callable, Dict, NamedTuple, Optional, List
from kfp.v2.dsl import (Artifact, Dataset, Input, InputPath, Model, Output,
                        OutputPath, component, Metrics)
@kfp.v2.dsl.component(
    base_image="gcr.io/google.com/cloudsdktool/cloud-sdk:latest",
    packages_to_install=[
        "google-cloud-build"
    ],
)
def build_custom_image(
    project: str,
    artifact_gcs_path: str,
    docker_name: str,
    app_dir_name: str,
    custom_image_uri: str,
) -> NamedTuple('Outputs', [
    ('custom_image_uri', str),
]):
    # TODO: make output Artifact for image_uri
    """
    custom pipeline component to build custom image using
    Cloud Build, the training/serving application code, and dependencies
    defined in the Dockerfile
    """
    
    import logging
    import os

    from google.cloud.devtools import cloudbuild_v1 as cloudbuild
    from google.protobuf.duration_pb2 import Duration

    # initialize client for cloud build
    logging.getLogger().setLevel(logging.INFO)
    build_client = cloudbuild.services.cloud_build.CloudBuildClient()
    
    # parse step inputs to get path to Dockerfile and training application code
    _gcs_dockerfile_path = os.path.join(artifact_gcs_path, f"{docker_name}") # Dockerfile.XXXXX
    _gcs_script_dir_path = os.path.join(artifact_gcs_path, f"{app_dir_name}/") # "trainer/"
    
    logging.info(f"_gcs_dockerfile_path: {_gcs_dockerfile_path}")
    logging.info(f"_gcs_script_dir_path: {_gcs_script_dir_path}")
    
    # define build steps to pull the training code and Dockerfile
    # and build/push the custom training container image
    build = cloudbuild.Build()
    build.steps = [
        {
            "name": "gcr.io/cloud-builders/gsutil",
            "args": ["cp", "-r", _gcs_script_dir_path, "."],
        },
        {
            "name": "gcr.io/cloud-builders/gsutil",
            "args": ["cp", _gcs_dockerfile_path, "Dockerfile"],
        },
        # enabling Kaniko cache in a Docker build that caches intermediate
        # layers and pushes image automatically to Container Registry
        # https://cloud.google.com/build/docs/kaniko-cache
        # {
        #     "name": "gcr.io/kaniko-project/executor:latest",
        #     # "name": "gcr.io/kaniko-project/executor:v1.8.0",        # TODO; downgraded to avoid error in build
        #     # "args": [f"--destination={training_image_uri}", "--cache=true"],
        #     "args": [f"--destination={training_image_uri}", "--cache=false"],
        # },
        {
            "name": "gcr.io/cloud-builders/docker",
            "args": ['build','-t', f'{custom_image_uri}', '.'],
        },
        {
            "name": "gcr.io/cloud-builders/docker",
            "args": ['push', f'{custom_image_uri}'], 
        },
    ]
    # override default timeout of 10min
    timeout = Duration()
    timeout.seconds = 7200
    build.timeout = timeout

    # create build
    operation = build_client.create_build(project_id=project, build=build)
    logging.info("IN PROGRESS:")
    logging.info(operation.metadata)

    # get build status
    result = operation.result()
    logging.info("RESULT:", result.status)

    # return step outputs
    return (
        custom_image_uri,
    )

Writing src/train_pipes/build_custom_image.py


## Create Tensorboard

In [21]:
%%writefile {REPO_DOCKER_PATH_PREFIX}/{PIPELINES_SUB_DIR}/create_tensorboard.py

import kfp
from typing import Any, Callable, Dict, NamedTuple, Optional, List
from kfp.v2.dsl import (Artifact, Dataset, Input, InputPath, Model, Output,
                        OutputPath, component, Metrics)
@kfp.v2.dsl.component(
    base_image='python:3.9',
    packages_to_install=[
        'google-cloud-aiplatform==1.20.0',
        'numpy',
        'google-cloud-storage',
    ],
    # output_component_file="./pipelines/train_custom_model.yaml",
)
def create_tensorboard(
    project: str,
    location: str,
    model_version: str,
    pipeline_version: str,
    model_name: str, 
    experiment_name: str,
    experiment_run: str,
) -> NamedTuple('Outputs', [
    ('tensorboard_resource_name', str),
    ('tensorboard_display_name', str),
]):
    
    import logging
    from google.cloud import aiplatform as vertex_ai
    from google.cloud import storage
    
    vertex_ai.init(
        project=project,
        location=location,
        # experiment=experiment_name,
    )
    
    logging.info(f'experiment_name: {experiment_name}')
    
    # # create new TB instance
    TENSORBOARD_DISPLAY_NAME=f"{experiment_name}-v1"
    tensorboard = vertex_ai.Tensorboard.create(display_name=TENSORBOARD_DISPLAY_NAME, project=project, location=location)
    TB_RESOURCE_NAME = tensorboard.resource_name
    
    logging.info(f'TENSORBOARD_DISPLAY_NAME: {TENSORBOARD_DISPLAY_NAME}')
    logging.info(f'TB_RESOURCE_NAME: {TB_RESOURCE_NAME}')
    
    return (
        f'{TB_RESOURCE_NAME}',
        f'{TENSORBOARD_DISPLAY_NAME}',
    )

Writing src/train_pipes/create_tensorboard.py


## Generate new train vocab

### fixed_text_layer adapts vocab

In [22]:
%%writefile {REPO_DOCKER_PATH_PREFIX}/{PIPELINES_SUB_DIR}/adapt_fixed_text_layer_vocab.py

import kfp
from typing import Any, Callable, Dict, NamedTuple, Optional, List
from kfp.v2.dsl import (Artifact, Dataset, Input, InputPath, Model, Output,
                        OutputPath, component, Metrics)

@kfp.v2.dsl.component(
    base_image="python:3.9",
    packages_to_install=[
        # 'google-cloud-aiplatform==1.18.1',
        'google-cloud-storage',
        'tensorflow==2.10.1',
    ],
)
def adapt_fixed_text_layer_vocab(
    project: str,
    location: str,
    version: str,
    data_dir_bucket_name: str,
    data_dir_path_prefix: str,
    train_output_gcs_bucket: str,
    experiment_name: str,
    experiment_run: str,
    max_playlist_length: int,
    max_tokens: int,
    ngrams: int,
    feature_name: str,
    generate_new_vocab: bool,
    # feat_type: str,
) -> NamedTuple('Outputs', [
    ('vocab_gcs_uri', str),
    # ('feature_name', str),
]):

    """
    custom pipeline component to adapt the `pl_name_src` layer
    writes vocab to pickled dict in GCS
    dict combined with other layer vocabs and used in Two Tower training
    """
    
    # import packages
    import os
    import logging
    import pickle as pkl
    import time
    
    from google.cloud import storage
    
    import tensorflow as tf
    
    storage_client = storage.Client(project=project)
    
    logging.info(f"feature_name: {feature_name}")
    
    # ===================================================
    # helper function
    # ===================================================
    
    def download_blob(bucket_name, source_gcs_obj, local_filename):
        """Uploads a file to the bucket."""
        # storage_client = storage.Client(project=project_number)
        bucket = storage_client.bucket(bucket_name)
        blob = bucket.blob(source_gcs_obj)
        blob.download_to_filename(local_filename)

        filehandler = open(f'{local_filename}', 'rb')
        loaded_dict = pkl.load(filehandler)
        filehandler.close()

        logging.info(f"File {local_filename} downloaded from gs://{bucket_name}/{source_gcs_obj}")

        return loaded_dict
    
    # ===================================================
    # set feature vars
    # ===================================================
    MAX_PLAYLIST_LENGTH = max_playlist_length
    logging.info(f"MAX_PLAYLIST_LENGTH: {MAX_PLAYLIST_LENGTH}")
    
    FEATURES_PREFIX = f'{experiment_name}/{experiment_run}/features'
    logging.info(f"FEATURES_PREFIX: {FEATURES_PREFIX}")
    
    all_features_dict = {}
    
    # ===================================================
    # load pickled Candidate features
    # ===================================================
    
    # candidate features
    CAND_FEAT_FILENAME = 'candidate_feats_dict.pkl'
    CAND_FEAT_GCS_OBJ = f'{FEATURES_PREFIX}/{CAND_FEAT_FILENAME}'
    LOADED_CANDIDATE_DICT = f'loaded_{CAND_FEAT_FILENAME}'
    
    loaded_candidate_features_dict = download_blob(
        train_output_gcs_bucket,
        CAND_FEAT_GCS_OBJ,
        LOADED_CANDIDATE_DICT
    )
    # logging.info(f"CAND_FEAT_FILENAME: {CAND_FEAT_FILENAME}; CAND_FEAT_GCS_OBJ:{CAND_FEAT_GCS_OBJ}; LOADED_CANDIDATE_DICT: {LOADED_CANDIDATE_DICT}")
    
#     # os.system(f'gsutil cp gs://{train_output_gcs_bucket}/{CAND_FEAT_GCS_OBJ} {LOADED_CANDIDATE_DICT}')
#     bucket = storage_client.bucket(train_output_gcs_bucket)
#     blob = bucket.blob(CAND_FEAT_GCS_OBJ)
#     blob.download_to_filename(LOADED_CANDIDATE_DICT)
    
#     filehandler = open(f'{LOADED_CANDIDATE_DICT}', 'rb')
#     loaded_candidate_features_dict = pkl.load(filehandler)
#     filehandler.close()
#     logging.info(f"loaded_candidate_features_dict: {loaded_candidate_features_dict}")
    
    all_features_dict.update(loaded_candidate_features_dict)
    logging.info(f"all_features_dict: {all_features_dict}")

    # ===================================================
    # load pickled Query features
    # ===================================================

    # query features
    QUERY_FEAT_FILENAME = 'query_feats_dict.pkl'
    QUERY_FEAT_GCS_OBJ = f'{FEATURES_PREFIX}/{QUERY_FEAT_FILENAME}'
    LOADED_QUERY_DICT = f'loaded_{QUERY_FEAT_FILENAME}'
    
    loaded_query_features_dict = download_blob(
        train_output_gcs_bucket,
        QUERY_FEAT_GCS_OBJ,
        LOADED_QUERY_DICT
    )
#     logging.info(f"QUERY_FEAT_FILENAME: {QUERY_FEAT_FILENAME}; QUERY_FEAT_GCS_OBJ:{QUERY_FEAT_GCS_OBJ}; LOADED_QUERY_DICT: {LOADED_QUERY_DICT}")
    
#     # os.system(f'gsutil cp gs://{train_output_gcs_bucket}/{QUERY_FEATURES_GCS_OBJ} {LOADED_QUERY_DICT}')
#     bucket = storage_client.bucket(train_output_gcs_bucket)
#     blob = bucket.blob(QUERY_FEAT_GCS_OBJ)
#     blob.download_to_filename(LOADED_QUERY_DICT)
    
#     filehandler = open(f'{LOADED_QUERY_DICT}', 'rb')
#     loaded_query_features_dict = pkl.load(filehandler)
#     filehandler.close()
#     logging.info(f"loaded_query_features_dict: {loaded_query_features_dict}")
    
    all_features_dict.update(loaded_query_features_dict)
    logging.info(f"all_features_dict: {all_features_dict}")
    
    # ===================================================
    # tfrecord parser
    # ===================================================
    
#     candidate_features = {
#         "track_uri_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),            
#         "track_name_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),
#         "artist_uri_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),
#         "artist_name_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),
#         "album_uri_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),           
#         "album_name_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()), 
#         "duration_ms_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),      
#         "track_pop_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),      
#         "artist_pop_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
#         "artist_genres_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),
#         "artist_followers_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
#         # new
#         # "track_pl_titles_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),
#         "track_danceability_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
#         "track_energy_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
#         "track_key_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),
#         "track_loudness_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
#         "track_mode_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),
#         "track_speechiness_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
#         "track_acousticness_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
#         "track_instrumentalness_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
#         "track_liveness_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
#         "track_valence_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
#         "track_tempo_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
#         "time_signature_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),
#     }

#     feats = {
#         # ===================================================
#         # candidate track features
#         # ===================================================
#         "track_uri_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),            
#         "track_name_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),
#         "artist_uri_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),
#         "artist_name_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),
#         "album_uri_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),           
#         "album_name_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()), 
#         "duration_ms_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),      
#         "track_pop_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),      
#         "artist_pop_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
#         "artist_genres_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),
#         "artist_followers_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
#         # "track_pl_titles_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),
#         "track_danceability_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
#         "track_energy_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
#         "track_key_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),
#         "track_loudness_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
#         "track_mode_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),
#         "track_speechiness_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
#         "track_acousticness_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
#         "track_instrumentalness_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
#         "track_liveness_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
#         "track_valence_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
#         "track_tempo_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
#         "time_signature_can": tf.io.FixedLenFeature(dtype=tf.string, shape=()), # track_time_signature_can

#         # ===================================================
#         # summary playlist features
#         # ===================================================
#         "pl_name_src" : tf.io.FixedLenFeature(dtype=tf.string, shape=()), 
#         'pl_collaborative_src' : tf.io.FixedLenFeature(dtype=tf.string, shape=()), 
#         # 'num_pl_followers_src' : tf.io.FixedLenFeature(dtype=tf.float32, shape=()), 
#         'pl_duration_ms_new' : tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
#         'num_pl_songs_new' : tf.io.FixedLenFeature(dtype=tf.float32, shape=()), # n_songs_pl_new | num_pl_songs_new
#         'num_pl_artists_new' : tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
#         'num_pl_albums_new' : tf.io.FixedLenFeature(dtype=tf.float32, shape=()), 
#         # 'avg_track_pop_pl_new' : tf.io.FixedLenFeature(dtype=tf.float32, shape=()), 
#         # 'avg_artist_pop_pl_new' : tf.io.FixedLenFeature(dtype=tf.float32, shape=()), 
#         # 'avg_art_followers_pl_new' : tf.io.FixedLenFeature(dtype=tf.float32, shape=()), 

#         # ===================================================
#         # ragged playlist features
#         # ===================================================
#         # bytes / string
#         "track_uri_pl": tf.io.FixedLenFeature(dtype=tf.string, shape=(MAX_PLAYLIST_LENGTH,)),
#         "track_name_pl": tf.io.FixedLenFeature(dtype=tf.string, shape=(MAX_PLAYLIST_LENGTH,)),
#         "artist_uri_pl": tf.io.FixedLenFeature(dtype=tf.string, shape=(MAX_PLAYLIST_LENGTH,)),
#         "artist_name_pl": tf.io.FixedLenFeature(dtype=tf.string, shape=(MAX_PLAYLIST_LENGTH,)),
#         "album_uri_pl": tf.io.FixedLenFeature(dtype=tf.string, shape=(MAX_PLAYLIST_LENGTH,)),
#         "album_name_pl": tf.io.FixedLenFeature(dtype=tf.string, shape=(MAX_PLAYLIST_LENGTH,)),
#         "artist_genres_pl": tf.io.FixedLenFeature(dtype=tf.string, shape=(MAX_PLAYLIST_LENGTH,)),
#         # "tracks_playlist_titles_pl": tf.io.FixedLenFeature(dtype=tf.string, shape=(MAX_PLAYLIST_LENGTH,)),
#         "track_key_pl": tf.io.FixedLenFeature(dtype=tf.string, shape=(MAX_PLAYLIST_LENGTH,)),
#         "track_mode_pl": tf.io.FixedLenFeature(dtype=tf.string, shape=(MAX_PLAYLIST_LENGTH,)),
#         "time_signature_pl": tf.io.FixedLenFeature(dtype=tf.string, shape=(MAX_PLAYLIST_LENGTH,)), 

#         # Float List
#         "duration_ms_songs_pl": tf.io.FixedLenFeature(dtype=tf.float32, shape=(MAX_PLAYLIST_LENGTH,)),
#         "track_pop_pl": tf.io.FixedLenFeature(dtype=tf.float32, shape=(MAX_PLAYLIST_LENGTH,)),
#         "artist_pop_pl": tf.io.FixedLenFeature(dtype=tf.float32, shape=(MAX_PLAYLIST_LENGTH,)),
#         "artists_followers_pl": tf.io.FixedLenFeature(dtype=tf.float32, shape=(MAX_PLAYLIST_LENGTH,)),
#         "track_danceability_pl": tf.io.FixedLenFeature(dtype=tf.float32, shape=(MAX_PLAYLIST_LENGTH,)),
#         "track_energy_pl": tf.io.FixedLenFeature(dtype=tf.float32, shape=(MAX_PLAYLIST_LENGTH,)),
#         "track_loudness_pl": tf.io.FixedLenFeature(dtype=tf.float32, shape=(MAX_PLAYLIST_LENGTH,)),
#         "track_speechiness_pl": tf.io.FixedLenFeature(dtype=tf.float32, shape=(MAX_PLAYLIST_LENGTH,)),
#         "track_acousticness_pl": tf.io.FixedLenFeature(dtype=tf.float32, shape=(MAX_PLAYLIST_LENGTH,)),
#         "track_instrumentalness_pl": tf.io.FixedLenFeature(dtype=tf.float32, shape=(MAX_PLAYLIST_LENGTH,)),
#         "track_liveness_pl": tf.io.FixedLenFeature(dtype=tf.float32, shape=(MAX_PLAYLIST_LENGTH,)),
#         "track_valence_pl": tf.io.FixedLenFeature(dtype=tf.float32, shape=(MAX_PLAYLIST_LENGTH,)),
#         "track_tempo_pl": tf.io.FixedLenFeature(dtype=tf.float32, shape=(MAX_PLAYLIST_LENGTH,)),
#     }
    
    # parsing function
    def parse_tfrecord(example):
        """
        Reads a serialized example from GCS and converts to tfrecord
        """
        # example = tf.io.parse_single_example(
        example = tf.io.parse_example(
            example,
            # feats
            features=all_features_dict
        )
        return example
    
    if generate_new_vocab:
        logging.info(f"Generating new vocab file...")
        
        # list blobs (tfrecords)
        train_files = []
        for blob in storage_client.list_blobs(f'{data_dir_bucket_name}', prefix=f'{data_dir_path_prefix}'):
            if '.tfrecords' in blob.name:
                train_files.append(blob.public_url.replace("https://storage.googleapis.com/", "gs://"))

        logging.info(f"TFRecord file count: {len(train_files)}")

        # ===================================================
        # create TF dataset
        # ===================================================
        logging.info(f"Creating TFRecordDataset...")
        train_dataset = tf.data.TFRecordDataset(train_files)
        train_parsed = train_dataset.map(parse_tfrecord)

        # ===================================================
        # adapt layer for feature
        # ===================================================
        start = time.time()
        text_layer = tf.keras.layers.TextVectorization(
            max_tokens=max_tokens,
            ngrams=ngrams
        )
        text_layer.adapt(train_parsed.map(lambda x: x[f'{feature_name}']))
        end = time.time()

        logging.info(f'Layer adapt elapsed time: {round((end - start), 2)} seconds')

        # ===================================================
        # write vocab to pickled dict --> gcs
        # ===================================================
        logging.info(f"Writting pickled dict to GCS...")

        VOCAB_LOCAL_FILE = f'{feature_name}_vocab_dict.pkl'
        VOCAB_GCS_OBJ = f'{experiment_name}/{experiment_run}/vocab-staging/{VOCAB_LOCAL_FILE}' # destination folder prefix and blob name
        VOCAB_DICT = {f'{feature_name}' : text_layer.get_vocabulary(),}

        logging.info(f"VOCAB_LOCAL_FILE: {VOCAB_LOCAL_FILE}")
        logging.info(f"VOCAB_GCS_OBJ: {VOCAB_GCS_OBJ}")

        # pickle
        filehandler = open(f'{VOCAB_LOCAL_FILE}', 'wb')
        pkl.dump(VOCAB_DICT, filehandler)
        filehandler.close()

        # upload to GCS
        bucket_client = storage_client.bucket(train_output_gcs_bucket)
        blob = bucket_client.blob(VOCAB_GCS_OBJ)
        blob.upload_from_filename(VOCAB_LOCAL_FILE)

        vocab_uri = f'gs://{train_output_gcs_bucket}/{VOCAB_GCS_OBJ}'

        logging.info(f"File {VOCAB_LOCAL_FILE} uploaded to {vocab_uri}")
        
    else:
        logging.info(f"Using existing vocab file...")
        
        vocab_uri = 'gs://two-tower-models/vocabs/vocab_dict.pkl'
        logging.info(f"Using vocab file: {vocab_uri}")
    
    return(
        vocab_uri,
        # feature_name,
    )

Writing src/train_pipes/adapt_fixed_text_layer_vocab.py


### ragged adapts vocab

In [23]:
%%writefile {REPO_DOCKER_PATH_PREFIX}/{PIPELINES_SUB_DIR}/adapt_ragged_text_layer_vocab.py

import kfp
from typing import Any, Callable, Dict, NamedTuple, Optional, List
from kfp.v2.dsl import (Artifact, Dataset, Input, InputPath, Model, Output,
                        OutputPath, component, Metrics)

@kfp.v2.dsl.component(
    base_image="python:3.9",
    packages_to_install=[
        # 'google-cloud-aiplatform==1.18.1',
        'google-cloud-storage',
        'tensorflow==2.10.1',
    ],
)
def adapt_ragged_text_layer_vocab(
    project: str,
    location: str,
    version: str,
    data_dir_bucket_name: str,
    data_dir_path_prefix: str,
    train_output_gcs_bucket: str,
    experiment_name: str,
    experiment_run: str,
    max_playlist_length: int,
    max_tokens: int,
    ngrams: int,
    feature_name: str,
    generate_new_vocab: bool,
    # feat_type: str,
) -> NamedTuple('Outputs', [
    ('vocab_gcs_uri', str),
    # ('feature_name', str),
]):

    """
    custom pipeline component to adapt the `pl_name_src` layer
    writes vocab to pickled dict in GCS
    dict combined with other layer vocabs and used in Two Tower training
    """
    
    # import packages
    import os
    import logging
    import pickle as pkl
    import time
    
    from google.cloud import storage
    
    import tensorflow as tf
    
    storage_client = storage.Client(project=project)
    
    logging.info(f"feature_name: {feature_name}")
    # logging.info(f"feat_type: {feat_type}")
    
    # ===================================================
    # set feature vars
    # ===================================================
    MAX_PLAYLIST_LENGTH = max_playlist_length
    logging.info(f"MAX_PLAYLIST_LENGTH: {MAX_PLAYLIST_LENGTH}")
    
    FEATURES_PREFIX = f'{experiment_name}/{experiment_run}/features'
    logging.info(f"FEATURES_PREFIX: {FEATURES_PREFIX}")
    
    all_features_dict = {}
    
    # ===================================================
    # load pickled Candidate features
    # ===================================================
    
    # candidate features
    CAND_FEAT_FILENAME = 'candidate_feats_dict.pkl'
    CAND_FEAT_GCS_OBJ = f'{FEATURES_PREFIX}/{CAND_FEAT_FILENAME}'
    LOADED_CANDIDATE_DICT = f'loaded_{CAND_FEAT_FILENAME}'
    logging.info(f"CAND_FEAT_FILENAME: {CAND_FEAT_FILENAME}; CAND_FEAT_GCS_OBJ:{CAND_FEAT_GCS_OBJ}; LOADED_CANDIDATE_DICT: {LOADED_CANDIDATE_DICT}")
    
    # os.system(f'gsutil cp gs://{train_output_gcs_bucket}/{CAND_FEAT_GCS_OBJ} {LOADED_CANDIDATE_DICT}')
    bucket = storage_client.bucket(train_output_gcs_bucket)
    blob = bucket.blob(CAND_FEAT_GCS_OBJ)
    blob.download_to_filename(LOADED_CANDIDATE_DICT)
    
    filehandler = open(f'{LOADED_CANDIDATE_DICT}', 'rb')
    loaded_candidate_features_dict = pkl.load(filehandler)
    filehandler.close()
    logging.info(f"loaded_candidate_features_dict: {loaded_candidate_features_dict}")
    
    all_features_dict.update(loaded_candidate_features_dict)
    logging.info(f"all_features_dict: {all_features_dict}")

    # ===================================================
    # load pickled Query features
    # ===================================================

    # query features
    QUERY_FEAT_FILENAME = 'query_feats_dict.pkl'
    QUERY_FEAT_GCS_OBJ = f'{FEATURES_PREFIX}/{QUERY_FEAT_FILENAME}'
    LOADED_QUERY_DICT = f'loaded_{QUERY_FEAT_FILENAME}'
    logging.info(f"QUERY_FEAT_FILENAME: {QUERY_FEAT_FILENAME}; QUERY_FEAT_GCS_OBJ:{QUERY_FEAT_GCS_OBJ}; LOADED_QUERY_DICT: {LOADED_QUERY_DICT}")
    
    # os.system(f'gsutil cp gs://{train_output_gcs_bucket}/{QUERY_FEATURES_GCS_OBJ} {LOADED_QUERY_DICT}')
    bucket = storage_client.bucket(train_output_gcs_bucket)
    blob = bucket.blob(QUERY_FEAT_GCS_OBJ)
    blob.download_to_filename(LOADED_QUERY_DICT)
    
    filehandler = open(f'{LOADED_QUERY_DICT}', 'rb')
    loaded_query_features_dict = pkl.load(filehandler)
    filehandler.close()
    logging.info(f"loaded_query_features_dict: {loaded_query_features_dict}")
    
    all_features_dict.update(loaded_query_features_dict)
    logging.info(f"all_features_dict: {all_features_dict}")
    
    # ===================================================
    # tfrecord parser
    # ===================================================
    
    # candidate_features = {
    #     "track_uri_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),            
    #     "track_name_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),
    #     "artist_uri_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),
    #     "artist_name_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),
    #     "album_uri_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),           
    #     "album_name_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()), 
    #     "duration_ms_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),      
    #     "track_pop_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),      
    #     "artist_pop_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
    #     "artist_genres_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),
    #     "artist_followers_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
    #     # new
    #     # "track_pl_titles_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),
    #     "track_danceability_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
    #     "track_energy_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
    #     "track_key_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),
    #     "track_loudness_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
    #     "track_mode_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),
    #     "track_speechiness_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
    #     "track_acousticness_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
    #     "track_instrumentalness_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
    #     "track_liveness_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
    #     "track_valence_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
    #     "track_tempo_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
    #     "time_signature_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),
    # }

#     feats = {
#         # ===================================================
#         # candidate track features
#         # ===================================================
#         "track_uri_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),            
#         "track_name_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),
#         "artist_uri_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),
#         "artist_name_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),
#         "album_uri_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),           
#         "album_name_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()), 
#         "duration_ms_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),      
#         "track_pop_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),      
#         "artist_pop_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
#         "artist_genres_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),
#         "artist_followers_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
#         # "track_pl_titles_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),
#         "track_danceability_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
#         "track_energy_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
#         "track_key_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),
#         "track_loudness_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
#         "track_mode_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),
#         "track_speechiness_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
#         "track_acousticness_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
#         "track_instrumentalness_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
#         "track_liveness_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
#         "track_valence_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
#         "track_tempo_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
#         "time_signature_can": tf.io.FixedLenFeature(dtype=tf.string, shape=()), # track_time_signature_can

#         # ===================================================
#         # summary playlist features
#         # ===================================================
#         "pl_name_src" : tf.io.FixedLenFeature(dtype=tf.string, shape=()), 
#         'pl_collaborative_src' : tf.io.FixedLenFeature(dtype=tf.string, shape=()), 
#         # 'num_pl_followers_src' : tf.io.FixedLenFeature(dtype=tf.float32, shape=()), 
#         'pl_duration_ms_new' : tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
#         'num_pl_songs_new' : tf.io.FixedLenFeature(dtype=tf.float32, shape=()), # n_songs_pl_new | num_pl_songs_new
#         'num_pl_artists_new' : tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
#         'num_pl_albums_new' : tf.io.FixedLenFeature(dtype=tf.float32, shape=()), 
#         # 'avg_track_pop_pl_new' : tf.io.FixedLenFeature(dtype=tf.float32, shape=()), 
#         # 'avg_artist_pop_pl_new' : tf.io.FixedLenFeature(dtype=tf.float32, shape=()), 
#         # 'avg_art_followers_pl_new' : tf.io.FixedLenFeature(dtype=tf.float32, shape=()), 

#         # ===================================================
#         # ragged playlist features
#         # ===================================================
#         # bytes / string
#         "track_uri_pl": tf.io.FixedLenFeature(dtype=tf.string, shape=(MAX_PLAYLIST_LENGTH,)),
#         "track_name_pl": tf.io.FixedLenFeature(dtype=tf.string, shape=(MAX_PLAYLIST_LENGTH,)),
#         "artist_uri_pl": tf.io.FixedLenFeature(dtype=tf.string, shape=(MAX_PLAYLIST_LENGTH,)),
#         "artist_name_pl": tf.io.FixedLenFeature(dtype=tf.string, shape=(MAX_PLAYLIST_LENGTH,)),
#         "album_uri_pl": tf.io.FixedLenFeature(dtype=tf.string, shape=(MAX_PLAYLIST_LENGTH,)),
#         "album_name_pl": tf.io.FixedLenFeature(dtype=tf.string, shape=(MAX_PLAYLIST_LENGTH,)),
#         "artist_genres_pl": tf.io.FixedLenFeature(dtype=tf.string, shape=(MAX_PLAYLIST_LENGTH,)),
#         # "tracks_playlist_titles_pl": tf.io.FixedLenFeature(dtype=tf.string, shape=(MAX_PLAYLIST_LENGTH,)),
#         "track_key_pl": tf.io.FixedLenFeature(dtype=tf.string, shape=(MAX_PLAYLIST_LENGTH,)),
#         "track_mode_pl": tf.io.FixedLenFeature(dtype=tf.string, shape=(MAX_PLAYLIST_LENGTH,)),
#         "time_signature_pl": tf.io.FixedLenFeature(dtype=tf.string, shape=(MAX_PLAYLIST_LENGTH,)), 

#         # Float List
#         "duration_ms_songs_pl": tf.io.FixedLenFeature(dtype=tf.float32, shape=(MAX_PLAYLIST_LENGTH,)),
#         "track_pop_pl": tf.io.FixedLenFeature(dtype=tf.float32, shape=(MAX_PLAYLIST_LENGTH,)),
#         "artist_pop_pl": tf.io.FixedLenFeature(dtype=tf.float32, shape=(MAX_PLAYLIST_LENGTH,)),
#         "artists_followers_pl": tf.io.FixedLenFeature(dtype=tf.float32, shape=(MAX_PLAYLIST_LENGTH,)),
#         "track_danceability_pl": tf.io.FixedLenFeature(dtype=tf.float32, shape=(MAX_PLAYLIST_LENGTH,)),
#         "track_energy_pl": tf.io.FixedLenFeature(dtype=tf.float32, shape=(MAX_PLAYLIST_LENGTH,)),
#         "track_loudness_pl": tf.io.FixedLenFeature(dtype=tf.float32, shape=(MAX_PLAYLIST_LENGTH,)),
#         "track_speechiness_pl": tf.io.FixedLenFeature(dtype=tf.float32, shape=(MAX_PLAYLIST_LENGTH,)),
#         "track_acousticness_pl": tf.io.FixedLenFeature(dtype=tf.float32, shape=(MAX_PLAYLIST_LENGTH,)),
#         "track_instrumentalness_pl": tf.io.FixedLenFeature(dtype=tf.float32, shape=(MAX_PLAYLIST_LENGTH,)),
#         "track_liveness_pl": tf.io.FixedLenFeature(dtype=tf.float32, shape=(MAX_PLAYLIST_LENGTH,)),
#         "track_valence_pl": tf.io.FixedLenFeature(dtype=tf.float32, shape=(MAX_PLAYLIST_LENGTH,)),
#         "track_tempo_pl": tf.io.FixedLenFeature(dtype=tf.float32, shape=(MAX_PLAYLIST_LENGTH,)),
#     }
    
    # parsing function
    def parse_tfrecord(example):
        """
        Reads a serialized example from GCS and converts to tfrecord
        """
        # example = tf.io.parse_single_example(
        example = tf.io.parse_example(
            example,
            # feats
            features=all_features_dict
        )
        return example
    
    
    if generate_new_vocab:
        logging.info(f"Generating new vocab file...")
    
        # list blobs (tfrecords)
        train_files = []
        for blob in storage_client.list_blobs(f'{data_dir_bucket_name}', prefix=f'{data_dir_path_prefix}'):
            if '.tfrecords' in blob.name:
                train_files.append(blob.public_url.replace("https://storage.googleapis.com/", "gs://"))

        logging.info(f"TFRecord file count: {len(train_files)}")

        # ===================================================
        # create TF dataset
        # ===================================================
        logging.info(f"Creating TFRecordDataset...")
        train_dataset = tf.data.TFRecordDataset(train_files)
        train_parsed = train_dataset.map(parse_tfrecord)

        # ===================================================
        # adapt layer for feature
        # ===================================================

        start = time.time()
        text_layer = tf.keras.layers.TextVectorization(
            max_tokens=max_tokens,
            ngrams=ngrams
        )
        text_layer.adapt(train_parsed.map(lambda x: tf.reshape(x[f'{feature_name}'], [-1, MAX_PLAYLIST_LENGTH, 1])))
        end = time.time()

        logging.info(f'Layer adapt elapsed time: {round((end - start), 2)} seconds')

        # ===================================================
        # write vocab to pickled dict --> gcs
        # ===================================================
        logging.info(f"Writting pickled dict to GCS...")

        VOCAB_LOCAL_FILE = f'{feature_name}_vocab_dict.pkl'
        VOCAB_GCS_OBJ = f'{experiment_name}/{experiment_run}/vocab-staging/{VOCAB_LOCAL_FILE}' # destination folder prefix and blob name
        VOCAB_DICT = {f'{feature_name}' : text_layer.get_vocabulary(),}

        logging.info(f"VOCAB_LOCAL_FILE: {VOCAB_LOCAL_FILE}")
        logging.info(f"VOCAB_GCS_OBJ: {VOCAB_GCS_OBJ}")

        # pickle
        filehandler = open(f'{VOCAB_LOCAL_FILE}', 'wb')
        pkl.dump(VOCAB_DICT, filehandler)
        filehandler.close()

        # upload to GCS
        bucket_client = storage_client.bucket(train_output_gcs_bucket)
        blob = bucket_client.blob(VOCAB_GCS_OBJ)
        blob.upload_from_filename(VOCAB_LOCAL_FILE)

        vocab_uri = f'gs://{train_output_gcs_bucket}/{VOCAB_GCS_OBJ}'

        logging.info(f"File {VOCAB_LOCAL_FILE} uploaded to {vocab_uri}")
        
    else:
        logging.info(f"Using existing vocab files...")
        vocab_uri = 'gs://two-tower-models/vocabs/vocab_dict.pkl'
        logging.info(f"Using vocab file: {vocab_uri}")
    
    return(
        vocab_uri,
        # feature_name,
    )
    

Writing src/train_pipes/adapt_ragged_text_layer_vocab.py


### create master vocab

In [24]:
%%writefile {REPO_DOCKER_PATH_PREFIX}/{PIPELINES_SUB_DIR}/create_master_vocab.py

import kfp
from typing import Any, Callable, Dict, NamedTuple, Optional, List
from kfp.v2.dsl import (Artifact, Dataset, Input, InputPath, Model, Output,
                        OutputPath, component, Metrics)

@kfp.v2.dsl.component(
    base_image="python:3.9",
    packages_to_install=[
        # 'google-cloud-aiplatform==1.18.1',
        'google-cloud-storage',
        'numpy',
        # 'tensorflow==2.8.3',
    ],
)
def create_master_vocab(
    project: str,
    location: str,
    version: str,
    train_output_gcs_bucket: str,
    experiment_name: str,
    experiment_run: str,
    vocab_uri_1: str,
    vocab_uri_2: str,
    vocab_uri_3: str,
    vocab_uri_4: str,
    vocab_uri_5: str,
    vocab_uri_6: str,
    vocab_uri_7: str,
    vocab_uri_8: str,
    vocab_uri_9: str,
    generate_new_vocab: bool,
) -> NamedTuple('Outputs', [
    ('master_vocab_gcs_uri', str),
    ('experiment_name', str),
    ('experiment_run', str),
]):
    
    """
    combine layer dictionaires to master dictionary
    master dictionary passed to train job for layer vocabs
    """
    
    # import packages
    import os
    import logging
    import pickle as pkl
    import time
    import numpy as np
    
    from google.cloud import storage
    
    # setup clients
    storage_client = storage.Client()
    
    if generate_new_vocab:
        
        logging.info(f"Generating new vocab master file...")
        # ===================================================
        # Create list of all layer vocab dict uris
        # ===================================================

        vocab_dict_uris = [
            vocab_uri_1, vocab_uri_2, 
            vocab_uri_3, vocab_uri_4, 
            vocab_uri_5, vocab_uri_6, 
            vocab_uri_7, vocab_uri_8, 
            vocab_uri_9, 
        ]
        logging.info(f"count of vocab_dict_uris: {len(vocab_dict_uris)}")
        logging.info(f"vocab_dict_uris: {vocab_dict_uris}")

        # ===================================================
        # load pickled dicts
        # ===================================================

        loaded_pickle_list = []
        for i, pickled_dict in enumerate(vocab_dict_uris):

            with open(f"v{i}_vocab_pre_load", 'wb') as local_vocab_file:
                storage_client.download_blob_to_file(pickled_dict, local_vocab_file)

            with open(f"v{i}_vocab_pre_load", 'rb') as pickle_file:
                loaded_vocab_dict = pkl.load(pickle_file)

            loaded_pickle_list.append(loaded_vocab_dict)

        # ===================================================
        # create master vocab dict
        # ===================================================
        master_dict = {}
        for loaded_dict in loaded_pickle_list:
            master_dict.update(loaded_dict)

        # ===================================================
        # Upload master to GCS
        # ===================================================
        MASTER_VOCAB_LOCAL_FILE = f'vocab_dict.pkl'
        MASTER_VOCAB_GCS_OBJ = f'{experiment_name}/{experiment_run}/{MASTER_VOCAB_LOCAL_FILE}' # destination folder prefix and blob name

        # pickle
        filehandler = open(f'{MASTER_VOCAB_LOCAL_FILE}', 'wb')
        pkl.dump(master_dict, filehandler)
        filehandler.close()

        # upload to GCS
        bucket_client = storage_client.bucket(train_output_gcs_bucket)
        blob = bucket_client.blob(MASTER_VOCAB_GCS_OBJ)
        blob.upload_from_filename(MASTER_VOCAB_LOCAL_FILE)

        master_vocab_uri = f'gs://{train_output_gcs_bucket}/{MASTER_VOCAB_GCS_OBJ}'

        logging.info(f"File {MASTER_VOCAB_LOCAL_FILE} uploaded to {master_vocab_uri}")
        
    else:
        logging.info(f"Using existing vocab file...")
        master_vocab_uri = 'gs://two-tower-models/vocabs/vocab_dict.pkl'
        logging.info(f"Using vocab file: {master_vocab_uri}")
    
    return(
        master_vocab_uri,
        experiment_name,
        experiment_run
    )

Writing src/train_pipes/create_master_vocab.py


## Custom train job

In [25]:
%%writefile {REPO_DOCKER_PATH_PREFIX}/{PIPELINES_SUB_DIR}/train_custom_model.py

import kfp
from typing import Any, Callable, Dict, NamedTuple, Optional, List
from kfp.v2.dsl import (Artifact, Dataset, Input, InputPath, Model, Output,
                        OutputPath, component, Metrics)
@kfp.v2.dsl.component(
    base_image='python:3.9',
    packages_to_install=[
        'google-cloud-aiplatform==1.21.0',
        # 'tensorflow==2.9.2',
        # 'tensorflow-recommenders==0.7.0',
        'numpy',
        'google-cloud-storage',
    ],
    # output_component_file="./pipelines/train_custom_model.yaml",
)
def train_custom_model(
    project: str,
    location: str,
    model_version: str,
    pipeline_version: str,
    model_name: str, 
    worker_pool_specs: dict,
    # vocab_dict_uri: str, 
    train_output_gcs_bucket: str,                         # change to workdir?
    training_image_uri: str,
    tensorboard_resource_name: str,
    service_account: str,
    experiment_name: str,
    experiment_run: str,
    generate_new_vocab: bool,
) -> NamedTuple('Outputs', [
    ('job_dict_uri', str),
    ('query_tower_dir_uri', str),
    ('candidate_tower_dir_uri', str),
    ('experiment_run_dir', str),
]):
    
    import logging
    import numpy as np
    import pickle as pkl
    
    from google.cloud import aiplatform as vertex_ai
    # import google.cloud.aiplatform_v1beta1 as aip_beta
    from google.cloud import storage
    
    vertex_ai.init(
        project=project,
        location=location,
        experiment=experiment_name,
    )
    
    storage_client = storage.Client(project=project)
    
    JOB_NAME = f'train-{model_name}'
    logging.info(f'JOB_NAME: {JOB_NAME}')
    
    BASE_OUTPUT_DIR = f'gs://{train_output_gcs_bucket}/{experiment_name}/{experiment_run}'
    logging.info(f'BASE_OUTPUT_DIR: {BASE_OUTPUT_DIR}')
    
    # logging.info(f'vocab_dict_uri: {vocab_dict_uri}')
    
    logging.info(f'tensorboard_resource_name: {tensorboard_resource_name}')
    logging.info(f'service_account: {service_account}')
    logging.info(f'worker_pool_specs: {worker_pool_specs}')
    
    # ====================================================
    # Launch Vertex job
    # ====================================================
    
    worker_pool_specs[0]['container_spec']['args'].append(f'--tb_resource_name={tensorboard_resource_name}')
    
    if generate_new_vocab == 'True':
        worker_pool_specs[0]['container_spec']['args'].append(f'--new_vocab')
  
    job = vertex_ai.CustomJob(
        display_name=JOB_NAME,
        worker_pool_specs=worker_pool_specs,
        base_output_dir=BASE_OUTPUT_DIR,
        staging_bucket=f"{BASE_OUTPUT_DIR}/staging",
    )
    
    logging.info(f'Submitting train job to Vertex AI...')
    
    job.run(
        tensorboard=tensorboard_resource_name,
        service_account=f'{service_account}',
        restart_job_on_worker_restart=False,
        enable_web_access=True,
        sync=False,
    )
        
    # wait for job to complete
    job.wait()
    
    # ====================================================
    # Save job details
    # ====================================================
    
    train_job_dict = job.to_dict()
    logging.info(f'train_job_dict: {train_job_dict}')
    
    # pkl dict to GCS
    logging.info(f"Write pickled dict to GCS...")
    TRAIN_DICT_LOCAL = f'train_job_dict.pkl'
    TRAIN_DICT_GCS_OBJ = f'{experiment_name}/{experiment_run}/{TRAIN_DICT_LOCAL}' # destination folder prefix and blob name
    
    logging.info(f"TRAIN_DICT_LOCAL: {TRAIN_DICT_LOCAL}")
    logging.info(f"TRAIN_DICT_GCS_OBJ: {TRAIN_DICT_GCS_OBJ}")

    # pickle
    filehandler = open(f'{TRAIN_DICT_LOCAL}', 'wb')
    pkl.dump(train_job_dict, filehandler)
    filehandler.close()
    
    # upload to GCS
    bucket_client = storage_client.bucket(train_output_gcs_bucket)
    blob = bucket_client.blob(TRAIN_DICT_GCS_OBJ)
    blob.upload_from_filename(TRAIN_DICT_LOCAL)
    
    job_dict_uri = f'gs://{train_output_gcs_bucket}/{TRAIN_DICT_GCS_OBJ}'
    logging.info(f"{TRAIN_DICT_LOCAL} uploaded to {job_dict_uri}")
    
    # ====================================================
    # Model and index artifact uris
    # ====================================================
    EXPERIMENT_RUN_DIR = f"gs://{train_output_gcs_bucket}/{experiment_name}/{experiment_run}"
    query_tower_dir_uri = f"{EXPERIMENT_RUN_DIR}/model-dir/query_model" 
    candidate_tower_dir_uri = f"{EXPERIMENT_RUN_DIR}/model-dir/candidate_model"
    # candidate_index_dir_uri = f"gs://{output_dir_gcs_bucket_name}/{experiment_name}/{experiment_run}/candidate_model"
    
    logging.info(f'query_tower_dir_uri: {query_tower_dir_uri}')
    logging.info(f'candidate_tower_dir_uri: {candidate_tower_dir_uri}')
    # logging.info(f'candidate_index_dir_uri: {candidate_index_dir_uri}')
    
    return (
        f'{job_dict_uri}',
        f'{query_tower_dir_uri}',
        f'{candidate_tower_dir_uri}',
        f'{EXPERIMENT_RUN_DIR}',
    )

Writing src/train_pipes/train_custom_model.py


## Generate Candidates

In [26]:
%%writefile {REPO_DOCKER_PATH_PREFIX}/{PIPELINES_SUB_DIR}/generate_candidates.py

import kfp
from typing import Any, Callable, Dict, NamedTuple, Optional, List
from kfp.v2.dsl import (Artifact, Dataset, Input, InputPath, Model, Output,
                        OutputPath, component, Metrics)

@kfp.v2.dsl.component(
    base_image="python:3.9",
    packages_to_install=[
        'google-cloud-aiplatform==1.21.0',
        'tensorflow==2.10.1',
        'tensorflow-recommenders==0.7.2',
        'numpy',
        # 'google-cloud-storage',
    ],
)
def generate_candidates(
    project: str,
    location: str,
    version: str, 
    # emb_index_gcs_uri: str,
    candidate_tower_dir_uri: str,
    candidate_file_dir_bucket: str,
    candidate_file_dir_prefix: str,
    train_output_gcs_bucket: str,
    experiment_name: str,
    experiment_run: str,
    experiment_run_dir: str,
) -> NamedTuple('Outputs', [
    ('emb_index_gcs_uri', str),
    # ('emb_index_artifact', Artifact),
]):
    import logging
    import json
    import pickle as pkl
    from pprint import pprint
    import time
    import numpy as np

    import os

    # os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 

    import tensorflow as tf
    import tensorflow_recommenders as tfrs

    from google.cloud import storage
    from google.cloud.storage.bucket import Bucket
    from google.cloud.storage.blob import Blob

    import google.cloud.aiplatform as vertex_ai
    
    # set clients
    vertex_ai.init(
        project=project,
        location=location,
    )
    storage_client = storage.Client(project=project)

    # tf.Data confg
    options = tf.data.Options()
    options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
    
    # ====================================================
    # Load trained candidate tower
    # ====================================================
    logging.info(f"candidate_tower_dir_uri: {candidate_tower_dir_uri}")
    
    loaded_candidate_model = tf.saved_model.load(candidate_tower_dir_uri)
    logging.info(f"loaded_candidate_model.signatures: {loaded_candidate_model.signatures}")
    
    candidate_predictor = loaded_candidate_model.signatures["serving_default"]
    logging.info(f"structured_outputs: {candidate_predictor.structured_outputs}")
    
    # ===================================================
    # set feature vars
    # ===================================================
    FEATURES_PREFIX = f'{experiment_name}/{experiment_run}/features'
    logging.info(f"FEATURES_PREFIX: {FEATURES_PREFIX}")
    
    def download_blob(bucket_name, source_gcs_obj, local_filename):
        """Uploads a file to the bucket."""
        # storage_client = storage.Client(project=project_number)
        bucket = storage_client.bucket(bucket_name)
        blob = bucket.blob(source_gcs_obj)
        blob.download_to_filename(local_filename)
        
        filehandler = open(f'{local_filename}', 'rb')
        loaded_dict = pkl.load(filehandler)
        filehandler.close()
        
        logging.info(f"File {local_filename} downloaded from gs://{bucket_name}/{source_gcs_obj}")
        
        return loaded_dict
    
    # ===================================================
    # load pickled Candidate features
    # ===================================================
    
    # candidate features
    CAND_FEAT_FILENAME = 'candidate_feats_dict.pkl'
    CAND_FEAT_GCS_OBJ = f'{FEATURES_PREFIX}/{CAND_FEAT_FILENAME}'
    LOADED_CANDIDATE_DICT = f'loaded_{CAND_FEAT_FILENAME}'
    
    loaded_candidate_features_dict = download_blob(
        train_output_gcs_bucket,
        CAND_FEAT_GCS_OBJ,
        LOADED_CANDIDATE_DICT
    )
    
    # ====================================================
    # Features and Helper Functions
    # ====================================================
    
    def parse_candidate_tfrecord_fn(example):
        """
        Reads candidate serialized examples from gcs and converts to tfrecord
        """
        # example = tf.io.parse_single_example(
        example = tf.io.parse_example(
            example, 
            features=loaded_candidate_features_dict
        )
        return example

    def full_parse(data):
        # used for interleave - takes tensors and returns a tf.dataset
        data = tf.data.TFRecordDataset(data)
        return data
    
    # ====================================================
    # Create Candidate Dataset
    # ====================================================

    candidate_files = []
    for blob in storage_client.list_blobs(f"{candidate_file_dir_bucket}", prefix=f'{candidate_file_dir_prefix}/'):
        if '.tfrecords' in blob.name:
            candidate_files.append(blob.public_url.replace("https://storage.googleapis.com/", "gs://"))

    candidate_dataset = tf.data.Dataset.from_tensor_slices(candidate_files)

    parsed_candidate_dataset = candidate_dataset.interleave(
        # lambda x: tf.data.TFRecordDataset(x),
        full_parse,
        cycle_length=tf.data.AUTOTUNE, 
        num_parallel_calls=tf.data.AUTOTUNE,
        deterministic=False
    ).map(parse_candidate_tfrecord_fn, num_parallel_calls=tf.data.AUTOTUNE).with_options(options)

    parsed_candidate_dataset = parsed_candidate_dataset.cache() #400 MB on machine mem
    
    # ====================================================
    # Generate embedding vectors for each candidate
    # ====================================================
    logging.info("Starting candidate dataset mapping...")
    
    start_time = time.time()

    embs_iter = parsed_candidate_dataset.batch(1000).map(
        lambda data: candidate_predictor(
            track_uri_can = data["track_uri_can"],
            track_name_can = data['track_name_can'],
            artist_uri_can = data['artist_uri_can'],
            artist_name_can = data['artist_name_can'],
            album_uri_can = data['album_uri_can'],
            album_name_can = data['album_name_can'],
            duration_ms_can = data['duration_ms_can'],
            track_pop_can = data['track_pop_can'],
            artist_pop_can = data['artist_pop_can'],
            artist_genres_can = data['artist_genres_can'],
            artist_followers_can = data['artist_followers_can'],
            track_danceability_can = data['track_danceability_can'],
            track_energy_can = data['track_energy_can'],
            track_key_can = data['track_key_can'],
            track_loudness_can = data['track_loudness_can'],
            track_mode_can = data['track_mode_can'],
            track_speechiness_can = data['track_speechiness_can'],
            track_acousticness_can = data['track_acousticness_can'],
            track_instrumentalness_can = data['track_instrumentalness_can'],
            track_liveness_can = data['track_liveness_can'],
            track_valence_can = data['track_valence_can'],
            track_tempo_can = data['track_tempo_can'],
            time_signature_can = data['time_signature_can']
        )
    )

    embs = []
    for emb in embs_iter:
        embs.append(emb)

    end_time = time.time()
    elapsed_time = int((end_time - start_time) / 60)
    logging.info(f"elapsed_time: {elapsed_time}")
    logging.info(f"Length of embs: {len(embs)}")
    logging.info(f"embeddings[0]: {embs[0]}")
    
    # ====================================================
    # prep Track IDs and Vectors for JSON
    # ====================================================
    logging.info("Cleaning embeddings and track IDs...")
    start_time = time.time()
    
    cleaned_embs = [x['output_1'].numpy()[0] for x in embs] #clean up the output
    
    end_time = time.time()
    elapsed_time = int((end_time - start_time) / 60)
    logging.info(f"elapsed_time: {elapsed_time}")
    logging.info(f"Length of cleaned_embs: {len(cleaned_embs)}")
    
    # clean track IDs
    track_uris = [x['track_uri_can'].numpy() for x in parsed_candidate_dataset]
    logging.info(f"Length of track_uris: {len(track_uris)}")
    
    track_uris_decoded = [z.decode("utf-8") for z in track_uris]
    logging.info(f"Length of track_uris_decoded: {len(track_uris_decoded)}")
    
    # check for bad records
    bad_records = []

    for i, emb in enumerate(cleaned_embs):
        bool_emb = np.isnan(emb)
        for val in bool_emb:
            if val:
                bad_records.append(i)

    bad_record_filter = np.unique(bad_records)

    logging.info(f"bad_records: {len(bad_records)}")
    logging.info(f"bad_record_filter: {len(bad_record_filter)}")
    
    # ZIP together
    logging.info("Zipping IDs and vectors ...")
    
    track_uris_valid = []
    emb_valid = []

    for i, pair in enumerate(zip(track_uris_decoded, cleaned_embs)):
        if i in bad_record_filter:
            pass
        else:
            t_uri, embed = pair
            track_uris_valid.append(t_uri)
            emb_valid.append(embed)
            
    logging.info(f"track_uris_valid[0]: {track_uris_valid[0]}")
    logging.info(f"bad_records: {len(bad_records)}")
            
    # ====================================================
    # writting JSON file to GCS
    # ====================================================
    TIMESTAMP = time.strftime("%Y%m%d-%H%M%S")
    embeddings_index_filename = f'candidate_embs_{version}_{TIMESTAMP}.json'

    with open(f'{embeddings_index_filename}', 'w') as f:
        for prod, emb in zip(track_uris_valid, emb_valid):
            f.write('{"id":"' + str(prod) + '",')
            f.write('"embedding":[' + ",".join(str(x) for x in list(emb)) + "]}")
            f.write("\n")
            
    # write to GCS
    INDEX_GCS_URI = f'{experiment_run_dir}/candidate-embeddings-{version}'
    logging.info(f"INDEX_GCS_URI: {INDEX_GCS_URI}")

    DESTINATION_BLOB_NAME = embeddings_index_filename
    SOURCE_FILE_NAME = embeddings_index_filename

    logging.info(f"DESTINATION_BLOB_NAME: {DESTINATION_BLOB_NAME}")
    logging.info(f"SOURCE_FILE_NAME: {SOURCE_FILE_NAME}")
    
    blob = Blob.from_string(os.path.join(INDEX_GCS_URI, DESTINATION_BLOB_NAME))
    blob.bucket._client = storage_client
    blob.upload_from_filename(SOURCE_FILE_NAME)
    
    return (
        f'{INDEX_GCS_URI}',
        # f'{INDEX_GCS_URI}',
    )

Writing src/train_pipes/generate_candidates.py


## Create ANN Index

In [92]:
%%writefile {REPO_DOCKER_PATH_PREFIX}/{PIPELINES_SUB_DIR}/create_ann_index.py

import kfp
from typing import Any, Callable, Dict, NamedTuple, Optional, List
from kfp.v2.dsl import (Artifact, Dataset, Input, InputPath, Model, Output,
                        OutputPath, component, Metrics)

@kfp.v2.dsl.component(
    base_image="python:3.9",
    packages_to_install=[
        'google-cloud-aiplatform==1.21.0',
        'google-api-core==2.10.0'
        # 'google-cloud-storage',
    ],
)
def create_ann_index(
    project: str,
    location: str,
    version: str, 
    vpc_network_name: str,
    emb_index_gcs_uri: str,
    dimensions: int,
    ann_index_display_name: str,
    approximate_neighbors_count: int,
    distance_measure_type: str,
    leaf_node_embedding_count: int,
    leaf_nodes_to_search_percent: int, 
    ann_index_description: str,
    # ann_index_labels: Dict, 
) -> NamedTuple('Outputs', [
    ('ann_index_resource_uri', str),
    ('ann_index', Artifact),
]):
    import logging
    from google.cloud import aiplatform as vertex_ai
    from datetime import datetime
    import time

    vertex_ai.init(
        project=project,
        location=location,
    )
    
    VERSION = version.replace('_', '-')
    
    TIMESTAMP = datetime.now().strftime("%Y%m%d%H%M%S")
    
    ENDPOINT = "{}-aiplatform.googleapis.com".format(location)
    NETWORK_NAME = vpc_network_name
    INDEX_DIR_GCS = emb_index_gcs_uri
    PARENT = "projects/{}/locations/{}".format(project, location)

    logging.info(f"ENDPOINT: {ENDPOINT}")
    logging.info(f"project: {project}")
    logging.info(f"location: {location}")
    logging.info(f"INDEX_DIR_GCS: {INDEX_DIR_GCS}")
    
    display_name = f'{ann_index_display_name}-{VERSION}'
    
    logging.info(f"display_name: {display_name}")
    
    # ==============================================================================
    # Create Index 
    # ==============================================================================

    start = time.time()
        
    tree_ah_index = vertex_ai.MatchingEngineIndex.create_tree_ah_index(
        display_name=display_name,
        contents_delta_uri=f'{emb_index_gcs_uri}', # emb_index_gcs_uri,
        dimensions=dimensions,
        approximate_neighbors_count=approximate_neighbors_count,
        distance_measure_type=distance_measure_type,
        leaf_node_embedding_count=leaf_node_embedding_count,
        leaf_nodes_to_search_percent=leaf_nodes_to_search_percent,
        description=ann_index_description,
        # labels=ann_index_labels,
        sync=False,
    )

    end = time.time()
    elapsed_time = round((end - start), 2)
    logging.info(f'Elapsed time creating index: {elapsed_time} seconds\n')
    
    ann_index_resource_uri = tree_ah_index.resource_name
    logging.info("ann_index_resource_uri:", ann_index_resource_uri) 

    return (
      f'{ann_index_resource_uri}',
      tree_ah_index,
    )

Overwriting src/train_pipes/create_ann_index.py


## Create brute force index

In [93]:
%%writefile {REPO_DOCKER_PATH_PREFIX}/{PIPELINES_SUB_DIR}/create_brute_force_index.py

import kfp
from typing import Any, Callable, Dict, NamedTuple, Optional, List
from kfp.v2.dsl import (Artifact, Dataset, Input, InputPath, Model, Output,
                        OutputPath, component, Metrics)

@kfp.v2.dsl.component(
    base_image="python:3.9",
    packages_to_install=[
        'google-cloud-aiplatform==1.21.0',
        'google-api-core==2.10.0',
        # 'google-cloud-storage',
    ],
)
def create_brute_force_index(
    project: str,
    location: str,
    version: str,
    vpc_network_name: str,
    emb_index_gcs_uri: str,
    dimensions: int,
    brute_force_index_display_name: str,
    approximate_neighbors_count: int,
    distance_measure_type: str,
    brute_force_index_description: str,
    # brute_force_index_labels: Dict,
) -> NamedTuple('Outputs', [
    ('brute_force_index_resource_uri', str),
    ('brute_force_index', Artifact),
]):

    import logging
    from google.cloud import aiplatform as vertex_ai
    from datetime import datetime
    import time

    vertex_ai.init(
        project=project,
        location=location,
    )
    
    TIMESTAMP = datetime.now().strftime("%Y%m%d%H%M%S")
    VERSION = version.replace('_', '-')
    
    ENDPOINT = "{}-aiplatform.googleapis.com".format(location)
    NETWORK_NAME = vpc_network_name
    INDEX_DIR_GCS = emb_index_gcs_uri
    PARENT = "projects/{}/locations/{}".format(project, location)

    logging.info("ENDPOINT: {}".format(ENDPOINT))
    logging.info("PROJECT_ID: {}".format(project))
    logging.info("REGION: {}".format(location))
    
    display_name = f'{brute_force_index_display_name}-{VERSION}'
    
    logging.info(f"display_name: {display_name}")
    
    # ==============================================================================
    # Create Index 
    # ==============================================================================

    start = time.time()
    
    brute_force_index = vertex_ai.MatchingEngineIndex.create_brute_force_index(
        display_name=display_name,
        contents_delta_uri=f'{emb_index_gcs_uri}', # emb_index_gcs_uri,
        dimensions=dimensions,
        # approximate_neighbors_count=approximate_neighbors_count,
        distance_measure_type=distance_measure_type,
        description=brute_force_index_description,
        # labels=brute_force_index_labels,
        sync=False,
    )
    brute_force_index_resource_uri = brute_force_index.resource_name
    print("brute_force_index_resource_uri:",brute_force_index_resource_uri) 

    return (
      f'{brute_force_index_resource_uri}',
      brute_force_index,
    )

Overwriting src/train_pipes/create_brute_force_index.py


## Create ANN index endpoint

In [94]:
%%writefile {REPO_DOCKER_PATH_PREFIX}/{PIPELINES_SUB_DIR}/create_ann_index_endpoint_vpc.py

import kfp
from typing import Any, Callable, Dict, NamedTuple, Optional, List
from kfp.v2.dsl import (Artifact, Dataset, Input, InputPath, Model, Output,
                        OutputPath, component, Metrics)

@kfp.v2.dsl.component(
    base_image="python:3.9",
    packages_to_install=[
        'google-cloud-aiplatform==1.21.0',
        # 'google-cloud-storage',
    ],
)
def create_ann_index_endpoint_vpc(
    ann_index_artifact: Input[Artifact],
    project: str,
    project_number: str,
    location: str,
    version: str,
    vpc_network_name: str,
    ann_index_endpoint_display_name: str,
    ann_index_endpoint_description: str,
    ann_index_resource_uri: str,
) -> NamedTuple('Outputs', [
    ('vpc_network_resource_uri', str),
    ('ann_index_endpoint_resource_uri', str),
    ('ann_index_endpoint', Artifact),
    ('ann_index_endpoint_display_name', str),
    ('ann_index_resource_uri', str),
]):

    import logging
    from google.cloud import aiplatform as vertex_ai
    from datetime import datetime
    import time

    vertex_ai.init(
        project=project,
        location=location,
    )
    
    TIMESTAMP = datetime.now().strftime("%Y%m%d%H%M%S")

    vpc_network_resource_uri = f'projects/{project_number}/global/networks/{vpc_network_name}'
    logging.info(f"vpc_network_resource_uri: {vpc_network_resource_uri}")

    ann_index_endpoint = vertex_ai.MatchingEngineIndexEndpoint.create(
        display_name=f'{ann_index_endpoint_display_name}',
        description=ann_index_endpoint_description,
        network=vpc_network_resource_uri,
    )
    ann_index_endpoint_resource_uri = ann_index_endpoint.resource_name
    logging.info(f"ann_index_endpoint_resource_uri: {ann_index_endpoint_resource_uri}")

    return (
        f'{vpc_network_resource_uri}',
        f'{ann_index_endpoint_resource_uri}',
        ann_index_endpoint,
        f'{ann_index_endpoint_display_name}',
        f'{ann_index_resource_uri}',
    )

Overwriting src/train_pipes/create_ann_index_endpoint_vpc.py


## Create brute force index endpoint

In [95]:
%%writefile {REPO_DOCKER_PATH_PREFIX}/{PIPELINES_SUB_DIR}/create_brute_index_endpoint_vpc.py

import kfp
from typing import Any, Callable, Dict, NamedTuple, Optional, List
from kfp.v2.dsl import (Artifact, Dataset, Input, InputPath, Model, Output,
                        OutputPath, component, Metrics)

@kfp.v2.dsl.component(
    base_image="python:3.9",
    packages_to_install=[
        'google-cloud-aiplatform==1.21.0',
        # 'google-cloud-storage',
    ],
)
def create_brute_index_endpoint_vpc(
    bf_index_artifact: Input[Artifact],
    project: str,
    project_number: str,
    location: str,
    version: str,
    vpc_network_name: str,
    brute_index_endpoint_display_name: str,
    brute_index_endpoint_description: str,
    brute_force_index_resource_uri: str,
) -> NamedTuple('Outputs', [
    ('vpc_network_resource_uri', str),
    ('brute_index_endpoint_resource_uri', str),
    ('brute_index_endpoint', Artifact),
    ('brute_index_endpoint_display_name', str),
    ('brute_force_index_resource_uri', str),
]):

    import logging
    from google.cloud import aiplatform as vertex_ai
    from datetime import datetime
    import time

    vertex_ai.init(
        project=project,
        location=location,
    )
    
    TIMESTAMP = datetime.now().strftime("%Y%m%d%H%M%S")

    vpc_network_resource_uri = f'projects/{project_number}/global/networks/{vpc_network_name}'
    logging.info(f"vpc_network_resource_uri: {vpc_network_resource_uri}")

    brute_index_endpoint = vertex_ai.MatchingEngineIndexEndpoint.create(
        display_name=f'{brute_index_endpoint_display_name}',
        description=brute_index_endpoint_description,
        network=vpc_network_resource_uri,
    )
    brute_index_endpoint_resource_uri = brute_index_endpoint.resource_name
    logging.info(f"brute_index_endpoint_resource_uri: {brute_index_endpoint_resource_uri}")

    return (
      f'{vpc_network_resource_uri}',
      f'{brute_index_endpoint_resource_uri}',
      brute_index_endpoint,
      f'{brute_index_endpoint_display_name}',
      f'{brute_force_index_resource_uri}',
    )

Overwriting src/train_pipes/create_brute_index_endpoint_vpc.py


## Deploy ANN Index

In [96]:
%%writefile {REPO_DOCKER_PATH_PREFIX}/{PIPELINES_SUB_DIR}/deploy_ann_index.py

import kfp
from typing import Any, Callable, Dict, NamedTuple, Optional, List
from kfp.v2.dsl import (Artifact, Dataset, Input, InputPath, Model, Output,
                        OutputPath, component, Metrics)

@kfp.v2.dsl.component(
    base_image="python:3.9",
    packages_to_install=[
        'google-cloud-aiplatform==1.21.0'
    ]
)
def deploy_ann_index(
    project: str,
    location: str,
    version: str,
    deployed_ann_index_name: str,
    ann_index_resource_uri: str,
    index_endpoint_resource_uri: str,
) -> NamedTuple('Outputs', [
    ('index_endpoint_resource_uri', str),
    ('ann_index_resource_uri', str),
    ('deployed_ann_index_name', str),
    ('deployed_ann_index', Artifact),
]):
  
    import logging
    from google.cloud import aiplatform as vertex_ai
    from datetime import datetime
    import time

    vertex_ai.init(
        project=project,
        location=location,
    )
    TIMESTAMP = datetime.now().strftime("%Y%m%d%H%M%S")
    deployed_ann_index_name = deployed_ann_index_name.replace('_', '-')
    logging.info(f"deployed_ann_index_name: {deployed_ann_index_name}")
    
    ann_index = vertex_ai.MatchingEngineIndex(
      index_name=ann_index_resource_uri
    )
    ann_index_resource_uri = ann_index.resource_name

    index_endpoint = vertex_ai.MatchingEngineIndexEndpoint(
      index_endpoint_resource_uri
    )

    index_endpoint = index_endpoint.deploy_index(
      index=ann_index, 
      deployed_index_id=f'{deployed_ann_index_name}' #-{TIMESTAMP}'
    )

    logging.info(f"index_endpoint.deployed_indexes: {index_endpoint.deployed_indexes}")

    return (
      f'{index_endpoint_resource_uri}',
      f'{ann_index_resource_uri}',
      f'{deployed_ann_index_name}',
      ann_index,
    )

Overwriting src/train_pipes/deploy_ann_index.py


## Deploy brute force Index

In [97]:
%%writefile {REPO_DOCKER_PATH_PREFIX}/{PIPELINES_SUB_DIR}/deploy_brute_index.py

import kfp
from typing import Any, Callable, Dict, NamedTuple, Optional, List
from kfp.v2.dsl import (Artifact, Dataset, Input, InputPath, Model, Output,
                        OutputPath, component, Metrics)

@kfp.v2.dsl.component(
    base_image="python:3.9",
    packages_to_install=[
        'google-cloud-aiplatform==1.21.0',
        # 'google-cloud-storage',
    ],
)
def deploy_brute_index(
    project: str,
    location: str,
    version: str,
    deployed_brute_force_index_name: str,
    brute_force_index_resource_uri: str,
    index_endpoint_resource_uri: str,
) -> NamedTuple('Outputs', [
    ('index_endpoint_resource_uri', str),
    ('brute_force_index_resource_uri', str),
    ('deployed_brute_force_index_name', str),
    ('deployed_brute_force_index', Artifact),
]):
  
    import logging
    from google.cloud import aiplatform as vertex_ai
    from datetime import datetime
    import time

    vertex_ai.init(
        project=project,
        location=location,
    )
    TIMESTAMP = datetime.now().strftime("%Y%m%d%H%M%S")
    deployed_brute_force_index_name = deployed_brute_force_index_name.replace('_', '-')
    logging.info(f"deployed_brute_force_index_name: {deployed_brute_force_index_name}")

    brute_index = vertex_ai.MatchingEngineIndex(
        index_name=brute_force_index_resource_uri
    )
    brute_force_index_resource_uri = brute_index.resource_name

    index_endpoint = vertex_ai.MatchingEngineIndexEndpoint(index_endpoint_resource_uri)

    index_endpoint = index_endpoint.deploy_index(
        index=brute_index, 
        deployed_index_id=f'{deployed_brute_force_index_name}', #-{TIMESTAMP}'
    )

    logging.info(f"index_endpoint.deployed_indexes: {index_endpoint.deployed_indexes}")

    return (
      f'{index_endpoint_resource_uri}',
      f'{brute_force_index_resource_uri}',
      f'{deployed_brute_force_index_name}', #-{TIMESTAMP}',
      brute_index,
    )

Overwriting src/train_pipes/deploy_brute_index.py


## Test index recall

In [98]:
%%writefile {REPO_DOCKER_PATH_PREFIX}/{PIPELINES_SUB_DIR}/test_index_recall.py

import kfp
from typing import Any, Callable, Dict, NamedTuple, Optional, List
from kfp.v2.dsl import (Artifact, Dataset, Input, InputPath, Model, Output,
                        OutputPath, component, Metrics)

@kfp.v2.dsl.component(
    base_image="python:3.9",
    packages_to_install=[
        'google-cloud-aiplatform==1.21.0',
        'google-cloud-storage',
        'tensorflow==2.10.1',
        'numpy'
    ],
)
def test_index_recall(
    project: str,
    location: str,
    version: str,
    # train_dir: str,
    # train_dir_prefix: str,
    # ann_index_resource_uri: str,
    ann_index_endpoint_resource_uri: str,
    brute_index_endpoint_resource_uri: str,
    gcs_train_script_path: str,
    endpoint: str, # Input[Artifact],
    metrics: Output[Metrics],
):
    
    import logging
    from datetime import datetime
    import time
    import numpy as np
    
    import base64

    from typing import Dict, List, Union

    from google.cloud import aiplatform as vertex_ai
    from google.protobuf import json_format
    from google.protobuf.json_format import Parse
    from google.protobuf.struct_pb2 import Value

    from google_cloud_pipeline_components.proto.gcp_resources_pb2 import GcpResources

    import tensorflow as tf

    logging.getLogger().setLevel(logging.INFO)

    vertex_ai.init(
        project=project,
        location=location,
    )
    storage_client = storage.Client(project=project)
    
    # ====================================================
    # helper functions
    # ====================================================
    
    def download_blob(bucket_name, source_gcs_obj, local_filename):
        """Uploads a file to the bucket."""
        # storage_client = storage.Client(project=project_number)
        bucket = storage_client.bucket(bucket_name)
        blob = bucket.blob(source_gcs_obj)
        blob.download_to_filename(local_filename)
        
        filehandler = open(f'{local_filename}', 'rb')
        loaded_dict = pkl.load(filehandler)
        filehandler.close()
        
        logging.info(f"File {local_filename} downloaded from gs://{bucket_name}/{source_gcs_obj}")
        
        return loaded_dict
    
    # ====================================================
    # get deployed model endpoint
    # ====================================================
    logging.info(f"Endpoint = {endpoint}")
    gcp_resources = Parse(endpoint, GcpResources())
    logging.info(f"gcp_resources = {gcp_resources}")
    
    _endpoint_resource = gcp_resources.resources[0].resource_uri
    logging.info(f"_endpoint_resource = {_endpoint_resource}")
    
    _endpoint_uri = "/".join(_endpoint_resource.split("/")[-8:-2])
    logging.info(f"_endpoint_uri = {_endpoint_uri}")
    
    # define endpoint resource in component
    _endpoint = vertex_ai.Endpoint(_endpoint_uri)
    logging.info(f"_endpoint defined")
    
    # ==============================================================
    # helper function for returning endpoint predictions via json
    # ==============================================================
    
    def predict_custom_trained_model_sample(
        project: str,
        endpoint_id: str,
        instances: Dict,
        location: str = "us-central1",
        api_endpoint: str = "us-central1-aiplatform.googleapis.com",):
        """
        `instances` can be either single instance of type dict or a list
        of instances.
        """

        # The AI Platform services require regional API endpoints.
        client_options = {"api_endpoint": api_endpoint}
        
        # Initialize client that will be used to create and send requests.
        # This client only needs to be created once, and can be reused for multiple requests.
        client = vertex_ai.gapic.PredictionServiceClient(client_options=client_options)
        
        # The format of each instance should conform to the deployed model's prediction input schema.
        instances = instances if type(instances) == list else [instances]
        instances = [
            json_format.ParseDict(instance_dict, Value()) for instance_dict in instances
        ]
        
        parameters_dict = {}
        parameters = json_format.ParseDict(parameters_dict, Value())
        
        endpoint = client.endpoint_path(
            project=project, location=location, endpoint=endpoint_id
        )
        
        response = client.predict(
            endpoint=endpoint, instances=instances, parameters=parameters
        )
        logging.info(f'Response: {response}')
        
        logging.info(f'Deployed Model ID(s): {response.deployed_model_id}')

        # The predictions are a google.protobuf.Value representation of the model's predictions.
        _predictions = response.predictions
        logging.info(f'Response Predictions: {_predictions}')
        
        return _predictions
    
    # ===================================================
    # load test instance
    # ===================================================
    BUCKET_TEST = 'spotify-data-regimes'
    LOCAL_TEST_INSTANCE = 'test_instance_15_dict.pkl'
    PREFIX = 'jtv15-8m'
    TEST_GCS_OBJ = f'{PREFIX}/{LOCAL_TEST_INSTANCE}'
    LOADED_CANDIDATE_DICT = f'loaded_{CAND_FEAT_FILENAME}'
    
    loaded_test_instance = download_blob(
        BUCKET_TEST,
        TEST_GCS_OBJ,
        LOADED_CANDIDATE_DICT
    )
    logging.info(f'loaded_test_instance: {loaded_test_instance}')
    
    # make prediction request
    _endpoint_id = _endpoint_uri.split('/')[-1]    # "633325234048",
    logging.info(f"_endpoint_id created = {_endpoint_id}")
    prediction_test = predict_custom_trained_model_sample(
        project=project,                     
        endpoint_id=_endpoint_id,
        location="us-central1",
        instances=loaded_test_instance
    )
    
    ## Indexes
    logging.info(f"ann_index_endpoint_resource_uri: {ann_index_endpoint_resource_uri}")
    logging.info(f"brute_index_endpoint_resource_uri: {brute_index_endpoint_resource_uri}")

    deployed_ann_index = vertex_ai.MatchingEngineIndexEndpoint(ann_index_endpoint_resource_uri)
    deployed_bf_index = vertex_ai.MatchingEngineIndexEndpoint(brute_index_endpoint_resource_uri)

    DEPLOYED_ANN_ID = deployed_ann_index.deployed_indexes[0].id
    DEPLOYED_BF_ID = deployed_bf_index.deployed_indexes[0].id
    logging.info(f"DEPLOYED_ANN_ID: {DEPLOYED_ANN_ID}")
    logging.info(f"DEPLOYED_BF_ID: {DEPLOYED_BF_ID}")
    
    logging.info('Retreiving neighbors from ANN index...')
    
    ANN_response = deployed_ann_index.match(
        deployed_index_id=DEPLOYED_ANN_ID,
        queries=prediction_test.predictions,
        num_neighbors=10
    )
    
    logging.info('Retreiving neighbors from BF index...')
    BF_response = deployed_bf_index.match(
        deployed_index_id=DEPLOYED_BF_ID,
        queries=prediction_test.predictions,
        num_neighbors=10
    )
    
    # Calculate recall by determining how many neighbors were correctly retrieved as compared to the brute-force option.
    recalled_neighbors = 0
    for tree_ah_neighbors, brute_force_neighbors in zip(
        ANN_response, BF_response
    ):
        tree_ah_neighbor_ids = [neighbor.id for neighbor in tree_ah_neighbors]
        brute_force_neighbor_ids = [neighbor.id for neighbor in brute_force_neighbors]

        recalled_neighbors += len(
            set(tree_ah_neighbor_ids).intersection(brute_force_neighbor_ids)
        )

    recall = recalled_neighbors / len(
        [neighbor for neighbors in BF_response for neighbor in neighbors]
    )

    logging.info("Recall: {}".format(recall))
    
    metrics.log_metric("Recall", (recall * 100.0))

Overwriting src/train_pipes/test_index_recall.py


In [34]:
# %%writefile {REPO_DOCKER_PATH_PREFIX}/{PIPELINES_SUB_DIR}/test_index_recall.py

# import kfp
# from typing import Any, Callable, Dict, NamedTuple, Optional, List
# from kfp.v2.dsl import (Artifact, Dataset, Input, InputPath, Model, Output,
#                         OutputPath, component, Metrics)

# @kfp.v2.dsl.component(
#     base_image="python:3.9",
#     packages_to_install=[
#         'google-cloud-aiplatform==1.20.0',
#         # 'google-cloud-storage',
#     ],
# )
# def test_index_recall(
#     project: str,
#     location: str,
#     version: str,
#     ann_index_resource_uri: str,
#     brute_force_index_resource_uri: str,
#     gcs_train_script_path: str,
#     endpoint: Input[Artifact],
#     metrics: Output[Metrics],
# ):
#     # here
    
#     import base64
#     import logging

#     from typing import Dict, List, Union

#     from google.cloud import aiplatform as vertex_ai
#     from google.protobuf import json_format
#     from google.protobuf.json_format import Parse
#     from google.protobuf.struct_pb2 import Value
    
#     from google_cloud_pipeline_components.proto.gcp_resources_pb2 import GcpResources

#     logging.getLogger().setLevel(logging.INFO)
#     vertex_ai.init(
#         project=project,
#         location=location,
#     )
    
#     endpoint_resource_path = endpoint.metadata["resourceName"]

#     # define endpoint resource in component
#     logging.info(f"endpoint_resource_path = {endpoint_resource_path}")
#     _endpoint = vertex_ai.Endpoint(endpoint_resource_path)
    
#     ################################################################################
#     # Helper function for returning endpoint predictions via required json format
#     ################################################################################

#     def predict_custom_trained_model_sample(
#         project: str,
#         endpoint_id: str,
#         instances: Dict,
#         location: str = "us-central1",
#         api_endpoint: str = "us-central1-aiplatform.googleapis.com",
# ):
#         """
#         `instances` can be either single instance of type dict or a list
#         of instances.
#         """

#         ########################################################################
#         # Initialize Vertex Endpoint
#         ########################################################################

#         # The AI Platform services require regional API endpoints.
#         client_options = {"api_endpoint": api_endpoint}
        
#         # Initialize client that will be used to create and send requests.
#         # This client only needs to be created once, and can be reused for multiple requests.
#         client = vertex_ai.gapic.PredictionServiceClient(client_options=client_options)
        
#         # The format of each instance should conform to the deployed model's prediction input schema.
#         instances = instances if type(instances) == list else [instances]
#         instances = [
#             json_format.ParseDict(instance_dict, Value()) for instance_dict in instances
#         ]
        
#         parameters_dict = {}
#         parameters = json_format.ParseDict(parameters_dict, Value())
        
#         endpoint = client.endpoint_path(
#             project=project, location=location, endpoint=endpoint_id
#         )
        
#         response = client.predict(
#             endpoint=endpoint, instances=instances, parameters=parameters
#         )
#         logging.info(f'Response: {response}')
        
#         logging.info(f'Deployed Model ID(s): {response.deployed_model_id}')

#         # The predictions are a google.protobuf.Value representation of the model's predictions.
#         _predictions = response.predictions
#         logging.info(f'Response Predictions: {_predictions}')
        
#         return _predictions
    
#     ################################################################################
#     # Request Prediction
#     ################################################################################
#     logging.info(f'Response gcs_train_script_path: {gcs_train_script_path}')
    
#     # jt-tfrs-central-v2/tfrs-e2e-pipe-test-v15-jtv4/run-20230214-021008/pipeline_root/trainer
#     # gs://jt-tfrs-central-v2/tfrs-e2e-pipe-test-v15-jtv4/run-20230214-021008/pipeline_root/trainer
#     # gcs_train_script_path

#     TEST_INSTANCE_15 = {
#         'album_name_can': 'Capoeira Electronica',
#         'album_name_pl': [
#             'Odilara', 'Capoeira Electronica', 'Capoeira Ultimate','Festa Popular', 'Capoeira Electronica',
#             'Odilara', 'Capoeira Electronica', 'Capoeira Ultimate','Festa Popular', 'Capoeira Electronica',
#             'Odilara', 'Capoeira Electronica', 'Capoeira Ultimate','Festa Popular', 'Capoeira Electronica'
#         ],
#         'album_uri_can': 'spotify:album:2FsSSHGt8JM0JgRy6ZX3kR',
#         'album_uri_pl': [
#             'spotify:album:4Y8RfvZzCiApBCIZswj9Ry',
#             'spotify:album:2FsSSHGt8JM0JgRy6ZX3kR',
#             'spotify:album:55HHBqZ2SefPeaENOgWxYK',
#             'spotify:album:150L1V6UUT7fGUI3PbxpkE',
#             'spotify:album:2FsSSHGt8JM0JgRy6ZX3kR',
#             'spotify:album:4Y8RfvZzCiApBCIZswj9Ry',
#             'spotify:album:2FsSSHGt8JM0JgRy6ZX3kR',
#             'spotify:album:55HHBqZ2SefPeaENOgWxYK',
#             'spotify:album:150L1V6UUT7fGUI3PbxpkE',
#             'spotify:album:2FsSSHGt8JM0JgRy6ZX3kR',
#             'spotify:album:4Y8RfvZzCiApBCIZswj9Ry',
#             'spotify:album:2FsSSHGt8JM0JgRy6ZX3kR',
#             'spotify:album:55HHBqZ2SefPeaENOgWxYK',
#             'spotify:album:150L1V6UUT7fGUI3PbxpkE',
#             'spotify:album:2FsSSHGt8JM0JgRy6ZX3kR'
#         ],
#         'artist_followers_can': 5170.0,
#         'artist_genres_can': 'capoeira',
#         'artist_genres_pl': [
#             'samba moderno', 'capoeira', 'capoeira', 'NONE','capoeira',
#             'samba moderno', 'capoeira', 'capoeira', 'NONE','capoeira',
#             'samba moderno', 'capoeira', 'capoeira', 'NONE','capoeira'
#         ],
#         'artist_name_can': 'Capoeira Experience',
#         'artist_name_pl': [
#             'Odilara', 'Capoeira Experience', 'Denis Porto', 'Zambe','Capoeira Experience',
#             'Odilara', 'Capoeira Experience', 'Denis Porto', 'Zambe','Capoeira Experience',
#             'Odilara', 'Capoeira Experience', 'Denis Porto', 'Zambe','Capoeira Experience'
#         ],
#         'artist_pop_can': 24.0,
#         'artist_pop_pl':[
#             4., 24.,  2.,  0., 24.,
#             4., 24.,  2.,  0., 24.,
#             4., 24.,  2.,  0., 24.
#         ],
#         'artist_uri_can': 'spotify:artist:5SKEXbgzIdRl3gQJ23CnUP',
#         'artist_uri_pl': [
#             'spotify:artist:72oameojLOPWYB7nB8rl6c',
#             'spotify:artist:5SKEXbgzIdRl3gQJ23CnUP',
#             'spotify:artist:67p5GMYQZOgaAfx1YyttQk',
#             'spotify:artist:4fH3OXCRcPsaHFE5KhgqZS',
#             'spotify:artist:5SKEXbgzIdRl3gQJ23CnUP',
#             'spotify:artist:72oameojLOPWYB7nB8rl6c',
#             'spotify:artist:5SKEXbgzIdRl3gQJ23CnUP',
#             'spotify:artist:67p5GMYQZOgaAfx1YyttQk',
#             'spotify:artist:4fH3OXCRcPsaHFE5KhgqZS',
#             'spotify:artist:5SKEXbgzIdRl3gQJ23CnUP',
#             'spotify:artist:72oameojLOPWYB7nB8rl6c',
#             'spotify:artist:5SKEXbgzIdRl3gQJ23CnUP',
#             'spotify:artist:67p5GMYQZOgaAfx1YyttQk',
#             'spotify:artist:4fH3OXCRcPsaHFE5KhgqZS',
#             'spotify:artist:5SKEXbgzIdRl3gQJ23CnUP'
#         ],
#         'artists_followers_pl': [ 
#             316., 5170.,  448.,   19., 5170.,
#             316., 5170.,  448.,   19., 5170.,
#             316., 5170.,  448.,   19., 5170.
#         ],
#         'duration_ms_can': 192640.0,
#         'duration_ms_songs_pl': [234612., 226826., 203480., 287946., 271920., 234612., 226826., 203480., 287946., 271920., 234612., 226826., 203480., 287946., 271920.],
#         'num_pl_albums_new': 9.0,
#         'num_pl_artists_new': 5.0,
#         'num_pl_songs_new': 85.0,
#         'pl_collaborative_src': 'false',
#         'pl_duration_ms_new': 17971314.0,
#         'pl_name_src': 'Capoeira',
#         'time_signature_can': '4',
#         'time_signature_pl': ['4', '4', '4', '4', '4', '4', '4', '4', '4', '4', '4', '4', '4', '4', '4'],
#         'track_acousticness_can': 0.478,
#         'track_acousticness_pl': [0.238 , 0.105 , 0.0242, 0.125 , 0.304, 0.238 , 0.105 , 0.0242, 0.125 , 0.304, 0.238 , 0.105 , 0.0242, 0.125 , 0.304 ],
#         'track_danceability_can': 0.709,
#         'track_danceability_pl': [0.703, 0.712, 0.806, 0.529, 0.821, 0.238 , 0.105 , 0.0242, 0.125 , 0.304, 0.238 , 0.105 , 0.0242, 0.125 , 0.304],
#         'track_energy_can': 0.742,
#         'track_energy_pl': [0.743, 0.41 , 0.794, 0.776, 0.947, 0.238 , 0.105 , 0.0242, 0.125 , 0.304, 0.238 , 0.105 , 0.0242, 0.125 , 0.304],
#         'track_instrumentalness_can': 0.00297,
#         'track_instrumentalness_pl': [4.84e-06, 4.30e-01, 7.42e-04, 4.01e-01, 5.07e-03, 4.84e-06, 4.30e-01, 7.42e-04, 4.01e-01, 5.07e-03, 4.84e-06, 4.30e-01, 7.42e-04, 4.01e-01, 5.07e-03],
#         'track_key_can': '0',
#         'track_key_pl': ['5', '0', '1', '10', '10', '5', '0', '1', '10', '10', '5', '0', '1', '10', '10'],
#         'track_liveness_can': 0.0346,
#         'track_liveness_pl': [0.128 , 0.0725, 0.191 , 0.105 , 0.0552,0.128 , 0.0725, 0.191 , 0.105 , 0.0552, 0.128 , 0.0725, 0.191 , 0.105 , 0.0552],
#         'track_loudness_can': -7.295,
#         'track_loudness_pl': [-8.638, -8.754, -9.084, -7.04 , -6.694, -8.638, -8.754, -9.084, -7.04 , -6.694, -8.638, -8.754, -9.084, -7.04 , -6.694],
#         'track_mode_can': '1',
#         'track_mode_pl': ['0', '1', '1', '0', '1', '0', '1', '1', '0', '1', '0', '1', '1', '0', '1'],
#         'track_name_can': 'Bezouro Preto - Studio',
#         'track_name_pl': [
#             'O Telefone Tocou Novamente', 'Bem Devagar - Studio','Angola Dream', 'Janaina', 'Louco Berimbau - Studio',
#             'O Telefone Tocou Novamente', 'Bem Devagar - Studio','Angola Dream', 'Janaina', 'Louco Berimbau - Studio',
#             'O Telefone Tocou Novamente', 'Bem Devagar - Studio','Angola Dream', 'Janaina', 'Louco Berimbau - Studio'
#         ],
#         'track_pop_can': 3.0,
#         'track_pop_pl': [5., 1., 0., 0., 1., 5., 1., 0., 0., 1., 5., 1., 0., 0., 1.],
#         'track_speechiness_can': 0.0802,
#         'track_speechiness_pl':[0.0367, 0.0272, 0.0407, 0.132 , 0.0734, 0.0367, 0.0272, 0.0407, 0.132 , 0.0734, 0.0367, 0.0272, 0.0407, 0.132 , 0.0734],
#         'track_tempo_can': 172.238,
#         'track_tempo_pl': [100.039,  89.089, 123.999, 119.963, 119.214, 100.039,  89.089, 123.999, 119.963, 119.214, 100.039,  89.089, 123.999, 119.963, 119.214],
#         'track_uri_can': 'spotify:track:0tlhK4OvpHCYpReTABvKFb',
#         'track_uri_pl': [
#             'spotify:track:1pQkOdcTDfLr84TDCrmGy7',
#             'spotify:track:39grEDsAHAjmo2QFo4G8D9',
#             'spotify:track:5vxSLdJXqbKYH487YO8LSL',
#             'spotify:track:6T9GbmZ6voDM4aTBsG5VDh',
#             'spotify:track:7ELt9eslVvWo276pX2garN',
#             'spotify:track:1pQkOdcTDfLr84TDCrmGy7',
#             'spotify:track:39grEDsAHAjmo2QFo4G8D9',
#             'spotify:track:5vxSLdJXqbKYH487YO8LSL',
#             'spotify:track:6T9GbmZ6voDM4aTBsG5VDh',
#             'spotify:track:7ELt9eslVvWo276pX2garN',
#             'spotify:track:1pQkOdcTDfLr84TDCrmGy7',
#             'spotify:track:39grEDsAHAjmo2QFo4G8D9',
#             'spotify:track:5vxSLdJXqbKYH487YO8LSL',
#             'spotify:track:6T9GbmZ6voDM4aTBsG5VDh',
#             'spotify:track:7ELt9eslVvWo276pX2garN'
#         ],
#         'track_valence_can': 0.844,
#         'track_valence_pl': [
#             0.966, 0.667, 0.696, 0.876, 0.655,
#             0.966, 0.667, 0.696, 0.876, 0.655,
#             0.966, 0.667, 0.696, 0.876, 0.655
#         ],
#     }

#     prediction_test = predict_custom_trained_model_sample(
#         project=project,                     
#         endpoint_id=endpoint_resource_path,
#         location="us-central1",
#         instances=TEST_INSTANCE_15
#     )
#     logging.info(f"prediction_test: {prediction_test}")
    
#     ################################################################################
#     # Init deployed indexes
#     ################################################################################
    
#     logging.info(f"ann_index_resource_uri: {ann_index_resource_uri}")
#     logging.info(f"brute_force_index_resource_uri: {brute_force_index_resource_uri}")

#     tree_ah_index = vertex_ai.MatchingEngineIndexEndpoint(index_name=ann_index_resource_uri)
#     brute_force_index = vertex_ai.MatchingEngineIndexEndpoint(index_name=brute_force_index_resource_uri)
    
#     DEPLOYED_ANN_INDEX_ID = tree_ah_index.deployed_indexes[0]
#     DEPLOYED_BF_INDEX_ID = brute_force_index.deployed_indexes[0]
    
#     logging.info(f"DEPLOYED_ANN_INDEX_ID: {DEPLOYED_ANN_INDEX_ID}")
#     logging.info(f"DEPLOYED_BF_INDEX_ID: {DEPLOYED_BF_INDEX_ID}")
    
#     ANN_response = deployed_ann_index.match(
#         deployed_index_id=DEPLOYED_ANN_INDEX_ID,
#         queries=prediction_test.predictions,
#         num_neighbors=10
#     )
    
#     BF_response = deployed_bf_index.match(
#         deployed_index_id=DEPLOYED_BF_INDEX_ID,
#         queries=prediction_test.predictions,
#         num_neighbors=10
#     )
    
#     # Calculate recall by determining how many neighbors were correctly retrieved as compared to the brute-force option.
#     recalled_neighbors = 0
#     for tree_ah_neighbors, brute_force_neighbors in zip(
#         ANN_response, BF_response
#     ):
#         tree_ah_neighbor_ids = [neighbor.id for neighbor in tree_ah_neighbors]
#         brute_force_neighbor_ids = [neighbor.id for neighbor in brute_force_neighbors]

#         recalled_neighbors += len(
#             set(tree_ah_neighbor_ids).intersection(brute_force_neighbor_ids)
#         )

#     recall = recalled_neighbors / len(
#         [neighbor for neighbors in BF_response for neighbor in neighbors]
#     )

#     logging.info("Recall: {}".format(recall))
    
#     metrics.log_metric("Recall", (recall * 100.0))

## compute config for pipeline steps

In [99]:
%%writefile {REPO_DOCKER_PATH_PREFIX}/{PIPELINES_SUB_DIR}/pipeline_config.py

CPU_LIMIT='96'
MEMORY_LIMIT='624G'

Overwriting src/train_pipes/pipeline_config.py


# Prepare Job Specs

## Vertex Train: workerpool specs

In [36]:
def prepare_worker_pool_specs(
    image_uri,
    args,
    cmd,
    replica_count=1,
    machine_type="n1-standard-16",
    accelerator_count=1,
    accelerator_type="ACCELERATOR_TYPE_UNSPECIFIED",
    reduction_server_count=0,
    reduction_server_machine_type="n1-highcpu-16",
    reduction_server_image_uri="us-docker.pkg.dev/vertex-ai-restricted/training/reductionserver:latest",
):

    if accelerator_count > 0:
        machine_spec = {
            "machine_type": machine_type,
            "accelerator_type": accelerator_type,
            "accelerator_count": accelerator_count,
        }
    else:
        machine_spec = {"machine_type": machine_type}

    container_spec = {
        "image_uri": image_uri,
        "args": args,
        "command": cmd,
    }

    chief_spec = {
        "replica_count": 1,
        "machine_spec": machine_spec,
        "container_spec": container_spec,
    }

    worker_pool_specs = [chief_spec]
    if replica_count > 1:
        workers_spec = {
            "replica_count": replica_count - 1,
            "machine_spec": machine_spec,
            "container_spec": container_spec,
        }
        worker_pool_specs.append(workers_spec)
    if reduction_server_count > 1:
        workers_spec = {
            "replica_count": reduction_server_count,
            "machine_spec": {
                "machine_type": reduction_server_machine_type,
            },
            "container_spec": {"image_uri": reduction_server_image_uri},
        }
        worker_pool_specs.append(workers_spec)

    return worker_pool_specs

## Accelerators and Device Strategy

In [37]:
# # # # Single machine, single GPU
WORKER_MACHINE_TYPE = 'a2-highgpu-1g'
REPLICA_COUNT = 1
ACCELERATOR_TYPE = 'NVIDIA_TESLA_A100'
PER_MACHINE_ACCELERATOR_COUNT = 1
REDUCTION_SERVER_COUNT = 0                                                      
REDUCTION_SERVER_MACHINE_TYPE = "n1-highcpu-16"
DISTRIBUTE_STRATEGY = 'single'

# Single machine, single GPU, 80 GB 'NVIDIA_A100_80GB'
# WORKER_MACHINE_TYPE = 'a2-ultragpu-1g' # 80 GB
# REPLICA_COUNT = 1
# ACCELERATOR_TYPE = 'NVIDIA_A100_80GB'
# PER_MACHINE_ACCELERATOR_COUNT = 1
# REDUCTION_SERVER_COUNT = 0                                                      
# REDUCTION_SERVER_MACHINE_TYPE = "n1-highcpu-16"
# DISTRIBUTE_STRATEGY = 'single'

# # # Single Machine; multiple GPU
# WORKER_MACHINE_TYPE = 'a2-highgpu-4g' # a2-ultragpu-4g
# REPLICA_COUNT = 1
# ACCELERATOR_TYPE = 'NVIDIA_TESLA_A100'
# PER_MACHINE_ACCELERATOR_COUNT = 4
# REDUCTION_SERVER_COUNT = 0                                                      
# REDUCTION_SERVER_MACHINE_TYPE = "n1-highcpu-16"
# DISTRIBUTE_STRATEGY = 'mirrored'

# # # # Multiple Machine; 1 GPU per machine
# WORKER_MACHINE_TYPE = 'a2-highgpu-2g' # a2-ultragpu-4g
# REPLICA_COUNT = 2
# ACCELERATOR_TYPE = 'NVIDIA_TESLA_A100'
# PER_MACHINE_ACCELERATOR_COUNT = 2
# REDUCTION_SERVER_COUNT = 4                                                      
# REDUCTION_SERVER_MACHINE_TYPE = "n1-highcpu-16"
# DISTRIBUTE_STRATEGY = 'multiworker'

# # # Multiple Machines, 1 GPU per Machine
# WORKER_MACHINE_TYPE = 'n1-standard-16'
# REPLICA_COUNT = 9
# ACCELERATOR_TYPE = 'NVIDIA_TESLA_T4'
# PER_MACHINE_ACCELERATOR_COUNT = 1
# REDUCTION_SERVER_COUNT = 10                                                      
# REDUCTION_SERVER_MACHINE_TYPE = "n1-highcpu-16"
# DISTRIBUTE_STRATEGY = 'multiworker'

## Vertex AI Experiments

In [38]:
EXPERIMENT_PREFIX = 'tfrs-e2e-pipe-test-v18'                     # custom identifier for organizing experiments
EXPERIMENT_NAME=f'{EXPERIMENT_PREFIX}-{VERSION}'
RUN_NAME = f'run-{time.strftime("%Y%m%d-%H%M%S")}'

print(f"EXPERIMENT_NAME: {EXPERIMENT_NAME}")
print(f"RUN_NAME: {RUN_NAME}")

EXPERIMENT_NAME: tfrs-e2e-pipe-test-v18-jtv5
RUN_NAME: run-20230214-144020


## Managed Tensorboard

In [39]:
# use existing TB instance
# TB_RESOURCE_NAME = 'projects/934903580331/locations/us-central1/tensorboards/6924469145035603968'

# # # create new TB instance
# TENSORBOARD_DISPLAY_NAME=f"{EXPERIMENT_PREFIX}-v1"
# tensorboard = vertex_ai.Tensorboard.create(display_name=TENSORBOARD_DISPLAY_NAME, project=PROJECT_ID, location=REGION)
# TB_RESOURCE_NAME = tensorboard.resource_name


# print(f"TB_RESOURCE_NAME: {TB_RESOURCE_NAME}")
# print(f"TB display name: {tensorboard.display_name}")

## Training Config

* see [src code](https://github.com/googleapis/python-aiplatform/blob/e7bf0d83d8bb0849a9bce886c958d13f5cbe5fab/google/cloud/aiplatform/utils/worker_spec_utils.py#L153) for worker_pool_spec

In [40]:
# =================================================
# trainconfig: gcs locations
# =================================================
OUTPUT_BUCKET = 'jt-tfrs-central-v2'
OUTPUT_GCS_URI =f'gs://{OUTPUT_BUCKET}'

# Stores pipeline executions for each run
PIPELINE_ROOT_PATH = f'gs://{OUTPUT_BUCKET}/{EXPERIMENT_NAME}/{RUN_NAME}/pipeline_root'
print('PIPELINE_ROOT_PATH: {}'.format(PIPELINE_ROOT_PATH))

PIPELINE_ROOT_PATH: gs://jt-tfrs-central-v2/tfrs-e2e-pipe-test-v18-jtv5/run-20230214-144020/pipeline_root


### feature lists

In [41]:
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
    
import tensorflow as tf

FEATURES_PREFIX = f'{EXPERIMENT_NAME}/{RUN_NAME}/features'
print(f"FEATURES_PREFIX: {FEATURES_PREFIX}")

FEATURES_PREFIX: tfrs-e2e-pipe-test-v18-jtv5/run-20230214-144020/features


##### candidate features

In [42]:
# candidate features
CANDIDATE_FILENAME = 'candidate_feats_dict.pkl'
CANDIDATE_FEATURES_GCS_OBJ = f'{FEATURES_PREFIX}/{CANDIDATE_FILENAME}'

CANDIDATE_FEATURES_DICT = {
    "track_uri_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),            
    "track_name_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),
    "artist_uri_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),
    "artist_name_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),
    "album_uri_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),           
    "album_name_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()), 
    "duration_ms_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),      
    "track_pop_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),      
    "artist_pop_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
    "artist_genres_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),
    "artist_followers_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
    # new
    # "track_pl_titles_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),
    "track_danceability_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
    "track_energy_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
    "track_key_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),
    "track_loudness_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
    "track_mode_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),
    "track_speechiness_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
    "track_acousticness_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
    "track_instrumentalness_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
    "track_liveness_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
    "track_valence_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
    "track_tempo_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
    "time_signature_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),
}

# pickle
filehandler = open(f'{CANDIDATE_FILENAME}', 'wb')
pkl.dump(CANDIDATE_FEATURES_DICT, filehandler)
filehandler.close()

# upload to GCS
bucket_client = storage_client.bucket(OUTPUT_BUCKET)
blob = bucket_client.blob(CANDIDATE_FEATURES_GCS_OBJ)
blob.upload_from_filename(CANDIDATE_FILENAME)

##### query features

In [43]:
MAX_PLAYLIST_LENGTH=15

# query features
QUERY_FILENAME = 'query_feats_dict.pkl'
QUERY_FEATURES_GCS_OBJ = f'{FEATURES_PREFIX}/{QUERY_FILENAME}'

QUERY_FEATURES_DICT = {
    # ===================================================
    # summary playlist features
    # ===================================================
    "pl_name_src" : tf.io.FixedLenFeature(dtype=tf.string, shape=()), 
    'pl_collaborative_src' : tf.io.FixedLenFeature(dtype=tf.string, shape=()), 
    # 'num_pl_followers_src' : tf.io.FixedLenFeature(dtype=tf.float32, shape=()), 
    'pl_duration_ms_new' : tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
    'num_pl_songs_new' : tf.io.FixedLenFeature(dtype=tf.float32, shape=()), # n_songs_pl_new | num_pl_songs_new
    'num_pl_artists_new' : tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
    'num_pl_albums_new' : tf.io.FixedLenFeature(dtype=tf.float32, shape=()), 
    # 'avg_track_pop_pl_new' : tf.io.FixedLenFeature(dtype=tf.float32, shape=()), 
    # 'avg_artist_pop_pl_new' : tf.io.FixedLenFeature(dtype=tf.float32, shape=()), 
    # 'avg_art_followers_pl_new' : tf.io.FixedLenFeature(dtype=tf.float32, shape=()), 

    # ===================================================
    # ragged playlist features
    # ===================================================
    # bytes / string
    "track_uri_pl": tf.io.FixedLenFeature(dtype=tf.string, shape=(MAX_PLAYLIST_LENGTH,)),
    "track_name_pl": tf.io.FixedLenFeature(dtype=tf.string, shape=(MAX_PLAYLIST_LENGTH,)),
    "artist_uri_pl": tf.io.FixedLenFeature(dtype=tf.string, shape=(MAX_PLAYLIST_LENGTH,)),
    "artist_name_pl": tf.io.FixedLenFeature(dtype=tf.string, shape=(MAX_PLAYLIST_LENGTH,)),
    "album_uri_pl": tf.io.FixedLenFeature(dtype=tf.string, shape=(MAX_PLAYLIST_LENGTH,)),
    "album_name_pl": tf.io.FixedLenFeature(dtype=tf.string, shape=(MAX_PLAYLIST_LENGTH,)),
    "artist_genres_pl": tf.io.FixedLenFeature(dtype=tf.string, shape=(MAX_PLAYLIST_LENGTH,)),
    # "tracks_playlist_titles_pl": tf.io.FixedLenFeature(dtype=tf.string, shape=(MAX_PLAYLIST_LENGTH,)),
    "track_key_pl": tf.io.FixedLenFeature(dtype=tf.string, shape=(MAX_PLAYLIST_LENGTH,)),
    "track_mode_pl": tf.io.FixedLenFeature(dtype=tf.string, shape=(MAX_PLAYLIST_LENGTH,)),
    "time_signature_pl": tf.io.FixedLenFeature(dtype=tf.string, shape=(MAX_PLAYLIST_LENGTH,)), 

    # Float List
    "duration_ms_songs_pl": tf.io.FixedLenFeature(dtype=tf.float32, shape=(MAX_PLAYLIST_LENGTH,)),
    "track_pop_pl": tf.io.FixedLenFeature(dtype=tf.float32, shape=(MAX_PLAYLIST_LENGTH,)),
    "artist_pop_pl": tf.io.FixedLenFeature(dtype=tf.float32, shape=(MAX_PLAYLIST_LENGTH,)),
    "artists_followers_pl": tf.io.FixedLenFeature(dtype=tf.float32, shape=(MAX_PLAYLIST_LENGTH,)),
    "track_danceability_pl": tf.io.FixedLenFeature(dtype=tf.float32, shape=(MAX_PLAYLIST_LENGTH,)),
    "track_energy_pl": tf.io.FixedLenFeature(dtype=tf.float32, shape=(MAX_PLAYLIST_LENGTH,)),
    "track_loudness_pl": tf.io.FixedLenFeature(dtype=tf.float32, shape=(MAX_PLAYLIST_LENGTH,)),
    "track_speechiness_pl": tf.io.FixedLenFeature(dtype=tf.float32, shape=(MAX_PLAYLIST_LENGTH,)),
    "track_acousticness_pl": tf.io.FixedLenFeature(dtype=tf.float32, shape=(MAX_PLAYLIST_LENGTH,)),
    "track_instrumentalness_pl": tf.io.FixedLenFeature(dtype=tf.float32, shape=(MAX_PLAYLIST_LENGTH,)),
    "track_liveness_pl": tf.io.FixedLenFeature(dtype=tf.float32, shape=(MAX_PLAYLIST_LENGTH,)),
    "track_valence_pl": tf.io.FixedLenFeature(dtype=tf.float32, shape=(MAX_PLAYLIST_LENGTH,)),
    "track_tempo_pl": tf.io.FixedLenFeature(dtype=tf.float32, shape=(MAX_PLAYLIST_LENGTH,)),
}

# pickle
filehandler = open(f'{QUERY_FILENAME}', 'wb')
pkl.dump(QUERY_FEATURES_DICT, filehandler)
filehandler.close()

# upload to GCS
bucket_client = storage_client.bucket(OUTPUT_BUCKET)
blob = bucket_client.blob(QUERY_FEATURES_GCS_OBJ)
blob.upload_from_filename(QUERY_FILENAME)

### train image

In [44]:
# =================================================
# train image
# =================================================
# Existing image URI or name for image to create
IMAGE_URI = f"{IMAGE_URI}"
print(f"IMAGE_URI: {IMAGE_URI}")

IMAGE_URI: gcr.io/hybrid-vertex/sp-2tower-tfrs-jtv5-pipev3-training


### train params

In [45]:
SEED = 1234

# =================================================
# trainconfig: GPU related
# =================================================
TF_GPU_THREAD_COUNT='8'      # '1' | '4' | '8'

# =================================================
# trainconfig: data input pipeline
# =================================================
BLOCK_LENGTH = 64            # 1, 8, 16, 32, 64
NUM_DATA_SHARDS = 4          # 2, 4, 8, 16, 32, 64
# TRAIN_PREFETCH=3

# =================================================
# trainconfig: training hparams
# =================================================
NUM_EPOCHS = 3
LEARNING_RATE = 0.01
BATCH_SIZE = 8192           # 4096, 2048, 1024, 512 

# dropout
DROPOUT_RATE = 0.33

# model size
EMBEDDING_DIM = 128
PROJECTION_DIM = 50
LAYER_SIZES = '[512,256]'
MAX_TOKENS = 20000     # vocab

# =================================================
# trainconfig: tensorboard
# =================================================
EMBED_FREQUENCY=0
HISTOGRAM_FREQUENCY=0
CHECKPOINT_FREQ='epoch'

### train & valid epoch steps

In [46]:
# =================================================
# trainconfig: train & valid steps
# =================================================
train_sample_cnt = 82_959 # 8_205_265
valid_samples_cnt = 82_959

# validation & evaluation
VALID_FREQUENCY = 50
VALID_STEPS = valid_samples_cnt // BATCH_SIZE # 100
EPOCH_STEPS = train_sample_cnt // BATCH_SIZE

print(f"VALID_STEPS: {VALID_STEPS}")
print(f"EPOCH_STEPS: {EPOCH_STEPS}")

VALID_STEPS: 10
EPOCH_STEPS: 10


### data source

In [47]:
# =================================================
# trainconfig: Data sources
# =================================================
BUCKET_DATA_DIR = 'spotify-data-regimes' 

# data strategy: 65m
# CANDIDATE_PREFIX = 'jtv10/candidates'
# TRAIN_DIR_PREFIX = 'jtv10/train_v9'
# VALID_DIR_PREFIX = 'jtv10/valid_v9'

# data strategy: 08m
CANDIDATE_PREFIX = 'jtv15-8m/candidates'
TRAIN_DIR_PREFIX = 'jtv15-8m/valid'     # train | train_v14
VALID_DIR_PREFIX = 'jtv15-8m/valid'     # valid_v14

### train args

In [48]:
WORKER_CMD = ["python", "two_tower_jt/task.py"]
# WORKER_CMD ["python", "-m", "trainer.task"]

WORKER_ARGS = [
    f'--project={PROJECT_ID}',
    f'--train_output_gcs_bucket={OUTPUT_BUCKET}',
    f'--train_dir={BUCKET_DATA_DIR}',
    f'--train_dir_prefix={TRAIN_DIR_PREFIX}',
    f'--valid_dir={BUCKET_DATA_DIR}',
    f'--valid_dir_prefix={VALID_DIR_PREFIX}',
    f'--candidate_file_dir={BUCKET_DATA_DIR}',
    f'--candidate_files_prefix={CANDIDATE_PREFIX}',
    f'--experiment_name={EXPERIMENT_NAME}',
    f'--experiment_run={RUN_NAME}',
    f'--num_epochs={NUM_EPOCHS}',
    f'--batch_size={BATCH_SIZE}',
    f'--embedding_dim={EMBEDDING_DIM}',
    f'--projection_dim={PROJECTION_DIM}',
    f'--layer_sizes={LAYER_SIZES}',
    f'--learning_rate={LEARNING_RATE}',
    f'--valid_frequency={VALID_FREQUENCY}',
    f'--valid_steps={VALID_STEPS}',
    f'--epoch_steps={EPOCH_STEPS}',
    f'--distribute={DISTRIBUTE_STRATEGY}',
    f'--model_version={VERSION}',
    f'--pipeline_version={PIPELINE_VERSION}',
    f'--seed={SEED}',
    f'--max_tokens={MAX_TOKENS}',
    # f'--tb_resource_name={TB_RESOURCE_NAME}',
    f'--embed_frequency={EMBED_FREQUENCY}',
    f'--hist_frequency={HISTOGRAM_FREQUENCY}',
    f'--tf_gpu_thread_count={TF_GPU_THREAD_COUNT}',
    f'--block_length={BLOCK_LENGTH}',
    f'--num_data_shards={NUM_DATA_SHARDS}',
    f'--chkpt_freq={CHECKPOINT_FREQ}',
    f'--dropout_rate={DROPOUT_RATE}',
    # uncomment these to pass value of True (bool)
    f'--cache_train',                                # caches train_dataset
    # f'--evaluate_model',                           # runs model.eval()
    # f'--write_embeddings',                         # writes embeddings index in train job
    f'--profiler',                                   # runs TB profiler
    # f'--set_jit',                                  # enables XLA
    f'--compute_batch_metrics',
    f'--use_cross_layer',
    f'--use_dropout',
]

WORKER_POOL_SPECS = prepare_worker_pool_specs(
    image_uri=IMAGE_URI,
    args=WORKER_ARGS,
    cmd=WORKER_CMD,
    replica_count=REPLICA_COUNT,
    machine_type=WORKER_MACHINE_TYPE,
    accelerator_count=PER_MACHINE_ACCELERATOR_COUNT,
    accelerator_type=ACCELERATOR_TYPE,
    reduction_server_count=REDUCTION_SERVER_COUNT,
    reduction_server_machine_type=REDUCTION_SERVER_MACHINE_TYPE,
)

from pprint import pprint
pprint(WORKER_POOL_SPECS)

[{'container_spec': {'args': ['--project=hybrid-vertex',
                              '--train_output_gcs_bucket=jt-tfrs-central-v2',
                              '--train_dir=spotify-data-regimes',
                              '--train_dir_prefix=jtv15-8m/valid',
                              '--valid_dir=spotify-data-regimes',
                              '--valid_dir_prefix=jtv15-8m/valid',
                              '--candidate_file_dir=spotify-data-regimes',
                              '--candidate_files_prefix=jtv15-8m/candidates',
                              '--experiment_name=tfrs-e2e-pipe-test-v18-jtv5',
                              '--experiment_run=run-20230214-144020',
                              '--num_epochs=3',
                              '--batch_size=8192',
                              '--embedding_dim=128',
                              '--projection_dim=50',
                              '--layer_sizes=[512,256]',
                              '--le

In [49]:
# WORKER_POOL_SPECS_v2=WORKER_POOL_SPECS
# TEST_TB_NAME = 'this-is-a-test'
# WORKER_POOL_SPECS_v2[0]['container_spec']['args'].append(f'--tb_resource_name={TEST_TB_NAME}')
# WORKER_POOL_SPECS_v2[0]['container_spec']['args']

In [100]:
!export PWD=pwd
!export PIPELINE_ROOT_PATH=PIPELINE_ROOT_PATH
!export REPO_DOCKER_PATH_PREFIX=REPO_DOCKER_PATH_PREFIX

! echo $PWD
! echo $PIPELINE_ROOT_PATH
! echo $REPO_DOCKER_PATH_PREFIX

/home/jupyter/jw-repo/spotify_mpd_two_tower
gs://jt-tfrs-central-v2/tfrs-e2e-pipe-test-v18-jtv5/run-20230214-144020/pipeline_root
src


### copy train package to GCS

In [101]:
# copy training Dockerfile
!gsutil cp $REPO_DOCKER_PATH_PREFIX/cloudbuild.yaml $PIPELINE_ROOT_PATH/cloudbuild.yaml
!gsutil cp $REPO_DOCKER_PATH_PREFIX/Dockerfile.tfrs $PIPELINE_ROOT_PATH/Dockerfile.tfrs

# # # copy training application code
! gsutil -m cp -r $REPO_DOCKER_PATH_PREFIX/two_tower_jt/* $PIPELINE_ROOT_PATH/trainer

print(f"\n Copied training package and Dockerfile to {PIPELINE_ROOT_PATH}\n")

Copying file://src/cloudbuild.yaml [Content-Type=application/octet-stream]...
/ [1 files][  178.0 B/  178.0 B]                                                
Operation completed over 1 objects/178.0 B.                                      
Copying file://src/Dockerfile.tfrs [Content-Type=application/octet-stream]...
/ [1 files][  387.0 B/  387.0 B]                                                
Operation completed over 1 objects/387.0 B.                                      
Copying file://src/two_tower_jt/__init__.py [Content-Type=text/x-python]...
Copying file://src/two_tower_jt/__pycache__/__init__.cpython-37.pyc [Content-Type=application/x-python-code]...
Copying file://src/two_tower_jt/__pycache__/two_tower.cpython-37.pyc [Content-Type=application/x-python-code]...
Copying file://src/two_tower_jt/__pycache__/test_instances.cpython-37.pyc [Content-Type=application/x-python-code]...
Copying file://src/two_tower_jt/__pycache__/train_config.cpython-37.pyc [Content-Type=application/x

In [102]:
! gsutil ls -Rl $PIPELINE_ROOT_PATH/trainer

gs://jt-tfrs-central-v2/tfrs-e2e-pipe-test-v18-jtv5/run-20230214-144020/pipeline_root/trainer/:
         0  2023-02-14T18:21:12Z  gs://jt-tfrs-central-v2/tfrs-e2e-pipe-test-v18-jtv5/run-20230214-144020/pipeline_root/trainer/__init__.py
       835  2023-02-14T18:21:12Z  gs://jt-tfrs-central-v2/tfrs-e2e-pipe-test-v18-jtv5/run-20230214-144020/pipeline_root/trainer/data-pipeline.py
        44  2023-02-14T18:21:12Z  gs://jt-tfrs-central-v2/tfrs-e2e-pipe-test-v18-jtv5/run-20230214-144020/pipeline_root/trainer/interactive_train.py
       219  2023-02-14T18:21:12Z  gs://jt-tfrs-central-v2/tfrs-e2e-pipe-test-v18-jtv5/run-20230214-144020/pipeline_root/trainer/requirements.txt
     27330  2023-02-14T18:21:12Z  gs://jt-tfrs-central-v2/tfrs-e2e-pipe-test-v18-jtv5/run-20230214-144020/pipeline_root/trainer/task.py
     10224  2023-02-14T18:21:12Z  gs://jt-tfrs-central-v2/tfrs-e2e-pipe-test-v18-jtv5/run-20230214-144020/pipeline_root/trainer/test_instances.py
       247  2023-02-14T18:21:12Z  gs://jt-t

# Build & Submit Pipeline

In [103]:
PIPELINE_TAG = f'2tower-{PIPELINE_VERSION}'
print("PIPELINE_TAG:", PIPELINE_TAG)

PIPELINE_NAME = f'tfrs-{VERSION}-{PIPELINE_TAG}'.replace('_', '-')
print("PIPELINE_NAME:", PIPELINE_NAME)

PIPELINE_TAG: 2tower-pipev3
PIPELINE_NAME: tfrs-jtv5-2tower-pipev3


## Create pipeline

In [104]:
from src.train_pipes import build_custom_image, train_custom_model, create_tensorboard, generate_candidates, \
                            create_ann_index, create_brute_force_index, create_ann_index_endpoint_vpc, \
                            create_brute_index_endpoint_vpc, deploy_ann_index, deploy_brute_index, \
                            adapt_ragged_text_layer_vocab, adapt_fixed_text_layer_vocab, create_master_vocab, \
                            test_index_recall

from src.train_pipes import pipeline_config as cfg

@kfp.v2.dsl.pipeline(
    name=f'{PIPELINE_NAME}'.replace('_', '-')
)
def pipeline(
    project: str,
    project_number: str,
    location: str,
    service_account: str,
    model_version: str,
    pipeline_version: str,
    train_image_uri: str,
    train_output_gcs_bucket: str,
    gcs_train_script_path: str,
    model_display_name: str,
    train_dockerfile_name: str,
    train_dir: str,
    train_dir_prefix: str,
    valid_dir: str,
    valid_dir_prefix: str,
    candidate_file_dir: str,
    candidate_files_prefix: str,
    # tensorboard_resource_name: str,
    experiment_name: str,
    experiment_run: str,
    register_model_flag: str,
    # deploy_indexes_flag: str,
    vpc_network_name: str,
    generate_new_vocab: bool,
    max_playlist_length: int,
    max_tokens: int,
    ngrams: int,
):
    
    from kfp.v2.components import importer_node
    from google_cloud_pipeline_components.types import artifact_types
    
    # ========================================================================
    # Build Custom Train Image
    # ========================================================================
    
    # build_custom_train_image_op = (
    #     build_custom_train_image.build_custom_train_image(
    #         project=project,
    #         gcs_train_script_path=gcs_train_script_path,
    #         training_image_uri=train_image_uri,
    #         train_dockerfile_name=train_dockerfile_name,
    #     )
    #     .set_display_name("Build custom train image")
    #     .set_caching_options(False)
    # )
    
    # ========================================================================
    # Conditional: Upload models to Vertex model registry
    # ========================================================================
    # with kfp.v2.dsl.Condition(generate_new_vocab == 'True', name="Generate New Vocab"):
        
        # here
    # pl_name_src
    adapt_pl_name_src_op = (
        adapt_fixed_text_layer_vocab.adapt_fixed_text_layer_vocab(
            project=project,
            location=location,
            version=model_version,
            data_dir_bucket_name=train_dir,
            data_dir_path_prefix=train_dir_prefix,
            train_output_gcs_bucket=train_output_gcs_bucket,
            experiment_name=experiment_name,
            experiment_run=experiment_run,
            max_playlist_length=max_playlist_length,
            max_tokens=max_tokens,
            ngrams=ngrams,
            feature_name='pl_name_src',
            generate_new_vocab=generate_new_vocab,
        )
        .set_display_name(f"adapt: pl_name_src")
        .set_caching_options(True)
        .set_cpu_limit('96')
        .set_memory_limit('624G')
    )
    # track_name_can
    adapt_track_name_can_op = (
        adapt_fixed_text_layer_vocab.adapt_fixed_text_layer_vocab(
            project=project,
            location=location,
            version=model_version,
            data_dir_bucket_name=train_dir,
            data_dir_path_prefix=train_dir_prefix,
            train_output_gcs_bucket=train_output_gcs_bucket,
            experiment_name=experiment_name,
            experiment_run=experiment_run,
            max_playlist_length=max_playlist_length,
            max_tokens=max_tokens,
            ngrams=ngrams,
            feature_name='track_name_can',
            generate_new_vocab=generate_new_vocab,
        )
        .set_display_name(f"adapt: track_name_can")
        .set_caching_options(True)
        .set_cpu_limit(cfg.CPU_LIMIT)
        .set_memory_limit(cfg.MEMORY_LIMIT)
    )
    # artist_name_can
    adapt_artist_name_can_op = (
        adapt_fixed_text_layer_vocab.adapt_fixed_text_layer_vocab(
            project=project,
            location=location,
            version=model_version,
            data_dir_bucket_name=train_dir,
            data_dir_path_prefix=train_dir_prefix,
            train_output_gcs_bucket=train_output_gcs_bucket,
            experiment_name=experiment_name,
            experiment_run=experiment_run,
            max_playlist_length=max_playlist_length,
            max_tokens=max_tokens,
            ngrams=ngrams,
            feature_name='artist_name_can',
            generate_new_vocab=generate_new_vocab,
        )
        .set_display_name(f"adapt: artist_name_can")
        .set_caching_options(True)
        .set_cpu_limit(cfg.CPU_LIMIT)
        .set_memory_limit(cfg.MEMORY_LIMIT)
    )

    # album_name_can
    adapt_album_name_can_op = (
        adapt_fixed_text_layer_vocab.adapt_fixed_text_layer_vocab(
            project=project,
            location=location,
            version=model_version,
            data_dir_bucket_name=train_dir,
            data_dir_path_prefix=train_dir_prefix,
            train_output_gcs_bucket=train_output_gcs_bucket,
            experiment_name=experiment_name,
            experiment_run=experiment_run,
            max_playlist_length=max_playlist_length,
            max_tokens=max_tokens,
            ngrams=ngrams,
            feature_name='album_name_can',
            generate_new_vocab=generate_new_vocab,
        )
        .set_display_name(f"adapt: album_name_can")
        .set_caching_options(True)
        .set_cpu_limit(cfg.CPU_LIMIT)
        .set_memory_limit(cfg.MEMORY_LIMIT)
    )
    # artist_genres_can
    adapt_artist_genres_can_op = (
        adapt_fixed_text_layer_vocab.adapt_fixed_text_layer_vocab(
            project=project,
            location=location,
            version=model_version,
            data_dir_bucket_name=train_dir,
            data_dir_path_prefix=train_dir_prefix,
            train_output_gcs_bucket=train_output_gcs_bucket,
            experiment_name=experiment_name,
            experiment_run=experiment_run,
            max_playlist_length=max_playlist_length,
            max_tokens=max_tokens,
            ngrams=ngrams,
            feature_name='artist_genres_can',
            generate_new_vocab=generate_new_vocab,
        )
        .set_display_name(f"adapt: artist_genres_can")
        .set_caching_options(True)
        .set_cpu_limit(cfg.CPU_LIMIT)
        .set_memory_limit(cfg.MEMORY_LIMIT)
    )
    # raggeds

    # track_name_pl
    adapt_track_name_pl_features_op = (
        adapt_ragged_text_layer_vocab.adapt_ragged_text_layer_vocab(
            project=project,
            location=location,
            version=model_version,
            data_dir_bucket_name=train_dir,
            data_dir_path_prefix=train_dir_prefix,
            train_output_gcs_bucket=train_output_gcs_bucket,
            experiment_name=experiment_name,
            experiment_run=experiment_run,
            max_playlist_length=max_playlist_length,
            max_tokens=max_tokens,
            ngrams=ngrams,
            feature_name='track_name_pl',
            generate_new_vocab=generate_new_vocab,
        )
        .set_display_name(f"adapt: track_name_pl")
        .set_caching_options(True)
        .set_cpu_limit(cfg.CPU_LIMIT)
        .set_memory_limit(cfg.MEMORY_LIMIT)
    )
    # artist_name_pl
    adapt_artist_name_pl_op = (
        adapt_ragged_text_layer_vocab.adapt_ragged_text_layer_vocab(
            project=project,
            location=location,
            version=model_version,
            data_dir_bucket_name=train_dir,
            data_dir_path_prefix=train_dir_prefix,
            train_output_gcs_bucket=train_output_gcs_bucket,
            experiment_name=experiment_name,
            experiment_run=experiment_run,
            max_playlist_length=max_playlist_length,
            max_tokens=max_tokens,
            ngrams=ngrams,
            feature_name='artist_name_pl',
            generate_new_vocab=generate_new_vocab,
        )
        .set_display_name(f"adapt: artist_name_pl")
        .set_caching_options(True)
        .set_cpu_limit(cfg.CPU_LIMIT)
        .set_memory_limit(cfg.MEMORY_LIMIT)
    )

    # album_name_pl
    adapt_album_name_pl_op = (
        adapt_ragged_text_layer_vocab.adapt_ragged_text_layer_vocab(
            project=project,
            location=location,
            version=model_version,
            data_dir_bucket_name=train_dir,
            data_dir_path_prefix=train_dir_prefix,
            train_output_gcs_bucket=train_output_gcs_bucket,
            experiment_name=experiment_name,
            experiment_run=experiment_run,
            max_playlist_length=max_playlist_length,
            max_tokens=max_tokens,
            ngrams=ngrams,
            feature_name='album_name_pl',
            generate_new_vocab=generate_new_vocab,
        )
        .set_display_name(f"adapt: album_name_pl")
        .set_caching_options(True)
        .set_cpu_limit(cfg.CPU_LIMIT)
        .set_memory_limit(cfg.MEMORY_LIMIT)
    )

    # artist_genres_pl
    adapt_artist_genres_pl_op = (
        adapt_ragged_text_layer_vocab.adapt_ragged_text_layer_vocab(
            project=project,
            location=location,
            version=model_version,
            data_dir_bucket_name=train_dir,
            data_dir_path_prefix=train_dir_prefix,
            train_output_gcs_bucket=train_output_gcs_bucket,
            experiment_name=experiment_name,
            experiment_run=experiment_run,
            max_playlist_length=max_playlist_length,
            max_tokens=max_tokens,
            ngrams=ngrams,
            feature_name='artist_genres_pl',
            generate_new_vocab=generate_new_vocab,
        )
        .set_display_name(f"adapt: artist_genres_pl")
        .set_caching_options(True)
        .set_cpu_limit(cfg.CPU_LIMIT)
        .set_memory_limit(cfg.MEMORY_LIMIT)
    )

    # ====================================================
    # Aggregate all Dicts
    # ====================================================

    create_master_vocab_op = (
        create_master_vocab.create_master_vocab(
            project=project,
            location=location,
            version=model_version,
            train_output_gcs_bucket=train_output_gcs_bucket,
            experiment_name=experiment_name,
            experiment_run=experiment_run,
            vocab_uri_1=adapt_pl_name_src_op.outputs['vocab_gcs_uri'], 
            vocab_uri_2=adapt_track_name_can_op.outputs['vocab_gcs_uri'], 
            vocab_uri_3=adapt_artist_name_can_op.outputs['vocab_gcs_uri'], 
            vocab_uri_4=adapt_album_name_can_op.outputs['vocab_gcs_uri'], 
            vocab_uri_5=adapt_artist_genres_can_op.outputs['vocab_gcs_uri'], 
            vocab_uri_6=adapt_track_name_pl_features_op.outputs['vocab_gcs_uri'], 
            vocab_uri_7=adapt_artist_name_pl_op.outputs['vocab_gcs_uri'], 
            vocab_uri_8=adapt_album_name_pl_op.outputs['vocab_gcs_uri'], 
            vocab_uri_9=adapt_artist_genres_pl_op.outputs['vocab_gcs_uri'],
            generate_new_vocab=generate_new_vocab,
        )
        # .after(fixed_for_loop_op).after(ragged_for_loop_op)
        .set_display_name("create master vocab")
        .set_caching_options(True)
        .set_cpu_limit(cfg.CPU_LIMIT)
        .set_memory_limit(cfg.MEMORY_LIMIT)
    )
            
    # ========================================================================
    # Managed TB
    # ========================================================================
    
    create_managed_tensorboard_op = (
        create_tensorboard.create_tensorboard(
            # here
            project=project,
            location=location,
            model_version=model_version,
            pipeline_version=pipeline_version,
            model_name=model_display_name, 
            experiment_name=create_master_vocab_op.outputs['experiment_name'],
            experiment_run=create_master_vocab_op.outputs['experiment_run'],
        )
        .set_display_name("Managed TB")
        .set_caching_options(True)
    )


    run_train_task_op = (
        train_custom_model.train_custom_model(
            project=project,
            location=location,
            model_version=model_version,
            pipeline_version=pipeline_version,
            model_name=model_display_name,
            worker_pool_specs=WORKER_POOL_SPECS, 
            train_output_gcs_bucket=train_output_gcs_bucket,
            # vocab_dict_uri=create_master_vocab_op.outputs['master_vocab_gcs_uri'],
            experiment_name=experiment_name,
            experiment_run=experiment_run,
            training_image_uri=train_image_uri,
            tensorboard_resource_name=create_managed_tensorboard_op.outputs['tensorboard_resource_name'], #tensorboard_resource_name, 
            service_account=service_account,
            generate_new_vocab=generate_new_vocab,
        )
        .set_display_name("2Tower Training")
        .set_caching_options(True)
        # .after(build_custom_train_image_op)
    )
    
    # ========================================================================
    # Import trained Query and Candidate Towers to this DAG (metadata)
    # ========================================================================
    
    import_unmanaged_query_model_task = (
        importer_node.importer(
            artifact_uri=run_train_task_op.outputs['query_tower_dir_uri'],
            artifact_class=artifact_types.UnmanagedContainerModel,
            metadata={
                'containerSpec': {
                    'imageUri': 'us-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-10:latest',
                },
            },
        )
        .set_display_name("Import Query Tower")
        .after(run_train_task_op)
        .set_caching_options(True)
    )
    
    import_unmanaged_candidate_model_task = (
        importer_node.importer(
            artifact_uri=run_train_task_op.outputs['candidate_tower_dir_uri'],
            artifact_class=artifact_types.UnmanagedContainerModel,
            metadata={
                'containerSpec': {
                    'imageUri': 'us-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-10:latest',
                },
            },
        )
        .set_display_name("Import Candidate Tower")
        .after(run_train_task_op)
        .set_caching_options(True)
    )
    
    # ========================================================================
    # Conditional: Upload models to Vertex model registry
    # ========================================================================
    # with kfp.v2.dsl.Condition(register_model_flag == "True", name="Register towers"):
        
    # here

    query_model_upload_op = (
        gcc_aip.ModelUploadOp(
            project=project,
            location=location,
            display_name=f'query-tower-{model_display_name}',
            unmanaged_container_model=import_unmanaged_query_model_task.outputs["artifact"],
            labels={"tower": "query"},
        )
        .set_display_name("Upload Query Tower")
        .set_caching_options(True)
    )

    candidate_model_upload_op = (
        gcc_aip.ModelUploadOp(
            project=project,
            location=location,
            display_name=f'candidate-tower-{model_display_name}',
            unmanaged_container_model=import_unmanaged_candidate_model_task.outputs["artifact"],
            labels={"tower": "candidate"},
        )
        .set_display_name("Upload Query Tower to Vertex")
        .set_caching_options(True)
    )

    # ========================================================================
    # Conditional: Precompute Candidate embeddings, Create & Deploy ME indexes
    # ========================================================================
    # with kfp.v2.dsl.Condition(deploy_indexes_flag == "True", name="Create and Deploy Indexes"):

    # ========================================================================
    # Deploy Query Tower to Endpoint
    # ========================================================================
    endpoint_create_op = (
        gcc_aip.EndpointCreateOp(
            project=project,
            display_name=f'query-tower-endpoint-{model_version}'
        )
        .after(query_model_upload_op)
        .set_display_name("Create Query Endpoint")
        .set_caching_options(True)
    )

    model_deploy_op = (
        gcc_aip.ModelDeployOp(
            endpoint=endpoint_create_op.outputs['endpoint'],
            model=query_model_upload_op.outputs['model'],
            deployed_model_display_name=f'deployed-qtower-{model_version}',
            # dedicated_resources_accelerator_type="NVIDIA_TESLA_T4",
            # dedicated_resources_accelerator_count=1,
            # dedicated_resources_max_replica_count=1,
            # dedicated_resources_min_replica_count=1,
            dedicated_resources_machine_type="n1-standard-16",
            dedicated_resources_min_replica_count=1,
            dedicated_resources_max_replica_count=1,
            service_account=service_account,
        )
        .set_display_name("Deploy Query Tower")
        .set_caching_options(True)
    )

    generate_candidates_op = (
        generate_candidates.generate_candidates(
            project=project,
            location=location,
            version=model_version,
            candidate_tower_dir_uri=run_train_task_op.outputs['candidate_tower_dir_uri'],
            train_output_gcs_bucket=train_output_gcs_bucket,
            experiment_name=experiment_name,
            experiment_run=experiment_run,
            candidate_file_dir_bucket=candidate_file_dir,
            candidate_file_dir_prefix=candidate_files_prefix,
            experiment_run_dir=run_train_task_op.outputs['experiment_run_dir']
        )
        .set_display_name("Generate Candidate emb vectors")
        .set_caching_options(True)
        .set_cpu_limit(cfg.CPU_LIMIT)
        .set_memory_limit(cfg.MEMORY_LIMIT)
    )

    # ========================================================================
    # Create ME indexes
    # ========================================================================

    create_ann_index_op = (
        create_ann_index.create_ann_index(
            project=project,
            location=location,
            version=model_version,
            vpc_network_name=vpc_network_name,
            emb_index_gcs_uri=generate_candidates_op.outputs['emb_index_gcs_uri'],
            dimensions=256,
            ann_index_display_name=f'ann_index_pipe_test',
            approximate_neighbors_count=50,
            distance_measure_type="DOT_PRODUCT_DISTANCE",
            leaf_node_embedding_count=500,
            leaf_nodes_to_search_percent=7, 
            ann_index_description="testing ann index for Merlin deployment",
            # ann_index_labels=ann_index_labels,
        )
        .set_display_name("Create ANN Index")
        # .after(XXXX)
        .set_caching_options(True)
    )

    create_brute_force_index_op = (
        create_brute_force_index.create_brute_force_index(
            project=project,
            location=location,
            version=model_version,
            vpc_network_name=vpc_network_name,
            emb_index_gcs_uri=generate_candidates_op.outputs['emb_index_gcs_uri'],
            dimensions=256,
            brute_force_index_display_name=f'bf_index_pipe_test',
            approximate_neighbors_count=50,
            distance_measure_type="DOT_PRODUCT_DISTANCE",
            brute_force_index_description="testing bf index for Merlin deployment",
            # brute_force_index_labels=brute_force_index_labels,
        )
        .set_display_name("Create BF Index")
        # .after(XXX)
        .set_caching_options(True)
    )

    # ========================================================================
    # Create ME index endpoints
    # ========================================================================

    create_ann_index_endpoint_vpc_op = (
        create_ann_index_endpoint_vpc.create_ann_index_endpoint_vpc(
            ann_index_artifact=create_ann_index_op.outputs['ann_index'],
            project=project,
            project_number=project_number,
            version=model_version,
            location=location,
            vpc_network_name=vpc_network_name,
            ann_index_endpoint_display_name=f'ann-index-endpoint',
            ann_index_endpoint_description='endpoint for ann index',
            ann_index_resource_uri=create_ann_index_op.outputs['ann_index_resource_uri'],
        )
        .set_display_name("Create ANN Index Endpoint")
        # .after(XXX)
    )

    create_brute_index_endpoint_vpc_op = (
        create_brute_index_endpoint_vpc.create_brute_index_endpoint_vpc(
            bf_index_artifact=create_brute_force_index_op.outputs['brute_force_index'],
            project=project,
            project_number=project_number,
            version=model_version,
            location=location,
            vpc_network_name=vpc_network_name,
            brute_index_endpoint_display_name=f'bf-index-endpoint',
            brute_index_endpoint_description='endpoint for brute force index',
            brute_force_index_resource_uri=create_brute_force_index_op.outputs['brute_force_index_resource_uri'],
        )
        .set_display_name("Create BF Index Endpoint")
        # .after(XXX)
    )

    # ========================================================================
    # Deploy Indexes
    # ========================================================================

    deploy_ann_index_op = (
        deploy_ann_index.deploy_ann_index(
            project=project,
            location=location,
            version=model_version,
            deployed_ann_index_name=f'deployed_ann_index',
            ann_index_resource_uri=create_ann_index_endpoint_vpc_op.outputs['ann_index_resource_uri'],
            index_endpoint_resource_uri=create_ann_index_endpoint_vpc_op.outputs['ann_index_endpoint_resource_uri'],
        )
        .set_display_name("Deploy ANN Index")
        .set_caching_options(True)
    )

    deploy_brute_index_op = (
        deploy_brute_index.deploy_brute_index(
            project=project,
            location=location,
            version=model_version,
            deployed_brute_force_index_name=f'deployed_bf_index',
            brute_force_index_resource_uri=create_brute_index_endpoint_vpc_op.outputs['brute_force_index_resource_uri'],
            index_endpoint_resource_uri=create_brute_index_endpoint_vpc_op.outputs['brute_index_endpoint_resource_uri'],
        )
        .set_display_name("Deploy BF Index")
        .set_caching_options(True)
    )

    test_index_recall_op = (
        test_index_recall.test_index_recall(
            project=project,
            location=location,
            version=model_version,
            gcs_train_script_path=gcs_train_script_path,
            ann_index_endpoint_resource_uri=deploy_ann_index_op.outputs['index_endpoint_resource_uri'],
            brute_index_endpoint_resource_uri=deploy_brute_index_op.outputs['index_endpoint_resource_uri'],
            endpoint=model_deploy_op.outputs['gcp_resources']
        )
    )

In [105]:
# ! rm -f custom_container_pipeline_spec.json

PIPELINE_JSON_SPEC_LOCAL = "custom_pipeline_spec.json"

! rm -f $PIPELINE_JSON_SPEC_LOCAL

kfp.v2.compiler.Compiler().compile(
    pipeline_func=pipeline, package_path=PIPELINE_JSON_SPEC_LOCAL,
)

### save pipeline spec json

In [106]:
# !gsutil cp custom_container_pipeline_spec.json $PIPELINE_ROOT_PATH/pipeline_spec.json

PIPELINES_FILEPATH = f'{PIPELINE_ROOT_PATH}/pipeline_spec.json'
print("PIPELINES_FILEPATH:", PIPELINES_FILEPATH)

!gsutil cp $PIPELINE_JSON_SPEC_LOCAL $PIPELINES_FILEPATH

PIPELINES_FILEPATH: gs://jt-tfrs-central-v2/tfrs-e2e-pipe-test-v18-jtv5/run-20230214-144020/pipeline_root/pipeline_spec.json
Copying file://custom_pipeline_spec.json [Content-Type=application/json]...
/ [1 files][306.7 KiB/306.7 KiB]                                                
Operation completed over 1 objects/306.7 KiB.                                    


In [107]:
!gsutil ls $PIPELINE_ROOT_PATH

gs://jt-tfrs-central-v2/tfrs-e2e-pipe-test-v18-jtv5/run-20230214-144020/pipeline_root/Dockerfile.tfrs
gs://jt-tfrs-central-v2/tfrs-e2e-pipe-test-v18-jtv5/run-20230214-144020/pipeline_root/cloudbuild.yaml
gs://jt-tfrs-central-v2/tfrs-e2e-pipe-test-v18-jtv5/run-20230214-144020/pipeline_root/pipeline_spec.json
gs://jt-tfrs-central-v2/tfrs-e2e-pipe-test-v18-jtv5/run-20230214-144020/pipeline_root/934903580331/
gs://jt-tfrs-central-v2/tfrs-e2e-pipe-test-v18-jtv5/run-20230214-144020/pipeline_root/trainer/


## Submit pipeline to Vertex

In [108]:
PIPELINE_NAME

'tfrs-jtv5-2tower-pipev3'

In [109]:
PROJECT_NUMBER='934903580331'
vpc_network_name = 'ucaip-haystack-vpc-network'
# SERVICE_ACCOUNT = '934903580331-compute@developer.gserviceaccount.com'
SERVICE_ACCOUNT = 'notebooksa@hybrid-vertex.iam.gserviceaccount.com'

TRAIN_APP_CODE_PATH = f'{PIPELINE_ROOT_PATH}/trainer'

job = vertex_ai.PipelineJob(
    display_name=PIPELINE_NAME,
    template_path=PIPELINES_FILEPATH,
    pipeline_root=f'{PIPELINE_ROOT_PATH}',
    failure_policy='fast', # slow | fast
    # enable_caching=False,
    parameter_values={
        'project': PROJECT_ID,
        'project_number': PROJECT_NUMBER,
        'location': REGION,
        'model_version': VERSION,
        'pipeline_version': PIPELINE_VERSION,
        'model_display_name': MODEL_ROOT_NAME,
        'vpc_network_name':vpc_network_name,
        # 'pipeline_tag': PIPELINE_TAG,
        'gcs_train_script_path': TRAIN_APP_CODE_PATH,
        'train_image_uri': f"{IMAGE_URI}",
        'train_output_gcs_bucket': OUTPUT_BUCKET,
        'train_dir': BUCKET_DATA_DIR,
        'train_dir_prefix': TRAIN_DIR_PREFIX,
        'valid_dir': BUCKET_DATA_DIR,
        'valid_dir_prefix': VALID_DIR_PREFIX,
        'candidate_file_dir': BUCKET_DATA_DIR,
        'candidate_files_prefix': CANDIDATE_PREFIX,
        # 'tensorboard_resource_name': TB_RESOURCE_NAME,
        'train_dockerfile_name': DOCKERNAME,
        'experiment_name': EXPERIMENT_NAME,
        'experiment_run': RUN_NAME,
        'service_account': SERVICE_ACCOUNT,
        'register_model_flag': 'True',
        # 'deploy_indexes_flag': 'True',
        'generate_new_vocab': False,
        'max_playlist_length': 15,
        'max_tokens': 20000,
        'ngrams': 2,
    },
)

job.run(
    sync=False,
    service_account=SERVICE_ACCOUNT,
    network=f'projects/{PROJECT_NUMBER}/global/networks/{vpc_network_name}'
)

#### clean up

In [74]:
# ! rm -rf custom_pipeline_spec.json

## local test

In [29]:
# !ls
from src.two_tower_jt import test_instances as test_instances

In [183]:
# test_instances.TEST_INSTANCE_5

In [184]:
# test_instances.TEST_INSTANCE_15

### Feature defs

In [None]:
vocab_dict_uris = [
    vocab_uri_1, vocab_uri_2, 
    vocab_uri_3, vocab_uri_4, 
    vocab_uri_5, vocab_uri_6, 
    vocab_uri_7, vocab_uri_8, 
    vocab_uri_9, 
]
logging.info(f"count of vocab_dict_uris: {len(vocab_dict_uris)}")
logging.info(f"vocab_dict_uris: {vocab_dict_uris}")

# ===================================================
# load pickled dicts
# ===================================================

loaded_pickle_list = []
for i, pickled_dict in enumerate(vocab_dict_uris):

    with open(f"v{i}_vocab_pre_load", 'wb') as local_vocab_file:
        storage_client.download_blob_to_file(pickled_dict, local_vocab_file)

    with open(f"v{i}_vocab_pre_load", 'rb') as pickle_file:
        loaded_vocab_dict = pkl.load(pickle_file)

    loaded_pickle_list.append(loaded_vocab_dict)

# ===================================================
# create master vocab dict
# ===================================================
master_dict = {}
for loaded_dict in loaded_pickle_list:
    master_dict.update(loaded_dict)

# ===================================================
# Upload master to GCS
# ===================================================
MASTER_VOCAB_LOCAL_FILE = f'vocab_dict.pkl'
MASTER_VOCAB_GCS_OBJ = f'{experiment_name}/{experiment_run}/{MASTER_VOCAB_LOCAL_FILE}' # destination folder prefix and blob name

# pickle
filehandler = open(f'{MASTER_VOCAB_LOCAL_FILE}', 'wb')
pkl.dump(master_dict, filehandler)
filehandler.close()

# upload to GCS
bucket_client = storage_client.bucket(train_output_gcs_bucket)
blob = bucket_client.blob(MASTER_VOCAB_GCS_OBJ)
blob.upload_from_filename(MASTER_VOCAB_LOCAL_FILE)

In [188]:
import tensorflow as tf

candidate_features_dict = {
    "track_uri_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),            
    "track_name_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),
    "artist_uri_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),
    "artist_name_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),
    "album_uri_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),           
    "album_name_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()), 
    "duration_ms_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),      
    "track_pop_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),      
    "artist_pop_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
    "artist_genres_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),
    "artist_followers_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
    # new
    # "track_pl_titles_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),
    "track_danceability_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
    "track_energy_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
    "track_key_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),
    "track_loudness_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
    "track_mode_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),
    "track_speechiness_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
    "track_acousticness_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
    "track_instrumentalness_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
    "track_liveness_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
    "track_valence_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
    "track_tempo_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
    "time_signature_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),
}

2023-02-14 08:19:42.940902: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-02-14 08:19:45.429956: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-02-14 08:19:53.453171: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda/lib64:/usr/local/nccl2/lib:/usr/local/cuda/extras/CUPTI/lib64
2023-02-1

### dump to pickle

In [190]:
import pickle as pkl

candidate_pickle = 'candidate_features.pkl'

filehandler = open(f'{candidate_pickle}', 'wb')
pkl.dump(candidate_features_dict, filehandler)
filehandler.close()

### upload to GCS

In [193]:
DESTINATION_PATH_PREFIX = f'a-local-test/feature-dicts'

storage_client = storage.Client(project=PROJECT_ID)
bucket = storage_client.bucket(OUTPUT_BUCKET)
blob = bucket.blob(f'{DESTINATION_PATH_PREFIX}/{candidate_pickle}')

blob.upload_from_filename(candidate_pickle)

### load pickled dict & use

In [194]:
import os
os.system(f'gsutil cp gs://{OUTPUT_BUCKET}/{DESTINATION_PATH_PREFIX}/{candidate_pickle} loaded_cand_feats.pkl')  # jw-repo/spotify_mpd_two_tower/loaded_cand_feats.pkl

Copying gs://jt-tfrs-central-v2/a-local-test/feature-dicts/candidate_features.pkl...
/ [1 files][  965.0 B/  965.0 B]                                                
Operation completed over 1 objects/965.0 B.                                      


0

In [195]:
import os
os.system(f'gsutil cp gs://{OUTPUT_BUCKET}/{DESTINATION_PATH_PREFIX}/{candidate_pickle} loaded_cand_feats.pkl')

filehandler = open('loaded_cand_feats.pkl', 'rb')
loaded_candidate_feature_dict = pkl.load(filehandler)
filehandler.close()

loaded_candidate_feature_dict

{'track_uri_can': FixedLenFeature(shape=(), dtype=tf.string, default_value=None),
 'track_name_can': FixedLenFeature(shape=(), dtype=tf.string, default_value=None),
 'artist_uri_can': FixedLenFeature(shape=(), dtype=tf.string, default_value=None),
 'artist_name_can': FixedLenFeature(shape=(), dtype=tf.string, default_value=None),
 'album_uri_can': FixedLenFeature(shape=(), dtype=tf.string, default_value=None),
 'album_name_can': FixedLenFeature(shape=(), dtype=tf.string, default_value=None),
 'duration_ms_can': FixedLenFeature(shape=(), dtype=tf.float32, default_value=None),
 'track_pop_can': FixedLenFeature(shape=(), dtype=tf.float32, default_value=None),
 'artist_pop_can': FixedLenFeature(shape=(), dtype=tf.float32, default_value=None),
 'artist_genres_can': FixedLenFeature(shape=(), dtype=tf.string, default_value=None),
 'artist_followers_can': FixedLenFeature(shape=(), dtype=tf.float32, default_value=None),
 'track_danceability_can': FixedLenFeature(shape=(), dtype=tf.float32, defa

In [None]:
for i in range(0, count):
    instance = {}
    for key, dict in loaded_candidate_feature_dict.items():
        new_dict = dict
        if key in candidate_features_dict.keys():
            new_dict = dict + perturb_cat[key]
        instance[key] = random.choices(new_dict)[0]


  for i in range(0, count):
    instance = {}
    for key, dict in CATEGORICAL_FEATURES.items():
      new_dict = dict
      if key in perturb_cat.keys():
        new_dict = dict + perturb_cat[key]
      instance[key] = random.choices(new_dict)[0]

In [None]:
for feat_name, fixed_length_def in loaded_candidate_feature_dict:
    
    names = [x['output_1'].numpy()[0] for x in embs] #clean up the output
    
embs_iter = parsed_candidate_dataset.batch(1).map(
    lambda data: candidate_predictor(
        track_uri_can = data["track_uri_can"],
        track_name_can = data['track_name_can'],
        artist_uri_can = data['artist_uri_can'],
        artist_name_can = data['artist_name_can'],
        album_uri_can = data['album_uri_can'],
        album_name_can = data['album_name_can'],
        duration_ms_can = data['duration_ms_can'],
        track_pop_can = data['track_pop_can'],
        artist_pop_can = data['artist_pop_can'],
        artist_genres_can = data['artist_genres_can'],
        artist_followers_can = data['artist_followers_can'],
        track_danceability_can = data['track_danceability_can'],
        track_energy_can = data['track_energy_can'],
        track_key_can = data['track_key_can'],
        track_loudness_can = data['track_loudness_can'],
        track_mode_can = data['track_mode_can'],
        track_speechiness_can = data['track_speechiness_can'],
        track_acousticness_can = data['track_acousticness_can'],
        track_instrumentalness_can = data['track_instrumentalness_can'],
        track_liveness_can = data['track_liveness_can'],
        track_valence_can = data['track_valence_can'],
        track_tempo_can = data['track_tempo_can'],
        time_signature_can = data['time_signature_can']
    )
)

In [67]:
GCS_URI = 'gs://jt-tfrs-central-v2/tfrs-e2e-pipe-test-v16-jtv5/run-20230214-110916/features/candidate_feats_dict.pkl'
LOCAL_FILENAME = 'v2-loaded-candidate-gcs.pkl'
with open(f"{LOCAL_FILENAME}", 'wb') as loaded:
        storage_client.download_blob_to_file(
            f"{GCS_URI}", loaded
        )
LOCAL_FILENAME

'v2-loaded-candidate-gcs.pkl'

In [68]:
TEST_BUCKET = 'jt-tfrs-central-v2'
TEST_BLOB_NAME = 'tfrs-e2e-pipe-test-v16-jtv5/run-20230214-110916/features/candidate_feats_dict.pkl'
TEST_LOCAL_FILENAME = 'v3-loaded_candidate_feats_dict.pkl'

# blob = bucket.get_blob(source_blob_name)

bucket = storage_client.bucket(TEST_BUCKET)
blob = bucket.blob(TEST_BLOB_NAME)
blob.download_to_filename(TEST_LOCAL_FILENAME)

filehandler = open(f'{TEST_LOCAL_FILENAME}', 'rb')
v3_loaded_query_features_dict = pkl.load(filehandler)
filehandler.close()

v3_loaded_query_features_dict

In [69]:
filehandler = open(f'{TEST_LOCAL_FILENAME}', 'rb')
v3_loaded_query_features_dict = pkl.load(filehandler)
filehandler.close()

v3_loaded_query_features_dict

{'track_uri_can': FixedLenFeature(shape=(), dtype=tf.string, default_value=None),
 'track_name_can': FixedLenFeature(shape=(), dtype=tf.string, default_value=None),
 'artist_uri_can': FixedLenFeature(shape=(), dtype=tf.string, default_value=None),
 'artist_name_can': FixedLenFeature(shape=(), dtype=tf.string, default_value=None),
 'album_uri_can': FixedLenFeature(shape=(), dtype=tf.string, default_value=None),
 'album_name_can': FixedLenFeature(shape=(), dtype=tf.string, default_value=None),
 'duration_ms_can': FixedLenFeature(shape=(), dtype=tf.float32, default_value=None),
 'track_pop_can': FixedLenFeature(shape=(), dtype=tf.float32, default_value=None),
 'artist_pop_can': FixedLenFeature(shape=(), dtype=tf.float32, default_value=None),
 'artist_genres_can': FixedLenFeature(shape=(), dtype=tf.string, default_value=None),
 'artist_followers_can': FixedLenFeature(shape=(), dtype=tf.float32, default_value=None),
 'track_danceability_can': FixedLenFeature(shape=(), dtype=tf.float32, defa

### query features

In [None]:
pl_query_features = {
    # ===================================================
    # summary playlist features
    # ===================================================
    "pl_name_src" : tf.io.FixedLenFeature(dtype=tf.string, shape=()), 
    'pl_collaborative_src' : tf.io.FixedLenFeature(dtype=tf.string, shape=()), 
    # 'num_pl_followers_src' : tf.io.FixedLenFeature(dtype=tf.float32, shape=()), 
    'pl_duration_ms_new' : tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
    'num_pl_songs_new' : tf.io.FixedLenFeature(dtype=tf.float32, shape=()), # n_songs_pl_new | num_pl_songs_new
    'num_pl_artists_new' : tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
    'num_pl_albums_new' : tf.io.FixedLenFeature(dtype=tf.float32, shape=()), 
    # 'avg_track_pop_pl_new' : tf.io.FixedLenFeature(dtype=tf.float32, shape=()), 
    # 'avg_artist_pop_pl_new' : tf.io.FixedLenFeature(dtype=tf.float32, shape=()), 
    # 'avg_art_followers_pl_new' : tf.io.FixedLenFeature(dtype=tf.float32, shape=()), 

    # ===================================================
    # ragged playlist features
    # ===================================================
    # bytes / string
    "track_uri_pl": tf.io.FixedLenFeature(dtype=tf.string, shape=(MAX_PLAYLIST_LENGTH,)),
    "track_name_pl": tf.io.FixedLenFeature(dtype=tf.string, shape=(MAX_PLAYLIST_LENGTH,)),
    "artist_uri_pl": tf.io.FixedLenFeature(dtype=tf.string, shape=(MAX_PLAYLIST_LENGTH,)),
    "artist_name_pl": tf.io.FixedLenFeature(dtype=tf.string, shape=(MAX_PLAYLIST_LENGTH,)),
    "album_uri_pl": tf.io.FixedLenFeature(dtype=tf.string, shape=(MAX_PLAYLIST_LENGTH,)),
    "album_name_pl": tf.io.FixedLenFeature(dtype=tf.string, shape=(MAX_PLAYLIST_LENGTH,)),
    "artist_genres_pl": tf.io.FixedLenFeature(dtype=tf.string, shape=(MAX_PLAYLIST_LENGTH,)),
    # "tracks_playlist_titles_pl": tf.io.FixedLenFeature(dtype=tf.string, shape=(MAX_PLAYLIST_LENGTH,)),
    "track_key_pl": tf.io.FixedLenFeature(dtype=tf.string, shape=(MAX_PLAYLIST_LENGTH,)),
    "track_mode_pl": tf.io.FixedLenFeature(dtype=tf.string, shape=(MAX_PLAYLIST_LENGTH,)),
    "time_signature_pl": tf.io.FixedLenFeature(dtype=tf.string, shape=(MAX_PLAYLIST_LENGTH,)), 

    # Float List
    "duration_ms_songs_pl": tf.io.FixedLenFeature(dtype=tf.float32, shape=(MAX_PLAYLIST_LENGTH,)),
    "track_pop_pl": tf.io.FixedLenFeature(dtype=tf.float32, shape=(MAX_PLAYLIST_LENGTH,)),
    "artist_pop_pl": tf.io.FixedLenFeature(dtype=tf.float32, shape=(MAX_PLAYLIST_LENGTH,)),
    "artists_followers_pl": tf.io.FixedLenFeature(dtype=tf.float32, shape=(MAX_PLAYLIST_LENGTH,)),
    "track_danceability_pl": tf.io.FixedLenFeature(dtype=tf.float32, shape=(MAX_PLAYLIST_LENGTH,)),
    "track_energy_pl": tf.io.FixedLenFeature(dtype=tf.float32, shape=(MAX_PLAYLIST_LENGTH,)),
    "track_loudness_pl": tf.io.FixedLenFeature(dtype=tf.float32, shape=(MAX_PLAYLIST_LENGTH,)),
    "track_speechiness_pl": tf.io.FixedLenFeature(dtype=tf.float32, shape=(MAX_PLAYLIST_LENGTH,)),
    "track_acousticness_pl": tf.io.FixedLenFeature(dtype=tf.float32, shape=(MAX_PLAYLIST_LENGTH,)),
    "track_instrumentalness_pl": tf.io.FixedLenFeature(dtype=tf.float32, shape=(MAX_PLAYLIST_LENGTH,)),
    "track_liveness_pl": tf.io.FixedLenFeature(dtype=tf.float32, shape=(MAX_PLAYLIST_LENGTH,)),
    "track_valence_pl": tf.io.FixedLenFeature(dtype=tf.float32, shape=(MAX_PLAYLIST_LENGTH,)),
    "track_tempo_pl": tf.io.FixedLenFeature(dtype=tf.float32, shape=(MAX_PLAYLIST_LENGTH,)),
}

# feats = {
    # # ===================================================
    # # candidate track features
    # # ===================================================
    # "track_uri_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),            
    # "track_name_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),
    # "artist_uri_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),
    # "artist_name_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),
    # "album_uri_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),           
    # "album_name_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()), 
    # "duration_ms_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),      
    # "track_pop_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),      
    # "artist_pop_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
    # "artist_genres_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),
    # "artist_followers_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
    # # "track_pl_titles_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),
    # "track_danceability_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
    # "track_energy_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
    # "track_key_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),
    # "track_loudness_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
    # "track_mode_can":tf.io.FixedLenFeature(dtype=tf.string, shape=()),
    # "track_speechiness_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
    # "track_acousticness_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
    # "track_instrumentalness_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
    # "track_liveness_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
    # "track_valence_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
    # "track_tempo_can":tf.io.FixedLenFeature(dtype=tf.float32, shape=()),
    # "time_signature_can": tf.io.FixedLenFeature(dtype=tf.string, shape=()), # track_time_signature_can

In [59]:
def download_blob(bucket_name, source_gcs_obj, local_filename):
    """Uploads a file to the bucket."""
    # storage_client = storage.Client(project=project_number)
    bucket = storage_client.bucket(bucket_name)
    blob = bucket.blob(source_gcs_obj)
    blob.download_to_filename(local_filename)

    filehandler = open(f'{local_filename}', 'rb')
    loaded_dict = pkl.load(filehandler)
    filehandler.close()

    logging.info(f"File {local_filename} downloaded from gs://{bucket_name}/{source_gcs_obj}")

    return loaded_dict

# candidate features
train_output_gcs_bucket = 'jt-tfrs-central-v2'
CAND_FEAT_FILENAME = 'candidate_feats_dict.pkl'
CAND_FEAT_GCS_OBJ = f'tfrs-e2e-pipe-test-v16-jtv5/run-20230214-110916/features/{CAND_FEAT_FILENAME}'
LOADED_CANDIDATE_DICT = f'loaded_{CAND_FEAT_FILENAME}'

loaded_candidate_features_dict = download_blob(
    train_output_gcs_bucket,
    CAND_FEAT_GCS_OBJ,
    LOADED_CANDIDATE_DICT
)

In [60]:
# candidate features
train_output_gcs_bucket = 'jt-tfrs-central-v2'
CAND_FEAT_FILENAME = 'candidate_feats_dict.pkl'
CAND_FEAT_GCS_OBJ = f'tfrs-e2e-pipe-test-v16-jtv5/run-20230214-110916/features/{CAND_FEAT_FILENAME}'
LOADED_CANDIDATE_DICT = f'loaded_{CAND_FEAT_FILENAME}'

loaded_candidate_features_dict = download_blob(
    train_output_gcs_bucket,
    CAND_FEAT_GCS_OBJ,
    LOADED_CANDIDATE_DICT
)
# logg
loaded_candidate_features_dict

{'track_uri_can': FixedLenFeature(shape=(), dtype=tf.string, default_value=None),
 'track_name_can': FixedLenFeature(shape=(), dtype=tf.string, default_value=None),
 'artist_uri_can': FixedLenFeature(shape=(), dtype=tf.string, default_value=None),
 'artist_name_can': FixedLenFeature(shape=(), dtype=tf.string, default_value=None),
 'album_uri_can': FixedLenFeature(shape=(), dtype=tf.string, default_value=None),
 'album_name_can': FixedLenFeature(shape=(), dtype=tf.string, default_value=None),
 'duration_ms_can': FixedLenFeature(shape=(), dtype=tf.float32, default_value=None),
 'track_pop_can': FixedLenFeature(shape=(), dtype=tf.float32, default_value=None),
 'artist_pop_can': FixedLenFeature(shape=(), dtype=tf.float32, default_value=None),
 'artist_genres_can': FixedLenFeature(shape=(), dtype=tf.string, default_value=None),
 'artist_followers_can': FixedLenFeature(shape=(), dtype=tf.float32, default_value=None),
 'track_danceability_can': FixedLenFeature(shape=(), dtype=tf.float32, defa

In [60]:
TEST_INSTANCE_15 = {
        'album_name_can': 'Capoeira Electronica',
        'album_name_pl': [
            'Odilara', 'Capoeira Electronica', 'Capoeira Ultimate','Festa Popular', 'Capoeira Electronica',
            'Odilara', 'Capoeira Electronica', 'Capoeira Ultimate','Festa Popular', 'Capoeira Electronica',
            'Odilara', 'Capoeira Electronica', 'Capoeira Ultimate','Festa Popular', 'Capoeira Electronica'
        ],
        'album_uri_can': 'spotify:album:2FsSSHGt8JM0JgRy6ZX3kR',
        'album_uri_pl': [
            'spotify:album:4Y8RfvZzCiApBCIZswj9Ry',
            'spotify:album:2FsSSHGt8JM0JgRy6ZX3kR',
            'spotify:album:55HHBqZ2SefPeaENOgWxYK',
            'spotify:album:150L1V6UUT7fGUI3PbxpkE',
            'spotify:album:2FsSSHGt8JM0JgRy6ZX3kR',
            'spotify:album:4Y8RfvZzCiApBCIZswj9Ry',
            'spotify:album:2FsSSHGt8JM0JgRy6ZX3kR',
            'spotify:album:55HHBqZ2SefPeaENOgWxYK',
            'spotify:album:150L1V6UUT7fGUI3PbxpkE',
            'spotify:album:2FsSSHGt8JM0JgRy6ZX3kR',
            'spotify:album:4Y8RfvZzCiApBCIZswj9Ry',
            'spotify:album:2FsSSHGt8JM0JgRy6ZX3kR',
            'spotify:album:55HHBqZ2SefPeaENOgWxYK',
            'spotify:album:150L1V6UUT7fGUI3PbxpkE',
            'spotify:album:2FsSSHGt8JM0JgRy6ZX3kR'
        ],
        'artist_followers_can': 5170.0,
        'artist_genres_can': 'capoeira',
        'artist_genres_pl': [
            'samba moderno', 'capoeira', 'capoeira', 'NONE','capoeira',
            'samba moderno', 'capoeira', 'capoeira', 'NONE','capoeira',
            'samba moderno', 'capoeira', 'capoeira', 'NONE','capoeira'
        ],
        'artist_name_can': 'Capoeira Experience',
        'artist_name_pl': [
            'Odilara', 'Capoeira Experience', 'Denis Porto', 'Zambe','Capoeira Experience',
            'Odilara', 'Capoeira Experience', 'Denis Porto', 'Zambe','Capoeira Experience',
            'Odilara', 'Capoeira Experience', 'Denis Porto', 'Zambe','Capoeira Experience'
        ],
        'artist_pop_can': 24.0,
        'artist_pop_pl':[
            4., 24.,  2.,  0., 24.,
            4., 24.,  2.,  0., 24.,
            4., 24.,  2.,  0., 24.
        ],
        'artist_uri_can': 'spotify:artist:5SKEXbgzIdRl3gQJ23CnUP',
        'artist_uri_pl': [
            'spotify:artist:72oameojLOPWYB7nB8rl6c',
            'spotify:artist:5SKEXbgzIdRl3gQJ23CnUP',
            'spotify:artist:67p5GMYQZOgaAfx1YyttQk',
            'spotify:artist:4fH3OXCRcPsaHFE5KhgqZS',
            'spotify:artist:5SKEXbgzIdRl3gQJ23CnUP',
            'spotify:artist:72oameojLOPWYB7nB8rl6c',
            'spotify:artist:5SKEXbgzIdRl3gQJ23CnUP',
            'spotify:artist:67p5GMYQZOgaAfx1YyttQk',
            'spotify:artist:4fH3OXCRcPsaHFE5KhgqZS',
            'spotify:artist:5SKEXbgzIdRl3gQJ23CnUP',
            'spotify:artist:72oameojLOPWYB7nB8rl6c',
            'spotify:artist:5SKEXbgzIdRl3gQJ23CnUP',
            'spotify:artist:67p5GMYQZOgaAfx1YyttQk',
            'spotify:artist:4fH3OXCRcPsaHFE5KhgqZS',
            'spotify:artist:5SKEXbgzIdRl3gQJ23CnUP'
        ],
        'artists_followers_pl': [ 
            316., 5170.,  448.,   19., 5170.,
            316., 5170.,  448.,   19., 5170.,
            316., 5170.,  448.,   19., 5170.
        ],
        'duration_ms_can': 192640.0,
        'duration_ms_songs_pl': [234612., 226826., 203480., 287946., 271920., 234612., 226826., 203480., 287946., 271920., 234612., 226826., 203480., 287946., 271920.],
        'num_pl_albums_new': 9.0,
        'num_pl_artists_new': 5.0,
        'num_pl_songs_new': 85.0,
        'pl_collaborative_src': 'false',
        'pl_duration_ms_new': 17971314.0,
        'pl_name_src': 'Capoeira',
        'time_signature_can': '4',
        'time_signature_pl': ['4', '4', '4', '4', '4', '4', '4', '4', '4', '4', '4', '4', '4', '4', '4'],
        'track_acousticness_can': 0.478,
        'track_acousticness_pl': [0.238 , 0.105 , 0.0242, 0.125 , 0.304, 0.238 , 0.105 , 0.0242, 0.125 , 0.304, 0.238 , 0.105 , 0.0242, 0.125 , 0.304 ],
        'track_danceability_can': 0.709,
        'track_danceability_pl': [0.703, 0.712, 0.806, 0.529, 0.821, 0.238 , 0.105 , 0.0242, 0.125 , 0.304, 0.238 , 0.105 , 0.0242, 0.125 , 0.304],
        'track_energy_can': 0.742,
        'track_energy_pl': [0.743, 0.41 , 0.794, 0.776, 0.947, 0.238 , 0.105 , 0.0242, 0.125 , 0.304, 0.238 , 0.105 , 0.0242, 0.125 , 0.304],
        'track_instrumentalness_can': 0.00297,
        'track_instrumentalness_pl': [4.84e-06, 4.30e-01, 7.42e-04, 4.01e-01, 5.07e-03, 4.84e-06, 4.30e-01, 7.42e-04, 4.01e-01, 5.07e-03, 4.84e-06, 4.30e-01, 7.42e-04, 4.01e-01, 5.07e-03],
        'track_key_can': '0',
        'track_key_pl': ['5', '0', '1', '10', '10', '5', '0', '1', '10', '10', '5', '0', '1', '10', '10'],
        'track_liveness_can': 0.0346,
        'track_liveness_pl': [0.128 , 0.0725, 0.191 , 0.105 , 0.0552,0.128 , 0.0725, 0.191 , 0.105 , 0.0552, 0.128 , 0.0725, 0.191 , 0.105 , 0.0552],
        'track_loudness_can': -7.295,
        'track_loudness_pl': [-8.638, -8.754, -9.084, -7.04 , -6.694, -8.638, -8.754, -9.084, -7.04 , -6.694, -8.638, -8.754, -9.084, -7.04 , -6.694],
        'track_mode_can': '1',
        'track_mode_pl': ['0', '1', '1', '0', '1', '0', '1', '1', '0', '1', '0', '1', '1', '0', '1'],
        'track_name_can': 'Bezouro Preto - Studio',
        'track_name_pl': [
            'O Telefone Tocou Novamente', 'Bem Devagar - Studio','Angola Dream', 'Janaina', 'Louco Berimbau - Studio',
            'O Telefone Tocou Novamente', 'Bem Devagar - Studio','Angola Dream', 'Janaina', 'Louco Berimbau - Studio',
            'O Telefone Tocou Novamente', 'Bem Devagar - Studio','Angola Dream', 'Janaina', 'Louco Berimbau - Studio'
        ],
        'track_pop_can': 3.0,
        'track_pop_pl': [5., 1., 0., 0., 1., 5., 1., 0., 0., 1., 5., 1., 0., 0., 1.],
        'track_speechiness_can': 0.0802,
        'track_speechiness_pl':[0.0367, 0.0272, 0.0407, 0.132 , 0.0734, 0.0367, 0.0272, 0.0407, 0.132 , 0.0734, 0.0367, 0.0272, 0.0407, 0.132 , 0.0734],
        'track_tempo_can': 172.238,
        'track_tempo_pl': [100.039,  89.089, 123.999, 119.963, 119.214, 100.039,  89.089, 123.999, 119.963, 119.214, 100.039,  89.089, 123.999, 119.963, 119.214],
        'track_uri_can': 'spotify:track:0tlhK4OvpHCYpReTABvKFb',
        'track_uri_pl': [
            'spotify:track:1pQkOdcTDfLr84TDCrmGy7',
            'spotify:track:39grEDsAHAjmo2QFo4G8D9',
            'spotify:track:5vxSLdJXqbKYH487YO8LSL',
            'spotify:track:6T9GbmZ6voDM4aTBsG5VDh',
            'spotify:track:7ELt9eslVvWo276pX2garN',
            'spotify:track:1pQkOdcTDfLr84TDCrmGy7',
            'spotify:track:39grEDsAHAjmo2QFo4G8D9',
            'spotify:track:5vxSLdJXqbKYH487YO8LSL',
            'spotify:track:6T9GbmZ6voDM4aTBsG5VDh',
            'spotify:track:7ELt9eslVvWo276pX2garN',
            'spotify:track:1pQkOdcTDfLr84TDCrmGy7',
            'spotify:track:39grEDsAHAjmo2QFo4G8D9',
            'spotify:track:5vxSLdJXqbKYH487YO8LSL',
            'spotify:track:6T9GbmZ6voDM4aTBsG5VDh',
            'spotify:track:7ELt9eslVvWo276pX2garN'
        ],
        'track_valence_can': 0.844,
        'track_valence_pl': [
            0.966, 0.667, 0.696, 0.876, 0.655,
            0.966, 0.667, 0.696, 0.876, 0.655,
            0.966, 0.667, 0.696, 0.876, 0.655
        ],
    }

In [61]:
# query features
LOCAL_TEST_INSTANCE = 'test_instance_15_dict.pkl'

BUCKET_TEST = 'spotify-data-regimes'
PREFIX = 'jtv15-8m'
TEST_GCS_OBJ = f'{PREFIX}/{LOCAL_TEST_INSTANCE}'

# pickle
filehandler = open(f'{LOCAL_TEST_INSTANCE}', 'wb')
pkl.dump(TEST_INSTANCE_15, filehandler)
filehandler.close()

# upload to GCS
bucket_client = storage_client.bucket(BUCKET_TEST)
blob = bucket_client.blob(TEST_GCS_OBJ)
blob.upload_from_filename(LOCAL_TEST_INSTANCE)

In [62]:
local_name = "local-jt-test-instance.pkl"
BUCKET_TEST = 'spotify-data-regimes'
PREFIX = 'jtv15-8m'
TEST_GCS_OBJ = f'{PREFIX}/{LOCAL_TEST_INSTANCE}'

bucket = storage_client.bucket(BUCKET_TEST)
blob = bucket.blob(TEST_GCS_OBJ)
blob.download_to_filename(local_name)

filehandler = open(f'{local_name}', 'rb')
loaded_test_dict = pkl.load(filehandler)
filehandler.close()

loaded_test_dict

{'album_name_can': 'Capoeira Electronica',
 'album_name_pl': ['Odilara',
  'Capoeira Electronica',
  'Capoeira Ultimate',
  'Festa Popular',
  'Capoeira Electronica',
  'Odilara',
  'Capoeira Electronica',
  'Capoeira Ultimate',
  'Festa Popular',
  'Capoeira Electronica',
  'Odilara',
  'Capoeira Electronica',
  'Capoeira Ultimate',
  'Festa Popular',
  'Capoeira Electronica'],
 'album_uri_can': 'spotify:album:2FsSSHGt8JM0JgRy6ZX3kR',
 'album_uri_pl': ['spotify:album:4Y8RfvZzCiApBCIZswj9Ry',
  'spotify:album:2FsSSHGt8JM0JgRy6ZX3kR',
  'spotify:album:55HHBqZ2SefPeaENOgWxYK',
  'spotify:album:150L1V6UUT7fGUI3PbxpkE',
  'spotify:album:2FsSSHGt8JM0JgRy6ZX3kR',
  'spotify:album:4Y8RfvZzCiApBCIZswj9Ry',
  'spotify:album:2FsSSHGt8JM0JgRy6ZX3kR',
  'spotify:album:55HHBqZ2SefPeaENOgWxYK',
  'spotify:album:150L1V6UUT7fGUI3PbxpkE',
  'spotify:album:2FsSSHGt8JM0JgRy6ZX3kR',
  'spotify:album:4Y8RfvZzCiApBCIZswj9Ry',
  'spotify:album:2FsSSHGt8JM0JgRy6ZX3kR',
  'spotify:album:55HHBqZ2SefPeaENOgWxYK'