# Implementing Recommendation Engines with Matching Engine

### VPC Network peering
Matching engine is a high performance vector matching service that requires a seperate VPC to ensure performance. 

Below are the one-time instructions to set up a peering network. 

**Once created, be sure to your notebook instance running this particular notebook is in the subnetwork... https://cloud.google.com/vertex-ai/docs/matching-engine/match-eng-setup**

## Setup

In [29]:
GCP_PROJECTS = !gcloud config get-value project
PROJECT_ID = GCP_PROJECTS[0]
PROJECT_NUM = !gcloud projects list --filter="$PROJECT_ID" --format="value(PROJECT_NUMBER)"
PROJECT_NUM = PROJECT_NUM[0]
LOCATION = 'us-central1'

BUCKET = 'jt-tfrs-central'
BUCKET_URI = f'gs://{BUCKET}'
RUN_DIR_PATH = 'a50-epoch/run-20221230-160518'
RUN_DIR_GCS_PATH = f'{BUCKET_URI}/{RUN_DIR_PATH}'

VERSION = 'v8'

print(f"PROJECT_ID: {PROJECT_ID}")
print(f"PROJECT_NUM: {PROJECT_NUM}")
print(f"LOCATION: {LOCATION}")
print(f"BUCKET_URI: {BUCKET_URI}")
print(f"RUN_DIR_GCS_PATH: {RUN_DIR_GCS_PATH}")

PROJECT_ID: hybrid-vertex
PROJECT_NUM: 934903580331
LOCATION: us-central1
BUCKET_URI: gs://jt-tfrs-central
RUN_DIR_GCS_PATH: gs://jt-tfrs-central/a50-epoch/run-20221230-160518


In [30]:
import os
import sys
from google.cloud import aiplatform as vertex_ai

vertex_ai.init(project=PROJECT_ID, location=LOCATION)

### Create a matching engine index

The matching engine loads an index from a file of embeddings created from the last notebook. 

Many of the optimization options for matching engine are found in the ah tree settings and testing is recommended depending on each use case

Recall we saved our two tower models and query embeddings (newline json) in a candidate folder

In [31]:
EMBEDDINGS_INITIAL_URI = f'{BUCKET_URI}/{RUN_DIR_PATH}/candidates-v8/'

print(f"EMBEDDINGS_INITIAL_URI: {EMBEDDINGS_INITIAL_URI}")

EMBEDDINGS_INITIAL_URI: gs://jt-tfrs-central/a50-epoch/run-20221230-160518/candidates-v8/


`EMBEDDINGS_INITIAL_URI` should lead to a folder with just the candidate json file...

In [32]:
! gsutil ls $EMBEDDINGS_INITIAL_URI

gs://jt-tfrs-central/a50-epoch/run-20221230-160518/candidates-v8/candidate_embs_local_v6_20230112-180944.json


### Create ANN index

In [33]:
# ANN index config
APPROX_NEIGHBORS=50
DISTANCE_MEASURE="DOT_PRODUCT_DISTANCE"
LEAF_NODE_EMB_COUNT=500
LEAF_NODES_SEARCH_PERCENT=7
DIMENSIONS = 32  # must match output dimensions

DISPLAY_NAME = f"tfrs_{DIMENSIONS}dim_50e_{VERSION}"
BF_DISPLAY_NAME=f"{DISPLAY_NAME}_bf"

EXPERIMENT_NAME='a50-epoch'
EXPERIMENT_RUN='run-20221230-160518'
DATA_REGIME='full'

> *Note: setting `sync=False` will allow us to proceed with the notebook while these operations complete*

In [34]:
tree_ah_index = vertex_ai.MatchingEngineIndex.create_tree_ah_index(
    display_name=DISPLAY_NAME,
    contents_delta_uri=EMBEDDINGS_INITIAL_URI,
    dimensions=DIMENSIONS,
    approximate_neighbors_count=APPROX_NEIGHBORS,
    distance_measure_type=DISTANCE_MEASURE,
    leaf_node_embedding_count=LEAF_NODE_EMB_COUNT,
    leaf_nodes_to_search_percent=LEAF_NODES_SEARCH_PERCENT,
    description="Songs embeddings from the Spotify million playlist dataset",
    sync=False,
    labels={
        "experiment_name": f'{EXPERIMENT_NAME}',
        "experiment_run": f'{EXPERIMENT_RUN}',
        "data_regime": f'{DATA_REGIME}',
    },
)

