# Candidate Generation

After Two-Tower training, the `candidate_tower` is used to convert all candidate items into embeddings.

The embeddings are indexed and deployed to an index endpoint for serving.

Steps performed in this notebook:

* Load trained candidate tower
* Generate candidate track embeddings
* Create brute-force (BF) and ANN indexes
* Deploy indexes to index endpoints
* Test model deployment endpoint (e.g., prediction requests)
* Test index deployment endpoint (e.g., recall accuracy between BF and ANN indices)

> * Note BF will always be 100% recall but at cost of speed and computational complexity

## Load env config

In [1]:
# naming convention for all cloud resources
VERSION        = "v1"                  # TODO
PREFIX         = f'ndr-{VERSION}'      # TODO

print(f"PREFIX = {PREFIX}")

PREFIX = ndr-v1


In [2]:
# staging GCS
GCP_PROJECTS             = !gcloud config get-value project
PROJECT_ID               = GCP_PROJECTS[0]

# GCS bucket and paths
BUCKET_NAME              = f'{PREFIX}-{PROJECT_ID}-bucket'
BUCKET_URI               = f'gs://{BUCKET_NAME}'

config = !gsutil cat {BUCKET_URI}/config/notebook_env.py
print(config.n)
exec(config.n)


PROJECT_ID               = "hybrid-vertex"
PROJECT_NUM              = "934903580331"
LOCATION                 = "us-central1"

REGION                   = "us-central1"
BQ_LOCATION              = "US"
VPC_NETWORK_NAME         = "ucaip-haystack-vpc-network"

VERTEX_SA                = "934903580331-compute@developer.gserviceaccount.com"

PREFIX                   = "ndr-v1"
VERSION                  = "v1"

APP                      = "sp"
MODEL_TYPE               = "2tower"
FRAMEWORK                = "tfrs"
DATA_VERSION             = "v1"
TRACK_HISTORY            = "5"

BUCKET_NAME              = "ndr-v1-hybrid-vertex-bucket"
BUCKET_URI               = "gs://ndr-v1-hybrid-vertex-bucket"
SOURCE_BUCKET            = "spotify-million-playlist-dataset"

DATA_GCS_PREFIX          = "data"
DATA_PATH                = "gs://ndr-v1-hybrid-vertex-bucket/data"
VOCAB_SUBDIR             = "vocabs"
VOCAB_FILENAME           = "vocab_dict.pkl"

CANDIDATE_PREFIX         = "candidates"
TRAIN_DIR_PREFIX      

### imports

In [3]:
import json
import numpy as np
import pickle as pkl
from pprint import pprint
import time

import os

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

# GPU
import gc
from numba import cuda

import tensorflow as tf
import tensorflow_recommenders as tfrs
import tensorflow_io as tfio

from google.cloud import storage
from google.cloud.storage.bucket import Bucket
from google.cloud.storage.blob import Blob

import google.cloud.aiplatform as vertex_ai

from src.two_tower_jt import two_tower as tt
from src.two_tower_jt import train_utils
from src.two_tower_jt import feature_sets

import warnings
warnings.filterwarnings('ignore')

In [4]:
# gpus = tf.config.experimental.list_physical_devices('GPU')
# for gpu in gpus:
#     tf.config.experimental.set_memory_growth(gpu, True)
    
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

Num GPUs Available:  2


In [6]:
device = cuda.get_current_device()
device.reset()
gc.collect()

37

## Trained Candidate model path

### Edit these

In [11]:
# TODO - grab from saved candiate_embedding.json URI
# local-train-v1/run-20230919-135222/candidates/candidate_embeddings.json

EXPERIMENT_NAME   = "tfrs-pipe-v1"        # TODO
RUN_NAME          = "run-20230919-173845" # TODO

PATH_TO_ARTIFACT_DIR = f'{EXPERIMENT_NAME}/{RUN_NAME}'
print(f"PATH_TO_ARTIFACT_DIR: {PATH_TO_ARTIFACT_DIR}")

PATH_TO_ARTIFACT_DIR: tfrs-pipe-v1/run-20230919-173845


In [12]:
! gsutil ls $BUCKET_URI/$EXPERIMENT_NAME/$RUN_NAME

gs://ndr-v1-hybrid-vertex-bucket/tfrs-pipe-v1/run-20230919-173845/test_instances_5.pkl
gs://ndr-v1-hybrid-vertex-bucket/tfrs-pipe-v1/run-20230919-173845/train_job_dict.pkl
gs://ndr-v1-hybrid-vertex-bucket/tfrs-pipe-v1/run-20230919-173845/vocab_dict.pkl
gs://ndr-v1-hybrid-vertex-bucket/tfrs-pipe-v1/run-20230919-173845/candidate-embeddings-v1/
gs://ndr-v1-hybrid-vertex-bucket/tfrs-pipe-v1/run-20230919-173845/checkpoints/
gs://ndr-v1-hybrid-vertex-bucket/tfrs-pipe-v1/run-20230919-173845/features/
gs://ndr-v1-hybrid-vertex-bucket/tfrs-pipe-v1/run-20230919-173845/logs/
gs://ndr-v1-hybrid-vertex-bucket/tfrs-pipe-v1/run-20230919-173845/model-dir/
gs://ndr-v1-hybrid-vertex-bucket/tfrs-pipe-v1/run-20230919-173845/pipeline_root/


In [13]:
# full gcs path
CANDIDATE_MODEL_GCS_PATH = f'{PATH_TO_ARTIFACT_DIR}/model-dir/candidate_model'
# CANDIDATE_MODEL_GCS_PATH = f'{PATH_TO_ARTIFACT_DIR}/candidate_model' # tmp - TODO

