In [3]:
import pandas as pd
import polars as pl
import numpy as np

In [4]:
from pathlib import Path

chains_embeddings_path = Path(
    "../../tests/fixtures/vendor_ranking/two_towers/inference_model_artifacts/AE/model_artifacts_recall_4858/tt_chain_embeddings_recall@10_4858.parquet"
)
assert chains_embeddings_path.exists()

In [5]:
from vendor_ranking.two_tower.model_artifacts import SourceCloudType, ArtifactsManager
from vendor_ranking.two_tower.artifacts_service import ArtifactsService
from vendor_ranking.data import refresh_vendors, get_all_vendors

country = "AE"
config = ArtifactsService.configs[country]
artifacts_manager = VendorArtifactsManager(
    recall=config.recall,
    base_dir="",
    country=country,
    version=config.version,
    source=SourceCloudType.AWS,
)
artifacts_manager.chain_embeddings_local_path = chains_embeddings_path
chain_embeddings: "pl.DataFrame" = ArtifactsService._load_chains_embeddings(artifacts_manager)

In [8]:
chain_embeddings

chain_embeddings,chain_id
list[f32],i64
"[0.029733, 0.193347, … 0.03776]",28525
"[0.060471, 0.267316, … 0.175063]",2000
"[-0.011328, 0.055134, … -0.036882]",6544
"[-0.096245, 0.211634, … 0.024418]",10581
"[0.045585, 0.2912, … 0.012362]",13187
"[-0.255341, 0.147362, … 0.075167]",620392
"[0.083082, 0.147104, … 0.156425]",634349
"[0.122256, 0.261384, … 0.114238]",655701
"[0.044029, 0.341454, … -0.122712]",10241
"[-0.083498, 0.18776, … 0.022631]",28634


In [35]:
from vendor_ranking.two_tower.artifacts_service import CHAIN_NAME, CHAIN_ID
from vendor_ranking import SERVICE_NAME
from ace import db
import asyncio


async def load_vendors_from_db() -> "pl.DataFrame":
    await refresh_vendors()
    return get_all_vendors()


await db.init_connection_pool(service_name=SERVICE_NAME, query_timeout=150_000)
all_vendors: "pl.DataFrame" = await load_vendors_from_db()
all_vendors = all_vendors.select([CHAIN_ID, CHAIN_NAME]).unique(
    [
        CHAIN_ID,
    ],
    keep="first",
)