Creating MatchingEngineIndex
Create MatchingEngineIndex backing LRO: projects/934903580331/locations/us-central1/indexes/5123345953436205056/operations/6355292451107766272


### Create Brute Force index

used to evaluate ANN retrieval

In [35]:
brute_force_index = vertex_ai.MatchingEngineIndex.create_brute_force_index(
    display_name=BF_DISPLAY_NAME,
    contents_delta_uri=EMBEDDINGS_INITIAL_URI,
    dimensions=DIMENSIONS,
    distance_measure_type=DISTANCE_MEASURE,
    sync=False,
    labels={
        "experiment_name": f'{EXPERIMENT_NAME}',
        "experiment_run": f'{EXPERIMENT_RUN}',
        "data_regime": f'{DATA_REGIME}',
    },
)

Creating MatchingEngineIndex
Create MatchingEngineIndex backing LRO: projects/934903580331/locations/us-central1/indexes/4605994946242019328/operations/7893271723854790656


### Create Matching Engine endpoint(s)

* both the ANN and brute force indices can be deployed to a single endpoint
* alternatively, we can create seperate endpoints, one for each index

#### index endpoint config: 

In [39]:
VPC_NETWORK = "ucaip-haystack-vpc-network" # TODO: update this

VPC_NETWORK_FULL = f"projects/{PROJECT_NUM}/global/networks/{VPC_NETWORK}"

ANN_ENDPOINT_DISPLAY_NAME = f'{DISPLAY_NAME}_endpoint'

BF_ENDPOINT_DISPLAY_NAME = f'{BF_DISPLAY_NAME}_endpoint'

print(f"VPC_NETWORK_FULL: {VPC_NETWORK_FULL}")
print(f"ANN_ENDPOINT_DISPLAY_NAME: {ANN_ENDPOINT_DISPLAY_NAME}")
print(f"BF_ENDPOINT_DISPLAY_NAME: {BF_ENDPOINT_DISPLAY_NAME}")

VPC_NETWORK_FULL: projects/934903580331/global/networks/ucaip-haystack-vpc-network
ANN_ENDPOINT_DISPLAY_NAME: tfrs_32dim_50e_v8_endpoint
BF_ENDPOINT_DISPLAY_NAME: tfrs_32dim_50e_v8_bf_endpoint


Then create the indices

In [37]:
my_ann_index_endpoint = vertex_ai.MatchingEngineIndexEndpoint.create(
    display_name=f'{ANN_ENDPOINT_DISPLAY_NAME}',
    description="index endpoint for ANN index",
    network=VPC_NETWORK_FULL,
    sync=False,
)

# to use existing
# my_ann_index_endpoint = aiplatform.MatchingEngineIndexEndpoint('projects/934903580331/locations/us-central1/indexEndpoints/8097410557360996352')

Creating MatchingEngineIndexEndpoint
Create MatchingEngineIndexEndpoint backing LRO: projects/934903580331/locations/us-central1/indexEndpoints/381618495768494080/operations/4548223100625354752
MatchingEngineIndexEndpoint created. Resource name: projects/934903580331/locations/us-central1/indexEndpoints/381618495768494080
To use this MatchingEngineIndexEndpoint in another session:
index_endpoint = aiplatform.MatchingEngineIndexEndpoint('projects/934903580331/locations/us-central1/indexEndpoints/381618495768494080')


In [38]:
my_bf_index_endpoint = vertex_ai.MatchingEngineIndexEndpoint.create(
    display_name=f'{BF_ENDPOINT_DISPLAY_NAME}',
    description="index endpoint for ANN index",
    network=VPC_NETWORK_FULL,
    sync=False,
)
# to use existing
# my_bf_index_endpoint = aiplatform.MatchingEngineIndexEndpoint('projects/934903580331/locations/us-central1/indexEndpoints/1972515064137121792')