CANDIDATE_MODEL_DIR = f'{BUCKET_URI}/{CANDIDATE_MODEL_GCS_PATH}'
print(f"CANDIDATE_MODEL_DIR: {CANDIDATE_MODEL_DIR}")

CANDIDATE_MODEL_DIR: gs://ndr-v1-hybrid-vertex-bucket/tfrs-pipe-v1/run-20230919-173845/model-dir/candidate_model


In [14]:
! gsutil ls $CANDIDATE_MODEL_DIR

gs://ndr-v1-hybrid-vertex-bucket/tfrs-pipe-v1/run-20230919-173845/model-dir/candidate_model/
gs://ndr-v1-hybrid-vertex-bucket/tfrs-pipe-v1/run-20230919-173845/model-dir/candidate_model/fingerprint.pb
gs://ndr-v1-hybrid-vertex-bucket/tfrs-pipe-v1/run-20230919-173845/model-dir/candidate_model/saved_model.pb
gs://ndr-v1-hybrid-vertex-bucket/tfrs-pipe-v1/run-20230919-173845/model-dir/candidate_model/assets/
gs://ndr-v1-hybrid-vertex-bucket/tfrs-pipe-v1/run-20230919-173845/model-dir/candidate_model/variables/


### Load Candidate `SavedModel`

In [15]:
candidate_tower_uri = f'{CANDIDATE_MODEL_DIR}' # vertex trained
loaded_candidate_model = tf.saved_model.load(candidate_tower_uri)

loaded_candidate_model.signatures

_SignatureMap({'serving_default': <ConcreteFunction signature_wrapper(*, track_speechiness_can, track_mode_can, track_energy_can, track_key_can, track_valence_can, album_name_can, album_uri_can, duration_ms_can, artist_name_can, artist_uri_can, track_liveness_can, track_uri_can, track_acousticness_can, artist_pop_can, track_instrumentalness_can, artist_genres_can, track_time_signature_can, track_tempo_can, track_pop_can, track_loudness_can, artist_followers_can, track_danceability_can, track_name_can) at 0x7F158C14CFD0>})

In [16]:
print(list(loaded_candidate_model.signatures.keys()))

['serving_default']


In [17]:
candidate_predictor = loaded_candidate_model.signatures["serving_default"]
print(candidate_predictor.structured_outputs)

{'output_1': TensorSpec(shape=(None, 128), dtype=tf.float32, name='output_1')}


In [18]:
candidate_predictor.output_shapes

{'output_1': TensorShape([None, 128])}

## Candidate Dataset

### helper functions

In [19]:
storage_client = storage.Client(project=PROJECT_ID)

options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA

In [20]:
candidate_features = feature_sets.get_candidate_features()
candidate_features

