# Candidate Generation

After Two-Tower training, the `candidate_tower` is used to convert all candidate items into embeddings.

The embeddings are indexed and deployed to an endpoint for serving.

Steps performed in this notebook:

1. Create a test dataset to send to the query endpoint 
2. Submit the `endpoint.predict()` calls to get the query vector representation
3. Inspect the records and familiarize, check for data quality


In [1]:
GCP_PROJECTS = !gcloud config get-value project
PROJECT_ID = GCP_PROJECTS[0]
PROJECT_NUM = !gcloud projects list --filter="$PROJECT_ID" --format="value(PROJECT_NUMBER)"
PROJECT_NUM = PROJECT_NUM[0]
LOCATION = 'us-central1'

print(f"PROJECT_ID: {PROJECT_ID}")
print(f"PROJECT_NUM: {PROJECT_NUM}")
print(f"LOCATION: {LOCATION}")

PROJECT_ID: hybrid-vertex
PROJECT_NUM: 934903580331
LOCATION: us-central1


In [3]:
import json
import numpy as np
import pickle as pkl
from pprint import pprint
import time

import os

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 

import tensorflow as tf
import tensorflow_recommenders as tfrs
import tensorflow_io as tfio

from google.cloud import storage
from google.cloud.storage.bucket import Bucket
from google.cloud.storage.blob import Blob

import google.cloud.aiplatform as vertex_ai

from src.two_tower_jt import two_tower as tt
from src.two_tower_jt import train_utils
from src.two_tower_jt import feature_sets

In [3]:
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)
    
gpus

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

In [4]:
# jt-tfrs-central-v3/customtest5-trainerv3/run-20230321-005425

BUCKET = 'jt-tfrs-central-v3' # -v2 jt-tfrs-central-v2
BUCKET_URI = f'gs://{BUCKET}'

# PATH_TO_INDEX_DIR = 'customtest5-trainerv3/run-20230321-005425' # 50e-8m-128d
# EXPERIMENT_TAG = 'customtest5-trainerv3'

# PATH_TO_INDEX_DIR = 'test-e2e-v1-trainerv4/run-20230410-232135' #/model-dir/candidate_model'
# EXPERIMENT_TAG = 'test-e2e-v1-trainerv4'

# PATH_TO_INDEX_DIR = '8m-tfrs-v100-jtv15/run-20230125-205451' # 50e-8m-128d
# EXPERIMENT_TAG = '8m-tfrs-v100-jtv15'

PATH_TO_INDEX_DIR = 'customtest5-trainerv3/run-20230321-005425' # 50e-8m-128d
EXPERIMENT_TAG = 'customtest5-trainerv3'

# full gcs path
CANDIDATE_MODEL_GCS_PATH = f'{PATH_TO_INDEX_DIR}/model-dir/candidate_model'
CANDIDATE_MODEL_DIR = f'{BUCKET_URI}/{CANDIDATE_MODEL_GCS_PATH}'

print(f"CANDIDATE_MODEL_DIR: {CANDIDATE_MODEL_DIR}")

CANDIDATE_MODEL_DIR: gs://jt-tfrs-central-v3/customtest5-trainerv3/run-20230321-005425/model-dir/candidate_model


In [5]:
! gsutil ls $CANDIDATE_MODEL_DIR

gs://jt-tfrs-central-v3/customtest5-trainerv3/run-20230321-005425/model-dir/candidate_model/
gs://jt-tfrs-central-v3/customtest5-trainerv3/run-20230321-005425/model-dir/candidate_model/saved_model.pb
gs://jt-tfrs-central-v3/customtest5-trainerv3/run-20230321-005425/model-dir/candidate_model/assets/
gs://jt-tfrs-central-v3/customtest5-trainerv3/run-20230321-005425/model-dir/candidate_model/variables/


## Load Candidate `SavedModel`

In [6]:
candidate_tower_uri = f'{CANDIDATE_MODEL_DIR}' # vertex trained

loaded_candidate_model = tf.saved_model.load(candidate_tower_uri)