Creating MatchingEngineIndexEndpoint
Create MatchingEngineIndexEndpoint backing LRO: projects/934903580331/locations/us-central1/indexEndpoints/6417567896351801344/operations/1470012755317620736
MatchingEngineIndexEndpoint created. Resource name: projects/934903580331/locations/us-central1/indexEndpoints/6417567896351801344
To use this MatchingEngineIndexEndpoint in another session:
index_endpoint = aiplatform.MatchingEngineIndexEndpoint('projects/934903580331/locations/us-central1/indexEndpoints/6417567896351801344')


In [40]:
ANN_INDEX_ENDPOINT_NAME = my_ann_index_endpoint.resource_name
BF_INDEX_ENDPOINT_NAME = my_bf_index_endpoint.resource_name

print(f"ANN_INDEX_ENDPOINT_NAME: {ANN_INDEX_ENDPOINT_NAME}")
print(f"BF_INDEX_ENDPOINT_NAME: {BF_INDEX_ENDPOINT_NAME}")

ANN_INDEX_ENDPOINT_NAME: projects/934903580331/locations/us-central1/indexEndpoints/381618495768494080
BF_INDEX_ENDPOINT_NAME: projects/934903580331/locations/us-central1/indexEndpoints/6417567896351801344


## Deploy Indexes to endpoints

> *Note: wait for indexes to be created (~40 mins) before deploying to endpoint*

In [47]:
# !gcloud ai indexes list \
#   --project=$PROJECT_ID \
#   --region=$LOCATION

In [46]:
# if neededing to create index obj in session
tree_ah_resource_name = f'projects/{PROJECT_NUM}/locations/us-central1/indexes/5123345953436205056'
brute_force_index_resource_name = f'projects/{PROJECT_NUM}/locations/us-central1/indexes/4605994946242019328'

tree_ah_index = vertex_ai.MatchingEngineIndex(index_name=tree_ah_resource_name)
brute_force_index = vertex_ai.MatchingEngineIndex(index_name=brute_force_index_resource_name)

In [48]:
ANN_INDEX_NAME = tree_ah_index.resource_name
BF_INDEX_NAME = brute_force_index.resource_name

print(f"ANN_INDEX_NAME: {ANN_INDEX_NAME}")
print(f"BF_INDEX_NAME: {BF_INDEX_NAME}")

DEPLOYED_ANN_INDEX_ID = f"deployed_{DISPLAY_NAME}"
DEPLOYED_BF_INDEX_ID = f"deployed_{BF_DISPLAY_NAME}"

print(f"DEPLOYED_ANN_INDEX_ID: {DEPLOYED_ANN_INDEX_ID}")
print(f"DEPLOYED_BF_INDEX_ID: {DEPLOYED_BF_INDEX_ID}")

ANN_INDEX_NAME: projects/934903580331/locations/us-central1/indexes/5123345953436205056
BF_INDEX_NAME: projects/934903580331/locations/us-central1/indexes/4605994946242019328
DEPLOYED_ANN_INDEX_ID: deployed_tfrs_32dim_50e_v8
DEPLOYED_BF_INDEX_ID: deployed_tfrs_32dim_50e_v8_bf


#### Deploy ANN index

In [49]:
deployed_ann_index = my_ann_index_endpoint.deploy_index(
    index=tree_ah_index, 
    deployed_index_id=DEPLOYED_ANN_INDEX_ID
)
deployed_ann_index.deployed_indexes

Deploying index MatchingEngineIndexEndpoint index_endpoint: projects/934903580331/locations/us-central1/indexEndpoints/381618495768494080
Deploy index MatchingEngineIndexEndpoint index_endpoint backing LRO: projects/934903580331/locations/us-central1/indexEndpoints/381618495768494080/operations/2757760773768871936
MatchingEngineIndexEndpoint index_endpoint Deployed index. Resource name: projects/934903580331/locations/us-central1/indexEndpoints/381618495768494080


