In [1]:
import os

os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "/Users/andreiromanov/DEV/talabat/Ace/omar_service.json"
os.environ["GCP_PROJECT"] = "tlb-data-dev"

In [2]:
from model_artifacts import ArtifactsManager
import os
import pandas as pd
import numpy as np

# TODO 1: script to load user model
# TODO 2: script to load chain model

# loading time
# 1. you define an instance of Artifacts Manager (country, version, recall)
# 2. download the model artifacts
# 3.a load user model
# 3.b update ese embeddings (from postgresql retrieve ese embeddings for a country and update it)
# 4. loop for all countries

# Inference
# given a user, time, location
# 1. retrieve static features
# 2. append order_hour, order_weekday, geohash6
# 3. route these features to the right model depending on request country
# 4. compute distance with available chains --> intersection(country chains, available chains)
# 5. replicating previous logic
# 6. penalizing previous orders with k positions


config = {"country": "AE", "recall": 3235, "version": 2}


def get_inference_path(country):
    base_path = os.path.join(os.getcwd(), f"inference_model_artifacts/{country}")
    if not os.path.exists(base_path):
        os.makedirs(base_path, exist_ok=True)
    return base_path


base_path = get_inference_path(config["country"])
artifacts_manager = VendorArtifactsManager(
    recall=config["recall"], base_dir=base_path, country=config["country"], version=config["version"]
)
artifacts_manager.download_model_artifacts()
user_model = artifacts_manager.load_user_model()
chain_model = artifacts_manager.load_chain_model()

query_features = artifacts_manager.get_query_features()
test_df = artifacts_manager.get_test_sample()
inferred_user_embedding = user_model(test_df[query_features]).numpy()

candidates_features = artifacts_manager.get_candidates_features()
chain_features = artifacts_manager.get_chain_features()
chain_embeddings = artifacts_manager.get_embeddings()
loaded_chain_embeddings = np.array([a for a in chain_embeddings.chain_embeddings.values])
inferred_chain_embeddings = chain_model(chain_features[candidates_features]).numpy()
path = artifacts_manager.base_local_dir + "/" + artifacts_manager.user_embeddings_file
loaded_user_embedding = np.array([a for a in pd.read_parquet(path).user_embeddings.values])
print(np.allclose(loaded_user_embedding, inferred_user_embedding))
print(np.allclose(inferred_chain_embeddings, loaded_chain_embeddings))



Downloading model weights
Downloaded storage object twotower_v2/AE/model_artifacts_recall_3235/tt_user_model_weights_recall@10_3235.data-00000-of-00001 from bucket tlb-data-dev-data-algorithms-content-optimization to local file /Users/andreiromanov/DEV/talabat/Ace/vendor_ranking/two_tower/inference_model_artifacts/AE/tt_user_model_weights_recall@10_3235.data-00000-of-00001.
Downloaded storage object twotower_v2/AE/model_artifacts_recall_3235/tt_user_model_weights_recall@10_3235.index from bucket tlb-data-dev-data-algorithms-content-optimization to local file /Users/andreiromanov/DEV/talabat/Ace/vendor_ranking/two_tower/inference_model_artifacts/AE/tt_user_model_weights_recall@10_3235.index.
Downloaded storage object twotower_v2/AE/model_artifacts_recall_3235/tt_user_model_params_recall@10_3235.pkl from bucket tlb-data-dev-data-algorithms-content-optimization to local file /Users/andreiromanov/DEV/talabat/Ace/vendor_ranking/two_tower/inference_model_artifacts/AE/tt_user_model_params_rec



created layer


0     9876 1106 617960 25817 636322 7043 7614 619528...   
1     26225 11877 28504 636829 28504 24210 629162 25...   
2     10756 18574 10465 655413 661432 647662 13265 1...   
3     1991 1991 655089 2000 655087 10694 620198 1069...   
4     638192 608113 6501 621527 600608 12709 621527 ...   
...                                                 ...   
9995                                      655703 655703   
9996  931 3288 10756 931 27063 621006 931 931 10406 ...   
9997  10756 649695 649695 637922 637922 10756 6315 1...   
9998  10756 1991 1991 658439 630524 635285 1991 3846...   
9999  2000 12802 603921 642835 636322 2000 642713 63...   

                                            prev_clicks  \
0     628813 609367 633423 660704 641747 648224 2757...   
1     645617 24786 21372 626581 660737 630524 648974...   
2     22736 28733 641642 615051 23821 10406 638013 2...   
3     17157 616602 17138 618758 7540 1410 620125 650...   
4     16732 2870 657351 652216 659131 641740 641739 ...

True
True


In [3]:
%timeit user_model(test_df[query_features].iloc[:1]).numpy()

0  9876 1106 617960 25817 636322 7043 7614 619528...   

                                         prev_clicks  \
0  628813 609367 633423 660704 641747 648224 2757...   

                                         freq_clicks  \
0  630197 9780 628813 605768 602533 17190 9766 65...   

                                         freq_chains  order_hour  \
0  7614 2241 18430 630134 648549 633575 10406 121...          21   

   order_weekday geohash6                                      prev_searches  
0              5   thqehh  ramadan aqua mc donalds burger chicano vapiano...  . Consider rewriting this model with the Functional API.
Name: user_prev_chains, dtype: object. Consider rewriting this model with the Functional API.
Name: prev_searches, dtype: object. Consider rewriting this model with the Functional API.
Name: prev_clicks, dtype: object. Consider rewriting this model with the Functional API.
Name: order_hour, dtype: int64. Consider rewriting this model with the Functional API.
Name:

69.7 ms ± 6.32 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