{'track_uri_can': FixedLenFeature(shape=(), dtype=tf.string, default_value=None),
 'track_name_can': FixedLenFeature(shape=(), dtype=tf.string, default_value=None),
 'artist_uri_can': FixedLenFeature(shape=(), dtype=tf.string, default_value=None),
 'artist_name_can': FixedLenFeature(shape=(), dtype=tf.string, default_value=None),
 'album_uri_can': FixedLenFeature(shape=(), dtype=tf.string, default_value=None),
 'album_name_can': FixedLenFeature(shape=(), dtype=tf.string, default_value=None),
 'duration_ms_can': FixedLenFeature(shape=(), dtype=tf.float32, default_value=None),
 'track_pop_can': FixedLenFeature(shape=(), dtype=tf.float32, default_value=None),
 'artist_pop_can': FixedLenFeature(shape=(), dtype=tf.float32, default_value=None),
 'artist_genres_can': FixedLenFeature(shape=(), dtype=tf.string, default_value=None),
 'artist_followers_can': FixedLenFeature(shape=(), dtype=tf.float32, default_value=None),
 'track_danceability_can': FixedLenFeature(shape=(), dtype=tf.float32, defa

### Candidate Records

In [21]:
# CANDIDATE_PREFIX = f'candidates' 

In [22]:
candidate_files = []

for blob in storage_client.list_blobs(f"{BUCKET_NAME}", prefix=f'data/{DATA_VERSION}/{CANDIDATE_PREFIX}'):
    candidate_files.append(blob.public_url.replace("https://storage.googleapis.com/", "gs://"))

candidate_dataset = tf.data.Dataset.from_tensor_slices(candidate_files)

parsed_candidate_dataset = candidate_dataset.interleave(
    # lambda x: tf.data.TFRecordDataset(x),
    train_utils.full_parse,
    cycle_length=tf.data.AUTOTUNE, 
    num_parallel_calls=tf.data.AUTOTUNE,
    deterministic=False
).map(
    feature_sets.parse_candidate_tfrecord_fn, 
    num_parallel_calls=tf.data.AUTOTUNE
).with_options(
    options
)

# parsed_candidate_dataset = parsed_candidate_dataset.cache() #400 MB on machine mem

In [23]:
for features in parsed_candidate_dataset.take(1):
    pprint(features)
    print("_______________")

{'album_name_can': <tf.Tensor: shape=(), dtype=string, numpy=b'Thanatophobia'>,
 'album_uri_can': <tf.Tensor: shape=(), dtype=string, numpy=b'spotify:album:5GBUYg5EqeDI0CuszAvDzj'>,
 'artist_followers_can': <tf.Tensor: shape=(), dtype=float32, numpy=27438.0>,
 'artist_genres_can': <tf.Tensor: shape=(), dtype=string, numpy=b"'indie garage rock'">,
 'artist_name_can': <tf.Tensor: shape=(), dtype=string, numpy=b'Worn-Tin'>,
 'artist_pop_can': <tf.Tensor: shape=(), dtype=float32, numpy=40.0>,
 'artist_uri_can': <tf.Tensor: shape=(), dtype=string, numpy=b'spotify:artist:7j8ds7BnqaEKuz1a1GN0J9'>,
 'duration_ms_can': <tf.Tensor: shape=(), dtype=float32, numpy=216923.0>,
 'track_acousticness_can': <tf.Tensor: shape=(), dtype=float32, numpy=0.655>,
 'track_danceability_can': <tf.Tensor: shape=(), dtype=float32, numpy=0.321>,
 'track_energy_can': <tf.Tensor: shape=(), dtype=float32, numpy=0.555>,
 'track_instrumentalness_can': <tf.Tensor: shape=(), dtype=float32, numpy=0.752>,
 'track_key_can': 

In [24]:
parsed_candidate_dataset

<_OptionsDataset element_spec={'album_name_can': TensorSpec(shape=(), dtype=tf.string, name=None), 'album_uri_can': TensorSpec(shape=(), dtype=tf.string, name=None), 'artist_followers_can': TensorSpec(shape=(), dtype=tf.float32, name=None), 'artist_genres_can': TensorSpec(shape=(), dtype=tf.string, name=None), 'artist_name_can': TensorSpec(shape=(), dtype=tf.string, name=None), 'artist_pop_can': TensorSpec(shape=(), dtype=tf.float32, name=None), 'artist_uri_can': TensorSpec(shape=(), dtype=tf.string, name=None), 'duration_ms_can': TensorSpec(shape=(), dtype=tf.float32, name=None), 'track_acousticness_can': TensorSpec(shape=(), dtype=tf.float32, name=None), 'track_danceability_can': TensorSpec(shape=(), dtype=tf.float32, name=None), 'track_energy_can': TensorSpec(shape=(), dtype=tf.float32, name=None), 'track_instrumentalness_can': TensorSpec(shape=(), dtype=tf.float32, name=None), 'track_key_can': TensorSpec(shape=(), dtype=tf.string, name=None), 'track_liveness_can': TensorSpec(shape=

# Generate Candidate Track Embeddings

* use candidate_predictor to produce embeddings for each candidate item
* store embeddings in list
* zip candidate embeddings and candidate IDs together
* write json or csv file for ANN Index

In [25]:
# previously created embedding output
# !gsutil cp gs://jt-tfrs-central/pipe-dev-2tower-tfrs-jtv10/run-20221228-210041/candidates/candidate_embeddings.json candidate_embs_20221228_210041.json

### candidate embedding vectors

In [26]:
start_time = time.time()

embs_iter = parsed_candidate_dataset.batch(10000).map(
    lambda data: (
        data["track_uri_can"],
        loaded_candidate_model(data)
    )
)

embs = []
for emb in embs_iter:
    embs.append(emb)

end_time = time.time()
elapsed_time = int((end_time - start_time) / 60)
print(f"elapsed_time: {elapsed_time}")

print(f"Length of embs: {len(embs)}")

Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
elapsed_time: 4
Length of embs: 225


### clean embedding output...

In [27]:
x,y = embs[0]

y

<tf.Tensor: shape=(10000, 128), dtype=float32, numpy=
array([[-0.2958832 ,  1.3680464 , -1.0969739 , ..., -0.4091584 ,
        -0.29286286,  0.05274126],
       [-0.1878705 ,  1.4427985 , -0.43994805, ...,  0.8956098 ,
        -0.64344513,  0.2103799 ],
       [-0.7545898 , -1.5391797 ,  1.3635055 , ...,  1.0590885 ,
        -0.15870447,  0.56919026],
       ...,
       [ 1.1976322 ,  0.9469893 ,  1.6760215 , ...,  1.0664071 ,
        -0.50861025, -1.01303   ],
       [-1.3623376 ,  0.17555963, -0.04761489, ...,  0.6235434 ,
         0.07053979,  0.2961999 ],
       [-1.0530623 ,  0.98027635,  1.1150482 , ...,  0.9307209 ,
         0.07073139,  0.06994925]], dtype=float32)>

In [28]:
start_time = time.time()

cleaned_embs = [] #clean up the output
track_uris = []
for ids , embedding in embs:
    cleaned_embs.extend(embedding.numpy())
    track_uris.extend(ids.numpy())

end_time = time.time()
elapsed_time = int((end_time - start_time) / 60)
print(f"elapsed_time: {elapsed_time}")

elapsed_time: 0


In [29]:
print(f"Length of cleaned_embs: {len(cleaned_embs)}")
cleaned_embs[0]

Length of cleaned_embs: 2243885


array([-0.2958832 ,  1.3680464 , -1.0969739 , -1.1368335 , -0.4342838 ,
        0.39283156,  0.27632928,  0.03127032, -0.58364165, -0.7053537 ,
        0.02434574,  0.44454443,  0.36360395,  0.5888878 ,  0.46440738,
        0.6275927 ,  1.6786953 ,  0.02449422, -0.6262758 ,  0.9112112 ,
       -0.9261704 , -1.2875332 ,  1.3164775 ,  0.98420537, -0.7581321 ,
       -1.1962838 ,  0.554428  , -0.24542667, -0.01585706,  2.6407895 ,
        0.92781043, -0.7784559 ,  2.3831542 , -0.09137036, -0.71217793,
        1.7602313 , -1.7687172 , -0.99954164, -1.0950288 ,  0.33157885,
        0.3068544 , -1.4212629 , -0.5321212 , -0.6750946 ,  0.43817887,
        0.4380906 ,  1.5712899 ,  0.6998104 ,  0.22836557,  0.07564595,
        0.77435833,  0.352889  ,  0.31121057, -0.6935091 , -0.7743054 ,
       -0.42792314, -0.11159866,  1.5658025 , -0.59310645, -0.44968712,
       -0.6262101 , -1.0050645 ,  0.13244593, -0.2033531 , -0.21466011,
       -0.32916766, -0.8308483 , -1.0437603 , -0.80158645, -0.34

#### candidate IDs

In [30]:
# clean product IDs
# track_uris = [ids.numpy() for ids , _ in embs]

print(f"Length of track_uris: {len(track_uris)}")

Length of track_uris: 2243885


In [31]:
# track_uris_cleaned = [str(z).replace("b'","").replace("'","") for z in track_uris]
track_uris_decoded = [z.decode("utf-8") for z in track_uris]

print(f"Length of track_uris_decoded: {len(track_uris_decoded)}")
track_uris_decoded[0]

Length of track_uris_decoded: 2243885


'spotify:track:2XZ3bL3ROk605SPpy6Dn9C'

In [32]:
print(f"Length of track_uris: {len(track_uris)}")
print(f"Length of track_uris_cleaned: {len(track_uris_decoded)}")

Length of track_uris: 2243885
Length of track_uris_cleaned: 2243885


#### Check for bad records

In [33]:
cleaned_embs[0]

array([-0.2958832 ,  1.3680464 , -1.0969739 , -1.1368335 , -0.4342838 ,
        0.39283156,  0.27632928,  0.03127032, -0.58364165, -0.7053537 ,
        0.02434574,  0.44454443,  0.36360395,  0.5888878 ,  0.46440738,
        0.6275927 ,  1.6786953 ,  0.02449422, -0.6262758 ,  0.9112112 ,
       -0.9261704 , -1.2875332 ,  1.3164775 ,  0.98420537, -0.7581321 ,
       -1.1962838 ,  0.554428  , -0.24542667, -0.01585706,  2.6407895 ,
        0.92781043, -0.7784559 ,  2.3831542 , -0.09137036, -0.71217793,
        1.7602313 , -1.7687172 , -0.99954164, -1.0950288 ,  0.33157885,
        0.3068544 , -1.4212629 , -0.5321212 , -0.6750946 ,  0.43817887,
        0.4380906 ,  1.5712899 ,  0.6998104 ,  0.22836557,  0.07564595,
        0.77435833,  0.352889  ,  0.31121057, -0.6935091 , -0.7743054 ,
       -0.42792314, -0.11159866,  1.5658025 , -0.59310645, -0.44968712,
       -0.6262101 , -1.0050645 ,  0.13244593, -0.2033531 , -0.21466011,
       -0.32916766, -0.8308483 , -1.0437603 , -0.80158645, -0.34

In [35]:
start_time = time.time()

bad_records = []

for i, emb in enumerate(cleaned_embs):
    bool_emb = np.isnan(emb)
    for val in bool_emb:
        if val:
            bad_records.append(i)

end_time = time.time()
elapsed_time = int((end_time - start_time) / 60)
print(f"elapsed_time: {elapsed_time}")

bad_record_filter = np.unique(bad_records)

print(f"bad_records: {len(bad_records)}")
print(f"bad_record_filter: {len(bad_record_filter)}")

elapsed_time: 0
bad_records: 0
bad_record_filter: 0


In [89]:
# bad_record_filter[0]

In [36]:
start_time = time.time()

track_uris_valid = []
emb_valid = []

for i, pair in enumerate(zip(track_uris_decoded, cleaned_embs)):
    if i in bad_record_filter:
        pass
    else:
        t_uri, embed = pair
        track_uris_valid.append(t_uri)
        emb_valid.append(embed)
        
end_time = time.time()
elapsed_time = int((end_time - start_time) / 60)
print(f"elapsed_time: {elapsed_time}")

elapsed_time: 0


In [37]:
emb_valid[0]

array([-0.2958832 ,  1.3680464 , -1.0969739 , -1.1368335 , -0.4342838 ,
        0.39283156,  0.27632928,  0.03127032, -0.58364165, -0.7053537 ,
        0.02434574,  0.44454443,  0.36360395,  0.5888878 ,  0.46440738,
        0.6275927 ,  1.6786953 ,  0.02449422, -0.6262758 ,  0.9112112 ,
       -0.9261704 , -1.2875332 ,  1.3164775 ,  0.98420537, -0.7581321 ,
       -1.1962838 ,  0.554428  , -0.24542667, -0.01585706,  2.6407895 ,
        0.92781043, -0.7784559 ,  2.3831542 , -0.09137036, -0.71217793,
        1.7602313 , -1.7687172 , -0.99954164, -1.0950288 ,  0.33157885,
        0.3068544 , -1.4212629 , -0.5321212 , -0.6750946 ,  0.43817887,
        0.4380906 ,  1.5712899 ,  0.6998104 ,  0.22836557,  0.07564595,
        0.77435833,  0.352889  ,  0.31121057, -0.6935091 , -0.7743054 ,
       -0.42792314, -0.11159866,  1.5658025 , -0.59310645, -0.44968712,
       -0.6262101 , -1.0050645 ,  0.13244593, -0.2033531 , -0.21466011,
       -0.32916766, -0.8308483 , -1.0437603 , -0.80158645, -0.34

In [38]:
len(emb_valid)

2243885

In [39]:
track_uris_valid[0]

'spotify:track:2XZ3bL3ROk605SPpy6Dn9C'

In [40]:
len(track_uris_valid)

2243885

## Write embedding vectors to json file

In [41]:
VERSION = 'local'
# TIMESTAMP = time.strftime("%Y%m%d-%H%M%S")

embeddings_index_filename = f'candidate_embs_{VERSION}.json'

with open(f'{embeddings_index_filename}', 'w') as f:
    for prod, emb in zip(track_uris_valid, emb_valid):
        f.write('{"id":"' + str(prod) + '",')
        f.write('"embedding":[' + ",".join(str(x) for x in list(emb)) + "]}")
        f.write("\n")

### Upload json to GCS

In [46]:
INDEX_GCS_URI = f'gs://{BUCKET_NAME}/{PATH_TO_ARTIFACT_DIR}/candidates-{VERSION}'

DESTINATION_BLOB_NAME = embeddings_index_filename
SOURCE_FILE_NAME = embeddings_index_filename

print(f"INDEX_GCS_URI         : {INDEX_GCS_URI}")
print(f"DESTINATION_BLOB_NAME : {DESTINATION_BLOB_NAME}")
print(f"SOURCE_FILE_NAME      : {SOURCE_FILE_NAME}")

INDEX_GCS_URI         : gs://ndr-v1-hybrid-vertex-bucket/tfrs-pipe-v1/run-20230919-173845/candidates-local
DESTINATION_BLOB_NAME : candidate_embs_local.json
SOURCE_FILE_NAME      : candidate_embs_local.json


In [47]:
blob = Blob.from_string(os.path.join(INDEX_GCS_URI, DESTINATION_BLOB_NAME))
blob.bucket._client = storage_client
blob.upload_from_filename(SOURCE_FILE_NAME)

# Matching Engine Index: 
*initialize existing or create a new one*

> Deploy candidate index to Matching Engine Index Endpoint

* **TODO**

#### Edit these

In [48]:
CREATE_NEW_ASSETS          = False # True | False

# ANN index config
APPROX_NEIGHBORS           = 50
DISTANCE_MEASURE           = "DOT_PRODUCT_DISTANCE"
LEAF_NODE_EMB_COUNT        = 500
LEAF_NODES_SEARCH_PERCENT  = 7
DIMENSIONS                 = 128 # must match output dimensions

# matching engine (vector search)
ANN_INDEX_DISPLAY_NAME    = f"tfrs_{DIMENSIONS}dim_{VERSION}"
ANN_ENDPOINT_DISPLAY_NAME = f'{ANN_INDEX_DISPLAY_NAME}_endpoint'

BF_DISPLAY_NAME           = f"{ANN_INDEX_DISPLAY_NAME}_bf"
BF_ENDPOINT_DISPLAY_NAME  = f'{BF_DISPLAY_NAME}_endpoint'

# labels
DATA_REGIME               = 'full-65m'

print(f"ANN_INDEX_DISPLAY_NAME   : {ANN_INDEX_DISPLAY_NAME}")
print(f"BF_DISPLAY_NAME          : {BF_DISPLAY_NAME}")

ANN_INDEX_DISPLAY_NAME   : tfrs_128dim_local
BF_DISPLAY_NAME          : tfrs_128dim_local_bf


## Create a Matching Engine Index

The matching engine loads an index from a file of embeddings created earlier in this notebook 

Many of the optimization options for matching engine are found in the ah tree settings and testing is recommended depending on each use case

Recall we saved our two tower models and query embeddings (newline json) in a candidate folder

### Create ANN index

In [49]:
if CREATE_NEW_ASSETS == True:
    
    tree_ah_index = vertex_ai.MatchingEngineIndex.create_tree_ah_index(
        display_name=ANN_INDEX_DISPLAY_NAME,
        contents_delta_uri=EMBEDDINGS_INITIAL_URI,
        dimensions=DIMENSIONS,
        approximate_neighbors_count=APPROX_NEIGHBORS,
        distance_measure_type=DISTANCE_MEASURE,
        leaf_node_embedding_count=LEAF_NODE_EMB_COUNT,
        leaf_nodes_to_search_percent=LEAF_NODES_SEARCH_PERCENT,
        description="Songs embeddings from MPD",
        sync=False,
        labels={
            "experiment_name": f'{EXPERIMENT_NAME}',
            "experiment_run": f'{RUN_NAME}',
            "data_regime": f'{DATA_REGIME}',
        },
    )
else:
    
    EXISTING_INDEX_NAME = f'projects/{PROJECT_NUM}/locations/{REGION}/indexes/1713892337098162176'
    tree_ah_index = vertex_ai.MatchingEngineIndex(EXISTING_INDEX_NAME)
    

print(f"display_name  : {tree_ah_index.display_name}\n")
print(f"tree_ah_index : {tree_ah_index}")

display_name  : ann_index_v1-v1

tree_ah_index : <google.cloud.aiplatform.matching_engine.matching_engine_index.MatchingEngineIndex object at 0x7f14d448c110> 
resource name: projects/934903580331/locations/us-central1/indexes/1713892337098162176


### Create Brute Force index

used to evaluate ANN retrieval

In [50]:
if CREATE_NEW_ASSETS == True:
    
    brute_force_index = vertex_ai.MatchingEngineIndex.create_brute_force_index(
        display_name=BF_DISPLAY_NAME,
        contents_delta_uri=EMBEDDINGS_INITIAL_URI,
        dimensions=DIMENSIONS,
        distance_measure_type=DISTANCE_MEASURE,
        sync=False,
        labels={
            "experiment_name": f'{EXPERIMENT_NAME}',
            "experiment_run": f'{RUN_NAME}',
            "data_regime": f'{DATA_REGIME}',
        },
    )
else: #6250846749208870912 | 5708585206575792128
    EXISTING_INDEX_NAME = f'projects/{PROJECT_NUM}/locations/{REGION}/indexes/1095773288241561600'
    brute_force_index = vertex_ai.MatchingEngineIndex(EXISTING_INDEX_NAME)
    
print(f"display_name      : {brute_force_index.display_name}\n")
print(f"brute_force_index : {brute_force_index}")

display_name      : bf_index_v1_v1

brute_force_index : <google.cloud.aiplatform.matching_engine.matching_engine_index.MatchingEngineIndex object at 0x7f1500710490> 
resource name: projects/934903580331/locations/us-central1/indexes/1095773288241561600


#### list all indexes

In [60]:
# list_of_indices = tree_ah_index.list()
# list_of_indices[:5]

## Create Matching Engine endpoint(s)

* both the ANN and brute force indices can be deployed to a single endpoint
* alternatively, we can create seperate endpoints, one for each index

**index endpoint config:** 

In [53]:
print(f"VPC_NETWORK_FULL          : {VPC_NETWORK_FULL}")
print(f"ANN_ENDPOINT_DISPLAY_NAME : {ANN_ENDPOINT_DISPLAY_NAME}")
print(f"BF_ENDPOINT_DISPLAY_NAME  : {BF_ENDPOINT_DISPLAY_NAME}")

VPC_NETWORK_FULL          : projects/934903580331/global/networks/ucaip-haystack-vpc-network
ANN_ENDPOINT_DISPLAY_NAME : tfrs_128dim_local_endpoint
BF_ENDPOINT_DISPLAY_NAME  : tfrs_128dim_local_bf_endpoint


### ANN index endpoint

In [54]:
if CREATE_NEW_ASSETS == True:
    
    my_ann_index_endpoint = vertex_ai.MatchingEngineIndexEndpoint.create(
        display_name=f'{ANN_ENDPOINT_DISPLAY_NAME}',
        description="index endpoint for ANN index",
        network=VPC_NETWORK_FULL,
        sync=False,
    )
    
else:
    EXISTING_INDEX_ENDPOINT = f'projects/{PROJECT_NUM}/locations/{REGION}/indexEndpoints/7571386602446913536'
    my_ann_index_endpoint = vertex_ai.MatchingEngineIndexEndpoint(EXISTING_INDEX_ENDPOINT)
    
print(f"display_name         : {my_ann_index_endpoint.display_name}\n")
print(f"my_ann_index_endpoint: {my_ann_index_endpoint}")

display_name         : ann_index_endpoint_v1

my_ann_index_endpoint: <google.cloud.aiplatform.matching_engine.matching_engine_index_endpoint.MatchingEngineIndexEndpoint object at 0x7f0eecc7f990> 
resource name: projects/934903580331/locations/us-central1/indexEndpoints/7571386602446913536


In [55]:
print(f"Deployed indexes on the index endpoint:")
for d in my_ann_index_endpoint.deployed_indexes:
    print(f"    {d.id}")

Deployed indexes on the index endpoint:
    deployedann_v1


### brute-force index endpoint

In [56]:
if CREATE_NEW_ASSETS == True:
    
    my_bf_index_endpoint = vertex_ai.MatchingEngineIndexEndpoint.create(
        display_name=f'{BF_ENDPOINT_DISPLAY_NAME}',
        description="index endpoint for ANN index",
        network=VPC_NETWORK_FULL,
        sync=False,
    )
    
else:
    EXISTING_INDEX_ENDPOINT = f'projects/{PROJECT_NUM}/locations/{REGION}/indexEndpoints/7850046829390462976'
    my_bf_index_endpoint = vertex_ai.MatchingEngineIndexEndpoint(EXISTING_INDEX_ENDPOINT)
    
print(f"display_name         : {my_bf_index_endpoint.display_name}\n")
print(f"my_bf_index_endpoint : {my_bf_index_endpoint}")

display_name         : bf_index_endpoint_v1

my_bf_index_endpoint : <google.cloud.aiplatform.matching_engine.matching_engine_index_endpoint.MatchingEngineIndexEndpoint object at 0x7f14d4314450> 
resource name: projects/934903580331/locations/us-central1/indexEndpoints/7850046829390462976


In [57]:
print(f"Deployed indexes on the index endpoint:")
for d in my_bf_index_endpoint.deployed_indexes:
    print(f"    {d.id}")

Deployed indexes on the index endpoint:
    deployedbf_v1


#### list all index endpoints

In [59]:
# list_of_index_endpoints = my_ann_index_endpoint.list()
# list_of_index_endpoints[:5]

## Deploy Indexes to endpoints

> *Note: wait for indexes to be created (~40 mins) before deploying to endpoint*

In [61]:
# !gcloud ai indexes list \
#   --project=$PROJECT_ID \
#   --region=$LOCATION

**Get resource names:**

In [62]:
ANN_INDEX_NAME = tree_ah_index.display_name
BF_INDEX_NAME = brute_force_index.display_name

ANN_INDEX_ENDPOINT_NAME = my_ann_index_endpoint.display_name
BF_INDEX_ENDPOINT_NAME = my_bf_index_endpoint.display_name

DEPLOYED_ANN_INDEX_ID = f"deployed_{ANN_INDEX_NAME}"
DEPLOYED_BF_INDEX_ID = f"deployed_{BF_INDEX_NAME}"

print(f"ANN_INDEX_NAME          : {ANN_INDEX_NAME}")
print(f"BF_INDEX_NAME           : {BF_INDEX_NAME}")
print(f"ANN_INDEX_ENDPOINT_NAME : {ANN_INDEX_ENDPOINT_NAME}")
print(f"BF_INDEX_ENDPOINT_NAME  : {BF_INDEX_ENDPOINT_NAME}")
print(f"DEPLOYED_ANN_INDEX_ID   : {DEPLOYED_ANN_INDEX_ID}")
print(f"DEPLOYED_BF_INDEX_ID    : {DEPLOYED_BF_INDEX_ID}")

ANN_INDEX_NAME          : ann_index_v1-v1
BF_INDEX_NAME           : bf_index_v1_v1
ANN_INDEX_ENDPOINT_NAME : ann_index_endpoint_v1
BF_INDEX_ENDPOINT_NAME  : bf_index_endpoint_v1
DEPLOYED_ANN_INDEX_ID   : deployed_ann_index_v1-v1
DEPLOYED_BF_INDEX_ID    : deployed_bf_index_v1_v1


### Deploy ANN index

In [63]:
if CREATE_NEW_ASSETS == True:
    
    deployed_ann_index = my_ann_index_endpoint.deploy_index(
        index=tree_ah_index, 
        deployed_index_id=DEPLOYED_ANN_INDEX_ID
    )
    
else: # 5661297410488401920 | 3370091100063662080
    EXISTING_DEPLOYED_ENDPOINT = f'projects/{PROJECT_NUM}/locations/{REGION}/indexEndpoints/7571386602446913536'
    deployed_ann_index = vertex_ai.MatchingEngineIndexEndpoint(EXISTING_DEPLOYED_ENDPOINT)

print(f"display_name       : {deployed_ann_index.display_name}\n")
print(f"deployed_ann_index : {deployed_ann_index}")

display_name       : ann_index_endpoint_v1

deployed_ann_index : <google.cloud.aiplatform.matching_engine.matching_engine_index_endpoint.MatchingEngineIndexEndpoint object at 0x7f0eec411e10> 
resource name: projects/934903580331/locations/us-central1/indexEndpoints/7571386602446913536


In [64]:
print(f"Deployed indexes on the index endpoint:")
for d in deployed_ann_index.deployed_indexes:
    print(f"    {d.id}")

Deployed indexes on the index endpoint:
    deployedann_v1


### Deploy brute-force index

In [65]:
if CREATE_NEW_ASSETS == True:
    
    deployed_bf_index = my_bf_index_endpoint.deploy_index(
        index=brute_force_index, 
        deployed_index_id=DEPLOYED_BF_INDEX_ID
    )
else: # 4346246319296217088 | 1049611392061014016
    EXISTING_DEPLOYED_ENDPOINT = f'projects/{PROJECT_NUM}/locations/{REGION}/indexEndpoints/7850046829390462976'
    deployed_bf_index = vertex_ai.MatchingEngineIndexEndpoint(EXISTING_DEPLOYED_ENDPOINT)
    
print(f"display_name         : {my_bf_index_endpoint.display_name}\n")
print(f"my_bf_index_endpoint : {my_bf_index_endpoint}")

display_name         : bf_index_endpoint_v1

my_bf_index_endpoint : <google.cloud.aiplatform.matching_engine.matching_engine_index_endpoint.MatchingEngineIndexEndpoint object at 0x7f14d4314450> 
resource name: projects/934903580331/locations/us-central1/indexEndpoints/7850046829390462976


In [66]:
print(f"Deployed indexes on the index endpoint:")
for d in deployed_bf_index.deployed_indexes:
    print(f"    {d.id}")

Deployed indexes on the index endpoint:
    deployedbf_v1


# Test Index Recall

* use query_model to convert test instance to embeddings
* use embeddings to search for NN in ANN index

In [67]:
DEPLOYED_ANN_INDEX_ID = deployed_ann_index.deployed_indexes[0].id
DEPLOYED_BF_INDEX_ID = deployed_bf_index.deployed_indexes[0].id

print(f"DEPLOYED_ANN_INDEX_ID: {DEPLOYED_ANN_INDEX_ID}")
print(f"DEPLOYED_BF_INDEX_ID: {DEPLOYED_BF_INDEX_ID}")

DEPLOYED_ANN_INDEX_ID: deployedann_v1
DEPLOYED_BF_INDEX_ID: deployedbf_v1


In [70]:
TEST_EMBEDDING = emb_valid[0]

print(f"Length of TEST_EMBEDDING: {len(TEST_EMBEDDING)}")
TEST_EMBEDDING

Length of TEST_EMBEDDING: 128


array([-0.2958832 ,  1.3680464 , -1.0969739 , -1.1368335 , -0.4342838 ,
        0.39283156,  0.27632928,  0.03127032, -0.58364165, -0.7053537 ,
        0.02434574,  0.44454443,  0.36360395,  0.5888878 ,  0.46440738,
        0.6275927 ,  1.6786953 ,  0.02449422, -0.6262758 ,  0.9112112 ,
       -0.9261704 , -1.2875332 ,  1.3164775 ,  0.98420537, -0.7581321 ,
       -1.1962838 ,  0.554428  , -0.24542667, -0.01585706,  2.6407895 ,
        0.92781043, -0.7784559 ,  2.3831542 , -0.09137036, -0.71217793,
        1.7602313 , -1.7687172 , -0.99954164, -1.0950288 ,  0.33157885,
        0.3068544 , -1.4212629 , -0.5321212 , -0.6750946 ,  0.43817887,
        0.4380906 ,  1.5712899 ,  0.6998104 ,  0.22836557,  0.07564595,
        0.77435833,  0.352889  ,  0.31121057, -0.6935091 , -0.7743054 ,
       -0.42792314, -0.11159866,  1.5658025 , -0.59310645, -0.44968712,
       -0.6262101 , -1.0050645 ,  0.13244593, -0.2033531 , -0.21466011,
       -0.32916766, -0.8308483 , -1.0437603 , -0.80158645, -0.34

### ANN search

In [72]:
# %%timeit 
start = time.time()

ANN_response = deployed_ann_index.match(
    deployed_index_id=DEPLOYED_ANN_INDEX_ID,
    queries=[TEST_EMBEDDING],
    num_neighbors=20
)

# end_time = time.time()
# elapsed_ann_time = (end_time - start_time) / 60
# print(f"elapsed_ann_time: {elapsed_ann_time}")
elapsed_ann_time = time.time() - start
elapsed_ann_time = round(elapsed_ann_time, 4)
print(f'ANN latency: {elapsed_ann_time} seconds')

ANN latency: 0.011 seconds


In [73]:
ANN_response

[[MatchNeighbor(id='spotify:track:4PqIj0WOfPAq4QAvisjgpd', distance=60.644474029541016),
  MatchNeighbor(id='spotify:track:46P6IXXFACigzFF4BLtrRS', distance=60.076210021972656),
  MatchNeighbor(id='spotify:track:3YfjqbdBNgvhO0FrLfe3r2', distance=59.810176849365234),
  MatchNeighbor(id='spotify:track:0xWuZbSZWD5NLczQlFAzw8', distance=59.684112548828125),
  MatchNeighbor(id='spotify:track:0xUgMJKVoTt1NgyCou7k6Q', distance=58.92349624633789),
  MatchNeighbor(id='spotify:track:2yxxAoyWkss8hz5OBsICRy', distance=58.070220947265625),
  MatchNeighbor(id='spotify:track:789fcu6AQuBkUEHMZtSE3l', distance=57.790321350097656),
  MatchNeighbor(id='spotify:track:64uzhdVg9iYlwHn14MqrSF', distance=57.411468505859375),
  MatchNeighbor(id='spotify:track:3fv313UzOGEVoSTJ6s06B6', distance=57.38988494873047),
  MatchNeighbor(id='spotify:track:415EcxgjNwCzaqWuBwc2EU', distance=57.115089416503906),
  MatchNeighbor(id='spotify:track:1cVSpj4cx1jltOCZdIKAzB', distance=56.294979095458984),
  MatchNeighbor(id='spo

### Brute-force search

In [74]:
# %%timeit 
start = time.time()

BF_response = deployed_bf_index.match(
    deployed_index_id=DEPLOYED_BF_INDEX_ID,
    queries=[TEST_EMBEDDING],
    num_neighbors=20
)

# end_time = time.time()
# elapsed_bf_time = (end_time - start_time) / 60
# print(f"elapsed_bf_time: {elapsed_bf_time}")
elapsed_bf_time = time.time() - start
elapsed_bf_time = round(elapsed_bf_time, 4)
print(f'Bruteforce latency: {elapsed_bf_time} seconds')

Bruteforce latency: 0.2811 seconds


In [75]:
BF_response

[[MatchNeighbor(id='spotify:track:4PqIj0WOfPAq4QAvisjgpd', distance=60.644474029541016),
  MatchNeighbor(id='spotify:track:46P6IXXFACigzFF4BLtrRS', distance=60.076210021972656),
  MatchNeighbor(id='spotify:track:3YfjqbdBNgvhO0FrLfe3r2', distance=59.810176849365234),
  MatchNeighbor(id='spotify:track:0xWuZbSZWD5NLczQlFAzw8', distance=59.684112548828125),
  MatchNeighbor(id='spotify:track:0xUgMJKVoTt1NgyCou7k6Q', distance=58.92349624633789),
  MatchNeighbor(id='spotify:track:1cadVY6jZWZeEAPumXsRut', distance=58.49407196044922),
  MatchNeighbor(id='spotify:track:2yxxAoyWkss8hz5OBsICRy', distance=58.070220947265625),
  MatchNeighbor(id='spotify:track:789fcu6AQuBkUEHMZtSE3l', distance=57.790321350097656),
  MatchNeighbor(id='spotify:track:64uzhdVg9iYlwHn14MqrSF', distance=57.411468505859375),
  MatchNeighbor(id='spotify:track:3fv313UzOGEVoSTJ6s06B6', distance=57.38988494873047),
  MatchNeighbor(id='spotify:track:415EcxgjNwCzaqWuBwc2EU', distance=57.115089416503906),
  MatchNeighbor(id='spot

## Compute Recall

In [76]:
# Calculate recall by determining how many neighbors were correctly retrieved as compared to the brute-force option.
recalled_neighbors = 0
for tree_ah_neighbors, brute_force_neighbors in zip(
    ANN_response, BF_response
):
    tree_ah_neighbor_ids = [neighbor.id for neighbor in tree_ah_neighbors]
    brute_force_neighbor_ids = [neighbor.id for neighbor in brute_force_neighbors]

    recalled_neighbors += len(
        set(tree_ah_neighbor_ids).intersection(brute_force_neighbor_ids)
    )

recall = recalled_neighbors / len(
    [neighbor for neighbors in BF_response for neighbor in neighbors]
)

print("Recall: {}".format(recall))

Recall: 0.75


In [77]:
reduction = (elapsed_bf_time - elapsed_ann_time) / elapsed_bf_time*100.00
increase  = (elapsed_bf_time - elapsed_ann_time)/elapsed_ann_time*100.00
faster    = elapsed_bf_time / elapsed_ann_time

print(f"reduction in time         : {round(reduction, 3)}%")
print(f"% increase in performance : {round(increase, 3)}%")
print(f"how many times faster     : {round(faster, 3)}x faster")

reduction in time         : 96.087%
% increase in performance : 2455.455%
how many times faster     : 25.555x faster


**Finished**