[id: "deployed_tfrs_32dim_50e_v8"
index: "projects/934903580331/locations/us-central1/indexes/5123345953436205056"
create_time {
  seconds: 1673552056
  nanos: 94509000
}
private_endpoints {
  match_grpc_address: "10.41.2.5"
}
index_sync_time {
  seconds: 1673552270
  nanos: 729035000
}
automatic_resources {
  min_replica_count: 2
  max_replica_count: 2
}
deployment_group: "default"
]

#### Deploy Brute Force index

In [50]:
deployed_bf_index = my_bf_index_endpoint.deploy_index(
    index=brute_force_index, 
    deployed_index_id=DEPLOYED_BF_INDEX_ID
)
deployed_bf_index.deployed_indexes

Deploying index MatchingEngineIndexEndpoint index_endpoint: projects/934903580331/locations/us-central1/indexEndpoints/6417567896351801344
Deploy index MatchingEngineIndexEndpoint index_endpoint backing LRO: projects/934903580331/locations/us-central1/indexEndpoints/6417567896351801344/operations/8436799903883067392
MatchingEngineIndexEndpoint index_endpoint Deployed index. Resource name: projects/934903580331/locations/us-central1/indexEndpoints/6417567896351801344


[id: "deployed_tfrs_32dim_50e_v8_bf"
index: "projects/934903580331/locations/us-central1/indexes/4605994946242019328"
create_time {
  seconds: 1673552276
  nanos: 416921000
}
private_endpoints {
  match_grpc_address: "10.41.2.5"
}
index_sync_time {
  seconds: 1673552492
  nanos: 130906000
}
automatic_resources {
  min_replica_count: 2
  max_replica_count: 2
}
deployment_group: "default"
]

# Query Model

### Upload Query Model to Vertex Model Registry

In [51]:
QUERY_MODEL_DIR = f"{BUCKET_URI}/{RUN_DIR_PATH}/model-dir/query_model"

print(f"QUERY_MODEL_DIR: {QUERY_MODEL_DIR}")

QUERY_MODEL_DIR: gs://jt-tfrs-central/a50-epoch/run-20221230-160518/model-dir/query_model


In [52]:
! gsutil ls $QUERY_MODEL_DIR

gs://jt-tfrs-central/a50-epoch/run-20221230-160518/model-dir/query_model/saved_model.pb
gs://jt-tfrs-central/a50-epoch/run-20221230-160518/model-dir/query_model/variables/


In [11]:
SERVING_CONTAINER_IMAGE_URI = 'us-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-10:latest'

In [21]:
uploaded_query_model = vertex_ai.Model.upload(
    display_name=f'query_model_{DISPLAY_NAME}',
    artifact_uri=QUERY_MODEL_DIR,
    serving_container_image_uri=SERVING_CONTAINER_IMAGE_URI,
    description="Top of the query tower, meant to return an embedding for each playlist instance",
    sync=True,
)

Creating Model
Create Model backing LRO: projects/934903580331/locations/us-central1/models/6791535990214230016/operations/914776975476785152
Model created. Resource name: projects/934903580331/locations/us-central1/models/6791535990214230016@1
To use this Model in another session:
model = aiplatform.Model('projects/934903580331/locations/us-central1/models/6791535990214230016@1')


#### Create model endpoint

In [22]:
endpoint = vertex_ai.Endpoint.create(
    display_name=f'endpoint_{DISPLAY_NAME}',
    project=PROJECT_ID,
    location=LOCATION,
    sync=True,
)

Creating Endpoint
Create Endpoint backing LRO: projects/934903580331/locations/us-central1/endpoints/4002948002778972160/operations/160424037892227072
Endpoint created. Resource name: projects/934903580331/locations/us-central1/endpoints/4002948002778972160
To use this Endpoint in another session:
endpoint = aiplatform.Endpoint('projects/934903580331/locations/us-central1/endpoints/4002948002778972160')
MatchingEngineIndex created. Resource name: projects/934903580331/locations/us-central1/indexes/2140274150256672768
To use this MatchingEngineIndex in another session:
index = aiplatform.MatchingEngineIndex('projects/934903580331/locations/us-central1/indexes/2140274150256672768')


