## Test the Retrieval Latency of Approximate vs Exact Matching 

In [1]:
import tensorflow as tf
import time

In [2]:
DIMENSIONS = 50
DISPLAY_NAME = "retail_demo_matching_engine"
DISPLAY_NAME_BRUTE_FORCE = DISPLAY_NAME + "_brute_force"

PROJECT_ID = 'rec-ai-demo-326116' # Change to your project.
BUCKET = 'rec_bq_jsw'
INDEX_DIR = f'gs://{BUCKET}/bqml/scann_index'
BQML_MODEL_DIR = f'gs://{BUCKET}/bqml/item_matching_model'
LOOKUP_MODEL_DIR = f'gs://{BUCKET}/bqml/embedding_lookup_model'

In [16]:
products = {
    '4096': "AX Paris Strapless Spot Print Orange Romper",
    '5120': "Lee Women's Plus-Size Comfort Fit Straight Leg Pant",
    '7424': "Allegra K Woman Plaid Elastic Waist Preppy Above Knee Skirt Gray Black S",
    '4352': "Silver Jeans Juniors Suki Surplus Mid Rise Bootcut Jean"
}

## Exact Matching

In [17]:
class ExactMatcher(object):
  def __init__(self, model_dir):
    print("Loading exact matchg model...")
    self.model = tf.saved_model.load(model_dir)
    print("Exact matchg model is loaded.")
  
  def match(self, instances):
    outputs = self.model.signatures['serving_default'](tf.constant(instances, tf.dtypes.int64))
    return outputs['predicted_item2_Id'].numpy()

In [18]:
exact_matcher = ExactMatcher(BQML_MODEL_DIR)

Loading exact matchg model...
Exact matchg model is loaded.


In [19]:
exact_matches = {}

start_time = time.time()
for i in range(100):
  for song in products:
    matches = exact_matcher.match([int(song)])
    exact_matches[song] = matches.tolist()[0]
end_time = time.time()
exact_elapsed_time = end_time - start_time

print(f'Elapsed time: {round(exact_elapsed_time, 3)} seconds - average time: {exact_elapsed_time / (100 * len(products))} seconds')

Elapsed time: 0.349 seconds - average time: 0.0008731639385223389 seconds


## Approximate Matching (ScaNN)

In [20]:
from index_server.matching import ScaNNMatcher
scann_matcher = ScaNNMatcher(INDEX_DIR)
embedding_lookup = tf.saved_model.load(LOOKUP_MODEL_DIR)

Loading ScaNN index...
ScaNN index is loadded.


In [22]:
approx_matches = dict()

start_time = time.time()
for i in range(100):
  for song in products:
    vector = embedding_lookup([song]).numpy()[0]
    matches = scann_matcher.match(vector, 50)
    approx_matches[song] = matches
end_time = time.time()
scann_elapsed_time = end_time - start_time

print(f'Elapsed time: {round(scann_elapsed_time, 3)} seconds - average time: {scann_elapsed_time / (100 * len(products))} seconds')

Elapsed time: 0.402 seconds - average time: 0.001004766821861267 seconds


In [23]:
speedup_percent = round(exact_elapsed_time / scann_elapsed_time, 1)
print(f'ScaNN speedup: {speedup_percent}x')

ScaNN speedup: 0.9x


In [31]:
# Another visualization

def get_stats(n_groups_of_four, matcher=scann_matcher):
    m = dict()

    start_time = time.time()
    for i in range(n_groups_of_four):
        for song in products:
            vector = embedding_lookup([song]).numpy()[0]
            if matcher == exact_matcher:
                matches = matcher.match([int(song)])
            else:
                matches = matcher.match(vector, 50)
            m[song] = matches
    end_time = time.time()
    elapsed_time = end_time - start_time
    return elapsed_time

In [None]:
# generate some data for retreiving matches, between ScaNN and BF (exact)
import pandas as pd
import seaborn as sns
import time


max_n_4 = 300_000 # number of iterations across the data - not too big only 4
data = {
    "n_4": [],
    "n_preds": [],
    "run_time": [],
    "matcher": []
}

tt = 0
# run a loop in scann
print(f"Running loop for scann - total of {max_n_4}")
for i in list(range(1, max_n_4, 1000)):
    i = i**2
    t = get_stats(i, matcher=scann_matcher)
    data["n_4"].append(i)
    data["n_preds"].append(i*4)
    data["run_time"].append(t)
    data["matcher"].append("scann")
    tt+=t

tt1 = tt
print(f"Total time: {tt}")
tt = 0   
print(f"Running loop for brute force")    
# run a loop in brute force
for i in list(range(1, max_n_4, 1000)):
    i = i**2
    t = get_stats(i, matcher=exact_matcher)
    data["n_4"].append(i)
    data["n_preds"].append(i*4)
    data["run_time"].append(t)
    data["matcher"].append("brute_force")
    tt+=t

    
print(f"Total time: {tt}")
print(f"Scann is {tt/tt1:.2f}x faster than Brute Force")


Running loop for scann - total of 300000


In [None]:
%%markdown
### Performance differences scale non-linearly
As number of predictions increase between brute force and scann


In [None]:
sns.lineplot(data=data, x="n_preds", y="run_time", hue="matcher")

## License

Copyright 2020 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. You may obtain a copy of the License at: http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

See the License for the specific language governing permissions and limitations under the License.

**This is not an official Google product but sample code provided for an educational purpose**