loaded_candidate_model.signatures

_SignatureMap({'serving_default': <ConcreteFunction signature_wrapper(*, track_mode_can, track_tempo_can, artist_genres_can, track_instrumentalness_can, track_valence_can, album_uri_can, track_name_can, album_name_can, track_acousticness_can, track_liveness_can, track_speechiness_can, track_pop_can, track_danceability_can, track_energy_can, artist_followers_can, track_loudness_can, duration_ms_can, track_key_can, track_uri_can, track_time_signature_can, artist_uri_can, artist_pop_can, artist_name_can) at 0x7F2E497FCFD0>})

In [7]:
print(list(loaded_candidate_model.signatures.keys()))

['serving_default']


In [8]:
candidate_predictor = loaded_candidate_model.signatures["serving_default"]
print(candidate_predictor.structured_outputs)

{'output_1': TensorSpec(shape=(None, 128), dtype=tf.float32, name='output_1')}


In [9]:
candidate_predictor.output_shapes

{'output_1': TensorShape([None, 128])}

## Candidate Dataset

### helper functions

In [10]:
storage_client = storage.Client(project=PROJECT_ID)

options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA

In [11]:
candidate_features = feature_sets.get_candidate_features()
candidate_features

{'track_uri_can': FixedLenFeature(shape=(), dtype=tf.string, default_value=None),
 'track_name_can': FixedLenFeature(shape=(), dtype=tf.string, default_value=None),
 'artist_uri_can': FixedLenFeature(shape=(), dtype=tf.string, default_value=None),
 'artist_name_can': FixedLenFeature(shape=(), dtype=tf.string, default_value=None),
 'album_uri_can': FixedLenFeature(shape=(), dtype=tf.string, default_value=None),
 'album_name_can': FixedLenFeature(shape=(), dtype=tf.string, default_value=None),
 'duration_ms_can': FixedLenFeature(shape=(), dtype=tf.float32, default_value=None),
 'track_pop_can': FixedLenFeature(shape=(), dtype=tf.float32, default_value=None),
 'artist_pop_can': FixedLenFeature(shape=(), dtype=tf.float32, default_value=None),
 'artist_genres_can': FixedLenFeature(shape=(), dtype=tf.string, default_value=None),
 'artist_followers_can': FixedLenFeature(shape=(), dtype=tf.float32, default_value=None),
 'track_danceability_can': FixedLenFeature(shape=(), dtype=tf.float32, defa

## Candidate Records

In [12]:
# CANDIDATE_FILE_DIR = 'spotify-data-regimes'
# CANDIDATE_PREFIX = 'jtv15-8m/candidates' # jtv10 | jtv14-8m

CANDIDATE_FILE_DIR = 'matching-engine-content'
DATA_VERSION = 'v1-0-0'
CANDIDATE_PREFIX = f'{DATA_VERSION}/candidates' 

In [13]:
candidate_files = []
for blob in storage_client.list_blobs(f"{CANDIDATE_FILE_DIR}", prefix=f'{CANDIDATE_PREFIX}/'):
    if '.tfrecords' in blob.name:
        candidate_files.append(blob.public_url.replace("https://storage.googleapis.com/", "gs://"))

candidate_dataset = tf.data.Dataset.from_tensor_slices(candidate_files)

parsed_candidate_dataset = candidate_dataset.interleave(
    # lambda x: tf.data.TFRecordDataset(x),
    train_utils.full_parse,
    cycle_length=tf.data.AUTOTUNE, 
    num_parallel_calls=tf.data.AUTOTUNE,
    deterministic=False
).map(
    feature_sets.parse_candidate_tfrecord_fn, 
    num_parallel_calls=tf.data.AUTOTUNE
).with_options(
    options
)

parsed_candidate_dataset = parsed_candidate_dataset.cache() #400 MB on machine mem

In [14]:
for features in parsed_candidate_dataset.take(1):
    pprint(features)
    print("_______________")

{'album_name_can': <tf.Tensor: shape=(), dtype=string, numpy=b'Festival Party Riddim'>,
 'album_uri_can': <tf.Tensor: shape=(), dtype=string, numpy=b'spotify:album:6HRMv5gpkJDvfBhpBr1OVK'>,
 'artist_followers_can': <tf.Tensor: shape=(), dtype=float32, numpy=0.0>,
 'artist_genres_can': <tf.Tensor: shape=(), dtype=string, numpy=b'NONE'>,
 'artist_name_can': <tf.Tensor: shape=(), dtype=string, numpy=b'The Winners Table Band'>,
 'artist_pop_can': <tf.Tensor: shape=(), dtype=float32, numpy=0.0>,
 'artist_uri_can': <tf.Tensor: shape=(), dtype=string, numpy=b'spotify:artist:2oy6bRhmrdW8M5IVCNpu1A'>,
 'duration_ms_can': <tf.Tensor: shape=(), dtype=float32, numpy=113554.0>,
 'track_acousticness_can': <tf.Tensor: shape=(), dtype=float32, numpy=0.0542>,
 'track_danceability_can': <tf.Tensor: shape=(), dtype=float32, numpy=0.455>,
 'track_energy_can': <tf.Tensor: shape=(), dtype=float32, numpy=0.965>,
 'track_instrumentalness_can': <tf.Tensor: shape=(), dtype=float32, numpy=0.959>,
 'track_key_can

In [15]:
parsed_candidate_dataset

<CacheDataset element_spec={'album_name_can': TensorSpec(shape=(), dtype=tf.string, name=None), 'album_uri_can': TensorSpec(shape=(), dtype=tf.string, name=None), 'artist_followers_can': TensorSpec(shape=(), dtype=tf.float32, name=None), 'artist_genres_can': TensorSpec(shape=(), dtype=tf.string, name=None), 'artist_name_can': TensorSpec(shape=(), dtype=tf.string, name=None), 'artist_pop_can': TensorSpec(shape=(), dtype=tf.float32, name=None), 'artist_uri_can': TensorSpec(shape=(), dtype=tf.string, name=None), 'duration_ms_can': TensorSpec(shape=(), dtype=tf.float32, name=None), 'track_acousticness_can': TensorSpec(shape=(), dtype=tf.float32, name=None), 'track_danceability_can': TensorSpec(shape=(), dtype=tf.float32, name=None), 'track_energy_can': TensorSpec(shape=(), dtype=tf.float32, name=None), 'track_instrumentalness_can': TensorSpec(shape=(), dtype=tf.float32, name=None), 'track_key_can': TensorSpec(shape=(), dtype=tf.string, name=None), 'track_liveness_can': TensorSpec(shape=(),

## Generate Embeddings

* use candidate_predictor to produce embeddings for each candidate item
* store embeddings in list
* zip candidate embeddings and candidate IDs together
* write json or csv file for ANN Index

In [20]:
# previously created embedding output
# !gsutil cp gs://jt-tfrs-central/pipe-dev-2tower-tfrs-jtv10/run-20221228-210041/candidates/candidate_embeddings.json candidate_embs_20221228_210041.json

Copying gs://jt-tfrs-central/pipe-dev-2tower-tfrs-jtv10/run-20221228-210041/candidates/candidate_embeddings.json...
- [1 files][882.4 MiB/882.4 MiB]   68.6 MiB/s                                   
Operation completed over 1 objects/882.4 MiB.                                    


### candidate embedding vectors

In [17]:
start_time = time.time()

embs_iter = parsed_candidate_dataset.batch(1).map(
    lambda data: candidate_predictor(
        track_uri_can = data["track_uri_can"],
        track_name_can = data['track_name_can'],
        artist_uri_can = data['artist_uri_can'],
        artist_name_can = data['artist_name_can'],
        album_uri_can = data['album_uri_can'],
        album_name_can = data['album_name_can'],
        duration_ms_can = data['duration_ms_can'],
        track_pop_can = data['track_pop_can'],
        artist_pop_can = data['artist_pop_can'],
        artist_genres_can = data['artist_genres_can'],
        artist_followers_can = data['artist_followers_can'],
        track_danceability_can = data['track_danceability_can'],
        track_energy_can = data['track_energy_can'],
        track_key_can = data['track_key_can'],
        track_loudness_can = data['track_loudness_can'],
        track_mode_can = data['track_mode_can'],
        track_speechiness_can = data['track_speechiness_can'],
        track_acousticness_can = data['track_acousticness_can'],
        track_instrumentalness_can = data['track_instrumentalness_can'],
        track_liveness_can = data['track_liveness_can'],
        track_valence_can = data['track_valence_can'],
        track_tempo_can = data['track_tempo_can'],
        track_time_signature_can = data['track_time_signature_can']
    )
)

embs = []
for emb in embs_iter:
    embs.append(emb)

end_time = time.time()
elapsed_time = int((end_time - start_time) / 60)
print(f"elapsed_time: {elapsed_time}")

print(f"Length of embs: {len(embs)}")
embs[0]

elapsed_time: 3
Length of embs: 22439


{'output_1': <tf.Tensor: shape=(100, 128), dtype=float32, numpy=
 array([[-3.9028122 ,  1.9181476 ,  1.9243963 , ...,  2.445345  ,
          0.06223404,  3.1558518 ],
        [-3.324296  ,  2.066658  ,  1.9385699 , ...,  2.8478448 ,
          0.03459817,  3.2164383 ],
        [-2.9631696 ,  1.9246789 ,  2.0482702 , ...,  2.7677617 ,
         -0.1826854 ,  3.6549532 ],
        ...,
        [-3.345099  ,  2.0698647 ,  2.0750623 , ...,  2.5321743 ,
          0.11192267,  2.81819   ],
        [-2.9702606 ,  2.065245  ,  1.9940509 , ...,  2.425759  ,
          0.01941515,  3.3566875 ],
        [-3.0347955 ,  2.3611856 ,  2.3625164 , ...,  2.4668849 ,
         -0.18338099,  2.8280168 ]], dtype=float32)>}

In [18]:
len(embs)

22439

Clean embedding output...

In [19]:
start_time = time.time()

cleaned_embs = [x['output_1'].numpy()[0] for x in embs] #clean up the output

end_time = time.time()
elapsed_time = int((end_time - start_time) / 60)
print(f"elapsed_time: {elapsed_time}")

elapsed_time: 0


In [20]:
print(f"Length of cleaned_embs: {len(cleaned_embs)}")
cleaned_embs[0]

Length of cleaned_embs: 22439


array([-3.9028122 ,  1.9181476 ,  1.9243963 ,  2.9426317 ,  1.0984983 ,
       -1.6019711 ,  1.8864298 , -1.8084981 , -1.3143858 , -0.90205085,
        1.2199564 , -0.2638863 ,  1.1716616 ,  1.0106726 , -2.5441952 ,
        0.09541601,  3.948473  , -1.9981352 ,  3.3929877 , -2.525893  ,
        2.6863084 , -0.41343668, -2.2722197 , -1.634059  , -1.1666859 ,
       -0.1565485 , -0.83394927, -2.0283368 ,  2.9164486 ,  2.6524365 ,
        2.6827588 , -2.8102674 , -1.2629974 , -1.4486234 ,  0.3337267 ,
       -0.48273367,  1.0239648 ,  1.824542  ,  0.67638993,  0.3069105 ,
       -0.07360408, -3.9400222 , -2.5889518 , -2.4068131 , -2.4706016 ,
        3.552813  ,  1.1203581 ,  1.776078  ,  2.845992  , -2.5003004 ,
       -3.3703184 ,  1.3982787 ,  1.7066772 ,  3.3853254 , -1.1815181 ,
       -3.0160904 ,  3.0032191 ,  2.1965175 ,  0.38071454, -1.4573872 ,
       -2.5126364 , -2.5437741 , -2.1613216 , -2.3365564 , -0.29526812,
       -2.5087624 , -0.6401027 , -3.3160193 , -1.2088568 ,  0.82

### candidate IDs

In [None]:
# clean product IDs
track_uris = [x['track_uri_can'].numpy() for x in parsed_candidate_dataset]

print(f"Length of track_uris: {len(track_uris)}")

track_uris[0]

In [23]:
# track_uris_cleaned = [str(z).replace("b'","").replace("'","") for z in track_uris]
track_uris_decoded = [z.decode("utf-8") for z in track_uris]

print(f"Length of track_uris_decoded: {len(track_uris_decoded)}")

track_uris_decoded[0]

Length of track_uris_decoded: 2262292


'spotify:track:6Nx4UYbpHuU4x5mozUDaQQ'

In [24]:
print(f"Length of track_uris: {len(track_uris)}")
print(f"Length of track_uris_cleaned: {len(track_uris_decoded)}")

Length of track_uris: 2262292
Length of track_uris_cleaned: 2262292


### Check for bad records

In [25]:
cleaned_embs[0]

array([-0.4443766 , -2.9618957 , -1.6283664 ,  3.8264632 , -1.1912026 ,
       -2.0853355 ,  1.4044247 , -3.3095267 , -2.1489854 , -1.8763275 ,
        2.509281  , -2.967039  , -2.9918072 ,  4.151778  ,  2.2562225 ,
       -2.998717  ,  2.729978  , -0.5180515 ,  0.20021534, -5.291335  ,
        2.078448  , -0.43222326,  2.6752548 ,  1.6742964 ,  3.208346  ,
       -2.3055122 ,  1.1319474 , -1.9209781 ,  1.4808187 ,  2.0377028 ,
        2.4703784 , -3.1968822 ,  3.2770886 , -0.78494287, -1.2607541 ,
       -4.416418  ,  2.2357628 ,  2.8092124 , -2.0393417 ,  1.6224779 ,
       -1.8457481 ,  0.3112812 , -1.6624097 ,  1.9860845 , -3.3870966 ,
        2.5134103 , -5.2180195 , -2.193615  ,  2.3267477 , -2.2696674 ,
       -2.087286  ,  3.44712   ,  2.760911  ,  0.84106797,  0.9023529 ,
       -0.24549764,  0.30034775, -0.00608075, -3.5383887 , -0.17121124,
       -2.4288304 , -1.5979798 , -0.31901774, -4.903339  , -1.9036381 ,
        0.30285656, -4.713433  , -2.3357277 ,  1.9909285 ,  3.19

In [26]:
bad_records = []

for i, emb in enumerate(cleaned_embs):
    bool_emb = np.isnan(emb)
    for val in bool_emb:
        if val:
            bad_records.append(i)
            
bad_record_filter = np.unique(bad_records)

print(f"bad_records: {len(bad_records)}")
print(f"bad_record_filter: {len(bad_record_filter)}")

bad_records: 0
bad_record_filter: 0


In [27]:
# bad_record_filter[0]

In [28]:
track_uris_valid = []
emb_valid = []

for i, pair in enumerate(zip(track_uris_decoded, cleaned_embs)):
    if i in bad_record_filter:
        pass
    else:
        t_uri, embed = pair
        track_uris_valid.append(t_uri)
        emb_valid.append(embed)

In [29]:
emb_valid[0]

array([-0.4443766 , -2.9618957 , -1.6283664 ,  3.8264632 , -1.1912026 ,
       -2.0853355 ,  1.4044247 , -3.3095267 , -2.1489854 , -1.8763275 ,
        2.509281  , -2.967039  , -2.9918072 ,  4.151778  ,  2.2562225 ,
       -2.998717  ,  2.729978  , -0.5180515 ,  0.20021534, -5.291335  ,
        2.078448  , -0.43222326,  2.6752548 ,  1.6742964 ,  3.208346  ,
       -2.3055122 ,  1.1319474 , -1.9209781 ,  1.4808187 ,  2.0377028 ,
        2.4703784 , -3.1968822 ,  3.2770886 , -0.78494287, -1.2607541 ,
       -4.416418  ,  2.2357628 ,  2.8092124 , -2.0393417 ,  1.6224779 ,
       -1.8457481 ,  0.3112812 , -1.6624097 ,  1.9860845 , -3.3870966 ,
        2.5134103 , -5.2180195 , -2.193615  ,  2.3267477 , -2.2696674 ,
       -2.087286  ,  3.44712   ,  2.760911  ,  0.84106797,  0.9023529 ,
       -0.24549764,  0.30034775, -0.00608075, -3.5383887 , -0.17121124,
       -2.4288304 , -1.5979798 , -0.31901774, -4.903339  , -1.9036381 ,
        0.30285656, -4.713433  , -2.3357277 ,  1.9909285 ,  3.19

In [30]:
len(emb_valid)

2262292

In [31]:
track_uris_valid[0]

'spotify:track:6Nx4UYbpHuU4x5mozUDaQQ'

In [32]:
len(track_uris_valid)

2262292

### tmp - dealing with bad track uris

## Write embedding vectors to json file

In [33]:
VERSION = 'local'
TIMESTAMP = time.strftime("%Y%m%d-%H%M%S")

embeddings_index_filename = f'candidate_embs_{VERSION}_{TIMESTAMP}.json'

with open(f'{embeddings_index_filename}', 'w') as f:
    for prod, emb in zip(track_uris_valid, emb_valid):
        f.write('{"id":"' + str(prod) + '",')
        f.write('"embedding":[' + ",".join(str(x) for x in list(emb)) + "]}")
        f.write("\n")

## Upload json to GCS

In [34]:
# jt-tfrs-central/pipe-dev-2tower-tfrs-jtv10/run-20221228-210041

# BUCKET = 'jt-tfrs-central'
# PATH_TO_INDEX_DIR = 'a50-epoch/run-20221230-160518'
INDEX_GCS_URI = f'gs://{BUCKET}/{PATH_TO_INDEX_DIR}/candidates-index-{VERSION}'

print(f"INDEX_GCS_URI: {INDEX_GCS_URI}")

DESTINATION_BLOB_NAME = embeddings_index_filename
SOURCE_FILE_NAME = embeddings_index_filename

print(f"DESTINATION_BLOB_NAME: {DESTINATION_BLOB_NAME}")
print(f"SOURCE_FILE_NAME: {SOURCE_FILE_NAME}")

INDEX_GCS_URI: gs://jt-tfrs-central-v2/8m-tfrs-v100-jtv15/run-20230125-205451/candidates-index-local
DESTINATION_BLOB_NAME: candidate_embs_local_20230130-180710.json
SOURCE_FILE_NAME: candidate_embs_local_20230130-180710.json


In [35]:
blob = Blob.from_string(os.path.join(INDEX_GCS_URI, DESTINATION_BLOB_NAME))
blob.bucket._client = storage_client
blob.upload_from_filename(SOURCE_FILE_NAME)

# Inspect track_uris

* id in the track_uri should be 22 characters (total of 36 characters including `spotify:track:`)
* some track_uris have an id that is 21 characters long
* these are not present in the source data (BigQuery)

In [36]:
len(track_uris_valid)

print(f"count of track_uris_valid: {len(track_uris_valid)}\n")
print(f"ex: track_uris_valid[0]: {track_uris_valid[0]}\n")
print(f"length of a track_uris_valid: {len(track_uris_valid[0])}\n")

count of track_uris_valid: 2262292

ex: track_uris_valid[0]: spotify:track:6Nx4UYbpHuU4x5mozUDaQQ

length of a track_uris_valid: 36



In [39]:
short = []
normal = []
long = []

for track_id in track_uris_valid:
    if len(track_id)==36:
        normal.append(track_id)
    if len(track_id)<36:
        short.append(track_id)
    if len(track_id)>36:
        long.append(track_id)
        
print(f"short: {len(short)}")
print(f"normal: {len(normal)}")
print(f"long: {len(long)}")

short: 0
normal: 2262292
long: 0