#### Deploy uploaded model to model endpoint

In [23]:
deployed_query_model = uploaded_query_model.deploy(
    endpoint=endpoint,
    deployed_model_display_name=f'deployed_qmodel_{DISPLAY_NAME}',
    machine_type="n1-standard-4",
    min_replica_count=1,
    max_replica_count=2,
    accelerator_type=None,
    accelerator_count=0,
    sync=True,
)

Deploying model to Endpoint : projects/934903580331/locations/us-central1/endpoints/4002948002778972160
Deploy Endpoint model backing LRO: projects/934903580331/locations/us-central1/endpoints/4002948002778972160/operations/3390067920670294016
MatchingEngineIndex created. Resource name: projects/934903580331/locations/us-central1/indexes/7417929963581472768
To use this MatchingEngineIndex in another session:
index = aiplatform.MatchingEngineIndex('projects/934903580331/locations/us-central1/indexes/7417929963581472768')
Endpoint model deployed. Resource name: projects/934903580331/locations/us-central1/endpoints/4002948002778972160


# Retrieve nearest neighbors from index

* use query_model to convert test instance to embeddings
* use embeddings to search for NN in ANN index

In [53]:
uploaded_query_model = vertex_ai.Model('projects/934903580331/locations/us-central1/models/6791535990214230016@1')
deployed_query_model = vertex_ai.Endpoint('projects/934903580331/locations/us-central1/endpoints/4002948002778972160')
deployed_ann_index = vertex_ai.MatchingEngineIndexEndpoint('projects/934903580331/locations/us-central1/indexEndpoints/381618495768494080')
deployed_bf_index = vertex_ai.MatchingEngineIndexEndpoint('projects/934903580331/locations/us-central1/indexEndpoints/6417567896351801344')

## Create Test Instance(s)

* We can create a quick example from the train or valid dataset by returning a structured example like:

```
for tensor_dict in train_dataset.unbatch().skip(12905).take(1):
    td_keys = tensor_dict.keys()
    list_dict = {}
    for k in td_keys:
        list_dict.update({k: tensor_dict[k].numpy()})
    print(list_dict)
```