[2m2023-08-07 12:55:20[0m [[32m[1minfo     [0m] [1mCheck if latest dynamic ranking vendor data is refreshed after 2023-04-28...[0m
[2m2023-08-07 12:55:20[0m [[32m[1minfo     [0m] [1mSkipping dynamic ranking data refresh...[0m


In [36]:
all_vendors.shape

(181, 2)

In [37]:
chain_embeddings, all_vendors

(shape: (8_798, 2)
 ┌───────────────────────────────────┬──────────┐
 │ chain_embeddings                  ┆ chain_id │
 │ ---                               ┆ ---      │
 │ list[f32]                         ┆ i64      │
 ╞═══════════════════════════════════╪══════════╡
 │ [0.029733, 0.193347, … 0.03776]   ┆ 28525    │
 │ [0.060471, 0.267316, … 0.175063]  ┆ 2000     │
 │ [-0.011328, 0.055134, … -0.03688… ┆ 6544     │
 │ [-0.096245, 0.211634, … 0.024418… ┆ 10581    │
 │ …                                 ┆ …        │
 │ [-0.147472, 0.125612, … 0.101698… ┆ 665671   │
 │ [-0.084697, 0.441283, … 0.019612… ┆ 665406   │
 │ [-0.092685, 0.3433, … -0.069334]  ┆ 665735   │
 │ [-0.108035, 0.354954, … -0.17012… ┆ 657865   │
 └───────────────────────────────────┴──────────┘,
 shape: (181, 2)
 ┌──────────┬─────────────────────────┐
 │ chain_id ┆ chain_name              │
 │ ---      ┆ ---                     │
 │ i64      ┆ str                     │
 ╞══════════╪═════════════════════════╡
 │ 23520    ┆

In [38]:
joined = chain_embeddings.join(all_vendors, on="chain_id", how="left").select(chain_embeddings.columns + ["chain_name"])
joined.shape

(8798, 3)

In [23]:
xx = joined[CHAIN_ID].sort()
yy = chain_embeddings[CHAIN_ID].sort()
for idx, (x, y) in enumerate(zip(xx, yy)):
    assert x == y, idx

AssertionError: 647

In [25]:
xx[647], yy[647]

(10181, 10184)

In [34]:
joined.filter(pl.col(CHAIN_ID).is_in([10181, 10184])), chain_embeddings.filter(
    pl.col(CHAIN_ID).is_in([10181, 10184])
), all_vendors.filter(pl.col(CHAIN_ID).is_in([10181, 10184]))

(shape: (3, 3)
 ┌───────────────────────────────────┬──────────┬─────────────────────────┐
 │ chain_embeddings                  ┆ chain_id ┆ chain_name              │
 │ ---                               ┆ ---      ┆ ---                     │
 │ list[f32]                         ┆ i64      ┆ str                     │
 ╞═══════════════════════════════════╪══════════╪═════════════════════════╡
 │ [-0.120137, 0.279757, … -0.10439… ┆ 10181    ┆ Nom Nom Asia Restaurant │
 │ [-0.120137, 0.279757, … -0.10439… ┆ 10181    ┆ NomNom Asia             │
 │ [-0.121796, 0.276462, … 0.104862… ┆ 10184    ┆ null                    │
 └───────────────────────────────────┴──────────┴─────────────────────────┘,
 shape: (2, 2)
 ┌───────────────────────────────────┬──────────┐
 │ chain_embeddings                  ┆ chain_id │
 │ ---                               ┆ ---      │
 │ list[f32]                         ┆ i64      │
 ╞═══════════════════════════════════╪══════════╡
 │ [-0.120137, 0.279757, … -0.10439

In [39]:
joined

chain_embeddings,chain_id,chain_name
list[f32],i64,str
"[0.029733, 0.193347, … 0.03776]",28525,
"[0.060471, 0.267316, … 0.175063]",2000,
"[-0.011328, 0.055134, … -0.036882]",6544,
"[-0.096245, 0.211634, … 0.024418]",10581,
"[0.045585, 0.2912, … 0.012362]",13187,
"[-0.255341, 0.147362, … 0.075167]",620392,
"[0.083082, 0.147104, … 0.156425]",634349,
"[0.122256, 0.261384, … 0.114238]",655701,
"[0.044029, 0.341454, … -0.122712]",10241,
"[-0.083498, 0.18776, … 0.022631]",28634,"""Karachi Darbar…"


In [53]:
chains_without_names = joined.filter(pl.col(CHAIN_NAME).is_null())[CHAIN_ID]

In [54]:
f"{list(chains_without_names)}"

'[28525, 2000, 6544, 10581, 13187, 620392, 634349, 655701, 10241, 6501, 10501, 8516, 18574, 619777, 661685, 711, 632368, 656721, 1991, 1968, 641136, 1106, 12838, 629162, 633557, 7540, 618873, 663020, 660015, 3219, 653862, 18876, 632024, 3006, 659005, 11198, 11179, 658439, 652384, 601419, 660505, 661295, 636322, 632016, 14354, 14542, 9804, 2599, 10636, 630435, 659745, 27682, 662552, 650136, 12802, 633299, 648114, 627543, 22849, 14420, 658475, 7700, 602102, 19122, 639298, 648670, 629725, 660567, 656411, 23254, 650453, 18741, 630398, 660851, 10782, 629666, 653734, 10065, 642835, 655491, 9694, 16810, 2699, 18897, 5699, 655086, 645526, 16311, 10329, 27575, 659652, 28448, 28514, 640347, 27962, 661724, 619156, 642713, 658301, 612648, 611005, 618704, 8722, 650182, 620064, 663892, 663177, 626084, 15456, 654979, 19056, 24702, 637893, 11076, 645659, 620116, 17676, 658394, 10388, 623304, 9275, 11698, 11277, 654974, 12045, 9869, 604911, 931, 647885, 25053, 663596, 626896, 21974, 618254, 9759, 61343