In [54]:
TEST_INSTANCE = {
    'album_name_can': 'Capoeira Electronica',
    'album_name_pl': [
        'Odilara', 'Capoeira Electronica', 'Capoeira Ultimate','Festa Popular', 'Capoeira Electronica'
    ],
    'album_uri_can': 'spotify:album:2FsSSHGt8JM0JgRy6ZX3kR',
    'album_uri_pl': [
        'spotify:album:4Y8RfvZzCiApBCIZswj9Ry',
        'spotify:album:2FsSSHGt8JM0JgRy6ZX3kR',
        'spotify:album:55HHBqZ2SefPeaENOgWxYK',
        'spotify:album:150L1V6UUT7fGUI3PbxpkE',
        'spotify:album:2FsSSHGt8JM0JgRy6ZX3kR'
    ],
    'artist_followers_can': 5170.0,
    'artist_genres_can': 'capoeira',
    'artist_genres_pl': ['samba moderno', 'capoeira', 'capoeira', 'NONE','capoeira'],
    'artist_name_can': 'Capoeira Experience',
    'artist_name_pl': ['Odilara', 'Capoeira Experience', 'Denis Porto', 'Zambe','Capoeira Experience'],
    'artist_pop_can': 24.0,
    'artist_pop_pl':[ 4., 24.,  2.,  0., 24.],
    'artist_uri_can': 'spotify:artist:5SKEXbgzIdRl3gQJ23CnUP',
    'artist_uri_pl': [
        'spotify:artist:72oameojLOPWYB7nB8rl6c',
        'spotify:artist:5SKEXbgzIdRl3gQJ23CnUP',
        'spotify:artist:67p5GMYQZOgaAfx1YyttQk',
        'spotify:artist:4fH3OXCRcPsaHFE5KhgqZS',
        'spotify:artist:5SKEXbgzIdRl3gQJ23CnUP'
    ],
    'artists_followers_pl': [ 316., 5170.,  448.,   19., 5170.],
    'duration_ms_can': 192640.0,
    'duration_ms_songs_pl': [234612., 226826., 203480., 287946., 271920.],
    'num_pl_albums_new': 9.0,
    'num_pl_artists_new': 5.0,
    'num_pl_songs_new': 85.0,
    'pl_collaborative_src': 'false',
    'pl_duration_ms_new': 17971314.0,
    'pl_name_src': 'Capoeira',
    'time_signature_can': '4',
    'time_signature_pl': ['4', '4', '4', '4', '4'],
    'track_acousticness_can': 0.478,
    'track_acousticness_pl': [0.238 , 0.105 , 0.0242, 0.125 , 0.304 ],
    'track_danceability_can': 0.709,
    'track_danceability_pl': [0.703, 0.712, 0.806, 0.529, 0.821],
    'track_energy_can': 0.742,
    'track_energy_pl': [0.743, 0.41 , 0.794, 0.776, 0.947],
    'track_instrumentalness_can': 0.00297,
    'track_instrumentalness_pl': [4.84e-06, 4.30e-01, 7.42e-04, 4.01e-01, 5.07e-03],
    'track_key_can': '0',
    'track_key_pl': ['5', '0', '1', '10', '10'],
    'track_liveness_can': 0.0346,
    'track_liveness_pl': [0.128 , 0.0725, 0.191 , 0.105 , 0.0552],
    'track_loudness_can': -7.295,
    'track_loudness_pl': [-8.638, -8.754, -9.084, -7.04 , -6.694],
    'track_mode_can': '1',
    'track_mode_pl': ['0', '1', '1', '0', '1'],
    'track_name_can': 'Bezouro Preto - Studio',
    'track_name_pl': ['O Telefone Tocou Novamente', 'Bem Devagar - Studio','Angola Dream', 'Janaina', 'Louco Berimbau - Studio'],
    'track_pop_can': 3.0,
    'track_pop_pl': [5., 1., 0., 0., 1.],
    'track_speechiness_can': 0.0802,
    'track_speechiness_pl':[0.0367, 0.0272, 0.0407, 0.132 , 0.0734],
    'track_tempo_can': 172.238,
    'track_tempo_pl': [100.039,  89.089, 123.999, 119.963, 119.214],
    'track_uri_can': 'spotify:track:0tlhK4OvpHCYpReTABvKFb',
    'track_uri_pl': [
        'spotify:track:1pQkOdcTDfLr84TDCrmGy7',
        'spotify:track:39grEDsAHAjmo2QFo4G8D9',
        'spotify:track:5vxSLdJXqbKYH487YO8LSL',
        'spotify:track:6T9GbmZ6voDM4aTBsG5VDh',
        'spotify:track:7ELt9eslVvWo276pX2garN'
    ],
    'track_valence_can': 0.844,
    'track_valence_pl': [0.966, 0.667, 0.696, 0.876, 0.655],
}

#### If needing to create model and index objects in session...

In [55]:
playlist_emb = deployed_query_model.predict([TEST_INSTANCE])
playlist_emb

Prediction(predictions=[[0.059902925, -0.492682517, 0.146539509, -0.296207726, -2.75934458, 3.86213779, 0.680902481, 1.17066145, -1.72475672, 1.10044408, 0.496523142, -1.07292652, 0.763490498, -1.37628, -0.630429149, 1.04659331, 0.444013417, -0.623084664, -1.07019091, -0.380872279, -1.43623781, 2.77222896, -0.356521338, 0.925348163, -2.93748331, 0.515905, 0.777000666, 1.78092742, 0.343145162, 2.04922, -2.74038124, -0.692501187]], deployed_model_id='520565979193802752', model_version_id='1', model_resource_name='projects/934903580331/locations/us-central1/models/6791535990214230016', explanations=None)

In [56]:
DEPLOYED_ANN_INDEX_ID = 'deployed_tfrs_32dim_50e_v8'
DEPLOYED_BF_INDEX_ID = 'deployed_tfrs_32dim_50e_v8_bf'

In [57]:
%%timeit 
ANN_response = deployed_ann_index.match(
    deployed_index_id=DEPLOYED_ANN_INDEX_ID,
    queries=playlist_emb.predictions,
    num_neighbors=10
)
# ANN_response

5.66 ms ± 66.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [58]:
ANN_response

[[MatchNeighbor(id='spotify:track:5EOCVGLC87l0ltBHPF1PLZ', distance=277.0671691894531),
  MatchNeighbor(id='spotify:track:2nRu8GJ4YOOTko2AL6JFHv', distance=276.6213684082031),
  MatchNeighbor(id='spotify:track:2y6E5270GWuFOkFT9IDKzp', distance=276.147705078125),
  MatchNeighbor(id='spotify:track:6YpVOTOw4tS3Rn3uFOprZW', distance=275.94091796875),
  MatchNeighbor(id='spotify:track:1EuEdRsSKW2KfxgjQHKOjB', distance=275.8877868652344),
  MatchNeighbor(id='spotify:track:1ylbo5s6dZEUC3Bi57Hxy0', distance=275.6029052734375),
  MatchNeighbor(id='spotify:track:3Ouo7onYOiX2R1xTG8jwMV', distance=275.3026428222656),
  MatchNeighbor(id='spotify:track:1JznSrEEOdCmVbPDKr1BmN', distance=275.12353515625),
  MatchNeighbor(id='spotify:track:4a8o7xDrEB2Emc1pCTBIDl', distance=274.9769592285156),
  MatchNeighbor(id='spotify:track:0BEfyNWHdUbzoZ81fOEQGB', distance=274.9180908203125)]]

In [59]:
%%timeit 
BF_response = deployed_bf_index.match(
    deployed_index_id=DEPLOYED_BF_INDEX_ID,
    queries=playlist_emb.predictions,
    num_neighbors=10
)
# BF_response

36.1 ms ± 2.17 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [60]:
BF_response

[[MatchNeighbor(id='spotify:track:5EOCVGLC87l0ltBHPF1PLZ', distance=277.0671691894531),
  MatchNeighbor(id='spotify:track:2nRu8GJ4YOOTko2AL6JFHv', distance=276.6213684082031),
  MatchNeighbor(id='spotify:track:7by9lz7HaaYM5ZMa7gWMhP', distance=276.21881103515625),
  MatchNeighbor(id='spotify:track:2y6E5270GWuFOkFT9IDKzp', distance=276.147705078125),
  MatchNeighbor(id='spotify:track:6YpVOTOw4tS3Rn3uFOprZW', distance=275.94091796875),
  MatchNeighbor(id='spotify:track:1EuEdRsSKW2KfxgjQHKOjB', distance=275.8877868652344),
  MatchNeighbor(id='spotify:track:5LzFJYbcTGLZ6c1NsjtAlj', distance=275.7001037597656),
  MatchNeighbor(id='spotify:track:12YahMP74oGi4dT8EO2uNu', distance=275.6883239746094),
  MatchNeighbor(id='spotify:track:1ylbo5s6dZEUC3Bi57Hxy0', distance=275.6029052734375),
  MatchNeighbor(id='spotify:track:3Ouo7onYOiX2R1xTG8jwMV', distance=275.3026428222656)]]

## Compute Recall

In [61]:
# Calculate recall by determining how many neighbors were correctly retrieved as compared to the brute-force option.
recalled_neighbors = 0
for tree_ah_neighbors, brute_force_neighbors in zip(
    ANN_response, BF_response
):
    tree_ah_neighbor_ids = [neighbor.id for neighbor in tree_ah_neighbors]
    brute_force_neighbor_ids = [neighbor.id for neighbor in brute_force_neighbors]

    recalled_neighbors += len(
        set(tree_ah_neighbor_ids).intersection(brute_force_neighbor_ids)
    )

recall = recalled_neighbors / len(
    [neighbor for neighbors in BF_response for neighbor in neighbors]
)

print("Recall: {}".format(recall))

Recall: 0.7
